Ricerca a tutto testo con Milvus
Con il rilascio di Milvus 2.5, la ricerca full text consente agli utenti di cercare in modo efficiente il testo in base a parole o frasi chiave, fornendo potenti capacità di recupero del testo. Questa funzione migliora l'accuratezza della ricerca e può essere perfettamente combinata con il reperimento basato sull'embedding per una ricerca ibrida, consentendo di ottenere risultati sia semantici che basati su parole chiave in un'unica interrogazione. In questo quaderno mostreremo l'uso di base della ricerca full text in Milvus.
Preparazione
Scaricare il set di dati
Il seguente comando scaricherà i dati di esempio utilizzati nella demo originale di Anthropic.
$ wget https://raw.githubusercontent.com/anthropics/anthropic-cookbook/refs/heads/main/skills/contextual-embeddings/data/codebase_chunks.json
$ wget https://raw.githubusercontent.com/anthropics/anthropic-cookbook/refs/heads/main/skills/contextual-embeddings/data/evaluation_set.jsonl
Installare Milvus 2.5
Per maggiori dettagli, consultare la guida ufficiale all'installazione.
Installare PyMilvus
Eseguire il seguente comando per installare PyMilvus:
pip install "pymilvus[model]" -U
Definire il Retriever
import json
from pymilvus import (
MilvusClient,
DataType,
Function,
FunctionType,
AnnSearchRequest,
RRFRanker,
)
from pymilvus.model.hybrid import BGEM3EmbeddingFunction
class HybridRetriever:
def __init__(self, uri, collection_name="hybrid", dense_embedding_function=None):
self.uri = uri
self.collection_name = collection_name
self.embedding_function = dense_embedding_function
self.use_reranker = True
self.use_sparse = True
self.client = MilvusClient(uri=uri)
def build_collection(self):
if isinstance(self.embedding_function.dim, dict):
dense_dim = self.embedding_function.dim["dense"]
else:
dense_dim = self.embedding_function.dim
tokenizer_params = {
"tokenizer": "standard",
"filter": [
"lowercase",
{
"type": "length",
"max": 200,
},
{"type": "stemmer", "language": "english"},
{
"type": "stop",
"stop_words": [
"a",
"an",
"and",
"are",
"as",
"at",
"be",
"but",
"by",
"for",
"if",
"in",
"into",
"is",
"it",
"no",
"not",
"of",
"on",
"or",
"such",
"that",
"the",
"their",
"then",
"there",
"these",
"they",
"this",
"to",
"was",
"will",
"with",
],
},
],
}
schema = MilvusClient.create_schema()
schema.add_field(
field_name="pk",
datatype=DataType.VARCHAR,
is_primary=True,
auto_id=True,
max_length=100,
)
schema.add_field(
field_name="content",
datatype=DataType.VARCHAR,
max_length=65535,
analyzer_params=tokenizer_params,
enable_match=True,
enable_analyzer=True,
)
schema.add_field(
field_name="sparse_vector", datatype=DataType.SPARSE_FLOAT_VECTOR
)
schema.add_field(
field_name="dense_vector", datatype=DataType.FLOAT_VECTOR, dim=dense_dim
)
schema.add_field(
field_name="original_uuid", datatype=DataType.VARCHAR, max_length=128
)
schema.add_field(field_name="doc_id", datatype=DataType.VARCHAR, max_length=64)
schema.add_field(
field_name="chunk_id", datatype=DataType.VARCHAR, max_length=64
),
schema.add_field(field_name="original_index", datatype=DataType.INT32)
functions = Function(
name="bm25",
function_type=FunctionType.BM25,
input_field_names=["content"],
output_field_names="sparse_vector",
)
schema.add_function(functions)
index_params = MilvusClient.prepare_index_params()
index_params.add_index(
field_name="sparse_vector",
index_type="SPARSE_INVERTED_INDEX",
metric_type="BM25",
)
index_params.add_index(
field_name="dense_vector", index_type="FLAT", metric_type="IP"
)
self.client.create_collection(
collection_name=self.collection_name,
schema=schema,
index_params=index_params,
)
def insert_data(self, chunk, metadata):
embedding = self.embedding_function([chunk])
if isinstance(embedding, dict) and "dense" in embedding:
dense_vec = embedding["dense"][0]
else:
dense_vec = embedding[0]
self.client.insert(
self.collection_name, {"dense_vector": dense_vec, **metadata}
)
def search(self, query: str, k: int = 20, mode="hybrid"):
output_fields = [
"content",
"original_uuid",
"doc_id",
"chunk_id",
"original_index",
]
if mode in ["dense", "hybrid"]:
embedding = self.embedding_function([query])
if isinstance(embedding, dict) and "dense" in embedding:
dense_vec = embedding["dense"][0]
else:
dense_vec = embedding[0]
if mode == "sparse":
results = self.client.search(
collection_name=self.collection_name,
data=[query],
anns_field="sparse_vector",
limit=k,
output_fields=output_fields,
)
elif mode == "dense":
results = self.client.search(
collection_name=self.collection_name,
data=[dense_vec],
anns_field="dense_vector",
limit=k,
output_fields=output_fields,
)
elif mode == "hybrid":
full_text_search_params = {"metric_type": "BM25"}
full_text_search_req = AnnSearchRequest(
[query], "sparse_vector", full_text_search_params, limit=k
)
dense_search_params = {"metric_type": "IP"}
dense_req = AnnSearchRequest(
[dense_vec], "dense_vector", dense_search_params, limit=k
)
results = self.client.hybrid_search(
self.collection_name,
[full_text_search_req, dense_req],
ranker=RRFRanker(),
limit=k,
output_fields=output_fields,
)
else:
raise ValueError("Invalid mode")
return [
{
"doc_id": doc["entity"]["doc_id"],
"chunk_id": doc["entity"]["chunk_id"],
"content": doc["entity"]["content"],
"score": doc["distance"],
}
for doc in results[0]
]
dense_ef = BGEM3EmbeddingFunction()
standard_retriever = HybridRetriever(
uri="http://localhost:19530",
collection_name="milvus_hybrid",
dense_embedding_function=dense_ef,
)
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 108848.72it/s]
Inserire i dati
path = "codebase_chunks.json"
with open(path, "r") as f:
dataset = json.load(f)
is_insert = True
if is_insert:
standard_retriever.build_collection()
for doc in dataset:
doc_content = doc["content"]
for chunk in doc["chunks"]:
metadata = {
"doc_id": doc["doc_id"],
"original_uuid": doc["original_uuid"],
"chunk_id": chunk["chunk_id"],
"original_index": chunk["original_index"],
"content": chunk["content"],
}
chunk_content = chunk["content"]
standard_retriever.insert_data(chunk_content, metadata)
Testare la ricerca sparsa
results = standard_retriever.search("create a logger?", mode="sparse", k=3)
print(results)
[{'doc_id': 'doc_10', 'chunk_id': 'doc_10_chunk_0', 'content': 'use {\n crate::args::LogArgs,\n anyhow::{anyhow, Result},\n simplelog::{Config, LevelFilter, WriteLogger},\n std::fs::File,\n};\n\npub struct Logger;\n\nimpl Logger {\n pub fn init(args: &impl LogArgs) -> Result<()> {\n let filter: LevelFilter = args.log_level().into();\n if filter != LevelFilter::Off {\n let logfile = File::create(args.log_file())\n .map_err(|e| anyhow!("Failed to open log file: {e:}"))?;\n WriteLogger::init(filter, Config::default(), logfile)\n .map_err(|e| anyhow!("Failed to initalize logger: {e:}"))?;\n }\n Ok(())\n }\n}\n', 'score': 9.12518310546875}, {'doc_id': 'doc_87', 'chunk_id': 'doc_87_chunk_3', 'content': '\t\tLoggerPtr INF = Logger::getLogger(LOG4CXX_TEST_STR("INF"));\n\t\tINF->setLevel(Level::getInfo());\n\n\t\tLoggerPtr INF_ERR = Logger::getLogger(LOG4CXX_TEST_STR("INF.ERR"));\n\t\tINF_ERR->setLevel(Level::getError());\n\n\t\tLoggerPtr DEB = Logger::getLogger(LOG4CXX_TEST_STR("DEB"));\n\t\tDEB->setLevel(Level::getDebug());\n\n\t\t// Note: categories with undefined level\n\t\tLoggerPtr INF_UNDEF = Logger::getLogger(LOG4CXX_TEST_STR("INF.UNDEF"));\n\t\tLoggerPtr INF_ERR_UNDEF = Logger::getLogger(LOG4CXX_TEST_STR("INF.ERR.UNDEF"));\n\t\tLoggerPtr UNDEF = Logger::getLogger(LOG4CXX_TEST_STR("UNDEF"));\n\n', 'score': 7.0077056884765625}, {'doc_id': 'doc_89', 'chunk_id': 'doc_89_chunk_3', 'content': 'using namespace log4cxx;\nusing namespace log4cxx::helpers;\n\nLOGUNIT_CLASS(FMTTestCase)\n{\n\tLOGUNIT_TEST_SUITE(FMTTestCase);\n\tLOGUNIT_TEST(test1);\n\tLOGUNIT_TEST(test1_expanded);\n\tLOGUNIT_TEST(test10);\n//\tLOGUNIT_TEST(test_date);\n\tLOGUNIT_TEST_SUITE_END();\n\n\tLoggerPtr root;\n\tLoggerPtr logger;\n\npublic:\n\tvoid setUp()\n\t{\n\t\troot = Logger::getRootLogger();\n\t\tMDC::clear();\n\t\tlogger = Logger::getLogger(LOG4CXX_TEST_STR("java.org.apache.log4j.PatternLayoutTest"));\n\t}\n\n', 'score': 6.750633716583252}]
Valutazione
Ora che abbiamo inserito il dataset in Milvus, possiamo usare la ricerca densa, rada o ibrida per recuperare i primi 5 risultati. È possibile cambiare il sito mode
e valutare ciascuno di essi. Presentiamo la metrica Pass@5, che prevede il recupero dei primi 5 risultati per ogni query e il calcolo del Recall.
def load_jsonl(file_path: str):
"""Load JSONL file and return a list of dictionaries."""
with open(file_path, "r") as file:
return [json.loads(line) for line in file]
dataset = load_jsonl("evaluation_set.jsonl")
k = 5
# mode can be "dense", "sparse" or "hybrid".
mode = "hybrid"
total_query_score = 0
num_queries = 0
for query_item in dataset:
query = query_item["query"]
golden_chunk_uuids = query_item["golden_chunk_uuids"]
chunks_found = 0
golden_contents = []
for doc_uuid, chunk_index in golden_chunk_uuids:
golden_doc = next(
(doc for doc in query_item["golden_documents"] if doc["uuid"] == doc_uuid),
None,
)
if golden_doc:
golden_chunk = next(
(
chunk
for chunk in golden_doc["chunks"]
if chunk["index"] == chunk_index
),
None,
)
if golden_chunk:
golden_contents.append(golden_chunk["content"].strip())
results = standard_retriever.search(query, mode=mode, k=5)
for golden_content in golden_contents:
for doc in results[:k]:
retrieved_content = doc["content"].strip()
if retrieved_content == golden_content:
chunks_found += 1
break
query_score = chunks_found / len(golden_contents)
total_query_score += query_score
num_queries += 1
print("Pass@5: ", total_query_score / num_queries)
Pass@5: 0.7911386328725037