This video is a step-by-step easy tutorial to build multi-modal RAG pipeline on your own custom data including images and text using LlamaIndex.
Code:
conda create -n multirag python=3.11 -y && conda activate multirag
pip install --upgrade git+https://github.com/huggingface/transformers.git
pip install torch torchvision pillow
pip install -q llama-index-vector-stores-qdrant
pip install llama_index ftfy regex tqdm
pip install git+https://github.com/openai/CLIP.git
pip install matplotlib scikit-image
pip install -U qdrant_client
pip install llama-index-embeddings-clip
export OPENAI_API_KEY=""
conda install jupyter -y
pip uninstall charset_normalizer -y
pip install charset_normalizer
jupyter notebook
data_path="/home/Ubuntu/multidata/"
import os
image_metadata_dict = {}
for file in os.listdir(data_path):
if file.endswith(".txt"):
filename = file
img_path = data_path + file.replace(".txt", ".jpg")
if os.path.exists(img_path):
image_metadata_dict[len(image_metadata_dict)] = {
"filename": filename,
"img_path": img_path
}
else:
img_path = data_path + file.replace(".txt", ".png")
if os.path.exists(img_path):
image_metadata_dict[len(image_metadata_dict)] = {
"filename": filename,
"img_path": img_path
}
print(image_metadata_dict)
import qdrant_client
from llama_index.core import SimpleDirectoryReader
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.core import VectorStoreIndex, StorageContext
from llama_index.core.indices import MultiModalVectorStoreIndex
# Create a local Qdrant vector store
client = qdrant_client.QdrantClient(path="qdrant_d_0")
text_store = QdrantVectorStore(
client=client, collection_name="text_collection_0"
)
image_store = QdrantVectorStore(
client=client, collection_name="image_collection_0"
)
storage_context = StorageContext.from_defaults(
vector_store=text_store, image_store=image_store
)
# Create the MultiModal index
documents = SimpleDirectoryReader(data_path).load_data()
index = MultiModalVectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
)
from PIL import Image
import matplotlib.pyplot as plt
import os
def plot_images(image_metadata_dict):
original_images_urls = []
images_shown = 0
for image_id in image_metadata_dict:
img_path = image_metadata_dict[image_id]["img_path"]
if os.path.isfile(img_path):
filename = image_metadata_dict[image_id]["filename"]
image = Image.open(img_path).convert("RGB")
plt.subplot(8, 8, len(original_images_urls) + 1)
plt.imshow(image)
plt.xticks([])
plt.yticks([])
original_images_urls.append(filename)
images_shown += 1
if images_shown >= 64:
break
plt.tight_layout()
plot_images(image_metadata_dict)
#Build a separate CLIP image embedding index under a differnt collection wikipedia_img
def plot_images(image_paths):
images_shown = 0
plt.figure(figsize=(16, 9))
for img_path in image_paths:
if os.path.isfile(img_path):
image = Image.open(img_path)
plt.subplot(2, 3, images_shown + 1)
plt.imshow(image)
plt.xticks([])
plt.yticks([])
images_shown += 1
if images_shown >= 9:
break
from llama_index.core.response.notebook_utils import display_source_node
from llama_index.core.schema import ImageNode
test_query = "Who is Fahd Mirza?"
# generate retrieval results
retriever = index.as_retriever(similarity_top_k=1, image_similarity_top_k=1)
retrieval_results = retriever.retrieve(test_query)
retrieved_image = []
for res_node in retrieval_results:
if isinstance(res_node.node, ImageNode):
retrieved_image.append(res_node.node.metadata["file_path"])
else:
display_source_node(res_node, source_length=200)
plot_images(retrieved_image)
test_query = "What is outback?"
# generate retrieval results
retriever = index.as_retriever(similarity_top_k=1, image_similarity_top_k=1)
retrieval_results = retriever.retrieve(test_query)
retrieved_image = []
for res_node in retrieval_results:
if isinstance(res_node.node, ImageNode):
retrieved_image.append(res_node.node.metadata["file_path"])
else:
display_source_node(res_node, source_length=200)
plot_images(retrieved_image)
test_query = "Where is Meenakshi Temple?"
# generate retrieval results
retriever = index.as_retriever(similarity_top_k=1, image_similarity_top_k=1)
retrieval_results = retriever.retrieve(test_query)
retrieved_image = []
for res_node in retrieval_results:
if isinstance(res_node.node, ImageNode):
retrieved_image.append(res_node.node.metadata["file_path"])
else:
display_source_node(res_node, source_length=200)
plot_images(retrieved_image)
No comments:
Post a Comment