milvus-logo
LFAI
フロントページへ
  • チュートリアル

マトリョーシカ埋め込みによる漏斗探索

効率的なベクトル検索システムを構築する際の重要な課題の1つは、許容可能なレイテンシとリコールを維持しながらストレージコストを管理することです。最新の埋め込みモデルは数百から数千次元のベクトルを出力するため、生のベクトルとインデックスに多大なストレージと計算オーバーヘッドが発生します。

従来は、インデックスを構築する直前に量子化や次元削減を行うことで、ストレージの容量を削減していました。例えば、積量子化(PQ)を使って精度を下げたり、主成分分析(PCA)を使って次元数を下げることで、ストレージを節約することができます。これらの方法はベクトル集合全体を分析し、ベクトル間の意味的関係を維持したまま、よりコンパクトなものを見つける。

効果的ではあるが、これらの標準的なアプローチは精度や次元数を一度だけ、しかも単一のスケールで削減する。しかし、複数の詳細なレイヤーを同時に維持し、ピラミッドのように精度を高めていくことができるとしたらどうだろう?

マトリョーシカ埋め込みが登場する。ロシアの入れ子人形にちなんで名付けられたこの巧妙な構造は(図を参照)、1つのベクトル内に複数のスケールの表現を埋め込む。従来の後処理手法とは異なり、マトリョーシカ埋め込みは最初の学習過程でこのマルチスケール構造を学習する。その結果は驚くべきもので、完全な埋め込みが入力セマンティクスを捉えるだけでなく、入れ子になった各サブセットの接頭辞(前半、4分の1など)が、詳細ではないものの、首尾一貫した表現を提供します。

このノートブックでは、Milvusを使ったマトリョーシカ埋め込みを意味検索に使う方法を検討する。ファネル検索」と呼ばれるアルゴリズムにより、埋め込み次元の小さなサブセットで類似検索を行うことができます。

import functools

from datasets import load_dataset
import numpy as np
import pandas as pd
import pymilvus
from pymilvus import MilvusClient
from pymilvus import FieldSchema, CollectionSchema, DataType
from sentence_transformers import SentenceTransformer
import torch
import torch.nn.functional as F
from tqdm import tqdm

マトリョーシカ埋め込みモデルのロード

のような標準的な埋め込みモデルを使う代わりに、Nomic社のモデルを使います。 sentence-transformers/all-MiniLM-L12-v2のような標準的な埋め込みモデルを使う代わりに,マトリョーシカ埋め込みを生成するために特別に訓練されたNomicのモデルを使います.

model = SentenceTransformer(
    # Remove 'device='mps' if running on non-Mac device
    "nomic-ai/nomic-embed-text-v1.5",
    trust_remote_code=True,
    device="mps",
)
<All keys matched successfully>

データセットの読み込み、項目の埋め込み、ベクトルデータベースの構築

以下のコードは、ドキュメントページ"Movie Search with Sentence Transformers and Milvus "のコードを改変したものですまず、HuggingFaceからデータセットをロードする。このデータセットには約35kのエントリーが含まれており、それぞれがウィキペディアの記事を持つ映画に対応している。この例では、TitlePlotSummary フィールドを使用する。

ds = load_dataset("vishnupriyavr/wiki-movie-plots-with-summaries", split="train")
print(ds)
Dataset({
    features: ['Release Year', 'Title', 'Origin/Ethnicity', 'Director', 'Cast', 'Genre', 'Wiki Page', 'Plot', 'PlotSummary'],
    num_rows: 34886
})

次に、milvus Liteデータベースに接続し、データスキーマを指定し、このスキーマでコレクションを作成する。非正規化埋め込みと埋め込みの最初の6番目は別々のフィールドに格納する。その理由は、マトリョーシカ埋め込みの最初の1/6は類似検索のために必要であり、残りの5/6は再ランク付けと検索結果の改善のために必要だからです。

embedding_dim = 768
search_dim = 128
collection_name = "movie_embeddings"

client = MilvusClient(uri="./wiki-movie-plots-matryoshka.db")

fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=256),
    # First sixth of unnormalized embedding vector
    FieldSchema(name="head_embedding", dtype=DataType.FLOAT_VECTOR, dim=search_dim),
    # Entire unnormalized embedding vector
    FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=embedding_dim),
]

schema = CollectionSchema(fields=fields, enable_dynamic_field=False)
client.create_collection(collection_name=collection_name, schema=schema)

Milvusは現在、埋め込みデータの部分集合を検索することをサポートしていないため、埋め込みデータを2つの部分に分割する:頭部はインデックスと検索を行うベクトルの初期部分集合を表し、尾部は残りの部分である。このモデルは余弦距離類似検索のために学習されたものなので、head embeddingsを正規化します。しかし、後でより大きな部分集合の類似度を計算するために、先頭の埋込みのノルムを保存する必要があります。

埋込みの最初の1/6を検索するためには、head_embedding フィールド上のベクトル検索インデックスを作成する必要があります。後ほど、「ファネル検索」と通常のベクトル検索の結果を比較しますので、完全な埋め込みに対する検索インデックスも作成します。

重要なのは、IP の距離尺度ではなく、COSINE の距離尺度を使うことです。そうしないと、埋め込みノルムを追跡する必要があり、実装が複雑になるからです(これは、ファネル検索のアルゴリズムが説明されれば、より理解できるようになるでしょう)。

index_params = client.prepare_index_params()
index_params.add_index(
    field_name="head_embedding", index_type="FLAT", metric_type="COSINE"
)
index_params.add_index(field_name="embedding", index_type="FLAT", metric_type="COSINE")
client.create_index(collection_name, index_params)

最後に、全35k映画のプロット要約をエンコードし、対応する埋め込みをデータベースに入力する。

for batch in tqdm(ds.batch(batch_size=512)):
    # This particular model requires us to prefix 'search_document:' to stored entities
    plot_summary = ["search_document: " + x.strip() for x in batch["PlotSummary"]]

    # Output of embedding model is unnormalized
    embeddings = model.encode(plot_summary, convert_to_tensor=True)
    head_embeddings = embeddings[:, :search_dim]

    data = [
        {
            "title": title,
            "head_embedding": head.cpu().numpy(),
            "embedding": embedding.cpu().numpy(),
        }
        for title, head, embedding in zip(batch["Title"], head_embeddings, embeddings)
    ]
    res = client.insert(collection_name=collection_name, data=data)
100%|██████████| 69/69 [05:57<00:00,  5.18s/it]

それでは、マトリョーシカ埋め込み次元の最初の1/6を使って "漏斗検索 "を実行してみましょう。検索用に3つの映画を考えており、データベースへのクエリ用に私自身のプロット要約を作成した。クエリを埋め込み、head_embedding フィールドでベクトル検索を行い、128の結果候補を取り出す。

queries = [
    "An archaeologist searches for ancient artifacts while fighting Nazis.",
    "A teenager fakes illness to get off school and have adventures with two friends.",
    "A young couple with a kid look after a hotel during winter and the husband goes insane.",
]


# Search the database based on input text
def embed_search(data):
    embeds = model.encode(data)
    return [x for x in embeds]


# This particular model requires us to prefix 'search_query:' to queries
instruct_queries = ["search_query: " + q.strip() for q in queries]
search_data = embed_search(instruct_queries)

# Normalize head embeddings
head_search = [x[:search_dim] for x in search_data]

# Perform standard vector search on first sixth of embedding dimensions
res = client.search(
    collection_name=collection_name,
    data=head_search,
    anns_field="head_embedding",
    limit=128,
    output_fields=["title", "head_embedding", "embedding"],
)

この時点で、より小さなベクトル空間に対して検索を実行したので、全空間に対して検索を実行するよりも、待ち時間が短縮され、インデックスのストレージ要件も軽減されている可能性が高い。各クエリの上位5件を調べてみましょう:

for query, hits in zip(queries, res):
    rows = [x["entity"] for x in hits][:5]

    print("Query:", query)
    print("Results:")
    for row in rows:
        print(row["title"].strip())
    print()
Query: An archaeologist searches for ancient artifacts while fighting Nazis.
Results:
"Pimpernel" Smith
Black Hunters
The Passage
Counterblast
Dominion: Prequel to the Exorcist

Query: A teenager fakes illness to get off school and have adventures with two friends.
Results:
How to Deal
Shorts
Blackbird
Valentine
Unfriended

Query: A young couple with a kid look after a hotel during winter and the husband goes insane.
Results:
Ghostkeeper
Our Vines Have Tender Grapes
The Ref
Impact
The House in Marsh Road

見てわかるように、検索中に埋め込みが切り捨てられた結果、リコールが低下しています。ファネル検索は、この問題を巧妙なトリックで解決します。埋め込み次元の残りを使って候補リストを再ランク付けし、プルーニングすることで、高価なベクトル検索を追加で実行することなく、検索パフォーマンスを回復することができます。

ファネル検索アルゴリズムの説明を簡単にするために、各クエリのMilvus検索ヒット数をPandasデータフレームに変換します。

def hits_to_dataframe(hits: pymilvus.client.abstract.Hits) -> pd.DataFrame:
    """
    Convert a Milvus search result to a Pandas dataframe. This function is specific to our data schema.

    """
    rows = [x["entity"] for x in hits]
    rows_dict = [
        {"title": x["title"], "embedding": torch.tensor(x["embedding"])} for x in rows
    ]
    return pd.DataFrame.from_records(rows_dict)


dfs = [hits_to_dataframe(hits) for hits in res]

さて、ファネル検索を実行するために、我々は埋込みのますます大きなサブセットに対して反復します。各反復において、我々は新しい類似度に従って候補を再ランク付けし、最下位にランク付けされた候補の一部を削除する。

これを具体的にするために、前のステップでは、埋め込みとクエリの次元の1/6を使って128の候補を検索した。ファネル検索の最初のステップは、最初の1/3の次元を用いてクエリと候補の類似度を再計算することである。下位64個の候補は刈り込まれる。次に最初の2/3の次元、そして全ての次元でこのプロセスを繰り返し、32と16の候補に順次刈り込んでいく。

# An optimized implementation would vectorize the calculation of similarity scores across rows (using a matrix)
def calculate_score(row, query_emb=None, dims=768):
    emb = F.normalize(row["embedding"][:dims], dim=-1)
    return (emb @ query_emb).item()


# You could also add a top-K parameter as a termination condition
def funnel_search(
    df: pd.DataFrame, query_emb, scales=[256, 512, 768], prune_ratio=0.5
) -> pd.DataFrame:
    # Loop over increasing prefixes of the embeddings
    for dims in scales:
        # Query vector must be normalized for each new dimensionality
        emb = torch.tensor(query_emb[:dims] / np.linalg.norm(query_emb[:dims]))

        # Score
        scores = df.apply(
            functools.partial(calculate_score, query_emb=emb, dims=dims), axis=1
        )
        df["scores"] = scores

        # Re-rank
        df = df.sort_values(by="scores", ascending=False)

        # Prune (in our case, remove half of candidates at each step)
        df = df.head(int(prune_ratio * len(df)))

    return df


dfs_results = [
    {"query": query, "results": funnel_search(df, query_emb)}
    for query, df, query_emb in zip(queries, dfs, search_data)
]
for d in dfs_results:
    print(d["query"], "\n", d["results"][:5]["title"], "\n")
An archaeologist searches for ancient artifacts while fighting Nazis. 
 0           "Pimpernel" Smith
1               Black Hunters
29    Raiders of the Lost Ark
34             The Master Key
51            My Gun Is Quick
Name: title, dtype: object 

A teenager fakes illness to get off school and have adventures with two friends. 
 21               How I Live Now
32     On the Edge of Innocence
77             Bratz: The Movie
4                    Unfriended
108                  Simon Says
Name: title, dtype: object 

A young couple with a kid look after a hotel during winter and the husband goes insane. 
 9         The Shining
0         Ghostkeeper
11     Fast and Loose
7      Killing Ground
12         Home Alone
Name: title, dtype: object 

追加のベクトル検索を行うことなく、想起を回復させることができた!定性的には、これらの結果は "Raiders of the Lost Ark "と "The Shining "については、チュートリアルの"MilvusとSentence Transformersを使った映画検索 "の標準的なベクトル検索よりも高い想起率を示しているようです。しかし、このチュートリアルで後ほど紹介する "Ferris Bueller's Day Off "を見つけることはできない。(より定量的な実験とベンチマークについては論文Matryoshka Representation Learningを参照)

ファネル検索の結果を、同じ埋め込みモデルの同じデータセットに対する標準的なベクトル検索と比較してみましょう。完全な埋め込みに対して検索を行います。

# Search on entire embeddings
res = client.search(
    collection_name=collection_name,
    data=search_data,
    anns_field="embedding",
    limit=5,
    output_fields=["title", "embedding"],
)
for query, hits in zip(queries, res):
    rows = [x["entity"] for x in hits]

    print("Query:", query)
    print("Results:")
    for row in rows:
        print(row["title"].strip())
    print()
Query: An archaeologist searches for ancient artifacts while fighting Nazis.
Results:
"Pimpernel" Smith
Black Hunters
Raiders of the Lost Ark
The Master Key
My Gun Is Quick

Query: A teenager fakes illness to get off school and have adventures with two friends.
Results:
A Walk to Remember
Ferris Bueller's Day Off
How I Live Now
On the Edge of Innocence
Bratz: The Movie

Query: A young couple with a kid look after a hotel during winter and the husband goes insane.
Results:
The Shining
Ghostkeeper
Fast and Loose
Killing Ground
Home Alone

A teenager fakes illness to get off school... "の結果を除いて、ファネル検索の結果は完全埋め込み検索とほとんど同じである。

Ferris Bueller's Day Offのファネルサーチリコール失敗の調査

なぜファネルサーチはFerris Bueller's Day Offの検索に成功しなかったのだろうか?元の候補リストにあったのか、間違ってフィルタリングされたのかを調べてみましょう。

queries2 = [
    "A teenager fakes illness to get off school and have adventures with two friends."
]


# Search the database based on input text
def embed_search(data):
    embeds = model.encode(data)
    return [x for x in embeds]


instruct_queries = ["search_query: " + q.strip() for q in queries2]
search_data2 = embed_search(instruct_queries)
head_search2 = [x[:search_dim] for x in search_data2]

# Perform standard vector search on subset of embeddings
res = client.search(
    collection_name=collection_name,
    data=head_search2,
    anns_field="head_embedding",
    limit=256,
    output_fields=["title", "head_embedding", "embedding"],
)
for query, hits in zip(queries, res):
    rows = [x["entity"] for x in hits]

    print("Query:", queries2[0])
    for idx, row in enumerate(rows):
        if row["title"].strip() == "Ferris Bueller's Day Off":
            print(f"Row {idx}: Ferris Bueller's Day Off")
Query: A teenager fakes illness to get off school and have adventures with two friends.
Row 228: Ferris Bueller's Day Off

最初の候補リストが十分な大きさでなかったこと、つまり、目的のヒットが、最高レベルの粒度でクエリと十分類似していなかったことが問題であったことがわかります。128 から256 に変更すると、検索に成功する。リコールと待ち時間のトレードオフを経験的に評価するために、保留セットの候補数を設定する経験則を形成すべきである。

dfs = [hits_to_dataframe(hits) for hits in res]

dfs_results = [
    {"query": query, "results": funnel_search(df, query_emb)}
    for query, df, query_emb in zip(queries2, dfs, search_data2)
]

for d in dfs_results:
    print(d["query"], "\n", d["results"][:7]["title"].to_string(index=False), "\n")
A teenager fakes illness to get off school and have adventures with two friends. 
       A Walk to Remember
Ferris Bueller's Day Off
          How I Live Now
On the Edge of Innocence
        Bratz: The Movie
              Unfriended
              Simon Says 

順序は重要か?接頭辞埋め込みと接尾辞埋め込み。

再帰的に小さい接頭辞の埋め込みにうまくマッチングするように学習されたモデル。使用する次元の順番は重要でしょうか?例えば、埋め込み要素の接尾辞の部分集合を取ることもできるでしょうか?この実験では、マトリョーシカ埋め込みにおける次元の順序を逆にし、漏斗探索を行う。

client = MilvusClient(uri="./wikiplots-matryoshka-flipped.db")

fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=256),
    FieldSchema(name="head_embedding", dtype=DataType.FLOAT_VECTOR, dim=search_dim),
    FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=embedding_dim),
]

schema = CollectionSchema(fields=fields, enable_dynamic_field=False)
client.create_collection(collection_name=collection_name, schema=schema)

index_params = client.prepare_index_params()
index_params.add_index(
    field_name="head_embedding", index_type="FLAT", metric_type="COSINE"
)
client.create_index(collection_name, index_params)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
    - Avoid using `tokenizers` before the fork if possible
    - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
for batch in tqdm(ds.batch(batch_size=512)):
    plot_summary = ["search_document: " + x.strip() for x in batch["PlotSummary"]]

    # Encode and flip embeddings
    embeddings = model.encode(plot_summary, convert_to_tensor=True)
    embeddings = torch.flip(embeddings, dims=[-1])
    head_embeddings = embeddings[:, :search_dim]

    data = [
        {
            "title": title,
            "head_embedding": head.cpu().numpy(),
            "embedding": embedding.cpu().numpy(),
        }
        for title, head, embedding in zip(batch["Title"], head_embeddings, embeddings)
    ]
    res = client.insert(collection_name=collection_name, data=data)
100%|██████████| 69/69 [05:50<00:00,  5.08s/it]
# Normalize head embeddings

flip_search_data = [
    torch.flip(torch.tensor(x), dims=[-1]).cpu().numpy() for x in search_data
]
flip_head_search = [x[:search_dim] for x in flip_search_data]

# Perform standard vector search on subset of embeddings
res = client.search(
    collection_name=collection_name,
    data=flip_head_search,
    anns_field="head_embedding",
    limit=128,
    output_fields=["title", "head_embedding", "embedding"],
)
dfs = [hits_to_dataframe(hits) for hits in res]

dfs_results = [
    {"query": query, "results": funnel_search(df, query_emb)}
    for query, df, query_emb in zip(queries, dfs, flip_search_data)
]

for d in dfs_results:
    print(
        d["query"],
        "\n",
        d["results"][:7]["title"].to_string(index=False, header=False),
        "\n",
    )
An archaeologist searches for ancient artifacts while fighting Nazis. 
       "Pimpernel" Smith
          Black Hunters
Raiders of the Lost Ark
         The Master Key
        My Gun Is Quick
            The Passage
        The Mole People 

A teenager fakes illness to get off school and have adventures with two friends. 
                       A Walk to Remember
                          How I Live Now
                              Unfriended
Cirque du Freak: The Vampire's Assistant
                             Last Summer
                                 Contest
                                 Day One 

A young couple with a kid look after a hotel during winter and the husband goes insane. 
         Ghostkeeper
     Killing Ground
Leopard in the Snow
              Stone
          Afterglow
         Unfaithful
     Always a Bride 

リコールは、予想通り、ファネル検索や通常の検索よりもはるかに低い(埋め込みモデルは、埋め込み次元の接頭辞ではなく、接頭辞の対比学習によって学習された)。

まとめ

以下は、メソッド間の検索結果の比較である:

我々はMilvusとマトリョーシカ埋め込みを用いて、"漏斗探索 "と呼ばれるより効率的な意味検索アルゴリズムを実行する方法を示した。また、アルゴリズムの再ランク付けと枝刈りステップの重要性と、初期候補リストが小さすぎる場合の失敗モードについても検討した。最後に、サブエンベッディングを形成する際に、次元の順序がいかに重要であるかを議論した。というか、モデルが特定の方法で学習されたからこそ、エンベッディングの接頭辞が意味を持つのです。これで、検索性能をあまり犠牲にすることなく、意味検索のストレージコストを削減するために、マトリョーシカ埋め込みとファネル検索を実装する方法がわかりました!