paint-brush
Want to Search for Something With an Image and a Text Description? Try a Multimodal RAGby@codingjaguar
516 reads
516 reads

Want to Search for Something With an Image and a Text Description? Try a Multimodal RAG

by Jiang Chen16mNovember 27th, 2024
Read on Terminal Reader
Read this story w/o Javascript

Too Long; Didn't Read

An in-depth guide on how to build a multimodal RAG system using Milvus and how to open up various possibilities for AI systems.
featured image - Want to Search for Something With an Image and a Text Description? Try a Multimodal RAG
Jiang Chen HackerNoon profile picture
0-item
1-item

This article provides an in-depth guide on how to build a multimodal RAG system using Milvus and how to open up various possibilities for AI systems.


Being constrained to a single data format isn’t good enough anymore. As businesses rely more heavily on information to make crucial decisions, they need the ability to compare data in disparate formats. Fortunately, traditional AI systems constrained to a single data type have given way to multimodal systems that can understand and process complex information.


Multimodal search and multimodal retrieval-augmented generation (RAG) systems have recently shown great advancements in this field. These systems process multiple types of data, including text, images, and audio, to provide context-aware responses.


In this blog post, we'll discuss how developers can build their own multimodal RAG system using Milvus. We’ll also walk you through building such a system that can handle text and image data, in particular, perform similarity searches, and leverage a language model to refine the output. So, let’s get started.

What Is Milvus?

A vector database is a special type of database used to store, index, and retrieve vector embeddings, which are mathematical representations of data that allow you to compare data for not just equivalence but semantic similarity. Milvus is an open-source, high-performance vector database built for scale. You can find it on GitHub with an Apache-2.0 license and more than 30K stars.


Milvus helps developers provide a flexible solution for managing and querying large-scale vector data. Its efficiency makes Milvus an ideal choice for developers building applications using deep learning models, such as retrieval augmented generation (RAG), multimodal search, recommendation engine, and anomaly detections.


Milvus offers multiple deployment options to match developers' needs. Milvus Lite is a lightweight version that runs inside a Python application and is perfect for prototyping applications inside a local environment. Milvus Standalone and Milvus Distributed are scalable and production-ready options.

Multimodal RAG: Expanding Beyond Text

Before building the system, it’s important to understand traditional text-based RAG and its evolution to Multimodal RAG.


Retrieval Augmented Generation (RAG) is a method for retrieving contextual information from external sources and generating more accurate output from large language models (LLMs). Traditional RAG is a highly effective strategy for improving LLM output, but it remains limited to textual data. In many real-world applications, data extends beyond text—incorporating images, charts, and other modalities provides critical context.


Multimodal RAG addresses the above limitation by enabling the use of different data types, providing better context to LLMs.


Simply put, in a multimodal RAG system, the retrieval component searches for relevant information across different data modalities, and the generation component generates more accurate results based on the retrieved information.

Vector embeddings and similarity search are two fundamental concepts of multimodal RAG. Let’s understand both of them.

Vector Embeddings

As discussed, vector embeddings are mathematical/numerical representations of data. Machines use this representation to understand the semantic meaning of different data types, such as text, images, and audio.


When using natural language processing (NLP), document chunks are transformed into vectors, and semantically similar words are mapped to nearby points in the vector space. The same goes for images, where embeddings represent the semantic features. This allows us to understand metrics like color, texture, and object shapes in a numerical format.


The main goal of using vector embeddings is to help preserve relationships and similarities between different pieces of data.

Similarity search is used to find and locate data in a given dataset. In the context of vector embeddings, similarity search finds vectors in the given dataset that are closest to the query vector.


The following are a few methods that are commonly used to measure similarity between vectors:

  1. Euclidean Distance: Measures the straight-line distance between two points in the vector space.
  2. Cosine Similarity: Measures the cosine of the angle between two vectors (with a focus on their direction rather than magnitude).
  3. Dot Product: A simple multiplication of corresponding elements summed up.


The choice of similarity measure usually depends on the application-specific data and how the developer approaches the problem.


When performing similarity search on large-scale datasets, the computation power and resources required are very high. This is where approximate nearest neighbor (ANN) algorithms come in. ANN algorithms are used to trade a small percentage or amount of accuracy for a significant speed upgrade. This makes them an appropriate choice for large-scale applications.


Milvus also uses advanced ANN algorithms, including HNSW and DiskANN, to perform efficient similarity searches on large vector embedding datasets, allowing developers to quickly find relevant data points. In addition, Milvus supports other indexing algorithms, such as HSNW, IVF, CAGRA, etc., making it a much more efficient vector search solution.


Building Multimodal RAG with Milvus

Now we’ve learned the concepts, it’s time to build a multimodal RAG system using Milvus. For this example, we’ll use Milvus Lite (the lightweight version of Milvus, ideal for experimenting and prototyping) for vector storage and retrieval, BGE for precise image processing and embedding, and GPT-4o for advanced result reranking.

Prerequisites

First, you’ll need a Milvus instance to store your data. You can set up Milvus Lite using pip, run a local instance using Docker, or sign up for a free hosted Milvus account through Zilliz Cloud.


Second, you need an LLM for your RAG pipeline, so head over to OpenAI and get an API key. The free tier is sufficient to get this code working.


Next, create a new directory and a Python virtual environment (or take whatever steps you use to manage Python).


For this tutorial, you’ll also need to install the pymilvus library, which is Milvus's official Python SDK, and a handful of common tools.

Set up Milvus Lite

pip install -U pymilvus

Install Dependencies

pip install --upgrade pymilvus openai datasets opencv-python timm einops ftfy peft tqdm

git clone https://github.com/FlagOpen/FlagEmbedding.git
pip install -e FlagEmbedding

Download Data

The following command will download the example data and extract it to a local folder “./images_folder”, which includes:


  • Images: A subset of Amazon Reviews 2023 containing approximately 900 images from the categories "Appliance", "Cell_Phones_and_Accessories", and "Electronics".
  • An example query image: leopard.jpg


wget https://github.com/milvus-io/bootcamp/releases/download/data/amazon_reviews_2023_subset.tar.gz
tar -xzf amazon_reviews_2023_subset.tar.gz

Load the Embedding Model

We will use the Visualized BGE model “bge-visualized-base-en-v1.5” to generate embeddings for both images and text.


Now download the weight from HuggingFace.


wget https://huggingface.co/BAAI/bge-visualized/resolve/main/Visualized_base_en_v1.5.pth


Then, let’s build an encoder.

import torch
from visual_bge.modeling import Visualized_BGE

class Encoder:

    def __init__(self, model_name: str, model_path: str):

        self.model = Visualized_BGE(model_name_bge=model_name, model_weight=model_path)

        self.model.eval()

    def encode_query(self, image_path: str, text: str) -> list[float]:

        with torch.no_grad():

            query_emb = self.model.encode(image=image_path, text=text)

        return query_emb.tolist()[0]

    def encode_image(self, image_path: str) -> list[float]:

        with torch.no_grad():

            query_emb = self.model.encode(image=image_path)

        return query_emb.tolist()[0]

model_name = "BAAI/bge-base-en-v1.5"

model_path = "./Visualized_base_en_v1.5.pth"  # Change to your own value if using a different model path

encoder = Encoder(model_name, model_path)

Generate Embeddings and Load Data into Milvus

This section will guide you how to load example images into our database with their corresponding embeddings.


Generate embeddings


First, we need to create embeddings for all the images in the dataset.


Load all images from the data directory and convert them to embeddings.


import os

from tqdm import tqdm

from glob import glob

data_dir = (

    "./images_folder"  # Change to your own value if using a different data directory

)

image_list = glob(

    os.path.join(data_dir, "images", "*.jpg")

)  # We will only use images ending with ".jpg"

image_dict = {}

for image_path in tqdm(image_list, desc="Generating image embeddings: "):

    try:

        image_dict[image_path] = encoder.encode_image(image_path)

    except Exception as e:

        print(f"Failed to generate embedding for {image_path}. Skipped.")

        continue

print("Number of encoded images:", len(image_dict))

Perform Multimodal Search and Rerank the Results

In this section, we will first search for relevant images using a multimodal query and then use an LLM service to rerank the retrieved results and find the best one with an explanation.


Run multimodal search


Now we are ready to perform the advanced multimodal search with the query composed of image and text instructions.


query_image = os.path.join(

    data_dir, "leopard.jpg"

)  # Change to your own query image path

query_text = "phone case with this image theme"

query_vec = encoder.encode_query(image_path=query_image, text=query_text)

search_results = milvus_client.search(

    collection_name=collection_name,

    data=[query_vec],

    output_fields=["image_path"],

    limit=9,  # Max number of search results to return

    search_params={"metric_type": "COSINE", "params": {}},  # Search parameters

)[0]

retrieved_images = [hit.get("entity").get("image_path") for hit in search_results]

print(retrieved_images)


The result is shown below:


['./images_folder/images/518Gj1WQ-RL._AC_.jpg', 
'./images_folder/images/41n00AOfWhL._AC_.jpg'


Rerank results with GPT-4o


Now, we will use GPT-4o to rank retrieved images and find the best-matched results. The LLM will also explain why it ranks like that.


1. Create a panoramic view.


import numpy as np

import cv2

img_height = 300

img_width = 300

row_count = 3

def create_panoramic_view(query_image_path: str, retrieved_images: list) -> np.ndarray:

    """

    creates a 5x5 panoramic view image from a list of images

    args:

        images: list of images to be combined

    returns:

        np.ndarray: the panoramic view image

    """

    panoramic_width = img_width * row_count

    panoramic_height = img_height * row_count

    panoramic_image = np.full(

        (panoramic_height, panoramic_width, 3), 255, dtype=np.uint8

    )

    # create and resize the query image with a blue border

    query_image_null = np.full((panoramic_height, img_width, 3), 255, dtype=np.uint8)

    query_image = Image.open(query_image_path).convert("RGB")

    query_array = np.array(query_image)[:, :, ::-1]

    resized_image = cv2.resize(query_array, (img_width, img_height))

    border_size = 10

    blue = (255, 0, 0)  # blue color in BGR

    bordered_query_image = cv2.copyMakeBorder(

        resized_image,

        border_size,

        border_size,

        border_size,

        border_size,

        cv2.BORDER_CONSTANT,

        value=blue,

    )

    query_image_null[img_height * 2 : img_height * 3, 0:img_width] = cv2.resize(

        bordered_query_image, (img_width, img_height)

    )

    # add text "query" below the query image

    text = "query"

    font_scale = 1

    font_thickness = 2

    text_org = (10, img_height * 3 + 30)

    cv2.putText(

        query_image_null,

        text,

        text_org,

        cv2.FONT_HERSHEY_SIMPLEX,

        font_scale,

        blue,

        font_thickness,

        cv2.LINE_AA,

    )

    # combine the rest of the images into the panoramic view

    retrieved_imgs = [

        np.array(Image.open(img).convert("RGB"))[:, :, ::-1] for img in retrieved_images

    ]

    for i, image in enumerate(retrieved_imgs):

        image = cv2.resize(image, (img_width - 4, img_height - 4))

        row = i // row_count

        col = i % row_count

        start_row = row * img_height

        start_col = col * img_width

        border_size = 2

        bordered_image = cv2.copyMakeBorder(

            image,

            border_size,

            border_size,

            border_size,

            border_size,

            cv2.BORDER_CONSTANT,

            value=(0, 0, 0),

        )

        panoramic_image[

            start_row : start_row + img_height, start_col : start_col + img_width

        ] = bordered_image

        # add red index numbers to each image

        text = str(i)

        org = (start_col + 50, start_row + 30)

        (font_width, font_height), baseline = cv2.getTextSize(

            text, cv2.FONT_HERSHEY_SIMPLEX, 1, 2

        )

        top_left = (org[0] - 48, start_row + 2)

        bottom_right = (org[0] - 48 + font_width + 5, org[1] + baseline + 5)

        cv2.rectangle(

            panoramic_image, top_left, bottom_right, (255, 255, 255), cv2.FILLED

        )

        cv2.putText(

            panoramic_image,

            text,

            (start_col + 10, start_row + 30),

            cv2.FONT_HERSHEY_SIMPLEX,

            1,

            (0, 0, 255),

            2,

            cv2.LINE_AA,

        )

    # combine the query image with the panoramic view

    panoramic_image = np.hstack([query_image_null, panoramic_image])

    return panoramic_image


2. Combine the query image and retrieved images with indices in a panoramic view.


from PIL import Image

combined_image_path = os.path.join(data_dir, "combined_image.jpg")

panoramic_image = create_panoramic_view(query_image, retrieved_images)

cv2.imwrite(combined_image_path, panoramic_image)

combined_image = Image.open(combined_image_path)

show_combined_image = combined_image.resize((300, 300))

show_combined_image.show()


Multimodal search results

3. Rerank the results and give explanation


We will send all the combined images to the multimodal LLM service together with proper prompts to rank the retrieved results with an explanation. Note: To enable GPT-4o as the LLM, you need to prepare your OpenAI API Key in advance.


import requests

import base64

openai_api_key = "sk-***"  # Change to your OpenAI API Key

def generate_ranking_explanation(

    combined_image_path: str, caption: str, infos: dict = None

) -> tuple[list[int], str]:

    with open(combined_image_path, "rb") as image_file:

        base64_image = base64.b64encode(image_file.read()).decode("utf-8")

    information = (

        "You are responsible for ranking results for a Composed Image Retrieval. "

        "The user retrieves an image with an 'instruction' indicating their retrieval intent. "

        "For example, if the user queries a red car with the instruction 'change this car to blue,' a similar type of car in blue would be ranked higher in the results. "

        "Now you would receive instruction and query image with blue border. Every item has its red index number in its top left. Do not misunderstand it. "

        f"User instruction: {caption} \n\n"

    )

    # add additional information for each image

    if infos:

        for i, info in enumerate(infos["product"]):

            information += f"{i}. {info}\n"

    information += (

        "Provide a new ranked list of indices from most suitable to least suitable, followed by an explanation for the top 1 most suitable item only. "

        "The format of the response has to be 'Ranked list: []' with the indices in brackets as integers, followed by 'Reasons:' plus the explanation why this most fit user's query intent."

    )

    headers = {

        "Content-Type": "application/json",

        "Authorization": f"Bearer {openai_api_key}",

    }

    payload = {

        "model": "gpt-4o",

        "messages": [

            {

                "role": "user",

                "content": [

                    {"type": "text", "text": information},

                    {

                        "type": "image_url",

                        "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},

                    },

                ],

            }

        ],

        "max_tokens": 300,

    }

    response = requests.post(

        "https://api.openai.com/v1/chat/completions", headers=headers, json=payload

    )

    result = response.json()["choices"][0]["message"]["content"]

    # parse the ranked indices from the response

    start_idx = result.find("[")

    end_idx = result.find("]")

    ranked_indices_str = result[start_idx + 1 : end_idx].split(",")

    ranked_indices = [int(index.strip()) for index in ranked_indices_str]

    # extract explanation

    explanation = result[end_idx + 1 :].strip()

    return ranked_indices, explanation


Get the image indices after ranking and the reason for the best result:


ranked_indices, explanation = generate_ranking_explanation(

    combined_image_path, query_text

)


4. Display the best result with explanation


print(explanation)

best_index = ranked_indices[0]

best_img = Image.open(retrieved_images[best_index])

best_img = best_img.resize((150, 150))

best_img.show()


Results:


Reasons: The most suitable item for the user's query intent is index 6 because the instruction specifies a phone case with the theme of the image, which is a leopard. The phone case with index 6 has a thematic design resembling the leopard pattern, making it the closest match to the user's request for a phone case with the image theme.



Leopard print phone case - Best Result


Check out the full code inthis notebook. To learn more about how to start an online demo with this tutorial, please refer to the example application.

Conclusion

In this blog post, we discussed building a multimodal RAG system using Milvus (an open-source vector database). We covered how developers can set up Milvus, load image data, perform similarity searches, and use an LLM to rerank the retrieved results for a more accurate responses.


The multimodal RAG solutions open up various possibilities for AI systems that can easily understand and process multiple forms of data. Some common possibilities include improved image search engines, better context-driven results, and more.