milvus-logo
LFAI
홈페이지
  • 통합

Milvus와 SentenceTransformers를 사용한 영화 검색

이 예제에서는 Milvus와 SentenceTransformers 라이브러리를 사용하여 영화 줄거리 요약을 검색하겠습니다. 우리가 사용할 데이터 세트는 HuggingFace에서 호스팅되는 요약이 포함된 Wikipedia 영화 플롯입니다.

시작해 보겠습니다!

필요한 라이브러리

이 예제에서는 Milvus를 사용하기 위해 pymilvus, 벡터 임베딩을 생성하기 위해 sentence-transformers, 예제 데이터 세트를 다운로드하기 위해 datasets 을 사용합니다.

pip install pymilvus sentence-transformers datasets tqdm
from datasets import load_dataset
from pymilvus import MilvusClient
from pymilvus import FieldSchema, CollectionSchema, DataType
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

몇 가지 전역 매개변수를 정의하겠습니다,

embedding_dim = 384
collection_name = "movie_embeddings"

데이터 세트 다운로드 및 열기

datasets 에서 한 줄로 데이터 세트를 다운로드하고 열 수 있습니다. 라이브러리는 데이터 집합을 로컬에 캐시하고 다음에 실행할 때 해당 복사본을 사용합니다. 각 행에는 Wikipedia 문서와 함께 제공되는 영화에 대한 세부 정보가 포함되어 있습니다. Title , PlotSummary, Release Year, Origin/Ethnicity 열을 사용합니다.

ds = load_dataset("vishnupriyavr/wiki-movie-plots-with-summaries", split="train")
print(ds)

데이터베이스에 연결하기

이제 Milvus 설정을 시작하겠습니다. 단계는 다음과 같습니다:

  1. 로컬 파일에 Milvus Lite 데이터베이스를 만듭니다. (이 URI를 Milvus 독립형 및 Milvus 배포용 서버 주소로 바꿉니다.).
client = MilvusClient(uri="./sentence_transformers_example.db")
  1. 데이터 스키마를 생성합니다. 여기에는 벡터 임베딩의 차원을 포함하여 요소를 구성하는 필드가 지정됩니다.
fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=256),
    FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=embedding_dim),
    FieldSchema(name="year", dtype=DataType.INT64),
    FieldSchema(name="origin", dtype=DataType.VARCHAR, max_length=64),
]

schema = CollectionSchema(fields=fields, enable_dynamic_field=False)
client.create_collection(collection_name=collection_name, schema=schema)
  1. 벡터 검색 인덱싱 알고리즘을 정의합니다. Milvus Lite는 FLAT 인덱스 유형을 지원하는 반면, Milvus Standalone과 Milvus Distributed는 IVF, HNSW, DiskANN과 같은 다양한 방법을 구현합니다. 이 데모의 데이터 규모가 작기 때문에 어떤 검색 인덱스 유형이든 충분하므로 여기서는 가장 간단한 FLAT을 사용합니다.
index_params = client.prepare_index_params()
index_params.add_index(field_name="embedding", index_type="FLAT", metric_type="IP")
client.create_index(collection_name, index_params)

이 단계가 완료되면 컬렉션에 데이터를 삽입하고 검색을 수행할 준비가 된 것입니다. 추가된 모든 데이터는 자동으로 색인되며 즉시 검색에 사용할 수 있습니다. 데이터가 매우 새 데이터인 경우, 색인화 작업이 진행 중인 데이터에 대해 무차별 대입 검색이 사용되므로 검색 속도가 느려질 수 있습니다.

데이터 삽입하기

이 예제에서는 SentenceTransformers miniLM 모델을 사용하여 플롯 텍스트의 임베딩을 만들겠습니다. 이 모델은 384차원 임베딩을 반환합니다.

model = SentenceTransformer("all-MiniLM-L12-v2")

데이터의 행을 반복하고, 플롯 요약 필드를 임베드하고, 엔티티를 벡터 데이터베이스에 삽입합니다. 일반적으로 이 단계는 임베딩 모델의 CPU 또는 GPU 처리량을 최대화하려면 여기처럼 데이터 항목의 배치에 대해 수행해야 합니다.

for batch in tqdm(ds.batch(batch_size=512)):
    embeddings = model.encode(batch["PlotSummary"])
    data = [
        {"title": title, "embedding": embedding, "year": year, "origin": origin}
        for title, embedding, year, origin in zip(
            batch["Title"], embeddings, batch["Release Year"], batch["Origin/Ethnicity"]
        )
    ]
    res = client.insert(collection_name=collection_name, data=data)

임베딩에는 시간이 걸리기 때문에 위의 작업은 상대적으로 시간이 많이 걸립니다. 이 단계는 2023 MacBook Pro의 CPU를 사용하면 약 2분이 소요되며, 전용 GPU를 사용하면 훨씬 더 빨라집니다. 잠시 휴식을 취하며 커피 한 잔을 즐기세요!

Milvus에 모든 데이터를 삽입했으면 검색을 시작할 수 있습니다. 이 예에서는 Wikipedia의 줄거리 요약을 기반으로 영화를 검색해 보겠습니다. 일괄 검색을 수행하기 때문에 검색 시간은 영화 검색에 걸쳐 공유됩니다. (쿼리 설명 텍스트를 기반으로 어떤 영화를 검색하려고 했는지 짐작할 수 있을까요?)

queries = [
    'A shark terrorizes an LA beach.',
    'An archaeologist searches for ancient artifacts while fighting Nazis.',
    'Teenagers in detention learn about themselves.',
    'A teenager fakes illness to get off school and have adventures with two friends.',
    'A young couple with a kid look after a hotel during winter and the husband goes insane.',
    'Four turtles fight bad guys.'
    ]

# Search the database based on input text
def embed_query(data):
    vectors = model.encode(data)
    return [x for x in vectors]


query_vectors = embed_query(queries)

res = client.search(
    collection_name=collection_name,
    data=query_vectors,
    filter='origin == "American" and year > 1945 and year < 2000',
    anns_field="embedding",
    limit=3,
    output_fields=["title"],
)

for idx, hits in enumerate(res):
    print("Query:", queries[idx])
    print("Results:")
    for hit in hits:
        print(hit["entity"].get("title"), "(", round(hit["distance"], 2), ")")
    print()

결과는 다음과 같습니다:

Query: An archaeologist searches for ancient artifacts while fighting Nazis.
Results:
Love Slaves of the Amazons ( 0.4 )
A Time to Love and a Time to Die ( 0.39 )
The Fifth Element ( 0.39 )

Query: Teenagers in detention learn about themselves.
Results:
The Breakfast Club ( 0.54 )
Up the Academy ( 0.46 )
Fame ( 0.43 )

Query: A teenager fakes illness to get off school and have adventures with two friends.
Results:
Ferris Bueller's Day Off ( 0.48 )
Fever Lake ( 0.47 )
Losin' It ( 0.39 )

Query: A young couple with a kid look after a hotel during winter and the husband goes insane.
Results:
The Shining ( 0.48 )
The Four Seasons ( 0.42 )
Highball ( 0.41 )

Query: Four turtles fight bad guys.
Results:
Teenage Mutant Ninja Turtles II: The Secret of the Ooze ( 0.47 )
Devil May Hare ( 0.43 )
Attack of the Giant Leeches ( 0.42 )

번역DeepLogo

Try Managed Milvus for Free

Zilliz Cloud is hassle-free, powered by Milvus and 10x faster.

Get Started
피드백

이 페이지가 도움이 되었나요?