milvus-logo
LFAI
フロントページへ
  • 統合

MilvusとSentenceTransformersを使った映画検索

この例では、MilvusとSentenceTransformersライブラリを使って映画のあらすじを検索します。使用するデータセットはHuggingFaceでホストされているWikipedia Movie Plots with Summariesです。

それでは始めましょう!

必要なライブラリ

この例では、Milvusに接続するためにpymilvus 、ベクトル埋め込みを生成するためにsentence-transformers 、サンプルデータセットをダウンロードするためにdatasets

pip install pymilvus sentence-transformers datasets tqdm
from datasets import load_dataset
from pymilvus import MilvusClient
from pymilvus import FieldSchema, CollectionSchema, DataType
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

いくつかのグローバルパラメータを定義します、

embedding_dim = 384
collection_name = "movie_embeddings"

データセットのダウンロードとオープン

datasets 、1行でデータセットをダウンロードし、開くことができる。ライブラリはデータセットをローカルにキャッシュし、次回実行時にはそのコピーを使う。各行には、ウィキペディアの記事が付随している映画の詳細が含まれている。TitlePlotSummaryRelease YearOrigin/Ethnicity のカラムを利用する。

ds = load_dataset("vishnupriyavr/wiki-movie-plots-with-summaries", split="train")
print(ds)

データベースへの接続

この時点で、Milvusのセットアップを開始します。手順は以下の通りである:

  1. ローカルファイルにMilvus Liteデータベースを作成する。(このURIをMilvus StandaloneおよびMilvus Distributedのサーバアドレスに置き換えてください)。
client = MilvusClient(uri="./sentence_transformers_example.db")
  1. データスキーマを作成する。ベクトル埋め込み次元を含む要素を構成するフィールドを指定する。
fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=256),
    FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=embedding_dim),
    FieldSchema(name="year", dtype=DataType.INT64),
    FieldSchema(name="origin", dtype=DataType.VARCHAR, max_length=64),
]

schema = CollectionSchema(fields=fields, enable_dynamic_field=False)
client.create_collection(collection_name=collection_name, schema=schema)
  1. ベクトル検索インデックスアルゴリズムを定義します。Milvus LiteはFLATインデックスタイプをサポートしていますが、Milvus StandaloneとMilvus DistributedはIVF、HNSW、DiskANNなど様々な方法を実装しています。このデモのデータ規模が小さい場合は、どの検索インデックスタイプでも十分なので、ここでは最も単純なFLATを使用する。
index_params = client.prepare_index_params()
index_params.add_index(field_name="embedding", index_type="FLAT", metric_type="IP")
client.create_index(collection_name, index_params)

これらのステップが完了すると、コレクションにデータを挿入し、検索を実行する準備が整います。追加されたデータは自動的にインデックス化され、すぐに検索できるようになります。データが非常に新しい場合は、まだインデックスが作成されていないデータに対して総当たり検索が使用されるため、検索が遅くなる可能性があります。

データの挿入

この例では、SentenceTransformers miniLMモデルを使ってプロットテキストの埋め込みを作成します。このモデルは384次元の埋め込みを返します。

model = SentenceTransformer("all-MiniLM-L12-v2")

データの行をループし、プロットサマリフィールドを埋め込み、エンティティをベクターデータベースに挿入します。一般的に、埋め込みモデルのCPUまたはGPUのスループットを最大にするために、データ項目のバッチに対してこのステップを実行する必要があります。

for batch in tqdm(ds.batch(batch_size=512)):
    embeddings = model.encode(batch["PlotSummary"])
    data = [
        {"title": title, "embedding": embedding, "year": year, "origin": origin}
        for title, embedding, year, origin in zip(
            batch["Title"], embeddings, batch["Release Year"], batch["Origin/Ethnicity"]
        )
    ]
    res = client.insert(collection_name=collection_name, data=data)

埋め込みには時間がかかるため、上記の操作は比較的時間がかかります。このステップは、2023 MacBook ProのCPUを使うと約2分かかりますが、専用GPUを使えばもっと速くなります。コーヒーでも飲みながら一休みしましょう!

Milvusにすべてのデータが挿入されたので、検索を開始することができます。この例では、Wikipediaのあらすじから映画を検索します。バッチ検索を行うため、検索時間は映画検索全体で共有されます。(クエリの説明文から、どんな映画を検索しようと考えていたかわかりますか?)

queries = [
    'A shark terrorizes an LA beach.',
    'An archaeologist searches for ancient artifacts while fighting Nazis.',
    'Teenagers in detention learn about themselves.',
    'A teenager fakes illness to get off school and have adventures with two friends.',
    'A young couple with a kid look after a hotel during winter and the husband goes insane.',
    'Four turtles fight bad guys.'
    ]

# Search the database based on input text
def embed_query(data):
    vectors = model.encode(data)
    return [x for x in vectors]


query_vectors = embed_query(queries)

res = client.search(
    collection_name=collection_name,
    data=query_vectors,
    filter='origin == "American" and year > 1945 and year < 2000',
    anns_field="embedding",
    limit=3,
    output_fields=["title"],
)

for idx, hits in enumerate(res):
    print("Query:", queries[idx])
    print("Results:")
    for hit in hits:
        print(hit["entity"].get("title"), "(", round(hit["distance"], 2), ")")
    print()

結果は以下の通り:

Query: An archaeologist searches for ancient artifacts while fighting Nazis.
Results:
Love Slaves of the Amazons ( 0.4 )
A Time to Love and a Time to Die ( 0.39 )
The Fifth Element ( 0.39 )

Query: Teenagers in detention learn about themselves.
Results:
The Breakfast Club ( 0.54 )
Up the Academy ( 0.46 )
Fame ( 0.43 )

Query: A teenager fakes illness to get off school and have adventures with two friends.
Results:
Ferris Bueller's Day Off ( 0.48 )
Fever Lake ( 0.47 )
Losin' It ( 0.39 )

Query: A young couple with a kid look after a hotel during winter and the husband goes insane.
Results:
The Shining ( 0.48 )
The Four Seasons ( 0.42 )
Highball ( 0.41 )

Query: Four turtles fight bad guys.
Results:
Teenage Mutant Ninja Turtles II: The Secret of the Ooze ( 0.47 )
Devil May Hare ( 0.43 )
Attack of the Giant Leeches ( 0.42 )

翻訳DeepL

Try Managed Milvus for Free

Zilliz Cloud is hassle-free, powered by Milvus and 10x faster.

Get Started
フィードバック

このページは役に立ちましたか ?