Utiliser ColPali pour la recherche multimodale avec Milvus
Les modèles de recherche modernes utilisent généralement un seul ancrage pour représenter le texte ou les images. ColBERT, en revanche, est un modèle neuronal qui utilise une liste d'enchâssements pour chaque instance de données et emploie une opération "MaxSim" pour calculer la similarité entre deux textes. Au-delà des données textuelles, les figures, les tableaux et les diagrammes contiennent également des informations riches, qui sont souvent ignorées dans la recherche d'informations basée sur le texte.
La fonction MaxSim compare une requête avec un document (ce que vous recherchez) en examinant leurs enchâssements de jetons. Pour chaque mot de la requête, elle choisit le mot le plus similaire du document (en utilisant la similarité cosinus ou la distance L2 au carré) et additionne ces similarités maximales pour tous les mots de la requête.
ColPali est une méthode qui combine la représentation multi-vectorielle de ColBERT avec PaliGemma (un modèle de langage large et multimodal) pour tirer parti de ses fortes capacités de compréhension. Cette approche permet de représenter une page contenant à la fois du texte et des images à l'aide d'une représentation multi-vectorielle unifiée. Les encastrements dans cette représentation multi-vectorielle peuvent capturer des informations détaillées, améliorant ainsi les performances de la génération augmentée de recherche (RAG) pour les données multimodales.
Dans ce manuel, nous appelons ce type de représentation multi-vectorielle "encastrements ColBERT" pour des raisons de généralité. Cependant, le modèle utilisé est le modèle ColPali. Nous montrerons comment utiliser Milvus pour la recherche multi-vectorielle. Sur cette base, nous présenterons l'utilisation de ColPali pour la recherche de pages sur la base d'une requête donnée.
Préparation
$ pip install pdf2image
$ pip pymilvus
$ pip install colpali_engine
$ pip install tqdm
$ pip instal pillow
Préparer les données
Nous utiliserons le PDF RAG comme exemple. Vous pouvez télécharger le document ColBERT et le placer dans ./pdf
. ColPali ne traite pas directement le texte, mais la page entière est transformée en image. Le modèle ColPali excelle dans la compréhension des informations textuelles contenues dans ces images. Par conséquent, nous convertirons chaque page PDF en une image à traiter.
from pdf2image import convert_from_path
pdf_path = "pdfs/2004.12832v2.pdf"
images = convert_from_path(pdf_path)
for i, image in enumerate(images):
image.save(f"pages/page_{i + 1}.png", "PNG")
Ensuite, nous initialiserons une base de données à l'aide de Milvus Lite. Vous pouvez facilement passer à une instance Milvus complète en définissant l'uri à l'adresse appropriée où votre service Milvus est hébergé.
from pymilvus import MilvusClient, DataType
import numpy as np
import concurrent.futures
client = MilvusClient(uri="milvus.db")
- Si vous n'avez besoin d'une base de données vectorielle locale que pour des données à petite échelle ou pour le prototypage, définir l'uri comme un fichier local, par exemple
./milvus.db
, est la méthode la plus pratique, car elle utilise automatiquement Milvus Lite pour stocker toutes les données dans ce fichier. - Si vous disposez de données à grande échelle, par exemple plus d'un million de vecteurs, vous pouvez configurer un serveur Milvus plus performant sur Docker ou Kubernetes. Dans cette configuration, veuillez utiliser l'adresse et le port du serveur comme uri, par exemple
http://localhost:19530
. Si vous activez la fonction d'authentification sur Milvus, utilisez "<votre_nom_d'utilisateur>:<votre_mot_de_passe>" comme jeton, sinon ne définissez pas le jeton. - Si vous utilisez Zilliz Cloud, le service en nuage entièrement géré pour Milvus, ajustez les valeurs
uri
ettoken
, qui correspondent au point de terminaison public et à la clé API dans Zilliz Cloud.
Nous allons définir une classe MilvusColbertRetriever pour envelopper le client Milvus afin de récupérer des données multi-vectorielles. L'implémentation aplatit les embeddings ColBERT et les insère dans une collection, où chaque ligne représente un embedding individuel de la liste des embeddings ColBERT. Elle enregistre également le doc_id et le seq_id pour retracer l'origine de chaque embedding.
Lors d'une recherche à l'aide d'une liste d'éléments intégrés de ColBERT, plusieurs recherches seront effectuées, une pour chaque élément intégré de ColBERT. Les doc_id récupérés seront ensuite dédupliqués. Un processus de reclassement sera effectué, dans lequel les embeddings complets pour chaque doc_id seront récupérés, et le score MaxSim sera calculé pour produire les résultats finaux classés.
class MilvusColbertRetriever:
def __init__(self, milvus_client, collection_name, dim=128):
# Initialize the retriever with a Milvus client, collection name, and dimensionality of the vector embeddings.
# If the collection exists, load it.
self.collection_name = collection_name
self.client = milvus_client
if self.client.has_collection(collection_name=self.collection_name):
self.client.load_collection(collection_name)
self.dim = dim
def create_collection(self):
# Create a new collection in Milvus for storing embeddings.
# Drop the existing collection if it already exists and define the schema for the collection.
if self.client.has_collection(collection_name=self.collection_name):
self.client.drop_collection(collection_name=self.collection_name)
schema = self.client.create_schema(
auto_id=True,
enable_dynamic_fields=True,
)
schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True)
schema.add_field(
field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim
)
schema.add_field(field_name="seq_id", datatype=DataType.INT16)
schema.add_field(field_name="doc_id", datatype=DataType.INT64)
schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
self.client.create_collection(
collection_name=self.collection_name, schema=schema
)
def create_index(self):
# Create an index on the vector field to enable fast similarity search.
# Releases and drops any existing index before creating a new one with specified parameters.
self.client.release_collection(collection_name=self.collection_name)
self.client.drop_index(
collection_name=self.collection_name, index_name="vector"
)
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_name="vector_index",
index_type="HNSW", # or any other index type you want
metric_type="IP", # or the appropriate metric type
params={
"M": 16,
"efConstruction": 500,
}, # adjust these parameters as needed
)
self.client.create_index(
collection_name=self.collection_name, index_params=index_params, sync=True
)
def create_scalar_index(self):
# Create a scalar index for the "doc_id" field to enable fast lookups by document ID.
self.client.release_collection(collection_name=self.collection_name)
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="doc_id",
index_name="int32_index",
index_type="INVERTED", # or any other index type you want
)
self.client.create_index(
collection_name=self.collection_name, index_params=index_params, sync=True
)
def search(self, data, topk):
# Perform a vector search on the collection to find the top-k most similar documents.
search_params = {"metric_type": "IP", "params": {}}
results = self.client.search(
self.collection_name,
data,
limit=int(50),
output_fields=["vector", "seq_id", "doc_id"],
search_params=search_params,
)
doc_ids = set()
for r_id in range(len(results)):
for r in range(len(results[r_id])):
doc_ids.add(results[r_id][r]["entity"]["doc_id"])
scores = []
def rerank_single_doc(doc_id, data, client, collection_name):
# Rerank a single document by retrieving its embeddings and calculating the similarity with the query.
doc_colbert_vecs = client.query(
collection_name=collection_name,
filter=f"doc_id in [{doc_id}]",
output_fields=["seq_id", "vector", "doc"],
limit=1000,
)
doc_vecs = np.vstack(
[doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))]
)
score = np.dot(data, doc_vecs.T).max(1).sum()
return (score, doc_id)
with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
futures = {
executor.submit(
rerank_single_doc, doc_id, data, client, self.collection_name
): doc_id
for doc_id in doc_ids
}
for future in concurrent.futures.as_completed(futures):
score, doc_id = future.result()
scores.append((score, doc_id))
scores.sort(key=lambda x: x[0], reverse=True)
if len(scores) >= topk:
return scores[:topk]
else:
return scores
def insert(self, data):
# Insert ColBERT embeddings and metadata for a document into the collection.
colbert_vecs = [vec for vec in data["colbert_vecs"]]
seq_length = len(colbert_vecs)
doc_ids = [data["doc_id"] for i in range(seq_length)]
seq_ids = list(range(seq_length))
docs = [""] * seq_length
docs[0] = data["filepath"]
# Insert the data as multiple vectors (one for each sequence) along with the corresponding metadata.
self.client.insert(
self.collection_name,
[
{
"vector": colbert_vecs[i],
"seq_id": seq_ids[i],
"doc_id": doc_ids[i],
"doc": docs[i],
}
for i in range(seq_length)
],
)
Nous utiliserons le moteur colpali pour extraire les listes d'intégration pour deux requêtes et récupérer les informations pertinentes des pages PDF.
from colpali_engine.models import ColPali
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
from torch.utils.data import DataLoader
import torch
from typing import List, cast
device = get_torch_device("cpu")
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
queries = [
"How to end-to-end retrieval with ColBert?",
"Where is ColBERT performance table?",
]
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
qs: List[torch.Tensor] = []
for batch_query in dataloader:
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
En outre, nous devrons extraire la liste d'encastrements pour chaque page, qui contient 1030 encastrements à 128 dimensions.
from tqdm import tqdm
from PIL import Image
import os
images = [Image.open("./pages/" + name) for name in os.listdir("./pages")]
dataloader = DataLoader(
dataset=ListDataset[str](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
ds: List[torch.Tensor] = []
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
print(ds[0].shape)
0%| | 0/10 [00:00<?, ?it/s]
100%|██████████| 10/10 [01:22<00:00, 8.24s/it]
torch.Size([1030, 128])
Nous allons créer une collection appelée "colpali" à l'aide de MilvusColbertRetriever.
retriever = MilvusColbertRetriever(collection_name="colpali", milvus_client=client)
retriever.create_collection()
retriever.create_index()
Nous insérerons les listes d'encastrements dans la base de données Milvus.
filepaths = ["./pages/" + name for name in os.listdir("./pages")]
for i in range(len(filepaths)):
data = {
"colbert_vecs": ds[i].float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
}
retriever.insert(data)
Nous pouvons maintenant rechercher la page la plus pertinente à l'aide de la liste d'intégration des requêtes.
for query in qs:
query = query.float().numpy()
result = retriever.search(query, topk=1)
print(filepaths[result[0][1]])
./pages/page_5.png
./pages/page_7.png
Enfin, nous récupérons le nom de la page originale. Avec ColPali, nous pouvons récupérer des documents multimodaux sans avoir besoin de techniques de traitement complexes pour extraire le texte et les images des documents. En tirant parti de modèles de vision de grande taille, il est possible d'analyser davantage d'informations, telles que les tableaux et les figures, sans perte significative d'informations.