milvus-logo
LFAI
Home
  • Anleitungen

ColPali für multimodales Retrieval mit Milvus verwenden

Open In Colab GitHub Repository

Moderne Retrieval-Modelle verwenden in der Regel eine einzige Einbettung, um Text oder Bilder darzustellen. ColBERT hingegen ist ein neuronales Modell, das eine Liste von Einbettungen für jede Dateninstanz verwendet und eine "MaxSim"-Operation zur Berechnung der Ähnlichkeit zwischen zwei Texten einsetzt. Neben Textdaten enthalten auch Abbildungen, Tabellen und Diagramme reichhaltige Informationen, die beim textbasierten Information Retrieval oft unberücksichtigt bleiben.

Die MaxSim-Funktion vergleicht eine Abfrage mit einem Dokument (in dem Sie suchen), indem sie deren Token-Einbettungen betrachtet. Für jedes Wort in der Abfrage wählt sie das ähnlichste Wort aus dem Dokument aus (unter Verwendung der Kosinusähnlichkeit oder der quadrierten L2-Distanz) und summiert diese maximalen Ähnlichkeiten über alle Wörter in der Abfrage

ColPali ist eine Methode, die die Multi-Vektor-Darstellung von ColBERT mit PaliGemma (einem multimodalen großen Sprachmodell) kombiniert, um dessen starke Verständnisfähigkeiten zu nutzen. Dieser Ansatz ermöglicht es, eine Seite, die sowohl Text als auch Bilder enthält, durch eine einheitliche Multi-Vektor-Einbettung darzustellen. Die Einbettungen innerhalb dieser Multi-Vektor-Darstellung können detaillierte Informationen erfassen und so die Leistung der Retrieval-augmented Generation (RAG) für multimodale Daten verbessern.

In diesem Notizbuch bezeichnen wir diese Art von Multi-Vektor-Darstellung aus Gründen der Allgemeinheit als "ColBERT-Einbettungen". Das tatsächliche Modell, das verwendet wird, ist jedoch das ColPali-Modell. Wir werden demonstrieren, wie Milvus für Multi-Vektor-Retrieval verwendet werden kann. Darauf aufbauend wird gezeigt, wie ColPali für das Abrufen von Seiten auf der Grundlage einer gegebenen Anfrage verwendet werden kann.

Vorbereitung

$ pip install pdf2image
$ pip pymilvus
$ pip install colpali_engine
$ pip install tqdm
$ pip instal pillow

Vorbereiten der Daten

Wir werden PDF RAG als unser Beispiel verwenden. Sie können ColBERT paper herunterladen und in ./pdf einfügen. ColPali verarbeitet den Text nicht direkt, sondern die gesamte Seite wird in ein Bild gerastert. Das ColPali-Modell ist hervorragend in der Lage, die in diesen Bildern enthaltenen Textinformationen zu verstehen. Daher wird jede PDF-Seite für die Verarbeitung in ein Bild umgewandelt.

from pdf2image import convert_from_path

pdf_path = "pdfs/2004.12832v2.pdf"
images = convert_from_path(pdf_path)

for i, image in enumerate(images):
    image.save(f"pages/page_{i + 1}.png", "PNG")

Als nächstes werden wir eine Datenbank mit Milvus Lite initialisieren. Sie können leicht zu einer vollständigen Milvus-Instanz wechseln, indem Sie die uri auf die entsprechende Adresse setzen, unter der Ihr Milvus-Dienst gehostet wird.

from pymilvus import MilvusClient, DataType
import numpy as np
import concurrent.futures

client = MilvusClient(uri="milvus.db")
  • Wenn Sie nur eine lokale Vektordatenbank für kleine Datenmengen oder Prototypen benötigen, ist es am bequemsten, die uri auf eine lokale Datei, z. B../milvus.db, zu setzen, da Milvus Lite automatisch alle Daten in dieser Datei speichert.
  • Wenn Sie große Datenmengen haben, z. B. mehr als eine Million Vektoren, können Sie einen leistungsfähigeren Milvus-Server auf Docker oder Kubernetes einrichten. Bei dieser Einrichtung verwenden Sie bitte die Serveradresse und den Port als Uri, z. B.http://localhost:19530. Wenn Sie die Authentifizierungsfunktion auf Milvus aktivieren, verwenden Sie "<Ihr_Benutzername>:<Ihr_Passwort>" als Token, andernfalls setzen Sie das Token nicht.
  • Wenn Sie Zilliz Cloud, den vollständig verwalteten Cloud-Service für Milvus, verwenden, passen Sie die uri und token an, die dem öffentlichen Endpunkt und dem API-Schlüssel in Zilliz Cloud entsprechen.

Wir werden eine MilvusColbertRetriever-Klasse definieren, die den Milvus-Client für den Abruf von Multivektordaten umhüllt. Die Implementierung flacht ColBERT-Einbettungen ab und fügt sie in eine Sammlung ein, wobei jede Zeile eine einzelne Einbettung aus der ColBERT-Einbettungsliste darstellt. Sie zeichnet auch die doc_id und seq_id auf, um den Ursprung jeder Einbettung zu ermitteln.

Bei der Suche mit einer ColBERT-Einbettungsliste werden mehrere Suchen durchgeführt - eine für jede ColBERT-Einbettung. Die gefundenen doc_ids werden dann dedupliziert. Es wird ein Reranking-Prozess durchgeführt, bei dem die vollständigen Einbettungen für jede doc_id abgerufen werden und der MaxSim-Score berechnet wird, um die endgültige Rangfolge der Ergebnisse zu ermitteln.

class MilvusColbertRetriever:
    def __init__(self, milvus_client, collection_name, dim=128):
        # Initialize the retriever with a Milvus client, collection name, and dimensionality of the vector embeddings.
        # If the collection exists, load it.
        self.collection_name = collection_name
        self.client = milvus_client
        if self.client.has_collection(collection_name=self.collection_name):
            self.client.load_collection(collection_name)
        self.dim = dim

    def create_collection(self):
        # Create a new collection in Milvus for storing embeddings.
        # Drop the existing collection if it already exists and define the schema for the collection.
        if self.client.has_collection(collection_name=self.collection_name):
            self.client.drop_collection(collection_name=self.collection_name)
        schema = self.client.create_schema(
            auto_id=True,
            enable_dynamic_fields=True,
        )
        schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True)
        schema.add_field(
            field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim
        )
        schema.add_field(field_name="seq_id", datatype=DataType.INT16)
        schema.add_field(field_name="doc_id", datatype=DataType.INT64)
        schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)

        self.client.create_collection(
            collection_name=self.collection_name, schema=schema
        )

    def create_index(self):
        # Create an index on the vector field to enable fast similarity search.
        # Releases and drops any existing index before creating a new one with specified parameters.
        self.client.release_collection(collection_name=self.collection_name)
        self.client.drop_index(
            collection_name=self.collection_name, index_name="vector"
        )
        index_params = self.client.prepare_index_params()
        index_params.add_index(
            field_name="vector",
            index_name="vector_index",
            index_type="HNSW",  # or any other index type you want
            metric_type="IP",  # or the appropriate metric type
            params={
                "M": 16,
                "efConstruction": 500,
            },  # adjust these parameters as needed
        )

        self.client.create_index(
            collection_name=self.collection_name, index_params=index_params, sync=True
        )

    def create_scalar_index(self):
        # Create a scalar index for the "doc_id" field to enable fast lookups by document ID.
        self.client.release_collection(collection_name=self.collection_name)

        index_params = self.client.prepare_index_params()
        index_params.add_index(
            field_name="doc_id",
            index_name="int32_index",
            index_type="INVERTED",  # or any other index type you want
        )

        self.client.create_index(
            collection_name=self.collection_name, index_params=index_params, sync=True
        )

    def search(self, data, topk):
        # Perform a vector search on the collection to find the top-k most similar documents.
        search_params = {"metric_type": "IP", "params": {}}
        results = self.client.search(
            self.collection_name,
            data,
            limit=int(50),
            output_fields=["vector", "seq_id", "doc_id"],
            search_params=search_params,
        )
        doc_ids = set()
        for r_id in range(len(results)):
            for r in range(len(results[r_id])):
                doc_ids.add(results[r_id][r]["entity"]["doc_id"])

        scores = []

        def rerank_single_doc(doc_id, data, client, collection_name):
            # Rerank a single document by retrieving its embeddings and calculating the similarity with the query.
            doc_colbert_vecs = client.query(
                collection_name=collection_name,
                filter=f"doc_id in [{doc_id}]",
                output_fields=["seq_id", "vector", "doc"],
                limit=1000,
            )
            doc_vecs = np.vstack(
                [doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))]
            )
            score = np.dot(data, doc_vecs.T).max(1).sum()
            return (score, doc_id)

        with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
            futures = {
                executor.submit(
                    rerank_single_doc, doc_id, data, client, self.collection_name
                ): doc_id
                for doc_id in doc_ids
            }
            for future in concurrent.futures.as_completed(futures):
                score, doc_id = future.result()
                scores.append((score, doc_id))

        scores.sort(key=lambda x: x[0], reverse=True)
        if len(scores) >= topk:
            return scores[:topk]
        else:
            return scores

    def insert(self, data):
        # Insert ColBERT embeddings and metadata for a document into the collection.
        colbert_vecs = [vec for vec in data["colbert_vecs"]]
        seq_length = len(colbert_vecs)
        doc_ids = [data["doc_id"] for i in range(seq_length)]
        seq_ids = list(range(seq_length))
        docs = [""] * seq_length
        docs[0] = data["filepath"]

        # Insert the data as multiple vectors (one for each sequence) along with the corresponding metadata.
        self.client.insert(
            self.collection_name,
            [
                {
                    "vector": colbert_vecs[i],
                    "seq_id": seq_ids[i],
                    "doc_id": doc_ids[i],
                    "doc": docs[i],
                }
                for i in range(seq_length)
            ],
        )

Wir werden die colpali_engine verwenden, um Einbettungslisten für zwei Abfragen zu extrahieren und die relevanten Informationen aus den PDF-Seiten abzurufen.

from colpali_engine.models import ColPali
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
from torch.utils.data import DataLoader
import torch
from typing import List, cast

device = get_torch_device("cpu")
model_name = "vidore/colpali-v1.2"

model = ColPali.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map=device,
).eval()

queries = [
    "How to end-to-end retrieval with ColBert?",
    "Where is ColBERT performance table?",
]

processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))

dataloader = DataLoader(
    dataset=ListDataset[str](queries),
    batch_size=1,
    shuffle=False,
    collate_fn=lambda x: processor.process_queries(x),
)

qs: List[torch.Tensor] = []
for batch_query in dataloader:
    with torch.no_grad():
        batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
        embeddings_query = model(**batch_query)
    qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))

Zusätzlich müssen wir die Einbettungsliste für jede Seite extrahieren, und es zeigt sich, dass es 1030 128-dimensionale Einbettungen für jede Seite gibt.

from tqdm import tqdm
from PIL import Image
import os

images = [Image.open("./pages/" + name) for name in os.listdir("./pages")]

dataloader = DataLoader(
    dataset=ListDataset[str](images),
    batch_size=1,
    shuffle=False,
    collate_fn=lambda x: processor.process_images(x),
)

ds: List[torch.Tensor] = []
for batch_doc in tqdm(dataloader):
    with torch.no_grad():
        batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
        embeddings_doc = model(**batch_doc)
    ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))

print(ds[0].shape)
  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [01:22<00:00,  8.24s/it]

torch.Size([1030, 128])

Wir werden mit MilvusColbertRetriever eine Sammlung namens "colpali" erstellen.

retriever = MilvusColbertRetriever(collection_name="colpali", milvus_client=client)
retriever.create_collection()
retriever.create_index()

Wir fügen die Einbettungslisten in die Milvus-Datenbank ein.

filepaths = ["./pages/" + name for name in os.listdir("./pages")]
for i in range(len(filepaths)):
    data = {
        "colbert_vecs": ds[i].float().numpy(),
        "doc_id": i,
        "filepath": filepaths[i],
    }
    retriever.insert(data)

Nun können wir die relevanteste Seite mit Hilfe der Abfrage-Einbettungsliste suchen.

for query in qs:
    query = query.float().numpy()
    result = retriever.search(query, topk=1)
    print(filepaths[result[0][1]])
./pages/page_5.png
./pages/page_7.png

Schließlich rufen wir den ursprünglichen Seitennamen ab. Mit ColPali können wir multimodale Dokumente abrufen, ohne dass wir komplexe Verarbeitungstechniken benötigen, um Text und Bilder aus den Dokumenten zu extrahieren. Durch den Einsatz von großen Bildverarbeitungsmodellen können mehr Informationen - wie z. B. Tabellen und Abbildungen - ohne signifikanten Informationsverlust analysiert werden.

Übersetzt vonDeepL

Try Managed Milvus for Free

Zilliz Cloud is hassle-free, powered by Milvus and 10x faster.

Get Started
Feedback

War diese Seite hilfreich?