ColPali für multimodales Retrieval mit Milvus verwenden
Moderne Retrieval-Modelle verwenden in der Regel eine einzige Einbettung, um Text oder Bilder darzustellen. ColBERT ist jedoch ein neuronales Modell, das eine Liste von Einbettungen für jede Dateninstanz verwendet und eine "MaxSim"-Operation zur Berechnung der Ähnlichkeit zwischen zwei Texten einsetzt. Neben Textdaten enthalten auch Abbildungen, Tabellen und Diagramme reichhaltige Informationen, die beim textbasierten Information Retrieval oft unberücksichtigt bleiben.
Die MaxSim-Funktion vergleicht eine Abfrage mit einem Dokument (in dem Sie suchen), indem sie deren Token-Einbettungen betrachtet. Für jedes Wort in der Abfrage wählt sie das ähnlichste Wort aus dem Dokument aus (unter Verwendung der Kosinusähnlichkeit oder der quadrierten L2-Distanz) und summiert diese maximalen Ähnlichkeiten über alle Wörter in der Abfrage
ColPali ist eine Methode, die die Multi-Vektor-Darstellung von ColBERT mit PaliGemma (einem multimodalen großen Sprachmodell) kombiniert, um dessen starke Verständnisfähigkeiten zu nutzen. Dieser Ansatz ermöglicht es, eine Seite, die sowohl Text als auch Bilder enthält, durch eine einheitliche Multi-Vektor-Einbettung darzustellen. Die Einbettungen innerhalb dieser Multi-Vektor-Darstellung können detaillierte Informationen erfassen und so die Leistung der Retrieval-augmented Generation (RAG) für multimodale Daten verbessern.
In diesem Notizbuch bezeichnen wir diese Art von Multi-Vektor-Darstellung aus Gründen der Allgemeinheit als "ColBERT-Einbettungen". Das tatsächliche Modell, das verwendet wird, ist jedoch das ColPali-Modell. Wir werden demonstrieren, wie Milvus für Multi-Vektor-Retrieval verwendet werden kann. Darauf aufbauend wird gezeigt, wie ColPali für das Abrufen von Seiten auf der Grundlage einer gegebenen Anfrage verwendet werden kann.
Vorbereitung
$ pip install pdf2image
$ pip pymilvus
$ pip install colpali_engine
$ pip install tqdm
$ pip instal pillow
Vorbereiten der Daten
Wir werden PDF RAG als unser Beispiel verwenden. Sie können ColBERT paper herunterladen und in ./pdf
einfügen. ColPali verarbeitet den Text nicht direkt, sondern die gesamte Seite wird in ein Bild gerastert. Das ColPali-Modell ist hervorragend in der Lage, die in diesen Bildern enthaltenen Textinformationen zu verstehen. Daher wird jede PDF-Seite für die Verarbeitung in ein Bild umgewandelt.
from pdf2image import convert_from_path
pdf_path = "pdfs/2004.12832v2.pdf"
images = convert_from_path(pdf_path)
for i, image in enumerate(images):
image.save(f"pages/page_{i + 1}.png", "PNG")
Als nächstes werden wir eine Datenbank mit Milvus Lite initialisieren. Sie können leicht zu einer vollständigen Milvus-Instanz wechseln, indem Sie die uri auf die entsprechende Adresse setzen, unter der Ihr Milvus-Dienst gehostet wird.
from pymilvus import MilvusClient, DataType
import numpy as np
import concurrent.futures
client = MilvusClient(uri="milvus.db")
- Wenn Sie nur eine lokale Vektordatenbank für kleine Datenmengen oder Prototypen benötigen, ist es am bequemsten, die uri auf eine lokale Datei, z. B.
./milvus.db
, zu setzen, da Milvus Lite automatisch alle Daten in dieser Datei speichert. - Wenn Sie große Datenmengen haben, z. B. mehr als eine Million Vektoren, können Sie einen leistungsfähigeren Milvus-Server auf Docker oder Kubernetes einrichten. Bei dieser Einrichtung verwenden Sie bitte die Serveradresse und den Port als Uri, z. B.
http://localhost:19530
. Wenn Sie die Authentifizierungsfunktion auf Milvus aktivieren, verwenden Sie "<Ihr_Benutzername>:<Ihr_Passwort>" als Token, andernfalls setzen Sie das Token nicht. - Wenn Sie Zilliz Cloud, den vollständig verwalteten Cloud-Service für Milvus, verwenden, passen Sie die
uri
undtoken
an, die dem öffentlichen Endpunkt und dem API-Schlüssel in Zilliz Cloud entsprechen.
Wir werden eine MilvusColbertRetriever-Klasse definieren, die den Milvus-Client für den Abruf von Multivektordaten umhüllt. Die Implementierung flacht ColBERT-Einbettungen ab und fügt sie in eine Sammlung ein, wobei jede Zeile eine einzelne Einbettung aus der ColBERT-Einbettungsliste darstellt. Sie zeichnet auch die doc_id und seq_id auf, um den Ursprung jeder Einbettung zu ermitteln.
Bei der Suche mit einer ColBERT-Einbettungsliste werden mehrere Suchen durchgeführt - eine für jede ColBERT-Einbettung. Die gefundenen doc_ids werden dann dedupliziert. Es wird ein Reranking-Prozess durchgeführt, bei dem die vollständigen Einbettungen für jede doc_id abgerufen werden und der MaxSim-Score berechnet wird, um die endgültige Rangfolge der Ergebnisse zu ermitteln.
class MilvusColbertRetriever:
def __init__(self, milvus_client, collection_name, dim=128):
# Initialize the retriever with a Milvus client, collection name, and dimensionality of the vector embeddings.
# If the collection exists, load it.
self.collection_name = collection_name
self.client = milvus_client
if self.client.has_collection(collection_name=self.collection_name):
self.client.load_collection(collection_name)
self.dim = dim
def create_collection(self):
# Create a new collection in Milvus for storing embeddings.
# Drop the existing collection if it already exists and define the schema for the collection.
if self.client.has_collection(collection_name=self.collection_name):
self.client.drop_collection(collection_name=self.collection_name)
schema = self.client.create_schema(
auto_id=True,
enable_dynamic_fields=True,
)
schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True)
schema.add_field(
field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim
)
schema.add_field(field_name="seq_id", datatype=DataType.INT16)
schema.add_field(field_name="doc_id", datatype=DataType.INT64)
schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
self.client.create_collection(
collection_name=self.collection_name, schema=schema
)
def create_index(self):
# Create an index on the vector field to enable fast similarity search.
# Releases and drops any existing index before creating a new one with specified parameters.
self.client.release_collection(collection_name=self.collection_name)
self.client.drop_index(
collection_name=self.collection_name, index_name="vector"
)
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_name="vector_index",
index_type="HNSW", # or any other index type you want
metric_type="IP", # or the appropriate metric type
params={
"M": 16,
"efConstruction": 500,
}, # adjust these parameters as needed
)
self.client.create_index(
collection_name=self.collection_name, index_params=index_params, sync=True
)
def create_scalar_index(self):
# Create a scalar index for the "doc_id" field to enable fast lookups by document ID.
self.client.release_collection(collection_name=self.collection_name)
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="doc_id",
index_name="int32_index",
index_type="INVERTED", # or any other index type you want
)
self.client.create_index(
collection_name=self.collection_name, index_params=index_params, sync=True
)
def search(self, data, topk):
# Perform a vector search on the collection to find the top-k most similar documents.
search_params = {"metric_type": "IP", "params": {}}
results = self.client.search(
self.collection_name,
data,
limit=int(50),
output_fields=["vector", "seq_id", "doc_id"],
search_params=search_params,
)
doc_ids = set()
for r_id in range(len(results)):
for r in range(len(results[r_id])):
doc_ids.add(results[r_id][r]["entity"]["doc_id"])
scores = []
def rerank_single_doc(doc_id, data, client, collection_name):
# Rerank a single document by retrieving its embeddings and calculating the similarity with the query.
doc_colbert_vecs = client.query(
collection_name=collection_name,
filter=f"doc_id in [{doc_id}]",
output_fields=["seq_id", "vector", "doc"],
limit=1000,
)
doc_vecs = np.vstack(
[doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))]
)
score = np.dot(data, doc_vecs.T).max(1).sum()
return (score, doc_id)
with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
futures = {
executor.submit(
rerank_single_doc, doc_id, data, client, self.collection_name
): doc_id
for doc_id in doc_ids
}
for future in concurrent.futures.as_completed(futures):
score, doc_id = future.result()
scores.append((score, doc_id))
scores.sort(key=lambda x: x[0], reverse=True)
if len(scores) >= topk:
return scores[:topk]
else:
return scores
def insert(self, data):
# Insert ColBERT embeddings and metadata for a document into the collection.
colbert_vecs = [vec for vec in data["colbert_vecs"]]
seq_length = len(colbert_vecs)
doc_ids = [data["doc_id"] for i in range(seq_length)]
seq_ids = list(range(seq_length))
docs = [""] * seq_length
docs[0] = data["filepath"]
# Insert the data as multiple vectors (one for each sequence) along with the corresponding metadata.
self.client.insert(
self.collection_name,
[
{
"vector": colbert_vecs[i],
"seq_id": seq_ids[i],
"doc_id": doc_ids[i],
"doc": docs[i],
}
for i in range(seq_length)
],
)
Wir werden die colpali_engine verwenden, um Einbettungslisten für zwei Abfragen zu extrahieren und die relevanten Informationen aus den PDF-Seiten abzurufen.
from colpali_engine.models import ColPali
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
from torch.utils.data import DataLoader
import torch
from typing import List, cast
device = get_torch_device("cpu")
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
queries = [
"How to end-to-end retrieval with ColBert?",
"Where is ColBERT performance table?",
]
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
qs: List[torch.Tensor] = []
for batch_query in dataloader:
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
Zusätzlich müssen wir die Einbettungsliste für jede Seite extrahieren, und es zeigt sich, dass es 1030 128-dimensionale Einbettungen für jede Seite gibt.
from tqdm import tqdm
from PIL import Image
import os
images = [Image.open("./pages/" + name) for name in os.listdir("./pages")]
dataloader = DataLoader(
dataset=ListDataset[str](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
ds: List[torch.Tensor] = []
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
print(ds[0].shape)
0%| | 0/10 [00:00<?, ?it/s]
100%|██████████| 10/10 [01:22<00:00, 8.24s/it]
torch.Size([1030, 128])
Wir werden mit MilvusColbertRetriever eine Sammlung namens "colpali" erstellen.
retriever = MilvusColbertRetriever(collection_name="colpali", milvus_client=client)
retriever.create_collection()
retriever.create_index()
Wir fügen die Einbettungslisten in die Milvus-Datenbank ein.
filepaths = ["./pages/" + name for name in os.listdir("./pages")]
for i in range(len(filepaths)):
data = {
"colbert_vecs": ds[i].float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
}
retriever.insert(data)
Jetzt können wir die relevanteste Seite mit Hilfe der Abfrage-Einbettungsliste suchen.
for query in qs:
query = query.float().numpy()
result = retriever.search(query, topk=1)
print(filepaths[result[0][1]])
./pages/page_5.png
./pages/page_7.png
Schließlich rufen wir den ursprünglichen Seitennamen ab. Mit ColPali können wir multimodale Dokumente abrufen, ohne dass wir komplexe Verarbeitungstechniken benötigen, um Text und Bilder aus den Dokumenten zu extrahieren. Durch den Einsatz von großen Bildverarbeitungsmodellen können mehr Informationen - wie z. B. Tabellen und Abbildungen - ohne signifikanten Informationsverlust analysiert werden.