Recuperação contextual com Milvus
A recuperação contextual de imagens é um método avançado de recuperação proposto pela Anthropic para resolver o problema do isolamento semântico dos fragmentos, que surge nas actuais soluções de geração aumentada de recuperação (RAG). No atual paradigma prático de RAG, os documentos são divididos em vários fragmentos e é utilizada uma base de dados vetorial para procurar a consulta, recuperando os fragmentos mais relevantes. Em seguida, um LLM responde à consulta usando esses pedaços recuperados. No entanto, este processo de recuperação de fragmentos pode resultar na perda de informação contextual, tornando difícil para o recuperador determinar a relevância.
A recuperação contextual melhora os sistemas de recuperação tradicionais, acrescentando contexto relevante a cada fragmento de documento antes da incorporação ou indexação, aumentando a precisão e reduzindo os erros de recuperação. Combinada com técnicas como a recuperação híbrida e o reranking, melhora os sistemas RAG (Retrieval-Augmented Generation), especialmente para grandes bases de conhecimento. Além disso, oferece uma solução económica quando combinada com o caching imediato, reduzindo significativamente a latência e os custos operacionais, com os pedaços contextualizados a custarem aproximadamente 1,02 dólares por milhão de tokens de documentos. Isso o torna uma abordagem escalável e eficiente para lidar com grandes bases de conhecimento. A solução da Anthropic mostra dois aspectos perspicazes:
Document Enhancement
: A reescrita de consultas é uma técnica crucial na recuperação de informação moderna, utilizando frequentemente informações auxiliares para tornar a consulta mais informativa. Do mesmo modo, para obter um melhor desempenho em RAG, o pré-processamento de documentos com um LLM (por exemplo, limpar a fonte de dados, complementar a informação perdida, resumir, etc.) antes da indexação pode melhorar significativamente as hipóteses de recuperar documentos relevantes. Por outras palavras, este passo de pré-processamento ajuda a aproximar os documentos das consultas em termos de relevância.Low-Cost Processing by Caching Long Context
: Uma preocupação comum quando se utilizam LLMs para processar documentos é o custo. A KVCache é uma solução popular que permite a reutilização de resultados intermédios para o mesmo contexto anterior. Enquanto a maioria dos fornecedores de LLMs hospedados torna esta caraterística transparente para o utilizador, o Anthropic dá-lhe controlo sobre o processo de cache. Quando ocorre um acerto no cache, a maioria dos cálculos pode ser salva (isso é comum quando o contexto longo permanece o mesmo, mas a instrução para cada consulta muda). Para obter mais detalhes, clique aqui.
Neste caderno, vamos demonstrar como efetuar a recuperação contextual usando o Milvus com um LLM, combinando a recuperação híbrida densa e esparsa e um reranker para criar um sistema de recuperação progressivamente mais poderoso. Os dados e a configuração experimental são baseados na recuperação contextual.
Preparação
Instalar dependências
$ pip install "pymilvus[model]"
$ pip install tqdm
$ pip install anthropic
Se estiver a utilizar o Google Colab, para ativar as dependências que acabou de instalar, poderá ter de reiniciar o tempo de execução (clique no menu "Tempo de execução" na parte superior do ecrã e selecione "Reiniciar sessão" no menu pendente).
Necessitará de chaves de API do Cohere, Voyage e Anthropic para executar o código.
Descarregar dados
O comando a seguir fará o download dos dados de exemplo usados na demonstração original do Anthropic.
$ wget https://raw.githubusercontent.com/anthropics/anthropic-cookbook/refs/heads/main/skills/contextual-embeddings/data/codebase_chunks.json
$ wget https://raw.githubusercontent.com/anthropics/anthropic-cookbook/refs/heads/main/skills/contextual-embeddings/data/evaluation_set.jsonl
Definir o Recuperador
Esta classe foi concebida para ser flexível, permitindo-lhe escolher entre diferentes modos de recuperação com base nas suas necessidades. Especificando opções no método de inicialização, pode determinar se pretende utilizar a recuperação contextual, a pesquisa híbrida (combinando métodos de recuperação densos e esparsos) ou um reranker para obter resultados melhorados.
from pymilvus.model.dense import VoyageEmbeddingFunction
from pymilvus.model.hybrid import BGEM3EmbeddingFunction
from pymilvus.model.reranker import CohereRerankFunction
from typing import List, Dict, Any
from typing import Callable
from pymilvus import (
MilvusClient,
DataType,
AnnSearchRequest,
RRFRanker,
)
from tqdm import tqdm
import json
import anthropic
class MilvusContextualRetriever:
def __init__(
self,
uri="milvus.db",
collection_name="contexual_bgem3",
dense_embedding_function=None,
use_sparse=False,
sparse_embedding_function=None,
use_contextualize_embedding=False,
anthropic_client=None,
use_reranker=False,
rerank_function=None,
):
self.collection_name = collection_name
# For Milvus-lite, uri is a local path like "./milvus.db"
# For Milvus standalone service, uri is like "http://localhost:19530"
# For Zilliz Clond, please set `uri` and `token`, which correspond to the [Public Endpoint and API key](https://docs.zilliz.com/docs/on-zilliz-cloud-console#cluster-details) in Zilliz Cloud.
self.client = MilvusClient(uri)
self.embedding_function = dense_embedding_function
self.use_sparse = use_sparse
self.sparse_embedding_function = None
self.use_contextualize_embedding = use_contextualize_embedding
self.anthropic_client = anthropic_client
self.use_reranker = use_reranker
self.rerank_function = rerank_function
if use_sparse is True and sparse_embedding_function:
self.sparse_embedding_function = sparse_embedding_function
elif sparse_embedding_function is False:
raise ValueError(
"Sparse embedding function cannot be None if use_sparse is False"
)
else:
pass
def build_collection(self):
schema = self.client.create_schema(
auto_id=True,
enable_dynamic_field=True,
)
schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True)
schema.add_field(
field_name="dense_vector",
datatype=DataType.FLOAT_VECTOR,
dim=self.embedding_function.dim,
)
if self.use_sparse is True:
schema.add_field(
field_name="sparse_vector", datatype=DataType.SPARSE_FLOAT_VECTOR
)
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="dense_vector", index_type="FLAT", metric_type="IP"
)
if self.use_sparse is True:
index_params.add_index(
field_name="sparse_vector",
index_type="SPARSE_INVERTED_INDEX",
metric_type="IP",
)
self.client.create_collection(
collection_name=self.collection_name,
schema=schema,
index_params=index_params,
enable_dynamic_field=True,
)
def insert_data(self, chunk, metadata):
dense_vec = self.embedding_function([chunk])[0]
if self.use_sparse is True:
sparse_result = self.sparse_embedding_function.encode_documents([chunk])
if type(sparse_result) == dict:
sparse_vec = sparse_result["sparse"][[0]]
else:
sparse_vec = sparse_result[[0]]
self.client.insert(
collection_name=self.collection_name,
data={
"dense_vector": dense_vec,
"sparse_vector": sparse_vec,
**metadata,
},
)
else:
self.client.insert(
collection_name=self.collection_name,
data={"dense_vector": dense_vec, **metadata},
)
def insert_contextualized_data(self, doc, chunk, metadata):
contextualized_text, usage = self.situate_context(doc, chunk)
metadata["context"] = contextualized_text
text_to_embed = f"{chunk}\n\n{contextualized_text}"
dense_vec = self.embedding_function([text_to_embed])[0]
if self.use_sparse is True:
sparse_vec = self.sparse_embedding_function.encode_documents(
[text_to_embed]
)["sparse"][[0]]
self.client.insert(
collection_name=self.collection_name,
data={
"dense_vector": dense_vec,
"sparse_vector": sparse_vec,
**metadata,
},
)
else:
self.client.insert(
collection_name=self.collection_name,
data={"dense_vector": dense_vec, **metadata},
)
def situate_context(self, doc: str, chunk: str):
DOCUMENT_CONTEXT_PROMPT = """
<document>
{doc_content}
</document>
"""
CHUNK_CONTEXT_PROMPT = """
Here is the chunk we want to situate within the whole document
<chunk>
{chunk_content}
</chunk>
Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk.
Answer only with the succinct context and nothing else.
"""
response = self.anthropic_client.beta.prompt_caching.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1000,
temperature=0.0,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc),
"cache_control": {
"type": "ephemeral"
}, # we will make use of prompt caching for the full documents
},
{
"type": "text",
"text": CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk),
},
],
},
],
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
)
return response.content[0].text, response.usage
def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
dense_vec = self.embedding_function([query])[0]
if self.use_sparse is True:
sparse_vec = self.sparse_embedding_function.encode_queries([query])[
"sparse"
][[0]]
req_list = []
if self.use_reranker:
k = k * 10
if self.use_sparse is True:
req_list = []
dense_search_param = {
"data": [dense_vec],
"anns_field": "dense_vector",
"param": {"metric_type": "IP"},
"limit": k * 2,
}
dense_req = AnnSearchRequest(**dense_search_param)
req_list.append(dense_req)
sparse_search_param = {
"data": [sparse_vec],
"anns_field": "sparse_vector",
"param": {"metric_type": "IP"},
"limit": k * 2,
}
sparse_req = AnnSearchRequest(**sparse_search_param)
req_list.append(sparse_req)
docs = self.client.hybrid_search(
self.collection_name,
req_list,
RRFRanker(),
k,
output_fields=[
"content",
"original_uuid",
"doc_id",
"chunk_id",
"original_index",
"context",
],
)
else:
docs = self.client.search(
self.collection_name,
data=[dense_vec],
anns_field="dense_vector",
limit=k,
output_fields=[
"content",
"original_uuid",
"doc_id",
"chunk_id",
"original_index",
"context",
],
)
if self.use_reranker and self.use_contextualize_embedding:
reranked_texts = []
reranked_docs = []
for i in range(k):
if self.use_contextualize_embedding:
reranked_texts.append(
f"{docs[0][i]['entity']['content']}\n\n{docs[0][i]['entity']['context']}"
)
else:
reranked_texts.append(f"{docs[0][i]['entity']['content']}")
results = self.rerank_function(query, reranked_texts)
for result in results:
reranked_docs.append(docs[0][result.index])
docs[0] = reranked_docs
return docs
def evaluate_retrieval(
queries: List[Dict[str, Any]], retrieval_function: Callable, db, k: int = 20
) -> Dict[str, float]:
total_score = 0
total_queries = len(queries)
for query_item in tqdm(queries, desc="Evaluating retrieval"):
query = query_item["query"]
golden_chunk_uuids = query_item["golden_chunk_uuids"]
# Find all golden chunk contents
golden_contents = []
for doc_uuid, chunk_index in golden_chunk_uuids:
golden_doc = next(
(
doc
for doc in query_item["golden_documents"]
if doc["uuid"] == doc_uuid
),
None,
)
if not golden_doc:
print(f"Warning: Golden document not found for UUID {doc_uuid}")
continue
golden_chunk = next(
(
chunk
for chunk in golden_doc["chunks"]
if chunk["index"] == chunk_index
),
None,
)
if not golden_chunk:
print(
f"Warning: Golden chunk not found for index {chunk_index} in document {doc_uuid}"
)
continue
golden_contents.append(golden_chunk["content"].strip())
if not golden_contents:
print(f"Warning: No golden contents found for query: {query}")
continue
retrieved_docs = retrieval_function(query, db, k=k)
# Count how many golden chunks are in the top k retrieved documents
chunks_found = 0
for golden_content in golden_contents:
for doc in retrieved_docs[0][:k]:
retrieved_content = doc["entity"]["content"].strip()
if retrieved_content == golden_content:
chunks_found += 1
break
query_score = chunks_found / len(golden_contents)
total_score += query_score
average_score = total_score / total_queries
pass_at_n = average_score * 100
return {
"pass_at_n": pass_at_n,
"average_score": average_score,
"total_queries": total_queries,
}
def retrieve_base(query: str, db, k: int = 20) -> List[Dict[str, Any]]:
return db.search(query, k=k)
def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
"""Load JSONL file and return a list of dictionaries."""
with open(file_path, "r") as file:
return [json.loads(line) for line in file]
def evaluate_db(db, original_jsonl_path: str, k):
# Load the original JSONL data for queries and ground truth
original_data = load_jsonl(original_jsonl_path)
# Evaluate retrieval
results = evaluate_retrieval(original_data, retrieve_base, db, k)
print(f"Pass@{k}: {results['pass_at_n']:.2f}%")
print(f"Total Score: {results['average_score']}")
print(f"Total queries: {results['total_queries']}")
Agora é necessário inicializar estes modelos para as experiências seguintes. Pode facilmente mudar para outros modelos usando a biblioteca de modelos PyMilvus.
dense_ef = VoyageEmbeddingFunction(api_key="your-voyage-api-key", model_name="voyage-2")
sparse_ef = BGEM3EmbeddingFunction()
cohere_rf = CohereRerankFunction(api_key="your-cohere-api-key")
Fetching 30 files: 0%| | 0/30 [00:00<?, ?it/s]
path = "codebase_chunks.json"
with open(path, "r") as f:
dataset = json.load(f)
Experiência I: Recuperação padrão
A recuperação padrão usa apenas embeddings densos para recuperar documentos relacionados. Nesta experiência, utilizaremos Pass@5 para reproduzir os resultados do repositório original.
standard_retriever = MilvusContextualRetriever(
uri="standard.db", collection_name="standard", dense_embedding_function=dense_ef
)
standard_retriever.build_collection()
for doc in dataset:
doc_content = doc["content"]
for chunk in doc["chunks"]:
metadata = {
"doc_id": doc["doc_id"],
"original_uuid": doc["original_uuid"],
"chunk_id": chunk["chunk_id"],
"original_index": chunk["original_index"],
"content": chunk["content"],
}
chunk_content = chunk["content"]
standard_retriever.insert_data(chunk_content, metadata)
evaluate_db(standard_retriever, "evaluation_set.jsonl", 5)
Evaluating retrieval: 100%|██████████| 248/248 [01:29<00:00, 2.77it/s]
Pass@5: 80.92%
Total Score: 0.8091877880184332
Total queries: 248
Experiência II: Recuperação híbrida
Agora que obtivemos resultados promissores com o embedding Voyage, vamos passar à recuperação híbrida utilizando o modelo BGE-M3 que gera embeddings esparsos poderosos. Os resultados da recuperação densa e da recuperação esparsa serão combinados utilizando o método Reciprocal Rank Fusion (RRF) para produzir um resultado híbrido.
hybrid_retriever = MilvusContextualRetriever(
uri="hybrid.db",
collection_name="hybrid",
dense_embedding_function=dense_ef,
use_sparse=True,
sparse_embedding_function=sparse_ef,
)
hybrid_retriever.build_collection()
for doc in dataset:
doc_content = doc["content"]
for chunk in doc["chunks"]:
metadata = {
"doc_id": doc["doc_id"],
"original_uuid": doc["original_uuid"],
"chunk_id": chunk["chunk_id"],
"original_index": chunk["original_index"],
"content": chunk["content"],
}
chunk_content = chunk["content"]
hybrid_retriever.insert_data(chunk_content, metadata)
evaluate_db(hybrid_retriever, "evaluation_set.jsonl", 5)
Evaluating retrieval: 100%|██████████| 248/248 [02:09<00:00, 1.92it/s]
Pass@5: 84.69%
Total Score: 0.8469182027649771
Total queries: 248
Experiência III: Recuperação contextual
A recuperação híbrida mostra uma melhoria, mas os resultados podem ser melhorados aplicando um método de recuperação contextual. Para o efeito, utilizaremos o modelo de linguagem do Anthropic para acrescentar o contexto de todo o documento para cada fragmento.
anthropic_client = anthropic.Anthropic(
api_key="your-anthropic-api-key",
)
contextual_retriever = MilvusContextualRetriever(
uri="contextual.db",
collection_name="contextual",
dense_embedding_function=dense_ef,
use_sparse=True,
sparse_embedding_function=sparse_ef,
use_contextualize_embedding=True,
anthropic_client=anthropic_client,
)
contextual_retriever.build_collection()
for doc in dataset:
doc_content = doc["content"]
for chunk in doc["chunks"]:
metadata = {
"doc_id": doc["doc_id"],
"original_uuid": doc["original_uuid"],
"chunk_id": chunk["chunk_id"],
"original_index": chunk["original_index"],
"content": chunk["content"],
}
chunk_content = chunk["content"]
contextual_retriever.insert_contextualized_data(
doc_content, chunk_content, metadata
)
evaluate_db(contextual_retriever, "evaluation_set.jsonl", 5)
Evaluating retrieval: 100%|██████████| 248/248 [01:55<00:00, 2.15it/s]
Pass@5: 87.14%
Total Score: 0.8713517665130568
Total queries: 248
Experiência IV: Recuperação contextual com Reranker
Os resultados podem ser melhorados adicionando um reranker Cohere. Sem inicializar um novo recuperador com o reranker separadamente, podemos simplesmente configurar o recuperador existente para utilizar o reranker para melhorar o desempenho.
contextual_retriever.use_reranker = True
contextual_retriever.rerank_function = cohere_rf
evaluate_db(contextual_retriever, "evaluation_set.jsonl", 5)
Evaluating retrieval: 100%|██████████| 248/248 [02:02<00:00, 2.00it/s]
Pass@5: 90.91%
Total Score: 0.9090821812596005
Total queries: 248
Demonstrámos vários métodos para melhorar o desempenho da recuperação. Com uma conceção mais ad-hoc adaptada ao cenário, a recuperação contextual mostra um potencial significativo para pré-processar documentos a baixo custo, conduzindo a um melhor sistema RAG.