milvus-logo
LFAI
Home
  • Anleitungen

Trichtersuche mit Matryoshka-Embeddings

Beim Aufbau effizienter Vektorsuchsysteme besteht eine der größten Herausforderungen darin, die Speicherkosten zu bewältigen und gleichzeitig eine akzeptable Latenzzeit und Wiederauffindbarkeit zu gewährleisten. Moderne Einbettungsmodelle geben Vektoren mit Hunderten oder Tausenden von Dimensionen aus, was zu einem erheblichen Speicher- und Rechenaufwand für den Rohvektor und den Index führt.

Traditionell wird der Speicherbedarf durch die Anwendung einer Quantisierungs- oder Dimensionalitätsreduzierungsmethode unmittelbar vor dem Aufbau des Index reduziert. Wir können beispielsweise Speicherplatz einsparen, indem wir die Genauigkeit mithilfe der Produktquantisierung (PQ) oder die Anzahl der Dimensionen mithilfe der Hauptkomponentenanalyse (PCA) verringern. Diese Methoden analysieren die gesamte Vektormenge, um eine kompaktere Menge zu finden, die die semantischen Beziehungen zwischen den Vektoren beibehält.

Diese Standardansätze sind zwar effektiv, aber sie reduzieren die Präzision oder Dimensionalität nur einmal und auf einer einzigen Ebene. Aber was wäre, wenn wir mehrere Detailschichten gleichzeitig beibehalten könnten, wie eine Pyramide von immer präziseren Darstellungen?

Das ist die Matrjoschka-Einbettung. Diese cleveren Konstrukte, die nach den russischen Schachtelpuppen benannt sind (siehe Abbildung), betten mehrere Darstellungsebenen in einen einzigen Vektor ein. Im Gegensatz zu herkömmlichen Nachbearbeitungsmethoden erlernen Matryoshka-Einbettungen diese Multiskalenstruktur während des anfänglichen Trainingsprozesses. Das Ergebnis ist bemerkenswert: Die vollständige Einbettung erfasst nicht nur die Semantik der Eingabe, sondern jedes verschachtelte Untergruppenpräfix (erste Hälfte, erstes Viertel usw.) liefert eine kohärente, wenn auch weniger detaillierte Darstellung.

In diesem Notizbuch untersuchen wir, wie man Matryoshka-Einbettungen mit Milvus für die semantische Suche verwenden kann. Wir veranschaulichen einen Algorithmus namens "Trichtersuche", der es uns ermöglicht, eine Ähnlichkeitssuche über eine kleine Teilmenge unserer Einbettungsdimensionen durchzuführen, ohne dass es zu einem drastischen Abfall der Wiedererkennbarkeit kommt.

import functools

from datasets import load_dataset
import numpy as np
import pandas as pd
import pymilvus
from pymilvus import MilvusClient
from pymilvus import FieldSchema, CollectionSchema, DataType
from sentence_transformers import SentenceTransformer
import torch
import torch.nn.functional as F
from tqdm import tqdm

Matrjoschka-Einbettungsmodell laden

Anstatt ein Standard-Einbettungsmodell wie sentence-transformers/all-MiniLM-L12-v2verwenden wir ein Modell von Nomic, das speziell für die Erstellung von Matryoshka-Einbettungen trainiert wurde.

model = SentenceTransformer(
    # Remove 'device='mps' if running on non-Mac device
    "nomic-ai/nomic-embed-text-v1.5",
    trust_remote_code=True,
    device="mps",
)
<All keys matched successfully>

Laden des Datensatzes, Einbetten der Elemente und Aufbau der Vektordatenbank

Der folgende Code ist eine Abwandlung des Codes aus der Dokumentationsseite "Movie Search with Sentence Transformers and Milvus". Zuerst laden wir den Datensatz von HuggingFace. Er enthält etwa 35k Einträge, die jeweils einem Film mit einem Wikipedia-Artikel entsprechen. Wir werden in diesem Beispiel die Felder Title und PlotSummary verwenden.

ds = load_dataset("vishnupriyavr/wiki-movie-plots-with-summaries", split="train")
print(ds)
Dataset({
    features: ['Release Year', 'Title', 'Origin/Ethnicity', 'Director', 'Cast', 'Genre', 'Wiki Page', 'Plot', 'PlotSummary'],
    num_rows: 34886
})

Als nächstes stellen wir eine Verbindung zu einer Milvus-Lite-Datenbank her, geben das Datenschema an und erstellen eine Sammlung mit diesem Schema. Wir werden sowohl die nicht normalisierte Einbettung als auch das erste Sechstel der Einbettung in separaten Feldern speichern. Der Grund dafür ist, dass wir das erste Sechstel des Matrjoschka-Embeddings für die Durchführung einer Ähnlichkeitssuche benötigen und die restlichen 5 Sechstel des Embeddings für das Reranking und die Verbesserung der Suchergebnisse.

embedding_dim = 768
search_dim = 128
collection_name = "movie_embeddings"

client = MilvusClient(uri="./wiki-movie-plots-matryoshka.db")

fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=256),
    # First sixth of unnormalized embedding vector
    FieldSchema(name="head_embedding", dtype=DataType.FLOAT_VECTOR, dim=search_dim),
    # Entire unnormalized embedding vector
    FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=embedding_dim),
]

schema = CollectionSchema(fields=fields, enable_dynamic_field=False)
client.create_collection(collection_name=collection_name, schema=schema)

Milvus unterstützt derzeit nicht die Suche über Teilmengen von Einbettungen, daher unterteilen wir die Einbettungen in zwei Teile: Der Kopf stellt die anfängliche Teilmenge des Vektors dar, die indiziert und durchsucht werden soll, und der Schwanz ist der Rest. Das Modell wurde für die Ähnlichkeitssuche über die Kosinusdistanz trainiert, daher normalisieren wir die Kopfeinbettungen. Um jedoch später Ähnlichkeiten für größere Teilmengen berechnen zu können, müssen wir die Norm der Kopfeinbettung speichern, damit wir sie vor dem Zusammenfügen mit dem Schwanz unnormalisieren können.

Um die Suche über das erste 1/6 der Einbettung durchzuführen, müssen wir einen Vektorsuchindex über das Feld head_embedding erstellen. Später werden wir die Ergebnisse der "Trichtersuche" mit einer normalen Vektorsuche vergleichen und daher auch einen Suchindex über die gesamte Einbettung erstellen.

Wichtig ist, dass wir die Distanzmetrik COSINE und nicht IP verwenden, da wir sonst die Einbettungsnormen im Auge behalten müssten, was die Implementierung verkomplizieren würde (dies wird sinnvoller, wenn der Trichtersuchalgorithmus beschrieben wurde).

index_params = client.prepare_index_params()
index_params.add_index(
    field_name="head_embedding", index_type="FLAT", metric_type="COSINE"
)
index_params.add_index(field_name="embedding", index_type="FLAT", metric_type="COSINE")
client.create_index(collection_name, index_params)

Schließlich kodieren wir die Handlungszusammenfassungen für alle 35k Filme und geben die entsprechenden Einbettungen in die Datenbank ein.

for batch in tqdm(ds.batch(batch_size=512)):
    # This particular model requires us to prefix 'search_document:' to stored entities
    plot_summary = ["search_document: " + x.strip() for x in batch["PlotSummary"]]

    # Output of embedding model is unnormalized
    embeddings = model.encode(plot_summary, convert_to_tensor=True)
    head_embeddings = embeddings[:, :search_dim]

    data = [
        {
            "title": title,
            "head_embedding": head.cpu().numpy(),
            "embedding": embedding.cpu().numpy(),
        }
        for title, head, embedding in zip(batch["Title"], head_embeddings, embeddings)
    ]
    res = client.insert(collection_name=collection_name, data=data)
100%|██████████| 69/69 [05:57<00:00,  5.18s/it]

Nun wollen wir eine "Trichtersuche" mit den ersten 1/6 der Matryoshka-Einbettungsdimensionen durchführen. Ich habe drei Filme im Sinn, die abgerufen werden sollen, und habe meine eigene Zusammenfassung für die Abfrage der Datenbank erstellt. Wir betten die Abfragen ein und führen dann eine Vektorsuche auf dem Feld head_embedding durch, wobei wir 128 Ergebniskandidaten erhalten.

queries = [
    "An archaeologist searches for ancient artifacts while fighting Nazis.",
    "A teenager fakes illness to get off school and have adventures with two friends.",
    "A young couple with a kid look after a hotel during winter and the husband goes insane.",
]


# Search the database based on input text
def embed_search(data):
    embeds = model.encode(data)
    return [x for x in embeds]


# This particular model requires us to prefix 'search_query:' to queries
instruct_queries = ["search_query: " + q.strip() for q in queries]
search_data = embed_search(instruct_queries)

# Normalize head embeddings
head_search = [x[:search_dim] for x in search_data]

# Perform standard vector search on first sixth of embedding dimensions
res = client.search(
    collection_name=collection_name,
    data=head_search,
    anns_field="head_embedding",
    limit=128,
    output_fields=["title", "head_embedding", "embedding"],
)

An diesem Punkt haben wir eine Suche über einen viel kleineren Vektorraum durchgeführt und haben daher wahrscheinlich eine geringere Latenz und geringere Speicheranforderungen für den Index im Vergleich zu einer Suche über den gesamten Raum. Untersuchen wir nun die ersten 5 Treffer für jede Abfrage:

for query, hits in zip(queries, res):
    rows = [x["entity"] for x in hits][:5]

    print("Query:", query)
    print("Results:")
    for row in rows:
        print(row["title"].strip())
    print()
Query: An archaeologist searches for ancient artifacts while fighting Nazis.
Results:
"Pimpernel" Smith
Black Hunters
The Passage
Counterblast
Dominion: Prequel to the Exorcist

Query: A teenager fakes illness to get off school and have adventures with two friends.
Results:
How to Deal
Shorts
Blackbird
Valentine
Unfriended

Query: A young couple with a kid look after a hotel during winter and the husband goes insane.
Results:
Ghostkeeper
Our Vines Have Tender Grapes
The Ref
Impact
The House in Marsh Road

Wie wir sehen können, hat die Wiederauffindbarkeit als Folge des Abschneidens der Einbettungen während der Suche gelitten. Die Trichtersuche behebt dieses Problem mit einem cleveren Trick: Wir können die verbleibenden Einbettungsdimensionen verwenden, um unsere Kandidatenliste neu zu ordnen und zu beschneiden, um die Abrufleistung wiederherzustellen, ohne zusätzliche teure Vektorsuchen durchzuführen.

Um die Darstellung des Trichtersuchalgorithmus zu vereinfachen, konvertieren wir die Milvus-Suchtreffer für jede Anfrage in einen Pandas-Datenrahmen.

def hits_to_dataframe(hits: pymilvus.client.abstract.Hits) -> pd.DataFrame:
    """
    Convert a Milvus search result to a Pandas dataframe. This function is specific to our data schema.

    """
    rows = [x["entity"] for x in hits]
    rows_dict = [
        {"title": x["title"], "embedding": torch.tensor(x["embedding"])} for x in rows
    ]
    return pd.DataFrame.from_records(rows_dict)


dfs = [hits_to_dataframe(hits) for hits in res]

Um die Trichtersuche durchzuführen, iterieren wir nun über die immer größer werdenden Teilmengen der Einbettungen. Bei jeder Iteration ordnen wir die Kandidaten entsprechend der neuen Ähnlichkeiten neu ein und streichen einen Teil der am niedrigsten eingestuften Kandidaten.

Um dies zu verdeutlichen, haben wir im vorigen Schritt 128 Kandidaten mit 1/6 der Einbettungs- und Abfragedimensionen gefunden. Der erste Schritt bei der Durchführung der Trichtersuche besteht darin, die Ähnlichkeiten zwischen den Suchanfragen und den Kandidaten unter Verwendung des ersten Drittels der Dimensionen neu zu berechnen. Die untersten 64 Kandidaten werden aussortiert. Dann wiederholen wir diesen Vorgang mit den ersten 2/3 der Dimensionen und dann mit allen Dimensionen, wobei wir nacheinander 32 und 16 Kandidaten aussondern.

# An optimized implementation would vectorize the calculation of similarity scores across rows (using a matrix)
def calculate_score(row, query_emb=None, dims=768):
    emb = F.normalize(row["embedding"][:dims], dim=-1)
    return (emb @ query_emb).item()


# You could also add a top-K parameter as a termination condition
def funnel_search(
    df: pd.DataFrame, query_emb, scales=[256, 512, 768], prune_ratio=0.5
) -> pd.DataFrame:
    # Loop over increasing prefixes of the embeddings
    for dims in scales:
        # Query vector must be normalized for each new dimensionality
        emb = torch.tensor(query_emb[:dims] / np.linalg.norm(query_emb[:dims]))

        # Score
        scores = df.apply(
            functools.partial(calculate_score, query_emb=emb, dims=dims), axis=1
        )
        df["scores"] = scores

        # Re-rank
        df = df.sort_values(by="scores", ascending=False)

        # Prune (in our case, remove half of candidates at each step)
        df = df.head(int(prune_ratio * len(df)))

    return df


dfs_results = [
    {"query": query, "results": funnel_search(df, query_emb)}
    for query, df, query_emb in zip(queries, dfs, search_data)
]
for d in dfs_results:
    print(d["query"], "\n", d["results"][:5]["title"], "\n")
An archaeologist searches for ancient artifacts while fighting Nazis. 
 0           "Pimpernel" Smith
1               Black Hunters
29    Raiders of the Lost Ark
34             The Master Key
51            My Gun Is Quick
Name: title, dtype: object 

A teenager fakes illness to get off school and have adventures with two friends. 
 21               How I Live Now
32     On the Edge of Innocence
77             Bratz: The Movie
4                    Unfriended
108                  Simon Says
Name: title, dtype: object 

A young couple with a kid look after a hotel during winter and the husband goes insane. 
 9         The Shining
0         Ghostkeeper
11     Fast and Loose
7      Killing Ground
12         Home Alone
Name: title, dtype: object 

Es ist uns gelungen, die Trefferquote ohne zusätzliche Vektorsuche wiederherzustellen! Qualitativ scheinen diese Ergebnisse für "Raiders of the Lost Ark" und "The Shining" eine höhere Trefferquote zu haben als die Standard-Vektorsuche aus dem Tutorial "Filmsuche mit Milvus und Satztransformatoren", die ein anderes Einbettungsmodell verwendet. Allerdings ist es nicht in der Lage, "Ferris Bueller's Day Off" zu finden, auf den wir später im Notizbuch zurückkommen werden. (Weitere quantitative Experimente und Benchmarking finden Sie im Dokument Matryoshka Representation Learning ).

Vergleichen wir die Ergebnisse unserer Trichtersuche mit einer normalen Vektorsuche auf demselben Datensatz mit demselben Einbettungsmodell. Wir führen eine Suche über die vollständigen Einbettungen durch.

# Search on entire embeddings
res = client.search(
    collection_name=collection_name,
    data=search_data,
    anns_field="embedding",
    limit=5,
    output_fields=["title", "embedding"],
)
for query, hits in zip(queries, res):
    rows = [x["entity"] for x in hits]

    print("Query:", query)
    print("Results:")
    for row in rows:
        print(row["title"].strip())
    print()
Query: An archaeologist searches for ancient artifacts while fighting Nazis.
Results:
"Pimpernel" Smith
Black Hunters
Raiders of the Lost Ark
The Master Key
My Gun Is Quick

Query: A teenager fakes illness to get off school and have adventures with two friends.
Results:
A Walk to Remember
Ferris Bueller's Day Off
How I Live Now
On the Edge of Innocence
Bratz: The Movie

Query: A young couple with a kid look after a hotel during winter and the husband goes insane.
Results:
The Shining
Ghostkeeper
Fast and Loose
Killing Ground
Home Alone

Mit Ausnahme der Ergebnisse für "Ein Jugendlicher täuscht eine Krankheit vor, um die Schule zu schwänzen..." sind die Ergebnisse der Trichtersuche fast identisch mit denen der Vollsuche, obwohl die Trichtersuche auf einem Suchraum von 128 Dimensionen gegenüber 768 Dimensionen für die reguläre Suche durchgeführt wurde.

Untersuchung des Misserfolgs der Trichtersuche bei Ferris Bueller's Day Off

Warum gelang es der Trichtersuche nicht, "Ferris Bueller's Day Off" abzurufen? Untersuchen wir, ob er in der ursprünglichen Kandidatenliste enthalten war oder irrtümlich herausgefiltert wurde.

queries2 = [
    "A teenager fakes illness to get off school and have adventures with two friends."
]


# Search the database based on input text
def embed_search(data):
    embeds = model.encode(data)
    return [x for x in embeds]


instruct_queries = ["search_query: " + q.strip() for q in queries2]
search_data2 = embed_search(instruct_queries)
head_search2 = [x[:search_dim] for x in search_data2]

# Perform standard vector search on subset of embeddings
res = client.search(
    collection_name=collection_name,
    data=head_search2,
    anns_field="head_embedding",
    limit=256,
    output_fields=["title", "head_embedding", "embedding"],
)
for query, hits in zip(queries, res):
    rows = [x["entity"] for x in hits]

    print("Query:", queries2[0])
    for idx, row in enumerate(rows):
        if row["title"].strip() == "Ferris Bueller's Day Off":
            print(f"Row {idx}: Ferris Bueller's Day Off")
Query: A teenager fakes illness to get off school and have adventures with two friends.
Row 228: Ferris Bueller's Day Off

Wir sehen, dass das Problem darin bestand, dass die ursprüngliche Kandidatenliste nicht groß genug war, oder besser gesagt, dass der gewünschte Treffer der Anfrage auf der höchsten Granularitätsebene nicht ähnlich genug ist. Eine Änderung von 128 auf 256 führt zu einem erfolgreichen Abruf. Wir sollten eine Faustregel aufstellen, um die Anzahl der Kandidaten in einer Warteschleife festzulegen, um den Kompromiss zwischen Wiederauffindbarkeit und Latenzzeit empirisch zu bewerten.

dfs = [hits_to_dataframe(hits) for hits in res]

dfs_results = [
    {"query": query, "results": funnel_search(df, query_emb)}
    for query, df, query_emb in zip(queries2, dfs, search_data2)
]

for d in dfs_results:
    print(d["query"], "\n", d["results"][:7]["title"].to_string(index=False), "\n")
A teenager fakes illness to get off school and have adventures with two friends. 
       A Walk to Remember
Ferris Bueller's Day Off
          How I Live Now
On the Edge of Innocence
        Bratz: The Movie
              Unfriended
              Simon Says 

Spielt die Reihenfolge eine Rolle? Präfix- vs. Suffixeinbettungen.

Das Modell wurde so trainiert, dass es rekursiv kleinere Präfixe der Einbettungen gut abbildet. Spielt die Reihenfolge der Dimensionen, die wir verwenden, eine Rolle? Könnten wir zum Beispiel auch Teilmengen der Einbettungen nehmen, die Suffixe sind? In diesem Experiment kehren wir die Reihenfolge der Dimensionen in den Matryoshka-Einbettungen um und führen eine Trichtersuche durch.

client = MilvusClient(uri="./wikiplots-matryoshka-flipped.db")

fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=256),
    FieldSchema(name="head_embedding", dtype=DataType.FLOAT_VECTOR, dim=search_dim),
    FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=embedding_dim),
]

schema = CollectionSchema(fields=fields, enable_dynamic_field=False)
client.create_collection(collection_name=collection_name, schema=schema)

index_params = client.prepare_index_params()
index_params.add_index(
    field_name="head_embedding", index_type="FLAT", metric_type="COSINE"
)
client.create_index(collection_name, index_params)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
    - Avoid using `tokenizers` before the fork if possible
    - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
for batch in tqdm(ds.batch(batch_size=512)):
    plot_summary = ["search_document: " + x.strip() for x in batch["PlotSummary"]]

    # Encode and flip embeddings
    embeddings = model.encode(plot_summary, convert_to_tensor=True)
    embeddings = torch.flip(embeddings, dims=[-1])
    head_embeddings = embeddings[:, :search_dim]

    data = [
        {
            "title": title,
            "head_embedding": head.cpu().numpy(),
            "embedding": embedding.cpu().numpy(),
        }
        for title, head, embedding in zip(batch["Title"], head_embeddings, embeddings)
    ]
    res = client.insert(collection_name=collection_name, data=data)
100%|██████████| 69/69 [05:50<00:00,  5.08s/it]
# Normalize head embeddings

flip_search_data = [
    torch.flip(torch.tensor(x), dims=[-1]).cpu().numpy() for x in search_data
]
flip_head_search = [x[:search_dim] for x in flip_search_data]

# Perform standard vector search on subset of embeddings
res = client.search(
    collection_name=collection_name,
    data=flip_head_search,
    anns_field="head_embedding",
    limit=128,
    output_fields=["title", "head_embedding", "embedding"],
)
dfs = [hits_to_dataframe(hits) for hits in res]

dfs_results = [
    {"query": query, "results": funnel_search(df, query_emb)}
    for query, df, query_emb in zip(queries, dfs, flip_search_data)
]

for d in dfs_results:
    print(
        d["query"],
        "\n",
        d["results"][:7]["title"].to_string(index=False, header=False),
        "\n",
    )
An archaeologist searches for ancient artifacts while fighting Nazis. 
       "Pimpernel" Smith
          Black Hunters
Raiders of the Lost Ark
         The Master Key
        My Gun Is Quick
            The Passage
        The Mole People 

A teenager fakes illness to get off school and have adventures with two friends. 
                       A Walk to Remember
                          How I Live Now
                              Unfriended
Cirque du Freak: The Vampire's Assistant
                             Last Summer
                                 Contest
                                 Day One 

A young couple with a kid look after a hotel during winter and the husband goes insane. 
         Ghostkeeper
     Killing Ground
Leopard in the Snow
              Stone
          Afterglow
         Unfaithful
     Always a Bride 

Der Rückruf ist erwartungsgemäß viel schlechter als bei der Trichtersuche oder der regulären Suche (das Einbettungsmodell wurde durch kontrastives Lernen auf Präfixen der Einbettungsdimensionen trainiert, nicht auf Suffixen).

Zusammenfassung

Hier ist ein Vergleich unserer Suchergebnisse zwischen den Methoden:

Wir haben gezeigt, wie man Matryoshka-Einbettungen mit Milvus verwendet, um einen effizienteren semantischen Suchalgorithmus namens "Trichtersuche" durchzuführen. Wir haben auch die Bedeutung der Reranking- und Pruning-Schritte des Algorithmus sowie einen Fehlermodus untersucht, wenn die anfängliche Kandidatenliste zu klein ist. Schließlich haben wir erörtert, wie wichtig die Reihenfolge der Dimensionen bei der Bildung von Sub-Embeddings ist - sie muss die gleiche sein, für die das Modell trainiert wurde. Oder besser gesagt, nur weil das Modell auf eine bestimmte Weise trainiert wurde, sind die Präfixe der Einbettungen sinnvoll. Jetzt wissen Sie, wie Sie Matryoshka-Einbettungen und die Trichtersuche implementieren können, um die Speicherkosten der semantischen Suche zu senken, ohne allzu große Einbußen bei der Abrufleistung hinnehmen zu müssen!