Milvus를 사용한 문맥 검색
이미지 문맥 검색은 현재의 검색 증강 세대(RAG) 솔루션에서 발생하는 청크의 의미적 고립 문제를 해결하기 위해 Anthropic에서 제안한 고급 검색 방법입니다. 현재의 실용적인 RAG 패러다임에서는 문서를 여러 청크로 나누고 벡터 데이터베이스를 사용해 쿼리를 검색하여 가장 관련성이 높은 청크를 검색합니다. 그런 다음 LLM은 이렇게 검색된 청크를 사용하여 쿼리에 응답합니다. 그러나 이 청크 처리 과정에서는 문맥 정보가 손실되어 검색기가 관련성을 판단하기 어려울 수 있습니다.
문맥 검색은 임베딩 또는 색인화 전에 각 문서 청크에 관련 문맥을 추가하여 정확도를 높이고 검색 오류를 줄임으로써 기존 검색 시스템을 개선합니다. 하이브리드 검색 및 재순위 지정과 같은 기술과 결합하면 특히 대규모 지식 베이스의 검색 증강 생성(RAG) 시스템을 향상시킬 수 있습니다. 또한, 신속한 캐싱과 함께 사용하면 지연 시간과 운영 비용을 크게 줄여주는 비용 효율적인 솔루션을 제공하며, 문맥화된 청크는 문서 토큰 백만 개당 약 1.02달러의 비용이 듭니다. 따라서 대규모 지식 베이스를 처리하기 위한 확장 가능하고 효율적인 접근 방식입니다. Anthropic의 솔루션은 두 가지 측면에서 인사이트를 제공합니다:
Document Enhancement
: 쿼리 재작성은 최신 정보 검색에서 중요한 기술로, 종종 보조 정보를 사용해 쿼리를 더 유익하게 만드는 데 사용됩니다. 마찬가지로, RAG에서 더 나은 성능을 얻으려면 색인 전에 LLM으로 문서를 전처리(예: 데이터 소스 정리, 손실된 정보 보완, 요약 등)하면 관련 문서를 검색할 가능성이 크게 향상될 수 있습니다. 즉, 이 전처리 단계는 관련성 측면에서 문서를 쿼리에 더 가깝게 만드는 데 도움이 됩니다.Low-Cost Processing by Caching Long Context
: LLM을 사용해 문서를 처리할 때 흔히 우려하는 것 중 하나는 비용입니다. KVCache는 동일한 이전 컨텍스트에 대해 중간 결과를 재사용할 수 있는 인기 있는 솔루션입니다. 대부분의 호스팅 LLM 공급업체는 이 기능을 사용자에게 투명하게 제공하지만, Anthropic은 사용자가 캐싱 프로세스를 제어할 수 있습니다. 캐시 히트가 발생하면 대부분의 계산을 저장할 수 있습니다(긴 컨텍스트는 동일하게 유지되지만 각 쿼리에 대한 명령어가 변경되는 경우가 일반적입니다). 자세한 내용은 여기를 클릭하세요.
이 노트북에서는 밀집-희소 하이브리드 검색과 재랭커를 결합하여 점점 더 강력한 검색 시스템을 만들기 위해 밀버스와 LLM을 사용해 문맥 검색을 수행하는 방법을 보여드리겠습니다. 데이터와 실험 설정은 문맥 검색을 기반으로 합니다.
준비 사항
설치 종속성
$ pip install "pymilvus[model]"
$ pip install tqdm
$ pip install anthropic
Google Colab을 사용하는 경우 방금 설치한 종속 요소를 사용하려면 런타임을 다시 시작해야 할 수 있습니다(화면 상단의 "런타임" 메뉴를 클릭하고 드롭다운 메뉴에서 "세션 다시 시작"을 선택).
코드를 실행하려면 Cohere, Voyage, Anthropic의 API 키가 필요합니다.
데이터 다운로드
다음 명령은 원래 Anthropic 데모에 사용된 예제 데이터를 다운로드합니다.
$ wget https://raw.githubusercontent.com/anthropics/anthropic-cookbook/refs/heads/main/skills/contextual-embeddings/data/codebase_chunks.json
$ wget https://raw.githubusercontent.com/anthropics/anthropic-cookbook/refs/heads/main/skills/contextual-embeddings/data/evaluation_set.jsonl
리트리버 정의
이 클래스는 필요에 따라 다양한 검색 모드 중에서 선택할 수 있도록 유연하게 설계되었습니다. 초기화 방법에서 옵션을 지정하여 문맥 검색, 하이브리드 검색(밀도 검색과 희소 검색 방법의 결합) 또는 향상된 결과를 위한 재랭커를 사용할지 여부를 결정할 수 있습니다.
from pymilvus.model.dense import VoyageEmbeddingFunction
from pymilvus.model.hybrid import BGEM3EmbeddingFunction
from pymilvus.model.reranker import CohereRerankFunction
from typing import List, Dict, Any
from typing import Callable
from pymilvus import (
MilvusClient,
DataType,
AnnSearchRequest,
RRFRanker,
)
from tqdm import tqdm
import json
import anthropic
class MilvusContextualRetriever:
def __init__(
self,
uri="milvus.db",
collection_name="contexual_bgem3",
dense_embedding_function=None,
use_sparse=False,
sparse_embedding_function=None,
use_contextualize_embedding=False,
anthropic_client=None,
use_reranker=False,
rerank_function=None,
):
self.collection_name = collection_name
# For Milvus-lite, uri is a local path like "./milvus.db"
# For Milvus standalone service, uri is like "http://localhost:19530"
# For Zilliz Clond, please set `uri` and `token`, which correspond to the [Public Endpoint and API key](https://docs.zilliz.com/docs/on-zilliz-cloud-console#cluster-details) in Zilliz Cloud.
self.client = MilvusClient(uri)
self.embedding_function = dense_embedding_function
self.use_sparse = use_sparse
self.sparse_embedding_function = None
self.use_contextualize_embedding = use_contextualize_embedding
self.anthropic_client = anthropic_client
self.use_reranker = use_reranker
self.rerank_function = rerank_function
if use_sparse is True and sparse_embedding_function:
self.sparse_embedding_function = sparse_embedding_function
elif sparse_embedding_function is False:
raise ValueError(
"Sparse embedding function cannot be None if use_sparse is False"
)
else:
pass
def build_collection(self):
schema = self.client.create_schema(
auto_id=True,
enable_dynamic_field=True,
)
schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True)
schema.add_field(
field_name="dense_vector",
datatype=DataType.FLOAT_VECTOR,
dim=self.embedding_function.dim,
)
if self.use_sparse is True:
schema.add_field(
field_name="sparse_vector", datatype=DataType.SPARSE_FLOAT_VECTOR
)
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="dense_vector", index_type="FLAT", metric_type="IP"
)
if self.use_sparse is True:
index_params.add_index(
field_name="sparse_vector",
index_type="SPARSE_INVERTED_INDEX",
metric_type="IP",
)
self.client.create_collection(
collection_name=self.collection_name,
schema=schema,
index_params=index_params,
enable_dynamic_field=True,
)
def insert_data(self, chunk, metadata):
dense_vec = self.embedding_function([chunk])[0]
if self.use_sparse is True:
sparse_result = self.sparse_embedding_function.encode_documents([chunk])
if type(sparse_result) == dict:
sparse_vec = sparse_result["sparse"][[0]]
else:
sparse_vec = sparse_result[[0]]
self.client.insert(
collection_name=self.collection_name,
data={
"dense_vector": dense_vec,
"sparse_vector": sparse_vec,
**metadata,
},
)
else:
self.client.insert(
collection_name=self.collection_name,
data={"dense_vector": dense_vec, **metadata},
)
def insert_contextualized_data(self, doc, chunk, metadata):
contextualized_text, usage = self.situate_context(doc, chunk)
metadata["context"] = contextualized_text
text_to_embed = f"{chunk}\n\n{contextualized_text}"
dense_vec = self.embedding_function([text_to_embed])[0]
if self.use_sparse is True:
sparse_vec = self.sparse_embedding_function.encode_documents(
[text_to_embed]
)["sparse"][[0]]
self.client.insert(
collection_name=self.collection_name,
data={
"dense_vector": dense_vec,
"sparse_vector": sparse_vec,
**metadata,
},
)
else:
self.client.insert(
collection_name=self.collection_name,
data={"dense_vector": dense_vec, **metadata},
)
def situate_context(self, doc: str, chunk: str):
DOCUMENT_CONTEXT_PROMPT = """
<document>
{doc_content}
</document>
"""
CHUNK_CONTEXT_PROMPT = """
Here is the chunk we want to situate within the whole document
<chunk>
{chunk_content}
</chunk>
Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk.
Answer only with the succinct context and nothing else.
"""
response = self.anthropic_client.beta.prompt_caching.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1000,
temperature=0.0,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc),
"cache_control": {
"type": "ephemeral"
}, # we will make use of prompt caching for the full documents
},
{
"type": "text",
"text": CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk),
},
],
},
],
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
)
return response.content[0].text, response.usage
def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
dense_vec = self.embedding_function([query])[0]
if self.use_sparse is True:
sparse_vec = self.sparse_embedding_function.encode_queries([query])[
"sparse"
][[0]]
req_list = []
if self.use_reranker:
k = k * 10
if self.use_sparse is True:
req_list = []
dense_search_param = {
"data": [dense_vec],
"anns_field": "dense_vector",
"param": {"metric_type": "IP"},
"limit": k * 2,
}
dense_req = AnnSearchRequest(**dense_search_param)
req_list.append(dense_req)
sparse_search_param = {
"data": [sparse_vec],
"anns_field": "sparse_vector",
"param": {"metric_type": "IP"},
"limit": k * 2,
}
sparse_req = AnnSearchRequest(**sparse_search_param)
req_list.append(sparse_req)
docs = self.client.hybrid_search(
self.collection_name,
req_list,
RRFRanker(),
k,
output_fields=[
"content",
"original_uuid",
"doc_id",
"chunk_id",
"original_index",
"context",
],
)
else:
docs = self.client.search(
self.collection_name,
data=[dense_vec],
anns_field="dense_vector",
limit=k,
output_fields=[
"content",
"original_uuid",
"doc_id",
"chunk_id",
"original_index",
"context",
],
)
if self.use_reranker and self.use_contextualize_embedding:
reranked_texts = []
reranked_docs = []
for i in range(k):
if self.use_contextualize_embedding:
reranked_texts.append(
f"{docs[0][i]['entity']['content']}\n\n{docs[0][i]['entity']['context']}"
)
else:
reranked_texts.append(f"{docs[0][i]['entity']['content']}")
results = self.rerank_function(query, reranked_texts)
for result in results:
reranked_docs.append(docs[0][result.index])
docs[0] = reranked_docs
return docs
def evaluate_retrieval(
queries: List[Dict[str, Any]], retrieval_function: Callable, db, k: int = 20
) -> Dict[str, float]:
total_score = 0
total_queries = len(queries)
for query_item in tqdm(queries, desc="Evaluating retrieval"):
query = query_item["query"]
golden_chunk_uuids = query_item["golden_chunk_uuids"]
# Find all golden chunk contents
golden_contents = []
for doc_uuid, chunk_index in golden_chunk_uuids:
golden_doc = next(
(
doc
for doc in query_item["golden_documents"]
if doc["uuid"] == doc_uuid
),
None,
)
if not golden_doc:
print(f"Warning: Golden document not found for UUID {doc_uuid}")
continue
golden_chunk = next(
(
chunk
for chunk in golden_doc["chunks"]
if chunk["index"] == chunk_index
),
None,
)
if not golden_chunk:
print(
f"Warning: Golden chunk not found for index {chunk_index} in document {doc_uuid}"
)
continue
golden_contents.append(golden_chunk["content"].strip())
if not golden_contents:
print(f"Warning: No golden contents found for query: {query}")
continue
retrieved_docs = retrieval_function(query, db, k=k)
# Count how many golden chunks are in the top k retrieved documents
chunks_found = 0
for golden_content in golden_contents:
for doc in retrieved_docs[0][:k]:
retrieved_content = doc["entity"]["content"].strip()
if retrieved_content == golden_content:
chunks_found += 1
break
query_score = chunks_found / len(golden_contents)
total_score += query_score
average_score = total_score / total_queries
pass_at_n = average_score * 100
return {
"pass_at_n": pass_at_n,
"average_score": average_score,
"total_queries": total_queries,
}
def retrieve_base(query: str, db, k: int = 20) -> List[Dict[str, Any]]:
return db.search(query, k=k)
def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
"""Load JSONL file and return a list of dictionaries."""
with open(file_path, "r") as file:
return [json.loads(line) for line in file]
def evaluate_db(db, original_jsonl_path: str, k):
# Load the original JSONL data for queries and ground truth
original_data = load_jsonl(original_jsonl_path)
# Evaluate retrieval
results = evaluate_retrieval(original_data, retrieve_base, db, k)
print(f"Pass@{k}: {results['pass_at_n']:.2f}%")
print(f"Total Score: {results['average_score']}")
print(f"Total queries: {results['total_queries']}")
이제 다음 실험을 위해 이러한 모델을 초기화해야 합니다. PyMilvus 모델 라이브러리를 사용하여 다른 모델로 쉽게 전환할 수 있습니다.
dense_ef = VoyageEmbeddingFunction(api_key="your-voyage-api-key", model_name="voyage-2")
sparse_ef = BGEM3EmbeddingFunction()
cohere_rf = CohereRerankFunction(api_key="your-cohere-api-key")
Fetching 30 files: 0%| | 0/30 [00:00<?, ?it/s]
path = "codebase_chunks.json"
with open(path, "r") as f:
dataset = json.load(f)
실험 1: 표준 검색
표준 검색은 밀도가 높은 임베딩만 사용하여 관련 문서를 검색합니다. 이 실험에서는 Pass@5를 사용하여 원본 리포지토리의 결과를 재현합니다.
standard_retriever = MilvusContextualRetriever(
uri="standard.db", collection_name="standard", dense_embedding_function=dense_ef
)
standard_retriever.build_collection()
for doc in dataset:
doc_content = doc["content"]
for chunk in doc["chunks"]:
metadata = {
"doc_id": doc["doc_id"],
"original_uuid": doc["original_uuid"],
"chunk_id": chunk["chunk_id"],
"original_index": chunk["original_index"],
"content": chunk["content"],
}
chunk_content = chunk["content"]
standard_retriever.insert_data(chunk_content, metadata)
evaluate_db(standard_retriever, "evaluation_set.jsonl", 5)
Evaluating retrieval: 100%|██████████| 248/248 [01:29<00:00, 2.77it/s]
Pass@5: 80.92%
Total Score: 0.8091877880184332
Total queries: 248
실험 II: 하이브리드 검색
Voyage 임베딩으로 유망한 결과를 얻었으므로 이제 강력한 스파스 임베딩을 생성하는 BGE-M3 모델을 사용해 하이브리드 검색을 수행해 보겠습니다. 밀도 검색과 희소 검색의 결과를 상호 순위 융합(RRF) 방법을 사용해 결합하여 하이브리드 결과를 생성합니다.
hybrid_retriever = MilvusContextualRetriever(
uri="hybrid.db",
collection_name="hybrid",
dense_embedding_function=dense_ef,
use_sparse=True,
sparse_embedding_function=sparse_ef,
)
hybrid_retriever.build_collection()
for doc in dataset:
doc_content = doc["content"]
for chunk in doc["chunks"]:
metadata = {
"doc_id": doc["doc_id"],
"original_uuid": doc["original_uuid"],
"chunk_id": chunk["chunk_id"],
"original_index": chunk["original_index"],
"content": chunk["content"],
}
chunk_content = chunk["content"]
hybrid_retriever.insert_data(chunk_content, metadata)
evaluate_db(hybrid_retriever, "evaluation_set.jsonl", 5)
Evaluating retrieval: 100%|██████████| 248/248 [02:09<00:00, 1.92it/s]
Pass@5: 84.69%
Total Score: 0.8469182027649771
Total queries: 248
실험 III: 문맥 검색
하이브리드 검색은 개선된 결과를 보여주지만, 문맥 검색 방법을 적용하면 결과를 더욱 향상시킬 수 있습니다. 이를 위해 Anthropic의 언어 모델을 사용하여 각 청크에 대해 전체 문서에서 문맥을 미리 추가합니다.
anthropic_client = anthropic.Anthropic(
api_key="your-anthropic-api-key",
)
contextual_retriever = MilvusContextualRetriever(
uri="contextual.db",
collection_name="contextual",
dense_embedding_function=dense_ef,
use_sparse=True,
sparse_embedding_function=sparse_ef,
use_contextualize_embedding=True,
anthropic_client=anthropic_client,
)
contextual_retriever.build_collection()
for doc in dataset:
doc_content = doc["content"]
for chunk in doc["chunks"]:
metadata = {
"doc_id": doc["doc_id"],
"original_uuid": doc["original_uuid"],
"chunk_id": chunk["chunk_id"],
"original_index": chunk["original_index"],
"content": chunk["content"],
}
chunk_content = chunk["content"]
contextual_retriever.insert_contextualized_data(
doc_content, chunk_content, metadata
)
evaluate_db(contextual_retriever, "evaluation_set.jsonl", 5)
Evaluating retrieval: 100%|██████████| 248/248 [01:55<00:00, 2.15it/s]
Pass@5: 87.14%
Total Score: 0.8713517665130568
Total queries: 248
실험 4: 리랭커를 사용한 문맥 검색
Cohere 리랭커를 추가하면 결과를 더욱 개선할 수 있습니다. 리랭커가 포함된 새 리트리버를 별도로 초기화하지 않고도 기존 리트리버가 리랭커를 사용하도록 간단히 구성하여 성능을 향상시킬 수 있습니다.
contextual_retriever.use_reranker = True
contextual_retriever.rerank_function = cohere_rf
evaluate_db(contextual_retriever, "evaluation_set.jsonl", 5)
Evaluating retrieval: 100%|██████████| 248/248 [02:02<00:00, 2.00it/s]
Pass@5: 90.91%
Total Score: 0.9090821812596005
Total queries: 248
검색 성능을 개선하는 몇 가지 방법을 시연해 보았습니다. 시나리오에 맞게 임시로 설계된 문맥 검색은 저렴한 비용으로 문서를 전처리할 수 있는 상당한 잠재력을 보여주며, 더 나은 RAG 시스템으로 이어집니다.