MilvusとSentenceTransformersを使った映画検索
この例では、MilvusとSentenceTransformersライブラリを使ったWikipediaの記事検索について説明します。検索するデータセットはKaggleにあるWikipedia-Movie-Plots Datasetです。この例では、公開されているgoogleドライブにデータを再ホストしています。
それでは始めましょう。
要件のインストール
この例では、pymilvus
を使ってMilvusに接続し、sentencetransformers
を使ってベクトル埋め込みを生成し、gdown
を使ってサンプルデータセットをダウンロードします。
pip install pymilvus sentence-transformers gdown
データの取得
gdown
を使ってGoogle Driveからzipを取得し、ビルトインzipfile
ライブラリを使って解凍します。
import gdown
url = 'https://drive.google.com/uc?id=11ISS45aO2ubNCGaC3Lvd3D7NT8Y7MeO8'
output = './movies.zip'
gdown.download(url, output)
import zipfile
with zipfile.ZipFile("./movies.zip","r") as zip_ref:
zip_ref.extractall("./movies")
グローバル・パラメーター
ここでは、独自のアカウントで実行するために変更する必要がある主な引数を見つけることができる。それぞれの横にはその説明があります。
# Milvus Setup Arguments
COLLECTION_NAME = 'movies_db' # Collection name
DIMENSION = 384 # Embeddings size
COUNT = 1000 # Number of vectors to insert
MILVUS_HOST = 'localhost'
MILVUS_PORT = '19530'
# Inference Arguments
BATCH_SIZE = 128
# Search Arguments
TOP_K = 3
Milvusのセットアップ
この時点でMilvusのセットアップを開始する。手順は以下の通りです:
提供されたURIを使用してMilvusインスタンスに接続する。
from pymilvus import connections # Connect to Milvus Database connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
コレクションが既に存在する場合は、それを削除する。
from pymilvus import utility # Remove any previous collections with the same name if utility.has_collection(COLLECTION_NAME): utility.drop_collection(COLLECTION_NAME)
id、映画のタイトル、プロットテキストの埋め込みを保持するコレクションを作成する。
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection # Create collection which includes the id, title, and embedding. fields = [ FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name='title', dtype=DataType.VARCHAR, max_length=200), # VARCHARS need a maximum length, so for this example they are set to 200 characters FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION) ] schema = CollectionSchema(fields=fields) collection = Collection(name=COLLECTION_NAME, schema=schema)
新しく作成されたコレクションにインデックスを作成し、メモリにロードする。
# Create an IVF_FLAT index for collection. index_params = { 'metric_type':'L2', 'index_type':"IVF_FLAT", 'params':{'nlist': 1536} } collection.create_index(field_name="embedding", index_params=index_params) collection.load()
これらのステップが完了すると、コレクションに挿入して検索する準備ができます。追加されたデータは自動的にインデックスが作成され、すぐに検索できるようになります。データが非常に新しい場合、まだインデックスが作成されていないデータに対して総当たり検索が使用されるため、検索が遅くなる可能性があります。
データの挿入
この例では、SentenceTransformers miniLMモデルを使ってプロットテキストの埋め込みを作成します。このモデルは384次元の埋め込みを返します。
次のいくつかのステップでは
- データをロードする。
- SentenceTransformersを使ってプロットテキストデータを埋め込む。
- データをMilvusに挿入する。
import csv
from sentence_transformers import SentenceTransformer
transformer = SentenceTransformer('all-MiniLM-L6-v2')
# Extract the book titles
def csv_load(file):
with open(file, newline='') as f:
reader = csv.reader(f, delimiter=',')
for row in reader:
if '' in (row[1], row[7]):
continue
yield (row[1], row[7])
# Extract embedding from text using OpenAI
def embed_insert(data):
embeds = transformer.encode(data[1])
ins = [
data[0],
[x for x in embeds]
]
collection.insert(ins)
import time
data_batch = [[],[]]
count = 0
for title, plot in csv_load('./movies/plots.csv'):
if count <= COUNT:
data_batch[0].append(title)
data_batch[1].append(plot)
if len(data_batch[0]) % BATCH_SIZE == 0:
embed_insert(data_batch)
data_batch = [[],[]]
count += 1
else:
break
# Embed and insert the remainder
if len(data_batch[0]) != 0:
embed_insert(data_batch)
# Call a flush to index any unsealed segments.
collection.flush()
埋め込みに時間がかかるため、上記の操作は比較的時間がかかります。この時間を許容範囲内に抑えるには、Global parametersのCOUNT
を適切な値に設定してください。コーヒーでも飲みながら一休みしてください!
検索の実行
すべてのデータがMilvusに挿入されたので、検索を開始することができます。この例では、プロットに基づいて映画を検索します。バッチ検索を行うため、検索時間は映画検索全体で共有されます。
# Search for titles that closest match these phrases.
search_terms = ['A movie about cars', 'A movie about monsters']
# Search the database based on input text
def embed_search(data):
embeds = transformer.encode(data)
return [x for x in embeds]
search_data = embed_search(search_terms)
start = time.time()
res = collection.search(
data=search_data, # Embeded search value
anns_field="embedding", # Search across embeddings
param={},
limit = TOP_K, # Limit to top_k results per search
output_fields=['title'] # Include title field in result
)
end = time.time()
for hits_i, hits in enumerate(res):
print('Title:', search_terms[hits_i])
print('Search Time:', end-start)
print('Results:')
for hit in hits:
print( hit.entity.get('title'), '----', hit.distance)
print()
出力は以下のようになります:
Title: A movie about cars
Search Time: 0.08636689186096191
Results:
Youth's Endearing Charm ---- 1.0954499244689941
From Leadville to Aspen: A Hold-Up in the Rockies ---- 1.1019384860992432
Gentlemen of Nerve ---- 1.1331942081451416
Title: A movie about monsters
Search Time: 0.08636689186096191
Results:
The Suburbanite ---- 1.0666425228118896
Youth's Endearing Charm ---- 1.1072258949279785
The Godless Girl ---- 1.1511223316192627