使用 ColPali 與 Milvus 進行多模式檢索
現代的檢索模型通常使用單一的嵌入來表示文字或影像。然而,ColBERT 是一種神經模型,它利用每個資料實例的嵌入清單,並採用「MaxSim」運算來計算兩個文字之間的相似度。除了文字資料之外,圖形、表格和圖表也包含豐富的資訊,這些資訊在以文字為基礎的資訊檢索中往往被忽略。
MaxSim 功能是透過查看查詢與文件 (您要搜尋的內容) 的代號嵌入 (token embeddings) 來比較它們。對於查詢中的每個單字,它會從文件中挑出最相似的單字 (使用余弦相似度或平方 L2 距離),然後將查詢中所有單字的最大相似度相加。
ColPali 是一種結合 ColBERT 的多向量表示法與 PaliGemma(多模態大語言模型)的方法,以利用其強大的理解能力。這種方法可以使用統一的多向量嵌入來表示包含文字和圖像的頁面。這個多向量表達中的嵌入可以捕捉到詳細的資訊,提高多模態資料的檢索增強生成 (RAG) 效能。
在本筆記簿中,為了一般性起見,我們將此種多向量表示法稱為「ColBERT 內嵌」。然而,實際使用的模型是ColPali 模型。我們將示範如何使用 Milvus 進行多向量檢索。在此基礎上,我們將介紹如何使用 ColPali 根據給定的查詢來檢索網頁。
準備工作
$ pip install pdf2image
$ pip pymilvus
$ pip install colpali_engine
$ pip install tqdm
$ pip instal pillow
準備資料
我們將以 PDF RAG 為例。您可以下載ColBERTpaper 並將其放入./pdf
。ColPali 並不直接處理文字,而是將整個頁面光柵化為影像。ColPali 模型擅長理解這些圖像中包含的文字資訊。因此,我們會將每個 PDF 頁面轉換成影像來處理。
from pdf2image import convert_from_path
pdf_path = "pdfs/2004.12832v2.pdf"
images = convert_from_path(pdf_path)
for i, image in enumerate(images):
image.save(f"pages/page_{i + 1}.png", "PNG")
接下來,我們將使用 Milvus Lite 來初始化資料庫。您可以透過設定 uri 到您的 Milvus 服務託管的適當位址,輕鬆切換到完整的 Milvus 實例。
from pymilvus import MilvusClient, DataType
import numpy as np
import concurrent.futures
client = MilvusClient(uri="milvus.db")
- 如果您只需要一個本機向量資料庫來進行小規模的資料或原型設計,將 uri 設定為一個本機檔案,例如
./milvus.db
,是最方便的方法,因為它會自動利用Milvus Lite將所有資料儲存在這個檔案中。 - 如果您有大規模的資料,例如超過一百萬個向量,您可以在Docker 或 Kubernetes 上架設效能更高的 Milvus 伺服器。在此設定中,請使用伺服器位址和連接埠作為您的 uri,例如
http://localhost:19530
。如果您啟用 Milvus 的驗證功能,請使用「<your_username>:<your_password>」作為令牌,否則請勿設定令牌。 - 如果您使用Zilliz Cloud(Milvus 的完全管理雲端服務),請調整
uri
和token
,它們對應於 Zilliz Cloud 中的Public Endpoint 和 API key。
我們將定義一個 MilvusColbertRetriever 類別,用來包圍 Milvus 用戶端進行多向量資料擷取。該實作會將 ColBERT embeddings 平面化,並將它們插入一個集合,其中每一行代表 ColBERT embedding 清單中的一個個別 embedding。它還記錄了 doc_id 和 seq_id,以便追蹤每個內嵌的來源。
使用 ColBERT 嵌入列表進行搜尋時,會進行多次搜尋,每次搜尋一個 ColBERT 嵌入。擷取的 doc_ids 將被重複。將執行重新排序過程,在此過程中,每個 doc_id 的完整內嵌被擷取,並計算 MaxSim 得分,以產生最終的排序結果。
class MilvusColbertRetriever:
def __init__(self, milvus_client, collection_name, dim=128):
# Initialize the retriever with a Milvus client, collection name, and dimensionality of the vector embeddings.
# If the collection exists, load it.
self.collection_name = collection_name
self.client = milvus_client
if self.client.has_collection(collection_name=self.collection_name):
self.client.load_collection(collection_name)
self.dim = dim
def create_collection(self):
# Create a new collection in Milvus for storing embeddings.
# Drop the existing collection if it already exists and define the schema for the collection.
if self.client.has_collection(collection_name=self.collection_name):
self.client.drop_collection(collection_name=self.collection_name)
schema = self.client.create_schema(
auto_id=True,
enable_dynamic_fields=True,
)
schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True)
schema.add_field(
field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim
)
schema.add_field(field_name="seq_id", datatype=DataType.INT16)
schema.add_field(field_name="doc_id", datatype=DataType.INT64)
schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
self.client.create_collection(
collection_name=self.collection_name, schema=schema
)
def create_index(self):
# Create an index on the vector field to enable fast similarity search.
# Releases and drops any existing index before creating a new one with specified parameters.
self.client.release_collection(collection_name=self.collection_name)
self.client.drop_index(
collection_name=self.collection_name, index_name="vector"
)
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_name="vector_index",
index_type="HNSW", # or any other index type you want
metric_type="IP", # or the appropriate metric type
params={
"M": 16,
"efConstruction": 500,
}, # adjust these parameters as needed
)
self.client.create_index(
collection_name=self.collection_name, index_params=index_params, sync=True
)
def create_scalar_index(self):
# Create a scalar index for the "doc_id" field to enable fast lookups by document ID.
self.client.release_collection(collection_name=self.collection_name)
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="doc_id",
index_name="int32_index",
index_type="INVERTED", # or any other index type you want
)
self.client.create_index(
collection_name=self.collection_name, index_params=index_params, sync=True
)
def search(self, data, topk):
# Perform a vector search on the collection to find the top-k most similar documents.
search_params = {"metric_type": "IP", "params": {}}
results = self.client.search(
self.collection_name,
data,
limit=int(50),
output_fields=["vector", "seq_id", "doc_id"],
search_params=search_params,
)
doc_ids = set()
for r_id in range(len(results)):
for r in range(len(results[r_id])):
doc_ids.add(results[r_id][r]["entity"]["doc_id"])
scores = []
def rerank_single_doc(doc_id, data, client, collection_name):
# Rerank a single document by retrieving its embeddings and calculating the similarity with the query.
doc_colbert_vecs = client.query(
collection_name=collection_name,
filter=f"doc_id in [{doc_id}]",
output_fields=["seq_id", "vector", "doc"],
limit=1000,
)
doc_vecs = np.vstack(
[doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))]
)
score = np.dot(data, doc_vecs.T).max(1).sum()
return (score, doc_id)
with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
futures = {
executor.submit(
rerank_single_doc, doc_id, data, client, self.collection_name
): doc_id
for doc_id in doc_ids
}
for future in concurrent.futures.as_completed(futures):
score, doc_id = future.result()
scores.append((score, doc_id))
scores.sort(key=lambda x: x[0], reverse=True)
if len(scores) >= topk:
return scores[:topk]
else:
return scores
def insert(self, data):
# Insert ColBERT embeddings and metadata for a document into the collection.
colbert_vecs = [vec for vec in data["colbert_vecs"]]
seq_length = len(colbert_vecs)
doc_ids = [data["doc_id"] for i in range(seq_length)]
seq_ids = list(range(seq_length))
docs = [""] * seq_length
docs[0] = data["filepath"]
# Insert the data as multiple vectors (one for each sequence) along with the corresponding metadata.
self.client.insert(
self.collection_name,
[
{
"vector": colbert_vecs[i],
"seq_id": seq_ids[i],
"doc_id": doc_ids[i],
"doc": docs[i],
}
for i in range(seq_length)
],
)
我們將使用colpali_engine來提取兩個查詢的嵌入列表,並從 PDF 頁面中擷取相關資訊。
from colpali_engine.models import ColPali
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
from torch.utils.data import DataLoader
import torch
from typing import List, cast
device = get_torch_device("cpu")
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
queries = [
"How to end-to-end retrieval with ColBert?",
"Where is ColBERT performance table?",
]
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
qs: List[torch.Tensor] = []
for batch_query in dataloader:
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
此外,我們還要抽取每個頁面的嵌入列表,它顯示每個頁面有 1030 個 128 維嵌入。
from tqdm import tqdm
from PIL import Image
import os
images = [Image.open("./pages/" + name) for name in os.listdir("./pages")]
dataloader = DataLoader(
dataset=ListDataset[str](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
ds: List[torch.Tensor] = []
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
print(ds[0].shape)
0%| | 0/10 [00:00<?, ?it/s]
100%|██████████| 10/10 [01:22<00:00, 8.24s/it]
torch.Size([1030, 128])
我們將使用 MilvusColbertRetriever 建立一個名為「colpali」的集合。
retriever = MilvusColbertRetriever(collection_name="colpali", milvus_client=client)
retriever.create_collection()
retriever.create_index()
我們將插入嵌入清單到 Milvus 資料庫。
filepaths = ["./pages/" + name for name in os.listdir("./pages")]
for i in range(len(filepaths)):
data = {
"colbert_vecs": ds[i].float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
}
retriever.insert(data)
現在我們可以使用查詢嵌入清單搜尋最相關的頁面。
for query in qs:
query = query.float().numpy()
result = retriever.search(query, topk=1)
print(filepaths[result[0][1]])
./pages/page_5.png
./pages/page_7.png
最後,我們擷取原始的頁面名稱。透過 ColPali,我們可以擷取多模態文件,而不需要複雜的處理技術來擷取文件中的文字和影像。透過利用大型視覺模型,可以分析更多的資訊,例如表格和圖表,而不會造成重大的資訊損失。