milvus-logo
LFAI
Home
  • Integrationen

Retrieval-Augmented Generation (RAG) mit Milvus und BentoML

Open In Colab GitHub Repository

Einführung

Dieser Leitfaden zeigt, wie man ein Open-Source-Einbettungsmodell und ein Großsprachenmodell auf BentoCloud mit der Vektordatenbank Milvus verwendet, um eine RAG-Anwendung (Retrieval Augmented Generation) zu erstellen. BentoCloud ist eine KI-Inferenzplattform für schnell arbeitende KI-Teams, die eine vollständig verwaltete Infrastruktur bietet, die auf die Modellinferenz zugeschnitten ist. Sie arbeitet mit BentoML zusammen, einem Open-Source-Framework für die Modellbereitstellung, um die einfache Erstellung und Bereitstellung von Hochleistungsmodelldiensten zu ermöglichen. In dieser Demo verwenden wir Milvus Lite als Vektordatenbank, eine schlanke Version von Milvus, die in Ihre Python-Anwendung eingebettet werden kann.

Bevor Sie beginnen

Milvus Lite ist auf PyPI verfügbar. Sie können es über pip für Python 3.8+ installieren:

$ pip install -U pymilvus bentoml

Wenn Sie Google Colab verwenden, müssen Sie möglicherweise die Runtime neu starten, um die soeben installierten Abhängigkeiten zu aktivieren (klicken Sie auf das Menü "Runtime" am oberen Rand des Bildschirms und wählen Sie "Restart session" aus dem Dropdown-Menü).

Nachdem Sie sich in BentoCloud angemeldet haben, können Sie mit den bereitgestellten BentoCloud-Diensten in Deployments interagieren, und der entsprechende END_POINT und die API befinden sich in Playground -> Python. Sie können die Stadtdaten hier herunterladen.

Servieren von Einbettungen mit BentoML/BentoCloud

Um diesen Endpunkt zu verwenden, importieren Sie bentoml und richten Sie einen HTTP-Client ein, der SyncHTTPClient verwendet, indem Sie den Endpunkt und optional das Token angeben (wenn Sie Endpoint Authorization auf BentoCloud aktivieren). Alternativ können Sie dasselbe Modell auch über BentoML mit seinem Sentence Transformers Embeddings Repository verwenden.

import bentoml

BENTO_EMBEDDING_MODEL_END_POINT = "BENTO_EMBEDDING_MODEL_END_POINT"
BENTO_API_TOKEN = "BENTO_API_TOKEN"

embedding_client = bentoml.SyncHTTPClient(
    BENTO_EMBEDDING_MODEL_END_POINT, token=BENTO_API_TOKEN
)

Sobald wir eine Verbindung mit dem embedding_client hergestellt haben, müssen wir unsere Daten verarbeiten. Wir haben mehrere Funktionen zur Verfügung gestellt, um die Daten aufzuteilen und einzubetten.

Dateien einlesen und den Text in eine Liste von Zeichenketten vorverarbeiten.

# naively chunk on newlines
def chunk_text(filename: str) -> list:
    with open(filename, "r") as f:
        text = f.read()
    sentences = text.split("\n")
    return sentences

Zunächst müssen wir die Städtedaten herunterladen.

import os
import requests
import urllib.request

# set up the data source
repo = "ytang07/bento_octo_milvus_RAG"
directory = "data"
save_dir = "./city_data"
api_url = f"https://api.github.com/repos/{repo}/contents/{directory}"


response = requests.get(api_url)
data = response.json()

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

for item in data:
    if item["type"] == "file":
        file_url = item["download_url"]
        file_path = os.path.join(save_dir, item["name"])
        urllib.request.urlretrieve(file_url, file_path)

Als Nächstes verarbeiten wir jede der Dateien, die wir haben.

# please upload your data directory under this file's folder
cities = os.listdir("city_data")
# store chunked text for each of the cities in a list of dicts
city_chunks = []
for city in cities:
    chunked = chunk_text(f"city_data/{city}")
    cleaned = []
    for chunk in chunked:
        if len(chunk) > 7:
            cleaned.append(chunk)
    mapped = {"city_name": city.split(".")[0], "chunks": cleaned}
    city_chunks.append(mapped)

Zerlegt eine Liste von Zeichenfolgen in eine Liste von Einbettungen, die jeweils 25 Textzeichenfolgen gruppieren.

def get_embeddings(texts: list) -> list:
    if len(texts) > 25:
        splits = [texts[x : x + 25] for x in range(0, len(texts), 25)]
        embeddings = []
        for split in splits:
            embedding_split = embedding_client.encode(sentences=split)
            embeddings += embedding_split
        return embeddings
    return embedding_client.encode(
        sentences=texts,
    )

Nun müssen wir die Einbettungen und die Textabschnitte zuordnen. Da die Liste der Einbettungen und die Liste der Sätze nach Index übereinstimmen sollten, können wir enumerate durch beide Listen gehen, um sie abzugleichen.

entries = []
for city_dict in city_chunks:
    # No need for the embeddings list if get_embeddings already returns a list of lists
    embedding_list = get_embeddings(city_dict["chunks"])  # returns a list of lists
    # Now match texts with embeddings and city name
    for i, embedding in enumerate(embedding_list):
        entry = {
            "embedding": embedding,
            "sentence": city_dict["chunks"][
                i
            ],  # Assume "chunks" has the corresponding texts for the embeddings
            "city": city_dict["city_name"],
        }
        entries.append(entry)
    print(entries)

Einfügen von Daten in eine Vektordatenbank für den Abruf

Nachdem wir unsere Einbettungen und Daten vorbereitet haben, können wir die Vektoren zusammen mit den Metadaten in Milvus Lite für die spätere Vektorsuche einfügen. Der erste Schritt in diesem Abschnitt besteht darin, einen Client zu starten, der eine Verbindung zu Milvus Lite herstellt. Wir importieren einfach das Modul MilvusClient und initialisieren einen Milvus Lite-Client, der eine Verbindung zu Ihrer Milvus Lite-Vektordatenbank herstellt. Die Größe der Dimensionen ergibt sich aus der Größe des Einbettungsmodells, z. B. erzeugt das Modell Sentence Transformer all-MiniLM-L6-v2 Vektoren der Dimension 384.

from pymilvus import MilvusClient

COLLECTION_NAME = "Bento_Milvus_RAG"  # random name for your collection
DIMENSION = 384

# Initialize a Milvus Lite client
milvus_client = MilvusClient("milvus_demo.db")

Wie bei dem Argument von MilvusClient:

  • Die Einstellung von uri als lokale Datei, z. B../milvus.db, ist die bequemste Methode, da sie automatisch Milvus Lite verwendet, um alle Daten in dieser Datei zu speichern.
  • Wenn Sie große Datenmengen haben, können Sie einen leistungsfähigeren Milvus-Server auf Docker oder Kubernetes einrichten. Bei dieser Einrichtung verwenden Sie bitte die Server-Uri, z. B.http://localhost:19530, als uri.
  • Wenn Sie Zilliz Cloud, den vollständig verwalteten Cloud-Service für Milvus, nutzen möchten, passen Sie uri und token an, die dem Public Endpoint und dem Api-Schlüssel in Zilliz Cloud entsprechen.

Oder mit der alten connections.connect API (nicht empfohlen):

from pymilvus import connections

connections.connect(uri="milvus_demo.db")

Erstellen Ihrer Milvus-Lite-Sammlung

Die Erstellung einer Sammlung mit Milvus Lite umfasst zwei Schritte: erstens die Definition des Schemas und zweitens die Definition des Indexes. Für diesen Abschnitt benötigen wir ein Modul: DataType sagt uns, welche Art von Daten in einem Feld enthalten sein wird. Außerdem müssen wir zwei Funktionen verwenden, um ein Schema zu erstellen und Felder hinzuzufügen. create_schema(): erstellt ein Sammlungsschema, add_field(): fügt ein Feld zum Schema einer Sammlung hinzu.

from pymilvus import MilvusClient, DataType, Collection

# Create schema
schema = MilvusClient.create_schema(
    auto_id=True,
    enable_dynamic_field=True,
)

# 3.2. Add fields to schema
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=DIMENSION)

Nachdem wir nun unser Schema erstellt und erfolgreich Datenfelder definiert haben, müssen wir den Index definieren. In Bezug auf die Suche definiert ein "Index", wie wir unsere Daten für den Abruf abbilden werden. Wir verwenden die Standardeinstellung AUTOINDEX, um unsere Daten für dieses Projekt zu indizieren.

Als nächstes erstellen wir die Sammlung mit dem zuvor angegebenen Namen, Schema und Index. Schließlich fügen wir die zuvor verarbeiteten Daten ein.

# prepare index parameters
index_params = milvus_client.prepare_index_params()

# add index
index_params.add_index(
    field_name="embedding",
    index_type="AUTOINDEX",  # use autoindex instead of other complex indexing method
    metric_type="COSINE",  # L2, COSINE, or IP
)

# create collection
if milvus_client.has_collection(collection_name=COLLECTION_NAME):
    milvus_client.drop_collection(collection_name=COLLECTION_NAME)
milvus_client.create_collection(
    collection_name=COLLECTION_NAME, schema=schema, index_params=index_params
)

# Outside the loop, now you upsert all the entries at once
milvus_client.insert(collection_name=COLLECTION_NAME, data=entries)

Richten Sie Ihren LLM für RAG ein

Um eine RAG-Anwendung zu erstellen, müssen wir einen LLM auf BentoCloud bereitstellen. Wir verwenden den neuesten Llama3 LLM. Sobald er einsatzbereit ist, kopieren Sie einfach den Endpunkt und das Token dieses Modelldienstes und richten einen Client für ihn ein.

BENTO_LLM_END_POINT = "BENTO_LLM_END_POINT"

llm_client = bentoml.SyncHTTPClient(BENTO_LLM_END_POINT, token=BENTO_API_TOKEN)

LLM-Anweisungen

Jetzt richten wir die LLM-Anweisungen mit dem Prompt, dem Kontext und der Frage ein. Hier ist die Funktion, die sich wie ein LLM verhält und dann die Ausgabe vom Client in einem String-Format zurückgibt.

def dorag(question: str, context: str):

    prompt = (
        f"You are a helpful assistant. The user has a question. Answer the user question based only on the context: {context}. \n"
        f"The user question is {question}"
    )

    results = llm_client.generate(
        max_tokens=1024,
        prompt=prompt,
    )

    res = ""
    for result in results:
        res += result

    return res

Ein RAG-Beispiel

Jetzt sind wir bereit, eine Frage zu stellen. Diese Funktion nimmt einfach eine Frage und führt dann RAG aus, um den relevanten Kontext aus den Hintergrundinformationen zu generieren. Dann übergeben wir den Kontext und die Frage an dorag() und erhalten das Ergebnis.

question = "What state is Cambridge in?"


def ask_a_question(question):
    embeddings = get_embeddings([question])
    res = milvus_client.search(
        collection_name=COLLECTION_NAME,
        data=embeddings,  # search for the one (1) embedding returned as a list of lists
        anns_field="embedding",  # Search across embeddings
        limit=5,  # get me the top 5 results
        output_fields=["sentence"],  # get the sentence/chunk and city
    )

    sentences = []
    for hits in res:
        for hit in hits:
            print(hit)
            sentences.append(hit["entity"]["sentence"])
    context = ". ".join(sentences)
    return context


context = ask_a_question(question=question)
print(context)

RAG implementieren

print(dorag(question=question, context=context))

Für die Beispielfrage, in der gefragt wird, in welchem Zustand sich Cambridge befindet, können wir die gesamte Antwort aus BentoML drucken. Wenn wir uns jedoch die Zeit nehmen, sie zu parsen, sieht es einfach schöner aus und sollte uns sagen, dass Cambridge in Massachusetts liegt.

Übersetzt vonDeepL

Try Managed Milvus for Free

Zilliz Cloud is hassle-free, powered by Milvus and 10x faster.

Get Started
Feedback

War diese Seite hilfreich?