milvus-logo
LFAI
Home
  • Anleitungen

Open In Colab GitHub Repository

Volltextsuche mit Milvus

Mit der Veröffentlichung von Milvus 2.5 ermöglicht die Volltextsuche eine effiziente Suche nach Text auf der Grundlage von Schlüsselwörtern oder Phrasen und bietet damit leistungsstarke Textabfragefunktionen. Diese Funktion verbessert die Suchgenauigkeit und kann nahtlos mit der einbettungsbasierten Suche für eine hybride Suche kombiniert werden, die sowohl semantische als auch stichwortbasierte Ergebnisse in einer einzigen Anfrage ermöglicht. In diesem Notizbuch werden wir die grundlegende Verwendung der Volltextsuche in Milvus zeigen.

Vorbereitung

Herunterladen des Datensatzes

Der folgende Befehl lädt die Beispieldaten herunter, die in der ursprünglichen Anthropic-Demo verwendet wurden.

$ wget https://raw.githubusercontent.com/anthropics/anthropic-cookbook/refs/heads/main/skills/contextual-embeddings/data/codebase_chunks.json
$ wget https://raw.githubusercontent.com/anthropics/anthropic-cookbook/refs/heads/main/skills/contextual-embeddings/data/evaluation_set.jsonl

Milvus 2.5 installieren

Lesen Sie die offizielle Installationsanleitung für weitere Details.

PyMilvus installieren

Führen Sie den folgenden Befehl aus, um PyMilvus zu installieren:

pip install "pymilvus[model]" -U 

Definieren Sie den Retriever

import json

from pymilvus import (
    MilvusClient,
    DataType,
    Function,
    FunctionType,
    AnnSearchRequest,
    RRFRanker,
)

from pymilvus.model.hybrid import BGEM3EmbeddingFunction


class HybridRetriever:
    def __init__(self, uri, collection_name="hybrid", dense_embedding_function=None):
        self.uri = uri
        self.collection_name = collection_name
        self.embedding_function = dense_embedding_function
        self.use_reranker = True
        self.use_sparse = True
        self.client = MilvusClient(uri=uri)

    def build_collection(self):
        if isinstance(self.embedding_function.dim, dict):
            dense_dim = self.embedding_function.dim["dense"]
        else:
            dense_dim = self.embedding_function.dim

        tokenizer_params = {
            "tokenizer": "standard",
            "filter": [
                "lowercase",
                {
                    "type": "length",
                    "max": 200,
                },
                {"type": "stemmer", "language": "english"},
                {
                    "type": "stop",
                    "stop_words": [
                        "a",
                        "an",
                        "and",
                        "are",
                        "as",
                        "at",
                        "be",
                        "but",
                        "by",
                        "for",
                        "if",
                        "in",
                        "into",
                        "is",
                        "it",
                        "no",
                        "not",
                        "of",
                        "on",
                        "or",
                        "such",
                        "that",
                        "the",
                        "their",
                        "then",
                        "there",
                        "these",
                        "they",
                        "this",
                        "to",
                        "was",
                        "will",
                        "with",
                    ],
                },
            ],
        }

        schema = MilvusClient.create_schema()
        schema.add_field(
            field_name="pk",
            datatype=DataType.VARCHAR,
            is_primary=True,
            auto_id=True,
            max_length=100,
        )
        schema.add_field(
            field_name="content",
            datatype=DataType.VARCHAR,
            max_length=65535,
            analyzer_params=tokenizer_params,
            enable_match=True,
            enable_analyzer=True,
        )
        schema.add_field(
            field_name="sparse_vector", datatype=DataType.SPARSE_FLOAT_VECTOR
        )
        schema.add_field(
            field_name="dense_vector", datatype=DataType.FLOAT_VECTOR, dim=dense_dim
        )
        schema.add_field(
            field_name="original_uuid", datatype=DataType.VARCHAR, max_length=128
        )
        schema.add_field(field_name="doc_id", datatype=DataType.VARCHAR, max_length=64)
        schema.add_field(
            field_name="chunk_id", datatype=DataType.VARCHAR, max_length=64
        ),
        schema.add_field(field_name="original_index", datatype=DataType.INT32)

        functions = Function(
            name="bm25",
            function_type=FunctionType.BM25,
            input_field_names=["content"],
            output_field_names="sparse_vector",
        )

        schema.add_function(functions)

        index_params = MilvusClient.prepare_index_params()
        index_params.add_index(
            field_name="sparse_vector",
            index_type="SPARSE_INVERTED_INDEX",
            metric_type="BM25",
        )
        index_params.add_index(
            field_name="dense_vector", index_type="FLAT", metric_type="IP"
        )

        self.client.create_collection(
            collection_name=self.collection_name,
            schema=schema,
            index_params=index_params,
        )

    def insert_data(self, chunk, metadata):
        embedding = self.embedding_function([chunk])
        if isinstance(embedding, dict) and "dense" in embedding:
            dense_vec = embedding["dense"][0]
        else:
            dense_vec = embedding[0]
        self.client.insert(
            self.collection_name, {"dense_vector": dense_vec, **metadata}
        )

    def search(self, query: str, k: int = 20, mode="hybrid"):

        output_fields = [
            "content",
            "original_uuid",
            "doc_id",
            "chunk_id",
            "original_index",
        ]
        if mode in ["dense", "hybrid"]:
            embedding = self.embedding_function([query])
            if isinstance(embedding, dict) and "dense" in embedding:
                dense_vec = embedding["dense"][0]
            else:
                dense_vec = embedding[0]

        if mode == "sparse":
            results = self.client.search(
                collection_name=self.collection_name,
                data=[query],
                anns_field="sparse_vector",
                limit=k,
                output_fields=output_fields,
            )
        elif mode == "dense":
            results = self.client.search(
                collection_name=self.collection_name,
                data=[dense_vec],
                anns_field="dense_vector",
                limit=k,
                output_fields=output_fields,
            )
        elif mode == "hybrid":
            full_text_search_params = {"metric_type": "BM25"}
            full_text_search_req = AnnSearchRequest(
                [query], "sparse_vector", full_text_search_params, limit=k
            )

            dense_search_params = {"metric_type": "IP"}
            dense_req = AnnSearchRequest(
                [dense_vec], "dense_vector", dense_search_params, limit=k
            )

            results = self.client.hybrid_search(
                self.collection_name,
                [full_text_search_req, dense_req],
                ranker=RRFRanker(),
                limit=k,
                output_fields=output_fields,
            )
        else:
            raise ValueError("Invalid mode")
        return [
            {
                "doc_id": doc["entity"]["doc_id"],
                "chunk_id": doc["entity"]["chunk_id"],
                "content": doc["entity"]["content"],
                "score": doc["distance"],
            }
            for doc in results[0]
        ]
dense_ef = BGEM3EmbeddingFunction()
standard_retriever = HybridRetriever(
    uri="http://localhost:19530",
    collection_name="milvus_hybrid",
    dense_embedding_function=dense_ef,
)
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 108848.72it/s]

Einfügen der Daten

path = "codebase_chunks.json"
with open(path, "r") as f:
    dataset = json.load(f)

is_insert = True
if is_insert:
    standard_retriever.build_collection()
    for doc in dataset:
        doc_content = doc["content"]
        for chunk in doc["chunks"]:
            metadata = {
                "doc_id": doc["doc_id"],
                "original_uuid": doc["original_uuid"],
                "chunk_id": chunk["chunk_id"],
                "original_index": chunk["original_index"],
                "content": chunk["content"],
            }
            chunk_content = chunk["content"]
            standard_retriever.insert_data(chunk_content, metadata)
results = standard_retriever.search("create a logger?", mode="sparse", k=3)
print(results)
[{'doc_id': 'doc_10', 'chunk_id': 'doc_10_chunk_0', 'content': 'use {\n    crate::args::LogArgs,\n    anyhow::{anyhow, Result},\n    simplelog::{Config, LevelFilter, WriteLogger},\n    std::fs::File,\n};\n\npub struct Logger;\n\nimpl Logger {\n    pub fn init(args: &impl LogArgs) -> Result<()> {\n        let filter: LevelFilter = args.log_level().into();\n        if filter != LevelFilter::Off {\n            let logfile = File::create(args.log_file())\n                .map_err(|e| anyhow!("Failed to open log file: {e:}"))?;\n            WriteLogger::init(filter, Config::default(), logfile)\n                .map_err(|e| anyhow!("Failed to initalize logger: {e:}"))?;\n        }\n        Ok(())\n    }\n}\n', 'score': 9.12518310546875}, {'doc_id': 'doc_87', 'chunk_id': 'doc_87_chunk_3', 'content': '\t\tLoggerPtr INF = Logger::getLogger(LOG4CXX_TEST_STR("INF"));\n\t\tINF->setLevel(Level::getInfo());\n\n\t\tLoggerPtr INF_ERR = Logger::getLogger(LOG4CXX_TEST_STR("INF.ERR"));\n\t\tINF_ERR->setLevel(Level::getError());\n\n\t\tLoggerPtr DEB = Logger::getLogger(LOG4CXX_TEST_STR("DEB"));\n\t\tDEB->setLevel(Level::getDebug());\n\n\t\t// Note: categories with undefined level\n\t\tLoggerPtr INF_UNDEF = Logger::getLogger(LOG4CXX_TEST_STR("INF.UNDEF"));\n\t\tLoggerPtr INF_ERR_UNDEF = Logger::getLogger(LOG4CXX_TEST_STR("INF.ERR.UNDEF"));\n\t\tLoggerPtr UNDEF = Logger::getLogger(LOG4CXX_TEST_STR("UNDEF"));\n\n', 'score': 7.0077056884765625}, {'doc_id': 'doc_89', 'chunk_id': 'doc_89_chunk_3', 'content': 'using namespace log4cxx;\nusing namespace log4cxx::helpers;\n\nLOGUNIT_CLASS(FMTTestCase)\n{\n\tLOGUNIT_TEST_SUITE(FMTTestCase);\n\tLOGUNIT_TEST(test1);\n\tLOGUNIT_TEST(test1_expanded);\n\tLOGUNIT_TEST(test10);\n//\tLOGUNIT_TEST(test_date);\n\tLOGUNIT_TEST_SUITE_END();\n\n\tLoggerPtr root;\n\tLoggerPtr logger;\n\npublic:\n\tvoid setUp()\n\t{\n\t\troot = Logger::getRootLogger();\n\t\tMDC::clear();\n\t\tlogger = Logger::getLogger(LOG4CXX_TEST_STR("java.org.apache.log4j.PatternLayoutTest"));\n\t}\n\n', 'score': 6.750633716583252}]

Auswertung

Nun, da wir den Datensatz in Milvus eingefügt haben, können wir die dichte, spärliche oder hybride Suche verwenden, um die besten 5 Ergebnisse zu erhalten. Sie können die mode ändern und jedes Ergebnis bewerten. Wir stellen die Pass@5-Metrik vor, bei der die Top-5-Ergebnisse für jede Abfrage abgerufen und der Recall berechnet werden.

def load_jsonl(file_path: str):
    """Load JSONL file and return a list of dictionaries."""
    with open(file_path, "r") as file:
        return [json.loads(line) for line in file]


dataset = load_jsonl("evaluation_set.jsonl")
k = 5

# mode can be "dense", "sparse" or "hybrid".
mode = "hybrid"

total_query_score = 0
num_queries = 0

for query_item in dataset:

    query = query_item["query"]

    golden_chunk_uuids = query_item["golden_chunk_uuids"]

    chunks_found = 0
    golden_contents = []
    for doc_uuid, chunk_index in golden_chunk_uuids:
        golden_doc = next(
            (doc for doc in query_item["golden_documents"] if doc["uuid"] == doc_uuid),
            None,
        )
        if golden_doc:
            golden_chunk = next(
                (
                    chunk
                    for chunk in golden_doc["chunks"]
                    if chunk["index"] == chunk_index
                ),
                None,
            )
            if golden_chunk:
                golden_contents.append(golden_chunk["content"].strip())

    results = standard_retriever.search(query, mode=mode, k=5)

    for golden_content in golden_contents:
        for doc in results[:k]:
            retrieved_content = doc["content"].strip()
            if retrieved_content == golden_content:
                chunks_found += 1
                break

    query_score = chunks_found / len(golden_contents)

    total_query_score += query_score
    num_queries += 1
print("Pass@5: ", total_query_score / num_queries)
Pass@5:  0.7911386328725037

Übersetzt vonDeepL

Try Managed Milvus for Free

Zilliz Cloud is hassle-free, powered by Milvus and 10x faster.

Get Started
Feedback

War diese Seite hilfreich?