MilvusによるグラフRAG
大規模な言語モデルの広範な応用は、その応答の精度と関連性を向上させることの重要性を強調している。検索補強型生成(RAG)は、外部知識ベースでモデルを強化し、より多くの文脈情報を提供し、幻覚や知識不足のような問題を軽減します。しかし、単純なRAGパラダイムだけに頼ることには限界があり、特に複雑なエンティティ関係やマルチホップの質問を扱う場合には、モデルが正確な回答を提供するのに苦労することが多い。
知識グラフ(KG)をRAGシステムに導入することは、新しい解決策を提供する。KGは、より正確な検索情報を提供し、RAGが複雑な質問応答タスクをよりよく処理するのを助ける、構造化された方法でエンティティとその関係を提示する。KG-RAGはまだ初期段階にあり、KGから実体と関係を効果的に検索する方法や、ベクトル類似性検索とグラフ構造を統合する方法についてのコンセンサスは得られていない。
本ノートブックでは、このシナリオの性能を大幅に向上させるシンプルかつ強力なアプローチを紹介する。これは、多方向検索とその後の再ランク付けという単純なRAGパラダイムであるが、グラフRAGを論理的に実装し、マルチホップ問題の処理において最先端の性能を達成している。どのように実装されるか見てみましょう。
前提条件
このノートブックを実行する前に、以下の依存関係がインストールされていることを確認してください:
$ pip install --upgrade --quiet pymilvus numpy scipy langchain langchain-core langchain-openai tqdm
Google Colabを使用している場合、インストールしたばかりの依存関係を有効にするために、ランタイムを再起動する必要があるかもしれません(画面上部の "Runtime "メニューをクリックし、ドロップダウンメニューから "Restart session "を選択してください)。
OpenAIのモデルを使います。api key OPENAI_API_KEY
を環境変数として用意しておく。
import os
os.environ["OPENAI_API_KEY"] = "sk-***********"
必要なライブラリと依存関係をインポートする。
import numpy as np
from collections import defaultdict
from scipy.sparse import csr_matrix
from pymilvus import MilvusClient
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from tqdm import tqdm
Milvusクライアント、LLM、埋め込みモデルのインスタンスを初期化します。
milvus_client = MilvusClient(uri="./milvus.db")
llm = ChatOpenAI(
model="gpt-4o",
temperature=0,
)
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
MilvusClientのargsについて:
- 例えば、
./milvus.db
のように、uri
をローカルファイルとして設定する方法は、Milvus Liteを自動的に利用してすべてのデータをこのファイルに格納するため、最も便利な方法です。 - データ規模が大きい場合は、dockerやkubernetes上に、よりパフォーマンスの高いMilvusサーバを構築することができます。このセットアップでは、
http://localhost:19530
などのサーバ uri をuri
として使用してください。 - MilvusのフルマネージドクラウドサービスであるZilliz Cloudを利用する場合は、Zilliz CloudのPublic EndpointとApi keyに対応する
uri
とtoken
を調整してください。
オフラインデータのロード
データの準備
例として、ベルヌーイファミリーとオイラーの関係を紹介するナノデータセットを使用する。このナノデータセットには4つの文章とそれに対応するトリプレットのセットが含まれており、各トリプレットには主語、述語、目的語が含まれている。 実際には、独自のコーパスからトリプレットを抽出するためにどのようなアプローチを使用することもできる。
nano_dataset = [
{
"passage": "Jakob Bernoulli (1654–1705): Jakob was one of the earliest members of the Bernoulli family to gain prominence in mathematics. He made significant contributions to calculus, particularly in the development of the theory of probability. He is known for the Bernoulli numbers and the Bernoulli theorem, a precursor to the law of large numbers. He was the older brother of Johann Bernoulli, another influential mathematician, and the two had a complex relationship that involved both collaboration and rivalry.",
"triplets": [
["Jakob Bernoulli", "made significant contributions to", "calculus"],
[
"Jakob Bernoulli",
"made significant contributions to",
"the theory of probability",
],
["Jakob Bernoulli", "is known for", "the Bernoulli numbers"],
["Jakob Bernoulli", "is known for", "the Bernoulli theorem"],
["The Bernoulli theorem", "is a precursor to", "the law of large numbers"],
["Jakob Bernoulli", "was the older brother of", "Johann Bernoulli"],
],
},
{
"passage": "Johann Bernoulli (1667–1748): Johann, Jakob’s younger brother, was also a major figure in the development of calculus. He worked on infinitesimal calculus and was instrumental in spreading the ideas of Leibniz across Europe. Johann also contributed to the calculus of variations and was known for his work on the brachistochrone problem, which is the curve of fastest descent between two points.",
"triplets": [
[
"Johann Bernoulli",
"was a major figure of",
"the development of calculus",
],
["Johann Bernoulli", "was", "Jakob's younger brother"],
["Johann Bernoulli", "worked on", "infinitesimal calculus"],
["Johann Bernoulli", "was instrumental in spreading", "Leibniz's ideas"],
["Johann Bernoulli", "contributed to", "the calculus of variations"],
["Johann Bernoulli", "was known for", "the brachistochrone problem"],
],
},
{
"passage": "Daniel Bernoulli (1700–1782): The son of Johann Bernoulli, Daniel made major contributions to fluid dynamics, probability, and statistics. He is most famous for Bernoulli’s principle, which describes the behavior of fluid flow and is fundamental to the understanding of aerodynamics.",
"triplets": [
["Daniel Bernoulli", "was the son of", "Johann Bernoulli"],
["Daniel Bernoulli", "made major contributions to", "fluid dynamics"],
["Daniel Bernoulli", "made major contributions to", "probability"],
["Daniel Bernoulli", "made major contributions to", "statistics"],
["Daniel Bernoulli", "is most famous for", "Bernoulli’s principle"],
[
"Bernoulli’s principle",
"is fundamental to",
"the understanding of aerodynamics",
],
],
},
{
"passage": "Leonhard Euler (1707–1783) was one of the greatest mathematicians of all time, and his relationship with the Bernoulli family was significant. Euler was born in Basel and was a student of Johann Bernoulli, who recognized his exceptional talent and mentored him in mathematics. Johann Bernoulli’s influence on Euler was profound, and Euler later expanded upon many of the ideas and methods he learned from the Bernoullis.",
"triplets": [
[
"Leonhard Euler",
"had a significant relationship with",
"the Bernoulli family",
],
["leonhard Euler", "was born in", "Basel"],
["Leonhard Euler", "was a student of", "Johann Bernoulli"],
["Johann Bernoulli's influence", "was profound on", "Euler"],
],
},
]
エンティティと関係を次のように構築します:
- エンティティはトリプレット内の主語または目的語であるため、トリプレットから直接抽出します。
- ここでは、主語、述語、目的語をスペースを挟んで直接連結することで、関係の概念を構築します。
また、後で使用するために、エンティティIDと関係IDを対応付けるdictと、関係IDと通過IDを対応付けるdictを用意します。
entityid_2_relationids = defaultdict(list)
relationid_2_passageids = defaultdict(list)
entities = []
relations = []
passages = []
for passage_id, dataset_info in enumerate(nano_dataset):
passage, triplets = dataset_info["passage"], dataset_info["triplets"]
passages.append(passage)
for triplet in triplets:
if triplet[0] not in entities:
entities.append(triplet[0])
if triplet[2] not in entities:
entities.append(triplet[2])
relation = " ".join(triplet)
if relation not in relations:
relations.append(relation)
entityid_2_relationids[entities.index(triplet[0])].append(
len(relations) - 1
)
entityid_2_relationids[entities.index(triplet[2])].append(
len(relations) - 1
)
relationid_2_passageids[relations.index(relation)].append(passage_id)
データの挿入
エンティティ、リレーション、パスのMilvusコレクションを作成する。エンティティ・コレクションとリレーション・コレクションは本手法におけるグラフ構築の主要なコレクションとして使用され、パッセージ・コレクションは素朴なRAG検索比較や補助的な目的として使用される。
embedding_dim = len(embedding_model.embed_query("foo"))
def create_milvus_collection(collection_name: str):
if milvus_client.has_collection(collection_name=collection_name):
milvus_client.drop_collection(collection_name=collection_name)
milvus_client.create_collection(
collection_name=collection_name,
dimension=embedding_dim,
consistency_level="Strong",
)
entity_col_name = "entity_collection"
relation_col_name = "relation_collection"
passage_col_name = "passage_collection"
create_milvus_collection(entity_col_name)
create_milvus_collection(relation_col_name)
create_milvus_collection(passage_col_name)
データをメタデータ情報とともにMilvusコレクションに挿入する。メタデータ情報には、通過IDと隣接エンティティまたは関係IDが含まれる。
def milvus_insert(
collection_name: str,
text_list: list[str],
):
batch_size = 512
for row_id in tqdm(range(0, len(text_list), batch_size), desc="Inserting"):
batch_texts = text_list[row_id : row_id + batch_size]
batch_embeddings = embedding_model.embed_documents(batch_texts)
batch_ids = [row_id + j for j in range(len(batch_texts))]
batch_data = [
{
"id": id_,
"text": text,
"vector": vector,
}
for id_, text, vector in zip(batch_ids, batch_texts, batch_embeddings)
]
milvus_client.insert(
collection_name=collection_name,
data=batch_data,
)
milvus_insert(
collection_name=relation_col_name,
text_list=relations,
)
milvus_insert(
collection_name=entity_col_name,
text_list=entities,
)
milvus_insert(
collection_name=passage_col_name,
text_list=passages,
)
Inserting: 100%|███████████████████████████████████| 1/1 [00:00<00:00, 1.02it/s]
Inserting: 100%|███████████████████████████████████| 1/1 [00:00<00:00, 1.39it/s]
Inserting: 100%|███████████████████████████████████| 1/1 [00:00<00:00, 2.28it/s]
オンラインクエリ
類似検索
Milvusからの入力クエリに基づいて、上位K個の類似エンティティと関係を検索する。
エンティティ検索を行う際には、まずNER (Named-entity recognition)のような特定の方法を用いて、クエリテキストからクエリエンティティを抽出する必要がある。ここでは簡単のため、NERの結果を用意する。実際には、クエリからエンティティを抽出するために、他のモデルやアプローチを使用することができます。
query = "What contribution did the son of Euler's teacher make?"
query_ner_list = ["Euler"]
# query_ner_list = ner(query) # In practice, replace it with your custom NER approach
query_ner_embeddings = [
embedding_model.embed_query(query_ner) for query_ner in query_ner_list
]
top_k = 3
entity_search_res = milvus_client.search(
collection_name=entity_col_name,
data=query_ner_embeddings,
limit=top_k,
output_fields=["id"],
)
query_embedding = embedding_model.embed_query(query)
relation_search_res = milvus_client.search(
collection_name=relation_col_name,
data=[query_embedding],
limit=top_k,
output_fields=["id"],
)[0]
サブグラフの展開
検索されたエンティティと関係を使って部分グラフを展開し、関係候補を取得します。以下は、部分グラフの展開処理のフローチャートである:
ここでは、隣接行列を構築し、行列の乗算を用いて、数度以内の隣接マッピング情報を計算する。こうすることで、どのような拡張度の情報でも素早く得ることができる。
# Construct the adjacency matrix of entities and relations where the value of the adjacency matrix is 1 if an entity is related to a relation, otherwise 0.
entity_relation_adj = np.zeros((len(entities), len(relations)))
for entity_id, entity in enumerate(entities):
entity_relation_adj[entity_id, entityid_2_relationids[entity_id]] = 1
# Convert the adjacency matrix to a sparse matrix for efficient computation.
entity_relation_adj = csr_matrix(entity_relation_adj)
# Use the entity-relation adjacency matrix to construct 1 degree entity-entity and relation-relation adjacency matrices.
entity_adj_1_degree = entity_relation_adj @ entity_relation_adj.T
relation_adj_1_degree = entity_relation_adj.T @ entity_relation_adj
# Specify the target degree of the subgraph to be expanded.
# 1 or 2 is enough for most cases.
target_degree = 1
# Compute the target degree adjacency matrices using matrix multiplication.
entity_adj_target_degree = entity_adj_1_degree
for _ in range(target_degree - 1):
entity_adj_target_degree = entity_adj_target_degree * entity_adj_1_degree
relation_adj_target_degree = relation_adj_1_degree
for _ in range(target_degree - 1):
relation_adj_target_degree = relation_adj_target_degree * relation_adj_1_degree
entity_relation_adj_target_degree = entity_adj_target_degree @ entity_relation_adj
対象の次数展開行列から値を取り出し、検索された実体と関係から対応する次数を展開することで、部分グラフの全ての関係を簡単に得ることができる。
expanded_relations_from_relation = set()
expanded_relations_from_entity = set()
# You can set the similarity threshold here to guarantee the quality of the retrieved ones.
# entity_sim_filter_thresh = ...
# relation_sim_filter_thresh = ...
filtered_hit_relation_ids = [
relation_res["entity"]["id"]
for relation_res in relation_search_res
# if relation_res['distance'] > relation_sim_filter_thresh
]
for hit_relation_id in filtered_hit_relation_ids:
expanded_relations_from_relation.update(
relation_adj_target_degree[hit_relation_id].nonzero()[1].tolist()
)
filtered_hit_entity_ids = [
one_entity_res["entity"]["id"]
for one_entity_search_res in entity_search_res
for one_entity_res in one_entity_search_res
# if one_entity_res['distance'] > entity_sim_filter_thresh
]
for filtered_hit_entity_id in filtered_hit_entity_ids:
expanded_relations_from_entity.update(
entity_relation_adj_target_degree[filtered_hit_entity_id].nonzero()[1].tolist()
)
# Merge the expanded relations from the relation and entity retrieval ways.
relation_candidate_ids = list(
expanded_relations_from_relation | expanded_relations_from_entity
)
relation_candidate_texts = [
relations[relation_id] for relation_id in relation_candidate_ids
]
部分グラフの展開により関係候補が得られたので、次のステップでLLMによる再ランク付けを行う。
LLMによる再ランク付け
この段階では、LLMの強力な自己注意メカニズムを利用して、関係候補をさらにフィルタリングし、洗練させる。一発勝負のプロンプトを採用し、クエリと関係の候補セットをプロンプトに組み込み、LLMにクエリへの回答を助ける可能性のある関係の候補を選択するよう指示する。クエリの中には複雑なものもあるため、LLMが思考プロセスを明確に表現できるように、思考の連鎖(Chain-of-Thought)アプローチを採用する。LLMの応答は、解析に便利なjson形式である。
query_prompt_one_shot_input = """I will provide you with a list of relationship descriptions. Your task is to select 3 relationships that may be useful to answer the given question. Please return a JSON object containing your thought process and a list of the selected relationships in order of their relevance.
Question:
When was the mother of the leader of the Third Crusade born?
Relationship descriptions:
[1] Eleanor was born in 1122.
[2] Eleanor married King Louis VII of France.
[3] Eleanor was the Duchess of Aquitaine.
[4] Eleanor participated in the Second Crusade.
[5] Eleanor had eight children.
[6] Eleanor was married to Henry II of England.
[7] Eleanor was the mother of Richard the Lionheart.
[8] Richard the Lionheart was the King of England.
[9] Henry II was the father of Richard the Lionheart.
[10] Henry II was the King of England.
[11] Richard the Lionheart led the Third Crusade.
"""
query_prompt_one_shot_output = """{"thought_process": "To answer the question about the birth of the mother of the leader of the Third Crusade, I first need to identify who led the Third Crusade and then determine who his mother was. After identifying his mother, I can look for the relationship that mentions her birth.", "useful_relationships": ["[11] Richard the Lionheart led the Third Crusade", "[7] Eleanor was the mother of Richard the Lionheart", "[1] Eleanor was born in 1122"]}"""
query_prompt_template = """Question:
{question}
Relationship descriptions:
{relation_des_str}
"""
def rerank_relations(
query: str, relation_candidate_texts: list[str], relation_candidate_ids: list[str]
) -> list[int]:
relation_des_str = "\n".join(
map(
lambda item: f"[{item[0]}] {item[1]}",
zip(relation_candidate_ids, relation_candidate_texts),
)
).strip()
rerank_prompts = ChatPromptTemplate.from_messages(
[
HumanMessage(query_prompt_one_shot_input),
AIMessage(query_prompt_one_shot_output),
HumanMessagePromptTemplate.from_template(query_prompt_template),
]
)
rerank_chain = (
rerank_prompts
| llm.bind(response_format={"type": "json_object"})
| JsonOutputParser()
)
rerank_res = rerank_chain.invoke(
{"question": query, "relation_des_str": relation_des_str}
)
rerank_relation_ids = []
rerank_relation_lines = rerank_res["useful_relationships"]
id_2_lines = {}
for line in rerank_relation_lines:
id_ = int(line[line.find("[") + 1 : line.find("]")])
id_2_lines[id_] = line.strip()
rerank_relation_ids.append(id_)
return rerank_relation_ids
rerank_relation_ids = rerank_relations(
query,
relation_candidate_texts=relation_candidate_texts,
relation_candidate_ids=relation_candidate_ids,
)
最終結果の取得
再ランク付けされた関係から最終的な検索結果を得ることができる。
final_top_k = 2
final_passages = []
final_passage_ids = []
for relation_id in rerank_relation_ids:
for passage_id in relationid_2_passageids[relation_id]:
if passage_id not in final_passage_ids:
final_passage_ids.append(passage_id)
final_passages.append(passages[passage_id])
passages_from_our_method = final_passages[:final_top_k]
この結果を、素朴なRAGメソッドと比較することができます。素朴なRAGメソッドは、クエリの埋め込みに基づく上位K個のパッセージを、パッセージのコレクションから直接取得します。
naive_passage_res = milvus_client.search(
collection_name=passage_col_name,
data=[query_embedding],
limit=final_top_k,
output_fields=["text"],
)[0]
passages_from_naive_rag = [res["entity"]["text"] for res in naive_passage_res]
print(
f"Passages retrieved from naive RAG: \n{passages_from_naive_rag}\n\n"
f"Passages retrieved from our method: \n{passages_from_our_method}\n\n"
)
prompt = ChatPromptTemplate.from_messages(
[
(
"human",
"""Use the following pieces of retrieved context to answer the question. If there is not enough information in the retrieved context to answer the question, just say that you don't know.
Question: {question}
Context: {context}
Answer:""",
)
]
)
rag_chain = prompt | llm | StrOutputParser()
answer_from_naive_rag = rag_chain.invoke(
{"question": query, "context": "\n".join(passages_from_naive_rag)}
)
answer_from_our_method = rag_chain.invoke(
{"question": query, "context": "\n".join(passages_from_our_method)}
)
print(
f"Answer from naive RAG: {answer_from_naive_rag}\n\nAnswer from our method: {answer_from_our_method}"
)
Passages retrieved from naive RAG:
['Leonhard Euler (1707–1783) was one of the greatest mathematicians of all time, and his relationship with the Bernoulli family was significant. Euler was born in Basel and was a student of Johann Bernoulli, who recognized his exceptional talent and mentored him in mathematics. Johann Bernoulli’s influence on Euler was profound, and Euler later expanded upon many of the ideas and methods he learned from the Bernoullis.', 'Johann Bernoulli (1667–1748): Johann, Jakob’s younger brother, was also a major figure in the development of calculus. He worked on infinitesimal calculus and was instrumental in spreading the ideas of Leibniz across Europe. Johann also contributed to the calculus of variations and was known for his work on the brachistochrone problem, which is the curve of fastest descent between two points.']
Passages retrieved from our method:
['Leonhard Euler (1707–1783) was one of the greatest mathematicians of all time, and his relationship with the Bernoulli family was significant. Euler was born in Basel and was a student of Johann Bernoulli, who recognized his exceptional talent and mentored him in mathematics. Johann Bernoulli’s influence on Euler was profound, and Euler later expanded upon many of the ideas and methods he learned from the Bernoullis.', 'Daniel Bernoulli (1700–1782): The son of Johann Bernoulli, Daniel made major contributions to fluid dynamics, probability, and statistics. He is most famous for Bernoulli’s principle, which describes the behavior of fluid flow and is fundamental to the understanding of aerodynamics.']
Answer from naive RAG: I don't know. The retrieved context does not provide information about the contributions made by the son of Euler's teacher.
Answer from our method: The son of Euler's teacher, Daniel Bernoulli, made major contributions to fluid dynamics, probability, and statistics. He is most famous for Bernoulli’s principle, which describes the behavior of fluid flow and is fundamental to the understanding of aerodynamics.
見てわかるように、素朴なRAGから検索されたパッセージは、真実のパッセージを見逃しており、間違った答えを導いています。 私たちの方法から検索されたパッセージは正しく、問題に対する正確な答えを得るのに役立ちます。