Saturday, July 27, 2024

How To Create Multi-Modal RAG Pipeline on Images and Text Locally - Step by Step Guide

 This video is a step-by-step easy tutorial to build multi-modal RAG pipeline on your own custom data including images and text using LlamaIndex.


Code:

conda create -n multirag python=3.11 -y && conda activate multirag

pip install --upgrade git+https://github.com/huggingface/transformers.git
pip install torch torchvision pillow
pip install -q llama-index-vector-stores-qdrant
pip install llama_index ftfy regex tqdm
pip install git+https://github.com/openai/CLIP.git
pip install matplotlib scikit-image
pip install -U qdrant_client
pip install llama-index-embeddings-clip

export OPENAI_API_KEY=""

conda install jupyter -y
pip uninstall charset_normalizer -y
pip install charset_normalizer
jupyter notebook

data_path="/home/Ubuntu/multidata/"
import os
image_metadata_dict = {}

for file in os.listdir(data_path):
    if file.endswith(".txt"):
        filename = file
        img_path = data_path + file.replace(".txt", ".jpg")
        if os.path.exists(img_path):
            image_metadata_dict[len(image_metadata_dict)] = {
                "filename": filename,
                "img_path": img_path
            }
        else:
            img_path = data_path + file.replace(".txt", ".png")
            if os.path.exists(img_path):
                image_metadata_dict[len(image_metadata_dict)] = {
                    "filename": filename,
                    "img_path": img_path
                }

print(image_metadata_dict)

 
import qdrant_client
from llama_index.core import SimpleDirectoryReader
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.core import VectorStoreIndex, StorageContext
from llama_index.core.indices import MultiModalVectorStoreIndex

# Create a local Qdrant vector store
client = qdrant_client.QdrantClient(path="qdrant_d_0")

text_store = QdrantVectorStore(
    client=client, collection_name="text_collection_0"
)
image_store = QdrantVectorStore(
    client=client, collection_name="image_collection_0"
)
storage_context = StorageContext.from_defaults(
    vector_store=text_store, image_store=image_store
)

# Create the MultiModal index
documents = SimpleDirectoryReader(data_path).load_data()
index = MultiModalVectorStoreIndex.from_documents(
    documents,
    storage_context=storage_context,
)


from PIL import Image
import matplotlib.pyplot as plt
import os


def plot_images(image_metadata_dict):
    original_images_urls = []
    images_shown = 0
    for image_id in image_metadata_dict:
        img_path = image_metadata_dict[image_id]["img_path"]
        if os.path.isfile(img_path):
            filename = image_metadata_dict[image_id]["filename"]
            image = Image.open(img_path).convert("RGB")

            plt.subplot(8, 8, len(original_images_urls) + 1)
            plt.imshow(image)
            plt.xticks([])
            plt.yticks([])

            original_images_urls.append(filename)
            images_shown += 1
            if images_shown >= 64:
                break

    plt.tight_layout()


plot_images(image_metadata_dict)

#Build a separate CLIP image embedding index under a differnt collection wikipedia_img
def plot_images(image_paths):
    images_shown = 0
    plt.figure(figsize=(16, 9))
    for img_path in image_paths:
        if os.path.isfile(img_path):
            image = Image.open(img_path)

            plt.subplot(2, 3, images_shown + 1)
            plt.imshow(image)
            plt.xticks([])
            plt.yticks([])

            images_shown += 1
            if images_shown >= 9:
                break

from llama_index.core.response.notebook_utils import display_source_node
from llama_index.core.schema import ImageNode
   
test_query = "Who is Fahd Mirza?"
# generate  retrieval results
retriever = index.as_retriever(similarity_top_k=1, image_similarity_top_k=1)
retrieval_results = retriever.retrieve(test_query)

retrieved_image = []
for res_node in retrieval_results:
    if isinstance(res_node.node, ImageNode):
        retrieved_image.append(res_node.node.metadata["file_path"])
    else:
        display_source_node(res_node, source_length=200)

plot_images(retrieved_image)

test_query = "What is outback?"
# generate  retrieval results
retriever = index.as_retriever(similarity_top_k=1, image_similarity_top_k=1)
retrieval_results = retriever.retrieve(test_query)

retrieved_image = []
for res_node in retrieval_results:
    if isinstance(res_node.node, ImageNode):
        retrieved_image.append(res_node.node.metadata["file_path"])
    else:
        display_source_node(res_node, source_length=200)

plot_images(retrieved_image)

test_query = "Where is Meenakshi Temple?"
# generate  retrieval results
retriever = index.as_retriever(similarity_top_k=1, image_similarity_top_k=1)
retrieval_results = retriever.retrieve(test_query)

retrieved_image = []
for res_node in retrieval_results:
    if isinstance(res_node.node, ImageNode):
        retrieved_image.append(res_node.node.metadata["file_path"])
    else:
        display_source_node(res_node, source_length=200)

plot_images(retrieved_image)

No comments: