Utilizar ColPali para la recuperación multimodal con Milvus
Los modelos de recuperación modernos suelen utilizar una única incrustación para representar texto o imágenes. ColBERT, sin embargo, es un modelo neuronal que utiliza una lista de incrustaciones para cada instancia de datos y emplea una operación "MaxSim" para calcular la similitud entre dos textos. Más allá de los datos textuales, las figuras, tablas y diagramas también contienen abundante información, que a menudo no se tiene en cuenta en la recuperación de información basada en texto.
La función MaxSim compara una consulta con un documento (lo que se está buscando) observando sus incrustaciones de tokens. Para cada palabra de la consulta, elige la palabra más similar del documento (utilizando la similitud coseno o la distancia L2 al cuadrado) y suma estas similitudes máximas entre todas las palabras de la consulta.
ColPali es un método que combina la representación multivectorial de ColBERT con PaliGemma (un gran modelo de lenguaje multimodal) para aprovechar sus grandes capacidades de comprensión. Este enfoque permite representar una página con texto e imágenes mediante una incrustación multivectorial unificada. Las incrustaciones dentro de esta representación multivectorial pueden capturar información detallada, mejorando el rendimiento de la generación aumentada de recuperación (RAG) para datos multimodales.
En este cuaderno, nos referimos a este tipo de representación multivectorial como "incrustaciones ColBERT" por razones de generalidad. Sin embargo, el modelo real que se utiliza es el modelo ColPali. Demostraremos cómo utilizar Milvus para la recuperación multivectorial. A partir de ahí, presentaremos cómo utilizar ColPali para recuperar páginas a partir de una consulta determinada.
Preparación
$ pip install pdf2image
$ pip pymilvus
$ pip install colpali_engine
$ pip install tqdm
$ pip instal pillow
Preparación de los datos
Utilizaremos PDF RAG como ejemplo. Puede descargar el documento ColBERT e introducirlo en ./pdf
. ColPali no procesa el texto directamente, sino que rasteriza toda la página en una imagen. El modelo ColPali destaca en la comprensión de la información textual contenida en estas imágenes. Por lo tanto, convertiremos cada página PDF en una imagen para su procesamiento.
from pdf2image import convert_from_path
pdf_path = "pdfs/2004.12832v2.pdf"
images = convert_from_path(pdf_path)
for i, image in enumerate(images):
image.save(f"pages/page_{i + 1}.png", "PNG")
A continuación, inicializaremos una base de datos utilizando Milvus Lite. Puede cambiar fácilmente a una instancia completa de Milvus estableciendo la uri en la dirección apropiada donde esté alojado su servicio Milvus.
from pymilvus import MilvusClient, DataType
import numpy as np
import concurrent.futures
client = MilvusClient(uri="milvus.db")
- Si sólo necesita una base de datos vectorial local para datos a pequeña escala o prototipos, establecer la uri como un archivo local, por ejemplo
./milvus.db
, es el método más conveniente, ya que utiliza automáticamente Milvus Lite para almacenar todos los datos en este archivo. - Si tiene una gran escala de datos, digamos más de un millón de vectores, puede configurar un servidor Milvus más eficiente en Docker o Kubernetes. En esta configuración, por favor utilice la dirección del servidor y el puerto como su uri, por ejemplo
http://localhost:19530
. Si habilita la función de autenticación en Milvus, utilice "<su_nombre_de_usuario>:<su_contraseña>" como token, de lo contrario no configure el token. - Si utilizas Zilliz Cloud, el servicio en la nube totalmente gestionado para Milvus, ajusta los
uri
ytoken
, que se corresponden con el Public Endpoint y la API key en Zilliz Cloud.
Definiremos una clase MilvusColbertRetriever para envolver el cliente Milvus para la recuperación de datos multivector. La implementación aplana las incrustaciones ColBERT y las inserta en una colección, donde cada fila representa una incrustación individual de la lista de incrustaciones ColBERT. También registra el doc_id y el seq_id para rastrear el origen de cada incrustación.
Al buscar con una lista de incrustaciones ColBERT, se realizarán varias búsquedas, una por cada incrustación ColBERT. Los doc_ids recuperados se deduplicarán. Se realizará un proceso de reordenación en el que se obtendrán las incrustaciones completas para cada doc_id y se calculará la puntuación MaxSim para obtener los resultados finales ordenados.
class MilvusColbertRetriever:
def __init__(self, milvus_client, collection_name, dim=128):
# Initialize the retriever with a Milvus client, collection name, and dimensionality of the vector embeddings.
# If the collection exists, load it.
self.collection_name = collection_name
self.client = milvus_client
if self.client.has_collection(collection_name=self.collection_name):
self.client.load_collection(collection_name)
self.dim = dim
def create_collection(self):
# Create a new collection in Milvus for storing embeddings.
# Drop the existing collection if it already exists and define the schema for the collection.
if self.client.has_collection(collection_name=self.collection_name):
self.client.drop_collection(collection_name=self.collection_name)
schema = self.client.create_schema(
auto_id=True,
enable_dynamic_fields=True,
)
schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True)
schema.add_field(
field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim
)
schema.add_field(field_name="seq_id", datatype=DataType.INT16)
schema.add_field(field_name="doc_id", datatype=DataType.INT64)
schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
self.client.create_collection(
collection_name=self.collection_name, schema=schema
)
def create_index(self):
# Create an index on the vector field to enable fast similarity search.
# Releases and drops any existing index before creating a new one with specified parameters.
self.client.release_collection(collection_name=self.collection_name)
self.client.drop_index(
collection_name=self.collection_name, index_name="vector"
)
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_name="vector_index",
index_type="HNSW", # or any other index type you want
metric_type="IP", # or the appropriate metric type
params={
"M": 16,
"efConstruction": 500,
}, # adjust these parameters as needed
)
self.client.create_index(
collection_name=self.collection_name, index_params=index_params, sync=True
)
def create_scalar_index(self):
# Create a scalar index for the "doc_id" field to enable fast lookups by document ID.
self.client.release_collection(collection_name=self.collection_name)
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="doc_id",
index_name="int32_index",
index_type="INVERTED", # or any other index type you want
)
self.client.create_index(
collection_name=self.collection_name, index_params=index_params, sync=True
)
def search(self, data, topk):
# Perform a vector search on the collection to find the top-k most similar documents.
search_params = {"metric_type": "IP", "params": {}}
results = self.client.search(
self.collection_name,
data,
limit=int(50),
output_fields=["vector", "seq_id", "doc_id"],
search_params=search_params,
)
doc_ids = set()
for r_id in range(len(results)):
for r in range(len(results[r_id])):
doc_ids.add(results[r_id][r]["entity"]["doc_id"])
scores = []
def rerank_single_doc(doc_id, data, client, collection_name):
# Rerank a single document by retrieving its embeddings and calculating the similarity with the query.
doc_colbert_vecs = client.query(
collection_name=collection_name,
filter=f"doc_id in [{doc_id}, {doc_id + 1}]",
output_fields=["seq_id", "vector", "doc"],
limit=1000,
)
doc_vecs = np.vstack(
[doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))]
)
score = np.dot(data, doc_vecs.T).max(1).sum()
return (score, doc_id)
with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
futures = {
executor.submit(
rerank_single_doc, doc_id, data, client, self.collection_name
): doc_id
for doc_id in doc_ids
}
for future in concurrent.futures.as_completed(futures):
score, doc_id = future.result()
scores.append((score, doc_id))
scores.sort(key=lambda x: x[0], reverse=True)
if len(scores) >= topk:
return scores[:topk]
else:
return scores
def insert(self, data):
# Insert ColBERT embeddings and metadata for a document into the collection.
colbert_vecs = [vec for vec in data["colbert_vecs"]]
seq_length = len(colbert_vecs)
doc_ids = [data["doc_id"] for i in range(seq_length)]
seq_ids = list(range(seq_length))
docs = [""] * seq_length
docs[0] = data["filepath"]
# Insert the data as multiple vectors (one for each sequence) along with the corresponding metadata.
self.client.insert(
self.collection_name,
[
{
"vector": colbert_vecs[i],
"seq_id": seq_ids[i],
"doc_id": doc_ids[i],
"doc": docs[i],
}
for i in range(seq_length)
],
)
Utilizaremos colpali_engine para extraer las listas de incrustación de dos consultas y recuperar la información relevante de las páginas PDF.
from colpali_engine.models import ColPali
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
from torch.utils.data import DataLoader
import torch
from typing import List, cast
device = get_torch_device("cpu")
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
queries = [
"How to end-to-end retrieval with ColBert?",
"Where is ColBERT performance table?",
]
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
qs: List[torch.Tensor] = []
for batch_query in dataloader:
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
Además, tendremos que extraer la lista de incrustación para cada página y muestra que hay 1030 128-dimensional incrustaciones para cada página.
from tqdm import tqdm
from PIL import Image
import os
images = [Image.open("./pages/" + name) for name in os.listdir("./pages")]
dataloader = DataLoader(
dataset=ListDataset[str](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
ds: List[torch.Tensor] = []
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
print(ds[0].shape)
0%| | 0/10 [00:00<?, ?it/s]
100%|██████████| 10/10 [01:22<00:00, 8.24s/it]
torch.Size([1030, 128])
Crearemos una colección llamada "colpali" utilizando MilvusColbertRetriever.
retriever = MilvusColbertRetriever(collection_name="colpali", milvus_client=client)
retriever.create_collection()
retriever.create_index()
Insertaremos las listas de incrustación en la base de datos Milvus.
filepaths = ["./pages/" + name for name in os.listdir("./pages")]
for i in range(len(filepaths)):
data = {
"colbert_vecs": ds[i].float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
}
retriever.insert(data)
Ahora podemos buscar la página más relevante utilizando la lista de incrustación de consulta.
for query in qs:
query = query.float().numpy()
result = retriever.search(query, topk=1)
print(filepaths[result[0][1]])
./pages/page_5.png
./pages/page_7.png
Por último, recuperamos el nombre de la página original. Con ColPali, podemos recuperar documentos multimodales sin necesidad de recurrir a complejas técnicas de procesamiento para extraer texto e imágenes de los documentos. Al aprovechar grandes modelos de visión, se puede analizar más información, como tablas y figuras, sin una pérdida significativa de información.