milvus-logo
LFAI
フロントページへ
  • 統合

MilvusとBentoMLによる検索支援型生成(RAG)

Open In Colab

はじめに

このガイドでは、BentoCloud上のオープンソースの埋め込みモデルと大規模言語モデルとMilvusベクトルデータベースを使用して、RAG(Retrieval Augmented Generation)アプリケーションを構築する方法を説明します。 BentoCloudは、モデル推論用に調整されたフルマネージドインフラストラクチャを提供する、動きの速いAIチームのためのAI推論プラットフォームです。オープンソースのモデル・サービング・フレームワークである BentoML と連携し、高性能なモデル・サービスの簡単な作成とデプロイを容易にします。このデモでは、Pythonアプリケーションに組み込むことができるMilvusの軽量版であるMilvus Liteをベクターデータベースとして使用します。

始める前に

Milvus LiteはPyPIから入手可能です。Python 3.8+ではpip経由でインストールできます:

$ pip install -U pymilvus bentoml

Google Colabを使用している場合、インストールした依存関係を有効にするために、ランタイムを再起動する必要があるかもしれません(画面上部の "Runtime "メニューをクリックし、ドロップダウンメニューから "Restart session "を選択してください)。

BentoCloudにサインインした後、DeploymentsでデプロイされたBentoCloudサービスと対話することができます。 対応するEND_POINTとAPIは、Playground -> Pythonにあります。 都市データはここからダウンロードできます。

BentoML/BentoCloud でエンベッディングを提供する

このエンドポイントを使うには、bentoml をインポートし、SyncHTTPClient を使って HTTP クライアントをセットアップします。エンドポイントと、オプションでトークン(BentoCloud でEndpoint Authorization をオンにした場合)を指定します。あるいは、BentoML のSentence Transformers Embeddingsリポジトリを使って、同じモデルを BentoML で提供することもできます。

import bentoml

BENTO_EMBEDDING_MODEL_END_POINT = "BENTO_EMBEDDING_MODEL_END_POINT"
BENTO_API_TOKEN = "BENTO_API_TOKEN"

embedding_client = bentoml.SyncHTTPClient(
    BENTO_EMBEDDING_MODEL_END_POINT, token=BENTO_API_TOKEN
)

embedding_client に接続したら、データを処理する必要があります。データの分割と埋め込みを行うための関数をいくつか用意しました。

ファイルを読み込み、テキストを文字列のリストに前処理する。

# naively chunk on newlines
def chunk_text(filename: str) -> list:
    with open(filename, "r") as f:
        text = f.read()
    sentences = text.split("\n")
    return sentences

まず、都市データをダウンロードします。

import os
import requests
import urllib.request

# set up the data source
repo = "ytang07/bento_octo_milvus_RAG"
directory = "data"
save_dir = "./city_data"
api_url = f"https://api.github.com/repos/{repo}/contents/{directory}"


response = requests.get(api_url)
data = response.json()

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

for item in data:
    if item["type"] == "file":
        file_url = item["download_url"]
        file_path = os.path.join(save_dir, item["name"])
        urllib.request.urlretrieve(file_url, file_path)

次に、持っているファイルをそれぞれ処理します。

# please upload your data directory under this file's folder
cities = os.listdir("city_data")
# store chunked text for each of the cities in a list of dicts
city_chunks = []
for city in cities:
    chunked = chunk_text(f"city_data/{city}")
    cleaned = []
    for chunk in chunked:
        if len(chunk) > 7:
            cleaned.append(chunk)
    mapped = {"city_name": city.split(".")[0], "chunks": cleaned}
    city_chunks.append(mapped)

文字列のリストを、25個の文字列をグループ化した埋め込みリストに分割する。

def get_embeddings(texts: list) -> list:
    if len(texts) > 25:
        splits = [texts[x : x + 25] for x in range(0, len(texts), 25)]
        embeddings = []
        for split in splits:
            embedding_split = embedding_client.encode(sentences=split)
            embeddings += embedding_split
        return embeddings
    return embedding_client.encode(
        sentences=texts,
    )

ここで、埋め込みとテキストチャンクをマッチングさせる必要がある。埋め込みリストと文章リストはインデックスで一致するはずなので、enumerate

entries = []
for city_dict in city_chunks:
    # No need for the embeddings list if get_embeddings already returns a list of lists
    embedding_list = get_embeddings(city_dict["chunks"])  # returns a list of lists
    # Now match texts with embeddings and city name
    for i, embedding in enumerate(embedding_list):
        entry = {
            "embedding": embedding,
            "sentence": city_dict["chunks"][
                i
            ],  # Assume "chunks" has the corresponding texts for the embeddings
            "city": city_dict["city_name"],
        }
        entries.append(entry)
    print(entries)

データをベクターデータベースに挿入して検索する

埋め込みとデータの準備ができたら、Milvus Liteにメタデータとともにベクトルを挿入し、後でベクトル検索を行う。ここではまず、Milvus Liteに接続してクライアントを起動します。MilvusClient モジュールをインポートして、Milvus Lite ベクトルデータベースに接続するMilvus Lite クライアントを初期化します。次元サイズは埋め込みモデルのサイズに由来します。例えば、Sentence Transformerモデルall-MiniLM-L6-v2 は384次元のベクトルを生成します。

from pymilvus import MilvusClient

COLLECTION_NAME = "Bento_Milvus_RAG"  # random name for your collection
DIMENSION = 384

# Initialize a Milvus Lite client
milvus_client = MilvusClient("milvus_demo.db")

MilvusClient の引数については、次のとおりです:

  • uri の引数をローカルファイル、例えば./milvus.db に設定するのが最も便利な方法です。
  • データ規模が大きい場合は、dockerやkubernetes上に、よりパフォーマンスの高いMilvusサーバを構築することができます。このセットアップでは、http://localhost:19530 などのサーバ uri をuri として使用してください。
  • MilvusのフルマネージドクラウドサービスであるZilliz Cloudを利用する場合は、Zilliz CloudのPublic EndpointとApi keyに対応するuritoken を調整してください。

または、古いconnection.connect APIを使用してください(推奨しません):

from pymilvus import connections

connections.connect(uri="milvus_demo.db")

Milvus Liteコレクションの作成

Milvus Liteを使用してコレクションを作成するには2つのステップがあります。このセクションでは、1つのモジュールが必要です:DataTypeはフィールドにどのようなデータタイプが入るかを示します。create_schema():コレクションのスキーマを作成し、add_field():コレクションのスキーマにフィールドを追加します。

from pymilvus import MilvusClient, DataType, Collection

# Create schema
schema = MilvusClient.create_schema(
    auto_id=True,
    enable_dynamic_field=True,
)

# 3.2. Add fields to schema
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=DIMENSION)

スキーマを作成し、データ・フィールドをうまく定義できたので、インデックスを定義する必要があります。検索に関しては、"インデックス "はデータを検索するためにどのようにマッピングするかを定義します。このプロジェクトでは、デフォルトのAUTOINDEXを使用してインデックスを作成します。

次に、先に指定した名前、スキーマ、インデックスでコレクションを作成します。最後に、前に処理したデータを挿入します。

# prepare index parameters
index_params = milvus_client.prepare_index_params()

# add index
index_params.add_index(
    field_name="embedding",
    index_type="AUTOINDEX",  # use autoindex instead of other complex indexing method
    metric_type="COSINE",  # L2, COSINE, or IP
)

# create collection
if milvus_client.has_collection(collection_name=COLLECTION_NAME):
    milvus_client.drop_collection(collection_name=COLLECTION_NAME)
milvus_client.create_collection(
    collection_name=COLLECTION_NAME, schema=schema, index_params=index_params
)

# Outside the loop, now you upsert all the entries at once
milvus_client.insert(collection_name=COLLECTION_NAME, data=entries)

RAG用にLLMをセットアップする

RAG アプリをビルドするには、BentoCloud に LLM をデプロイする必要があります。最新の Llama3 LLM を使ってみましょう。LLM が稼働したら、このモデルサービスのエンドポイントとトークンをコピーして、クライアントをセットアップするだけです。

BENTO_LLM_END_POINT = "BENTO_LLM_END_POINT"

llm_client = bentoml.SyncHTTPClient(BENTO_LLM_END_POINT, token=BENTO_API_TOKEN)

LLMの使い方

次に、プロンプト、コンテキスト、質問でLLM命令をセットアップします。以下は、LLMとして動作し、クライアントからの出力を文字列形式で返す関数です。

def dorag(question: str, context: str):

    prompt = (
        f"You are a helpful assistant. The user has a question. Answer the user question based only on the context: {context}. \n"
        f"The user question is {question}"
    )

    results = llm_client.generate(
        max_tokens=1024,
        prompt=prompt,
    )

    res = ""
    for result in results:
        res += result

    return res

RAGの例

質問をする準備ができました。この関数は単に質問を受け取り、背景情報から関連するコンテキストを生成するためにRAGを行います。そして、コンテキストと質問をdorag()に渡し、結果を取得します。

question = "What state is Cambridge in?"


def ask_a_question(question):
    embeddings = get_embeddings([question])
    res = milvus_client.search(
        collection_name=COLLECTION_NAME,
        data=embeddings,  # search for the one (1) embedding returned as a list of lists
        anns_field="embedding",  # Search across embeddings
        limit=5,  # get me the top 5 results
        output_fields=["sentence"],  # get the sentence/chunk and city
    )

    sentences = []
    for hits in res:
        for hit in hits:
            print(hit)
            sentences.append(hit["entity"]["sentence"])
    context = ". ".join(sentences)
    return context


context = ask_a_question(question=question)
print(context)

RAGの実装

print(dorag(question=question, context=context))

ケンブリッジがどの州にあるかという質問例では、BentoMLから回答全体を表示することができます。しかし、時間をかけて解析すれば、より見栄えが良くなり、ケンブリッジがマサチューセッツ州にあることを教えてくれるはずです。