milvus-logo
LFAI
Casa
  • Tutorial

Recupero contestuale con Milvus

Open In Colab GitHub Repository

image image Contextual Retrieval è un metodo di recupero avanzato proposto da Anthropic per risolvere il problema dell'isolamento semantico dei chunk, che si presenta nelle attuali soluzioni RAG (Retrieval-Augmented Generation). Nell'attuale paradigma pratico di RAG, i documenti sono divisi in diversi chunks e un database vettoriale viene utilizzato per cercare la query, recuperando i chunks più rilevanti. Un LLM risponde quindi alla query utilizzando questi chunks recuperati. Tuttavia, questo processo di chunking può comportare la perdita di informazioni contestuali, rendendo difficile per il retriever determinare la rilevanza.

Il Contextual Retrieval migliora i sistemi di recupero tradizionali aggiungendo il contesto pertinente a ogni chunk di documento prima dell'incorporazione o dell'indicizzazione, aumentando la precisione e riducendo gli errori di recupero. Combinato con tecniche come il recupero ibrido e il reranking, migliora i sistemi di Retrieval-Augmented Generation (RAG), soprattutto per le basi di conoscenza di grandi dimensioni. Inoltre, se abbinato al prompt caching, offre una soluzione economicamente vantaggiosa, riducendo in modo significativo la latenza e i costi operativi: i chunk contestualizzati costano circa 1,02 dollari per milione di token di documenti. Questo lo rende un approccio scalabile ed efficiente per la gestione di grandi basi di conoscenza. La soluzione di Anthropic mostra due aspetti interessanti:

  • Document Enhancement: La riscrittura delle query è una tecnica cruciale nel moderno information retrieval, che spesso utilizza informazioni ausiliarie per rendere la query più informativa. Allo stesso modo, per ottenere prestazioni migliori nella RAG, la preelaborazione dei documenti con un LLM (ad esempio, la pulizia della fonte dei dati, l'integrazione delle informazioni perse, la sintesi, ecc. In altre parole, questa fase di pre-elaborazione aiuta ad avvicinare i documenti alle query in termini di rilevanza.
  • Low-Cost Processing by Caching Long Context: Una preoccupazione comune quando si utilizzano gli LLM per elaborare i documenti è il costo. La KVCache è una soluzione molto diffusa che permette di riutilizzare i risultati intermedi per lo stesso contesto precedente. Mentre la maggior parte dei fornitori di LLM ospitati rende questa funzione trasparente all'utente, Anthropic offre all'utente il controllo sul processo di caching. Quando si verifica un hit della cache, la maggior parte dei calcoli può essere salvata (questo è comune quando il contesto lungo rimane lo stesso, ma l'istruzione per ogni query cambia). Per maggiori dettagli, fare clic qui.

In questo quaderno dimostreremo come eseguire il recupero contestuale utilizzando Milvus con un LLM, combinando il recupero ibrido denso-sparso e un reranker per creare un sistema di recupero progressivamente più potente. I dati e la configurazione sperimentale sono basati sul recupero contestuale.

Preparazione

Installare le dipendenze

$ pip install "pymilvus[model]"
$ pip install tqdm
$ pip install anthropic

Se si utilizza Google Colab, per abilitare le dipendenze appena installate potrebbe essere necessario riavviare il runtime (fare clic sul menu "Runtime" nella parte superiore dello schermo e selezionare "Restart session" dal menu a discesa).

Per eseguire il codice sono necessarie le chiavi API di Cohere, Voyage e Anthropic.

Scaricare i dati

Il comando seguente scarica i dati di esempio utilizzati nella demo originale di Anthropic.

$ wget https://raw.githubusercontent.com/anthropics/anthropic-cookbook/refs/heads/main/skills/contextual-embeddings/data/codebase_chunks.json
$ wget https://raw.githubusercontent.com/anthropics/anthropic-cookbook/refs/heads/main/skills/contextual-embeddings/data/evaluation_set.jsonl

Definizione di Retriever

Questa classe è stata progettata per essere flessibile, consentendo di scegliere tra diverse modalità di recupero in base alle proprie esigenze. Specificando le opzioni nel metodo di inizializzazione, è possibile stabilire se utilizzare il recupero contestuale, la ricerca ibrida (che combina metodi di recupero densi e radi) o un reranker per ottenere risultati migliori.

from pymilvus.model.dense import VoyageEmbeddingFunction
from pymilvus.model.hybrid import BGEM3EmbeddingFunction
from pymilvus.model.reranker import CohereRerankFunction

from typing import List, Dict, Any
from typing import Callable
from pymilvus import (
    MilvusClient,
    DataType,
    AnnSearchRequest,
    RRFRanker,
)
from tqdm import tqdm
import json
import anthropic


class MilvusContextualRetriever:
    def __init__(
        self,
        uri="milvus.db",
        collection_name="contexual_bgem3",
        dense_embedding_function=None,
        use_sparse=False,
        sparse_embedding_function=None,
        use_contextualize_embedding=False,
        anthropic_client=None,
        use_reranker=False,
        rerank_function=None,
    ):
        self.collection_name = collection_name

        # For Milvus-lite, uri is a local path like "./milvus.db"
        # For Milvus standalone service, uri is like "http://localhost:19530"
        # For Zilliz Clond, please set `uri` and `token`, which correspond to the [Public Endpoint and API key](https://docs.zilliz.com/docs/on-zilliz-cloud-console#cluster-details) in Zilliz Cloud.
        self.client = MilvusClient(uri)

        self.embedding_function = dense_embedding_function

        self.use_sparse = use_sparse
        self.sparse_embedding_function = None

        self.use_contextualize_embedding = use_contextualize_embedding
        self.anthropic_client = anthropic_client

        self.use_reranker = use_reranker
        self.rerank_function = rerank_function

        if use_sparse is True and sparse_embedding_function:
            self.sparse_embedding_function = sparse_embedding_function
        elif sparse_embedding_function is False:
            raise ValueError(
                "Sparse embedding function cannot be None if use_sparse is False"
            )
        else:
            pass

    def build_collection(self):
        schema = self.client.create_schema(
            auto_id=True,
            enable_dynamic_field=True,
        )
        schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True)
        schema.add_field(
            field_name="dense_vector",
            datatype=DataType.FLOAT_VECTOR,
            dim=self.embedding_function.dim,
        )
        if self.use_sparse is True:
            schema.add_field(
                field_name="sparse_vector", datatype=DataType.SPARSE_FLOAT_VECTOR
            )

        index_params = self.client.prepare_index_params()
        index_params.add_index(
            field_name="dense_vector", index_type="FLAT", metric_type="IP"
        )
        if self.use_sparse is True:
            index_params.add_index(
                field_name="sparse_vector",
                index_type="SPARSE_INVERTED_INDEX",
                metric_type="IP",
            )

        self.client.create_collection(
            collection_name=self.collection_name,
            schema=schema,
            index_params=index_params,
            enable_dynamic_field=True,
        )

    def insert_data(self, chunk, metadata):
        dense_vec = self.embedding_function([chunk])[0]
        if self.use_sparse is True:
            sparse_result = self.sparse_embedding_function.encode_documents([chunk])
            if type(sparse_result) == dict:
                sparse_vec = sparse_result["sparse"][[0]]
            else:
                sparse_vec = sparse_result[[0]]
            self.client.insert(
                collection_name=self.collection_name,
                data={
                    "dense_vector": dense_vec,
                    "sparse_vector": sparse_vec,
                    **metadata,
                },
            )
        else:
            self.client.insert(
                collection_name=self.collection_name,
                data={"dense_vector": dense_vec, **metadata},
            )

    def insert_contextualized_data(self, doc, chunk, metadata):
        contextualized_text, usage = self.situate_context(doc, chunk)
        metadata["context"] = contextualized_text
        text_to_embed = f"{chunk}\n\n{contextualized_text}"
        dense_vec = self.embedding_function([text_to_embed])[0]
        if self.use_sparse is True:
            sparse_vec = self.sparse_embedding_function.encode_documents(
                [text_to_embed]
            )["sparse"][[0]]
            self.client.insert(
                collection_name=self.collection_name,
                data={
                    "dense_vector": dense_vec,
                    "sparse_vector": sparse_vec,
                    **metadata,
                },
            )
        else:
            self.client.insert(
                collection_name=self.collection_name,
                data={"dense_vector": dense_vec, **metadata},
            )

    def situate_context(self, doc: str, chunk: str):
        DOCUMENT_CONTEXT_PROMPT = """
        <document>
        {doc_content}
        </document>
        """

        CHUNK_CONTEXT_PROMPT = """
        Here is the chunk we want to situate within the whole document
        <chunk>
        {chunk_content}
        </chunk>

        Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk.
        Answer only with the succinct context and nothing else.
        """

        response = self.anthropic_client.beta.prompt_caching.messages.create(
            model="claude-3-haiku-20240307",
            max_tokens=1000,
            temperature=0.0,
            messages=[
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc),
                            "cache_control": {
                                "type": "ephemeral"
                            },  # we will make use of prompt caching for the full documents
                        },
                        {
                            "type": "text",
                            "text": CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk),
                        },
                    ],
                },
            ],
            extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
        )
        return response.content[0].text, response.usage

    def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
        dense_vec = self.embedding_function([query])[0]
        if self.use_sparse is True:
            sparse_vec = self.sparse_embedding_function.encode_queries([query])[
                "sparse"
            ][[0]]

        req_list = []
        if self.use_reranker:
            k = k * 10
        if self.use_sparse is True:
            req_list = []
            dense_search_param = {
                "data": [dense_vec],
                "anns_field": "dense_vector",
                "param": {"metric_type": "IP"},
                "limit": k * 2,
            }
            dense_req = AnnSearchRequest(**dense_search_param)
            req_list.append(dense_req)

            sparse_search_param = {
                "data": [sparse_vec],
                "anns_field": "sparse_vector",
                "param": {"metric_type": "IP"},
                "limit": k * 2,
            }
            sparse_req = AnnSearchRequest(**sparse_search_param)

            req_list.append(sparse_req)

            docs = self.client.hybrid_search(
                self.collection_name,
                req_list,
                RRFRanker(),
                k,
                output_fields=[
                    "content",
                    "original_uuid",
                    "doc_id",
                    "chunk_id",
                    "original_index",
                    "context",
                ],
            )
        else:
            docs = self.client.search(
                self.collection_name,
                data=[dense_vec],
                anns_field="dense_vector",
                limit=k,
                output_fields=[
                    "content",
                    "original_uuid",
                    "doc_id",
                    "chunk_id",
                    "original_index",
                    "context",
                ],
            )
        if self.use_reranker and self.use_contextualize_embedding:
            reranked_texts = []
            reranked_docs = []
            for i in range(k):
                if self.use_contextualize_embedding:
                    reranked_texts.append(
                        f"{docs[0][i]['entity']['content']}\n\n{docs[0][i]['entity']['context']}"
                    )
                else:
                    reranked_texts.append(f"{docs[0][i]['entity']['content']}")
            results = self.rerank_function(query, reranked_texts)
            for result in results:
                reranked_docs.append(docs[0][result.index])
            docs[0] = reranked_docs
        return docs


def evaluate_retrieval(
    queries: List[Dict[str, Any]], retrieval_function: Callable, db, k: int = 20
) -> Dict[str, float]:
    total_score = 0
    total_queries = len(queries)
    for query_item in tqdm(queries, desc="Evaluating retrieval"):
        query = query_item["query"]
        golden_chunk_uuids = query_item["golden_chunk_uuids"]

        # Find all golden chunk contents
        golden_contents = []
        for doc_uuid, chunk_index in golden_chunk_uuids:
            golden_doc = next(
                (
                    doc
                    for doc in query_item["golden_documents"]
                    if doc["uuid"] == doc_uuid
                ),
                None,
            )
            if not golden_doc:
                print(f"Warning: Golden document not found for UUID {doc_uuid}")
                continue

            golden_chunk = next(
                (
                    chunk
                    for chunk in golden_doc["chunks"]
                    if chunk["index"] == chunk_index
                ),
                None,
            )
            if not golden_chunk:
                print(
                    f"Warning: Golden chunk not found for index {chunk_index} in document {doc_uuid}"
                )
                continue

            golden_contents.append(golden_chunk["content"].strip())

        if not golden_contents:
            print(f"Warning: No golden contents found for query: {query}")
            continue

        retrieved_docs = retrieval_function(query, db, k=k)

        # Count how many golden chunks are in the top k retrieved documents
        chunks_found = 0
        for golden_content in golden_contents:
            for doc in retrieved_docs[0][:k]:
                retrieved_content = doc["entity"]["content"].strip()
                if retrieved_content == golden_content:
                    chunks_found += 1
                    break

        query_score = chunks_found / len(golden_contents)
        total_score += query_score

    average_score = total_score / total_queries
    pass_at_n = average_score * 100
    return {
        "pass_at_n": pass_at_n,
        "average_score": average_score,
        "total_queries": total_queries,
    }


def retrieve_base(query: str, db, k: int = 20) -> List[Dict[str, Any]]:
    return db.search(query, k=k)


def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
    """Load JSONL file and return a list of dictionaries."""
    with open(file_path, "r") as file:
        return [json.loads(line) for line in file]


def evaluate_db(db, original_jsonl_path: str, k):
    # Load the original JSONL data for queries and ground truth
    original_data = load_jsonl(original_jsonl_path)

    # Evaluate retrieval
    results = evaluate_retrieval(original_data, retrieve_base, db, k)
    print(f"Pass@{k}: {results['pass_at_n']:.2f}%")
    print(f"Total Score: {results['average_score']}")
    print(f"Total queries: {results['total_queries']}")

Ora è necessario inizializzare questi modelli per gli esperimenti successivi. È possibile passare facilmente ad altri modelli utilizzando la libreria di modelli PyMilvus.

dense_ef = VoyageEmbeddingFunction(api_key="your-voyage-api-key", model_name="voyage-2")
sparse_ef = BGEM3EmbeddingFunction()
cohere_rf = CohereRerankFunction(api_key="your-cohere-api-key")
Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]
path = "codebase_chunks.json"
with open(path, "r") as f:
    dataset = json.load(f)

Esperimento I: Recupero standard

Il recupero standard utilizza solo embeddings densi per recuperare documenti correlati. In questo esperimento, utilizzeremo Pass@5 per riprodurre i risultati del repo originale.

standard_retriever = MilvusContextualRetriever(
    uri="standard.db", collection_name="standard", dense_embedding_function=dense_ef
)

standard_retriever.build_collection()
for doc in dataset:
    doc_content = doc["content"]
    for chunk in doc["chunks"]:
        metadata = {
            "doc_id": doc["doc_id"],
            "original_uuid": doc["original_uuid"],
            "chunk_id": chunk["chunk_id"],
            "original_index": chunk["original_index"],
            "content": chunk["content"],
        }
        chunk_content = chunk["content"]
        standard_retriever.insert_data(chunk_content, metadata)
evaluate_db(standard_retriever, "evaluation_set.jsonl", 5)
Evaluating retrieval: 100%|██████████| 248/248 [01:29<00:00,  2.77it/s]

Pass@5: 80.92%
Total Score: 0.8091877880184332
Total queries: 248

Esperimento II: Recupero ibrido

Dopo aver ottenuto risultati promettenti con l'incorporamento Voyage, passiamo a eseguire un recupero ibrido utilizzando il modello BGE-M3, che genera potenti incorporazioni rade. I risultati del reperimento denso e del reperimento rado saranno combinati con il metodo Reciprocal Rank Fusion (RRF) per produrre un risultato ibrido.

hybrid_retriever = MilvusContextualRetriever(
    uri="hybrid.db",
    collection_name="hybrid",
    dense_embedding_function=dense_ef,
    use_sparse=True,
    sparse_embedding_function=sparse_ef,
)

hybrid_retriever.build_collection()
for doc in dataset:
    doc_content = doc["content"]
    for chunk in doc["chunks"]:
        metadata = {
            "doc_id": doc["doc_id"],
            "original_uuid": doc["original_uuid"],
            "chunk_id": chunk["chunk_id"],
            "original_index": chunk["original_index"],
            "content": chunk["content"],
        }
        chunk_content = chunk["content"]
        hybrid_retriever.insert_data(chunk_content, metadata)
evaluate_db(hybrid_retriever, "evaluation_set.jsonl", 5)
Evaluating retrieval: 100%|██████████| 248/248 [02:09<00:00,  1.92it/s]

Pass@5: 84.69%
Total Score: 0.8469182027649771
Total queries: 248

Esperimento III: recupero contestuale

Il recupero ibrido mostra un miglioramento, ma i risultati possono essere ulteriormente migliorati applicando un metodo di recupero contestuale. A tal fine, utilizzeremo il modello linguistico di Anthropic per aggiungere il contesto all'intero documento per ogni chunk.

anthropic_client = anthropic.Anthropic(
    api_key="your-anthropic-api-key",
)
contextual_retriever = MilvusContextualRetriever(
    uri="contextual.db",
    collection_name="contextual",
    dense_embedding_function=dense_ef,
    use_sparse=True,
    sparse_embedding_function=sparse_ef,
    use_contextualize_embedding=True,
    anthropic_client=anthropic_client,
)

contextual_retriever.build_collection()
for doc in dataset:
    doc_content = doc["content"]
    for chunk in doc["chunks"]:
        metadata = {
            "doc_id": doc["doc_id"],
            "original_uuid": doc["original_uuid"],
            "chunk_id": chunk["chunk_id"],
            "original_index": chunk["original_index"],
            "content": chunk["content"],
        }
        chunk_content = chunk["content"]
        contextual_retriever.insert_contextualized_data(
            doc_content, chunk_content, metadata
        )
evaluate_db(contextual_retriever, "evaluation_set.jsonl", 5)
 Evaluating retrieval: 100%|██████████| 248/248 [01:55<00:00,  2.15it/s]
Pass@5: 87.14%
Total Score: 0.8713517665130568
Total queries: 248 

Esperimento IV: Recupero contestuale con Reranker

I risultati possono essere ulteriormente migliorati aggiungendo un reranker di Cohere. Senza inizializzare separatamente un nuovo retriever con il reranker, possiamo semplicemente configurare il retriever esistente in modo che utilizzi il reranker per migliorare le prestazioni.

contextual_retriever.use_reranker = True
contextual_retriever.rerank_function = cohere_rf
evaluate_db(contextual_retriever, "evaluation_set.jsonl", 5)
Evaluating retrieval: 100%|██████████| 248/248 [02:02<00:00,  2.00it/s]
Pass@5: 90.91%
Total Score: 0.9090821812596005
Total queries: 248

Abbiamo dimostrato diversi metodi per migliorare le prestazioni di recupero. Con una progettazione ad hoc adattata allo scenario, il reperimento contestuale mostra un potenziale significativo per preelaborare i documenti a basso costo, portando a un sistema RAG migliore.

Tradotto daDeepL

Try Managed Milvus for Free

Zilliz Cloud is hassle-free, powered by Milvus and 10x faster.

Get Started
Feedback

Questa pagina è stata utile?