Milvus로 전체 텍스트 검색
Milvus 2.5의 출시로 사용자는 키워드 또는 구문을 기반으로 텍스트를 효율적으로 검색할 수 있는 전체 텍스트 검색을 통해 강력한 텍스트 검색 기능을 사용할 수 있습니다. 이 기능은 검색 정확도를 향상시키고 임베딩 기반 검색과 원활하게 결합하여 하나의 쿼리로 의미론적 검색과 키워드 기반 검색 결과를 모두 얻을 수 있는 하이브리드 검색을 가능하게 합니다. 이 노트북에서는 Milvus에서 전체 텍스트 검색의 기본적인 사용법을 보여드리겠습니다.
준비하기
데이터 세트 다운로드
다음 명령은 원래의 Anthropic 데모에 사용된 예제 데이터를 다운로드합니다.
$ wget https://raw.githubusercontent.com/anthropics/anthropic-cookbook/refs/heads/main/skills/contextual-embeddings/data/codebase_chunks.json
$ wget https://raw.githubusercontent.com/anthropics/anthropic-cookbook/refs/heads/main/skills/contextual-embeddings/data/evaluation_set.jsonl
Milvus 2.5 설치
자세한 내용은 공식 설치 가이드를 참조하세요.
파이밀버스 설치
다음 명령어를 실행하여 PyMilvus를 설치합니다:
pip install "pymilvus[model]" -U
리트리버 정의
import json
from pymilvus import (
MilvusClient,
DataType,
Function,
FunctionType,
AnnSearchRequest,
RRFRanker,
)
from pymilvus.model.hybrid import BGEM3EmbeddingFunction
class HybridRetriever:
def __init__(self, uri, collection_name="hybrid", dense_embedding_function=None):
self.uri = uri
self.collection_name = collection_name
self.embedding_function = dense_embedding_function
self.use_reranker = True
self.use_sparse = True
self.client = MilvusClient(uri=uri)
def build_collection(self):
if isinstance(self.embedding_function.dim, dict):
dense_dim = self.embedding_function.dim["dense"]
else:
dense_dim = self.embedding_function.dim
tokenizer_params = {
"tokenizer": "standard",
"filter": [
"lowercase",
{
"type": "length",
"max": 200,
},
{"type": "stemmer", "language": "english"},
{
"type": "stop",
"stop_words": [
"a",
"an",
"and",
"are",
"as",
"at",
"be",
"but",
"by",
"for",
"if",
"in",
"into",
"is",
"it",
"no",
"not",
"of",
"on",
"or",
"such",
"that",
"the",
"their",
"then",
"there",
"these",
"they",
"this",
"to",
"was",
"will",
"with",
],
},
],
}
schema = MilvusClient.create_schema()
schema.add_field(
field_name="pk",
datatype=DataType.VARCHAR,
is_primary=True,
auto_id=True,
max_length=100,
)
schema.add_field(
field_name="content",
datatype=DataType.VARCHAR,
max_length=65535,
analyzer_params=tokenizer_params,
enable_match=True,
enable_analyzer=True,
)
schema.add_field(
field_name="sparse_vector", datatype=DataType.SPARSE_FLOAT_VECTOR
)
schema.add_field(
field_name="dense_vector", datatype=DataType.FLOAT_VECTOR, dim=dense_dim
)
schema.add_field(
field_name="original_uuid", datatype=DataType.VARCHAR, max_length=128
)
schema.add_field(field_name="doc_id", datatype=DataType.VARCHAR, max_length=64)
schema.add_field(
field_name="chunk_id", datatype=DataType.VARCHAR, max_length=64
),
schema.add_field(field_name="original_index", datatype=DataType.INT32)
functions = Function(
name="bm25",
function_type=FunctionType.BM25,
input_field_names=["content"],
output_field_names="sparse_vector",
)
schema.add_function(functions)
index_params = MilvusClient.prepare_index_params()
index_params.add_index(
field_name="sparse_vector",
index_type="SPARSE_INVERTED_INDEX",
metric_type="BM25",
)
index_params.add_index(
field_name="dense_vector", index_type="FLAT", metric_type="IP"
)
self.client.create_collection(
collection_name=self.collection_name,
schema=schema,
index_params=index_params,
)
def insert_data(self, chunk, metadata):
embedding = self.embedding_function([chunk])
if isinstance(embedding, dict) and "dense" in embedding:
dense_vec = embedding["dense"][0]
else:
dense_vec = embedding[0]
self.client.insert(
self.collection_name, {"dense_vector": dense_vec, **metadata}
)
def search(self, query: str, k: int = 20, mode="hybrid"):
output_fields = [
"content",
"original_uuid",
"doc_id",
"chunk_id",
"original_index",
]
if mode in ["dense", "hybrid"]:
embedding = self.embedding_function([query])
if isinstance(embedding, dict) and "dense" in embedding:
dense_vec = embedding["dense"][0]
else:
dense_vec = embedding[0]
if mode == "sparse":
results = self.client.search(
collection_name=self.collection_name,
data=[query],
anns_field="sparse_vector",
limit=k,
output_fields=output_fields,
)
elif mode == "dense":
results = self.client.search(
collection_name=self.collection_name,
data=[dense_vec],
anns_field="dense_vector",
limit=k,
output_fields=output_fields,
)
elif mode == "hybrid":
full_text_search_params = {"metric_type": "BM25"}
full_text_search_req = AnnSearchRequest(
[query], "sparse_vector", full_text_search_params, limit=k
)
dense_search_params = {"metric_type": "IP"}
dense_req = AnnSearchRequest(
[dense_vec], "dense_vector", dense_search_params, limit=k
)
results = self.client.hybrid_search(
self.collection_name,
[full_text_search_req, dense_req],
ranker=RRFRanker(),
limit=k,
output_fields=output_fields,
)
else:
raise ValueError("Invalid mode")
return [
{
"doc_id": doc["entity"]["doc_id"],
"chunk_id": doc["entity"]["chunk_id"],
"content": doc["entity"]["content"],
"score": doc["distance"],
}
for doc in results[0]
]
dense_ef = BGEM3EmbeddingFunction()
standard_retriever = HybridRetriever(
uri="http://localhost:19530",
collection_name="milvus_hybrid",
dense_embedding_function=dense_ef,
)
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 108848.72it/s]
데이터 삽입
path = "codebase_chunks.json"
with open(path, "r") as f:
dataset = json.load(f)
is_insert = True
if is_insert:
standard_retriever.build_collection()
for doc in dataset:
doc_content = doc["content"]
for chunk in doc["chunks"]:
metadata = {
"doc_id": doc["doc_id"],
"original_uuid": doc["original_uuid"],
"chunk_id": chunk["chunk_id"],
"original_index": chunk["original_index"],
"content": chunk["content"],
}
chunk_content = chunk["content"]
standard_retriever.insert_data(chunk_content, metadata)
스파스 검색 테스트
results = standard_retriever.search("create a logger?", mode="sparse", k=3)
print(results)
[{'doc_id': 'doc_10', 'chunk_id': 'doc_10_chunk_0', 'content': 'use {\n crate::args::LogArgs,\n anyhow::{anyhow, Result},\n simplelog::{Config, LevelFilter, WriteLogger},\n std::fs::File,\n};\n\npub struct Logger;\n\nimpl Logger {\n pub fn init(args: &impl LogArgs) -> Result<()> {\n let filter: LevelFilter = args.log_level().into();\n if filter != LevelFilter::Off {\n let logfile = File::create(args.log_file())\n .map_err(|e| anyhow!("Failed to open log file: {e:}"))?;\n WriteLogger::init(filter, Config::default(), logfile)\n .map_err(|e| anyhow!("Failed to initalize logger: {e:}"))?;\n }\n Ok(())\n }\n}\n', 'score': 9.12518310546875}, {'doc_id': 'doc_87', 'chunk_id': 'doc_87_chunk_3', 'content': '\t\tLoggerPtr INF = Logger::getLogger(LOG4CXX_TEST_STR("INF"));\n\t\tINF->setLevel(Level::getInfo());\n\n\t\tLoggerPtr INF_ERR = Logger::getLogger(LOG4CXX_TEST_STR("INF.ERR"));\n\t\tINF_ERR->setLevel(Level::getError());\n\n\t\tLoggerPtr DEB = Logger::getLogger(LOG4CXX_TEST_STR("DEB"));\n\t\tDEB->setLevel(Level::getDebug());\n\n\t\t// Note: categories with undefined level\n\t\tLoggerPtr INF_UNDEF = Logger::getLogger(LOG4CXX_TEST_STR("INF.UNDEF"));\n\t\tLoggerPtr INF_ERR_UNDEF = Logger::getLogger(LOG4CXX_TEST_STR("INF.ERR.UNDEF"));\n\t\tLoggerPtr UNDEF = Logger::getLogger(LOG4CXX_TEST_STR("UNDEF"));\n\n', 'score': 7.0077056884765625}, {'doc_id': 'doc_89', 'chunk_id': 'doc_89_chunk_3', 'content': 'using namespace log4cxx;\nusing namespace log4cxx::helpers;\n\nLOGUNIT_CLASS(FMTTestCase)\n{\n\tLOGUNIT_TEST_SUITE(FMTTestCase);\n\tLOGUNIT_TEST(test1);\n\tLOGUNIT_TEST(test1_expanded);\n\tLOGUNIT_TEST(test10);\n//\tLOGUNIT_TEST(test_date);\n\tLOGUNIT_TEST_SUITE_END();\n\n\tLoggerPtr root;\n\tLoggerPtr logger;\n\npublic:\n\tvoid setUp()\n\t{\n\t\troot = Logger::getRootLogger();\n\t\tMDC::clear();\n\t\tlogger = Logger::getLogger(LOG4CXX_TEST_STR("java.org.apache.log4j.PatternLayoutTest"));\n\t}\n\n', 'score': 6.750633716583252}]
평가
이제 Milvus에 데이터 세트를 삽입했으므로 밀도, 스파스 또는 하이브리드 검색을 사용하여 상위 5개의 결과를 검색할 수 있습니다. mode
을 변경하고 각각을 평가할 수 있습니다. 각 쿼리에 대해 상위 5개 결과를 검색하고 Recall을 계산하는 Pass@5 메트릭을 제시합니다.
def load_jsonl(file_path: str):
"""Load JSONL file and return a list of dictionaries."""
with open(file_path, "r") as file:
return [json.loads(line) for line in file]
dataset = load_jsonl("evaluation_set.jsonl")
k = 5
# mode can be "dense", "sparse" or "hybrid".
mode = "hybrid"
total_query_score = 0
num_queries = 0
for query_item in dataset:
query = query_item["query"]
golden_chunk_uuids = query_item["golden_chunk_uuids"]
chunks_found = 0
golden_contents = []
for doc_uuid, chunk_index in golden_chunk_uuids:
golden_doc = next(
(doc for doc in query_item["golden_documents"] if doc["uuid"] == doc_uuid),
None,
)
if golden_doc:
golden_chunk = next(
(
chunk
for chunk in golden_doc["chunks"]
if chunk["index"] == chunk_index
),
None,
)
if golden_chunk:
golden_contents.append(golden_chunk["content"].strip())
results = standard_retriever.search(query, mode=mode, k=5)
for golden_content in golden_contents:
for doc in results[:k]:
retrieved_content = doc["content"].strip()
if retrieved_content == golden_content:
chunks_found += 1
break
query_score = chunks_found / len(golden_contents)
total_query_score += query_score
num_queries += 1
print("Pass@5: ", total_query_score / num_queries)
Pass@5: 0.7911386328725037