milvus-logo
LFAI
フロントページへ
  • チュートリアル

Milvusでマルチモーダル検索にColPaliを使う

Open In Colab GitHub Repository

最新の検索モデルでは、テキストや画像を表現するために単一の埋め込みを使用するのが一般的です。しかしColBERTは、各データインスタンスに対して埋め込みリストを利用するニューラルモデルであり、2つのテキスト間の類似度を計算するために「MaxSim」演算を採用しています。テキストデータだけでなく、図、表、ダイアグラムにも豊富な情報が含まれているが、テキストベースの情報検索では軽視されがちである。

MaxSim関数は、クエリとドキュメント(検索対象)のトークン埋め込みを比較します。クエリ内の各単語について、ドキュメントから最も類似した単語を選び(コサイン類似度またはL2距離の2乗を使用)、クエリ内の全単語にわたってこれらの最大類似度を合計する。

ColPali は、ColBERT のマルチベクトル表現と PaliGemma(マルチモーダル大規模言語モデル)を組み合 わせ、その強力な理解能力を活用する手法である。このアプローチにより、テキストと画像の両方を含むページを、統一されたマルチベクター埋め込みを用いて表現することができる。このマルチベクトル表現内の埋め込みは詳細な情報を捉えることができ、マルチモーダルデータに対する検索支援生成(RAG)の性能を向上させる。

このノートブックでは、一般性のために、この種のマルチベクトル表現を「ColBERT埋め込み」と呼ぶ。しかし、実際に使われているモデルはColPaliモデルである。Milvusをマルチベクトル検索に利用する方法を紹介する。その上で、与えられたクエリに基づいてページを検索するためのColPaliの使い方を紹介する。

準備

$ pip install pdf2image
$ pip pymilvus
$ pip install colpali_engine
$ pip install tqdm
$ pip instal pillow

データの準備

PDF RAGを例として使用する。ColBERT論文をダウンロードし、./pdf 。ColPaliはテキストを直接処理するのではなく、ページ全体を画像にラスタライズする。ColPali モデルは、これらの画像に含まれるテキスト情報を理解することに優れています。したがって、各 PDF ページを画像に変換して処理します。

from pdf2image import convert_from_path

pdf_path = "pdfs/2004.12832v2.pdf"
images = convert_from_path(pdf_path)

for i, image in enumerate(images):
    image.save(f"pages/page_{i + 1}.png", "PNG")

次に、Milvus Liteを使ってデータベースを初期化します。Milvusサービスがホストされている適切なアドレスにuriを設定することで、簡単にフルMilvusインスタンスに切り替えることができます。

from pymilvus import MilvusClient, DataType
import numpy as np
import concurrent.futures

client = MilvusClient(uri="milvus.db")
  • 小規模なデータやプロトタイピングのためにローカルのベクターデータベースが必要なだけであれば、uriをローカルファイル、例えば./milvus.db に設定するのが最も便利な方法です。
  • もし、100万ベクトルを超えるような大規模なデータがある場合は、DockerやKubernetes上に、よりパフォーマンスの高いMilvusサーバを構築することができます。このセットアップでは、サーバのアドレスとポートをURIとして使用してください(例:http://localhost:19530 )。Milvusで認証機能を有効にしている場合、トークンには"<your_username>:<your_password>"を使用します。
  • MilvusのフルマネージドクラウドサービスであるMilvus Cloudを利用する場合は、Milvus CloudのPublic EndpointとAPI Keyに対応するuritoken を調整します。

MilvusColbertRetrieverクラスを定義し、Milvusクライアントをラップしてマルチベクターデータを取得できるようにします。この実装は、ColBERT埋め込みを平坦化してコレクションに挿入し、各行がColBERT埋め込みリストの個々の埋め込みを表す。また、各埋め込みの出所を追跡するために、doc_id と seq_id を記録する。

ColBERT 埋め込みリストで検索する場合、複数の検索が行われる。検索された doc_id は、重複排除される。再ランク付けプロセスが実行され、各 doc_id の完全な埋め込みが取得され、MaxSim スコアが計算され、最終的なランク付け結果が生成される。

class MilvusColbertRetriever:
    def __init__(self, milvus_client, collection_name, dim=128):
        # Initialize the retriever with a Milvus client, collection name, and dimensionality of the vector embeddings.
        # If the collection exists, load it.
        self.collection_name = collection_name
        self.client = milvus_client
        if self.client.has_collection(collection_name=self.collection_name):
            self.client.load_collection(collection_name)
        self.dim = dim

    def create_collection(self):
        # Create a new collection in Milvus for storing embeddings.
        # Drop the existing collection if it already exists and define the schema for the collection.
        if self.client.has_collection(collection_name=self.collection_name):
            self.client.drop_collection(collection_name=self.collection_name)
        schema = self.client.create_schema(
            auto_id=True,
            enable_dynamic_fields=True,
        )
        schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True)
        schema.add_field(
            field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim
        )
        schema.add_field(field_name="seq_id", datatype=DataType.INT16)
        schema.add_field(field_name="doc_id", datatype=DataType.INT64)
        schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)

        self.client.create_collection(
            collection_name=self.collection_name, schema=schema
        )

    def create_index(self):
        # Create an index on the vector field to enable fast similarity search.
        # Releases and drops any existing index before creating a new one with specified parameters.
        self.client.release_collection(collection_name=self.collection_name)
        self.client.drop_index(
            collection_name=self.collection_name, index_name="vector"
        )
        index_params = self.client.prepare_index_params()
        index_params.add_index(
            field_name="vector",
            index_name="vector_index",
            index_type="HNSW",  # or any other index type you want
            metric_type="IP",  # or the appropriate metric type
            params={
                "M": 16,
                "efConstruction": 500,
            },  # adjust these parameters as needed
        )

        self.client.create_index(
            collection_name=self.collection_name, index_params=index_params, sync=True
        )

    def create_scalar_index(self):
        # Create a scalar index for the "doc_id" field to enable fast lookups by document ID.
        self.client.release_collection(collection_name=self.collection_name)

        index_params = self.client.prepare_index_params()
        index_params.add_index(
            field_name="doc_id",
            index_name="int32_index",
            index_type="INVERTED",  # or any other index type you want
        )

        self.client.create_index(
            collection_name=self.collection_name, index_params=index_params, sync=True
        )

    def search(self, data, topk):
        # Perform a vector search on the collection to find the top-k most similar documents.
        search_params = {"metric_type": "IP", "params": {}}
        results = self.client.search(
            self.collection_name,
            data,
            limit=int(50),
            output_fields=["vector", "seq_id", "doc_id"],
            search_params=search_params,
        )
        doc_ids = set()
        for r_id in range(len(results)):
            for r in range(len(results[r_id])):
                doc_ids.add(results[r_id][r]["entity"]["doc_id"])

        scores = []

        def rerank_single_doc(doc_id, data, client, collection_name):
            # Rerank a single document by retrieving its embeddings and calculating the similarity with the query.
            doc_colbert_vecs = client.query(
                collection_name=collection_name,
                filter=f"doc_id in [{doc_id}]",
                output_fields=["seq_id", "vector", "doc"],
                limit=1000,
            )
            doc_vecs = np.vstack(
                [doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))]
            )
            score = np.dot(data, doc_vecs.T).max(1).sum()
            return (score, doc_id)

        with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
            futures = {
                executor.submit(
                    rerank_single_doc, doc_id, data, client, self.collection_name
                ): doc_id
                for doc_id in doc_ids
            }
            for future in concurrent.futures.as_completed(futures):
                score, doc_id = future.result()
                scores.append((score, doc_id))

        scores.sort(key=lambda x: x[0], reverse=True)
        if len(scores) >= topk:
            return scores[:topk]
        else:
            return scores

    def insert(self, data):
        # Insert ColBERT embeddings and metadata for a document into the collection.
        colbert_vecs = [vec for vec in data["colbert_vecs"]]
        seq_length = len(colbert_vecs)
        doc_ids = [data["doc_id"] for i in range(seq_length)]
        seq_ids = list(range(seq_length))
        docs = [""] * seq_length
        docs[0] = data["filepath"]

        # Insert the data as multiple vectors (one for each sequence) along with the corresponding metadata.
        self.client.insert(
            self.collection_name,
            [
                {
                    "vector": colbert_vecs[i],
                    "seq_id": seq_ids[i],
                    "doc_id": doc_ids[i],
                    "doc": docs[i],
                }
                for i in range(seq_length)
            ],
        )

colpali_engineを使用して、2つのクエリに対する埋め込みリストを抽出し、PDFページから関連する情報を取得します。

from colpali_engine.models import ColPali
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
from torch.utils.data import DataLoader
import torch
from typing import List, cast

device = get_torch_device("cpu")
model_name = "vidore/colpali-v1.2"

model = ColPali.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map=device,
).eval()

queries = [
    "How to end-to-end retrieval with ColBert?",
    "Where is ColBERT performance table?",
]

processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))

dataloader = DataLoader(
    dataset=ListDataset[str](queries),
    batch_size=1,
    shuffle=False,
    collate_fn=lambda x: processor.process_queries(x),
)

qs: List[torch.Tensor] = []
for batch_query in dataloader:
    with torch.no_grad():
        batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
        embeddings_query = model(**batch_query)
    qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))

さらに、各ページの埋め込みリストを抽出する必要があり、それは各ページに1030個の128次元埋め込みがあることを示しています。

from tqdm import tqdm
from PIL import Image
import os

images = [Image.open("./pages/" + name) for name in os.listdir("./pages")]

dataloader = DataLoader(
    dataset=ListDataset[str](images),
    batch_size=1,
    shuffle=False,
    collate_fn=lambda x: processor.process_images(x),
)

ds: List[torch.Tensor] = []
for batch_doc in tqdm(dataloader):
    with torch.no_grad():
        batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
        embeddings_doc = model(**batch_doc)
    ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))

print(ds[0].shape)
  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [01:22<00:00,  8.24s/it]

torch.Size([1030, 128])

MilvusColbertRetrieverを使って "colpali "というコレクションを作成します。

retriever = MilvusColbertRetriever(collection_name="colpali", milvus_client=client)
retriever.create_collection()
retriever.create_index()

埋め込みリストをMilvusデータベースに挿入します。

filepaths = ["./pages/" + name for name in os.listdir("./pages")]
for i in range(len(filepaths)):
    data = {
        "colbert_vecs": ds[i].float().numpy(),
        "doc_id": i,
        "filepath": filepaths[i],
    }
    retriever.insert(data)

これで、クエリ埋め込みリストを使って、最も関連性の高いページを検索することができる。

for query in qs:
    query = query.float().numpy()
    result = retriever.search(query, topk=1)
    print(filepaths[result[0][1]])
./pages/page_5.png
./pages/page_7.png

最後に、元のページ名を取得します。ColPaliを使えば、文書からテキストや画像を抽出するための複雑な処理技術を必要とせずに、マルチモーダルな文書を検索することができる。大規模な視覚モデルを活用することで、表や図など、より多くの情報を大きな情報損失なしに解析することができる。

翻訳DeepL

Try Managed Milvus for Free

Zilliz Cloud is hassle-free, powered by Milvus and 10x faster.

Get Started
フィードバック

このページは役に立ちましたか ?