Retrieval-Augmented Generation (RAG) mit Milvus und BentoML
Einführung
Dieser Leitfaden zeigt, wie man ein Open-Source-Einbettungsmodell und ein Großsprachenmodell auf BentoCloud mit der Vektordatenbank Milvus verwendet, um eine RAG-Anwendung (Retrieval Augmented Generation) zu erstellen. BentoCloud ist eine KI-Inferenzplattform für schnell arbeitende KI-Teams, die eine vollständig verwaltete Infrastruktur bietet, die auf die Modellinferenz zugeschnitten ist. Sie arbeitet mit BentoML zusammen, einem Open-Source-Framework für die Modellbereitstellung, um die einfache Erstellung und Bereitstellung von Hochleistungsmodelldiensten zu ermöglichen. In dieser Demo verwenden wir Milvus Lite als Vektordatenbank, eine schlanke Version von Milvus, die in Ihre Python-Anwendung eingebettet werden kann.
Bevor Sie beginnen
Milvus Lite ist auf PyPI verfügbar. Sie können es über pip für Python 3.8+ installieren:
$ pip install -U pymilvus bentoml
Wenn Sie Google Colab verwenden, müssen Sie möglicherweise die Runtime neu starten, um die soeben installierten Abhängigkeiten zu aktivieren (klicken Sie auf das Menü "Runtime" am oberen Bildschirmrand und wählen Sie "Restart session" aus dem Dropdown-Menü).
Nachdem Sie sich bei BentoCloud angemeldet haben, können Sie mit den bereitgestellten BentoCloud-Diensten in Deployments interagieren, und der entsprechende END_POINT und die API befinden sich in Playground -> Python. Sie können die Stadtdaten hier herunterladen.
Servieren von Einbettungen mit BentoML/BentoCloud
Um diesen Endpunkt zu verwenden, importieren Sie bentoml
und richten Sie einen HTTP-Client ein, der SyncHTTPClient
verwendet, indem Sie den Endpunkt und optional das Token angeben (wenn Sie Endpoint Authorization
auf BentoCloud aktivieren). Alternativ können Sie dasselbe Modell auch über BentoML mit seinem Sentence Transformers Embeddings Repository verwenden.
import bentoml
BENTO_EMBEDDING_MODEL_END_POINT = "BENTO_EMBEDDING_MODEL_END_POINT"
BENTO_API_TOKEN = "BENTO_API_TOKEN"
embedding_client = bentoml.SyncHTTPClient(
BENTO_EMBEDDING_MODEL_END_POINT, token=BENTO_API_TOKEN
)
Sobald wir eine Verbindung mit dem embedding_client hergestellt haben, müssen wir unsere Daten verarbeiten. Wir haben mehrere Funktionen zur Verfügung gestellt, um die Daten aufzuteilen und einzubetten.
Dateien einlesen und den Text in eine Liste von Zeichenketten vorverarbeiten.
# naively chunk on newlines
def chunk_text(filename: str) -> list:
with open(filename, "r") as f:
text = f.read()
sentences = text.split("\n")
return sentences
Zunächst müssen wir die Städtedaten herunterladen.
import os
import requests
import urllib.request
# set up the data source
repo = "ytang07/bento_octo_milvus_RAG"
directory = "data"
save_dir = "./city_data"
api_url = f"https://api.github.com/repos/{repo}/contents/{directory}"
response = requests.get(api_url)
data = response.json()
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for item in data:
if item["type"] == "file":
file_url = item["download_url"]
file_path = os.path.join(save_dir, item["name"])
urllib.request.urlretrieve(file_url, file_path)
Als Nächstes verarbeiten wir jede der Dateien, die wir haben.
# please upload your data directory under this file's folder
cities = os.listdir("city_data")
# store chunked text for each of the cities in a list of dicts
city_chunks = []
for city in cities:
chunked = chunk_text(f"city_data/{city}")
cleaned = []
for chunk in chunked:
if len(chunk) > 7:
cleaned.append(chunk)
mapped = {"city_name": city.split(".")[0], "chunks": cleaned}
city_chunks.append(mapped)
Zerlegt eine Liste von Zeichenfolgen in eine Liste von Einbettungen, die jeweils 25 Textzeichenfolgen gruppieren.
def get_embeddings(texts: list) -> list:
if len(texts) > 25:
splits = [texts[x : x + 25] for x in range(0, len(texts), 25)]
embeddings = []
for split in splits:
embedding_split = embedding_client.encode(sentences=split)
embeddings += embedding_split
return embeddings
return embedding_client.encode(
sentences=texts,
)
Nun müssen wir die Einbettungen und die Textabschnitte zuordnen. Da die Liste der Einbettungen und die Liste der Sätze nach Index übereinstimmen sollten, können wir enumerate
durch beide Listen gehen, um sie abzugleichen.
entries = []
for city_dict in city_chunks:
# No need for the embeddings list if get_embeddings already returns a list of lists
embedding_list = get_embeddings(city_dict["chunks"]) # returns a list of lists
# Now match texts with embeddings and city name
for i, embedding in enumerate(embedding_list):
entry = {
"embedding": embedding,
"sentence": city_dict["chunks"][
i
], # Assume "chunks" has the corresponding texts for the embeddings
"city": city_dict["city_name"],
}
entries.append(entry)
print(entries)
Einfügen von Daten in eine Vektordatenbank für den Abruf
Nachdem wir unsere Einbettungen und Daten vorbereitet haben, können wir die Vektoren zusammen mit den Metadaten in Milvus Lite für die spätere Vektorsuche einfügen. Der erste Schritt in diesem Abschnitt besteht darin, einen Client zu starten, der eine Verbindung zu Milvus Lite herstellt. Wir importieren einfach das Modul MilvusClient
und initialisieren einen Milvus-Lite-Client, der sich mit Ihrer Milvus-Lite-Vektordatenbank verbindet. Die Größe der Dimensionen ergibt sich aus der Größe des Einbettungsmodells, z.B. erzeugt das Sentence Transformer Modell all-MiniLM-L6-v2
Vektoren mit 384 Dimensionen.
from pymilvus import MilvusClient
COLLECTION_NAME = "Bento_Milvus_RAG" # random name for your collection
DIMENSION = 384
# Initialize a Milvus Lite client
milvus_client = MilvusClient("milvus_demo.db")
Für das Argument von MilvusClient
gilt Folgendes:
- Die Einstellung von
uri
als lokale Datei, z. B../milvus.db
, ist die bequemste Methode, da sie automatisch Milvus Lite verwendet, um alle Daten in dieser Datei zu speichern. - Wenn Sie große Datenmengen haben, können Sie einen leistungsfähigeren Milvus-Server auf Docker oder Kubernetes einrichten. Bei dieser Einrichtung verwenden Sie bitte die Server-Uri, z. B.
http://localhost:19530
, alsuri
. - Wenn Sie Zilliz Cloud, den vollständig verwalteten Cloud-Service für Milvus, nutzen möchten, passen Sie
uri
undtoken
an, die dem Public Endpoint und dem Api-Schlüssel in Zilliz Cloud entsprechen.
Oder mit der alten connections.connect API (nicht empfohlen):
from pymilvus import connections
connections.connect(uri="milvus_demo.db")
Erstellen Ihrer Milvus-Lite-Sammlung
Die Erstellung einer Sammlung mit Milvus Lite umfasst zwei Schritte: erstens die Definition des Schemas und zweitens die Definition des Indexes. Für diesen Abschnitt benötigen wir ein Modul: DataType sagt uns, welche Art von Daten in einem Feld enthalten sein wird. Außerdem müssen wir zwei Funktionen verwenden, um ein Schema zu erstellen und Felder hinzuzufügen. create_schema(): erstellt ein Sammlungsschema, add_field(): fügt ein Feld zum Schema einer Sammlung hinzu.
from pymilvus import MilvusClient, DataType, Collection
# Create schema
schema = MilvusClient.create_schema(
auto_id=True,
enable_dynamic_field=True,
)
# 3.2. Add fields to schema
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=DIMENSION)
Nachdem wir nun unser Schema erstellt und erfolgreich Datenfelder definiert haben, müssen wir den Index definieren. In Bezug auf die Suche definiert ein "Index", wie wir unsere Daten für den Abruf abbilden werden. Wir verwenden die Standardeinstellung AUTOINDEX, um unsere Daten für dieses Projekt zu indizieren.
Als nächstes erstellen wir die Sammlung mit dem zuvor angegebenen Namen, Schema und Index. Schließlich fügen wir die zuvor verarbeiteten Daten ein.
# prepare index parameters
index_params = milvus_client.prepare_index_params()
# add index
index_params.add_index(
field_name="embedding",
index_type="AUTOINDEX", # use autoindex instead of other complex indexing method
metric_type="COSINE", # L2, COSINE, or IP
)
# create collection
if milvus_client.has_collection(collection_name=COLLECTION_NAME):
milvus_client.drop_collection(collection_name=COLLECTION_NAME)
milvus_client.create_collection(
collection_name=COLLECTION_NAME, schema=schema, index_params=index_params
)
# Outside the loop, now you upsert all the entries at once
milvus_client.insert(collection_name=COLLECTION_NAME, data=entries)
Richten Sie Ihren LLM für RAG ein
Um eine RAG-Anwendung zu erstellen, müssen wir einen LLM auf BentoCloud bereitstellen. Wir verwenden den neuesten Llama3 LLM. Sobald er einsatzbereit ist, kopieren Sie einfach den Endpunkt und das Token dieses Modelldienstes und richten einen Client für ihn ein.
BENTO_LLM_END_POINT = "BENTO_LLM_END_POINT"
llm_client = bentoml.SyncHTTPClient(BENTO_LLM_END_POINT, token=BENTO_API_TOKEN)
LLM-Anweisungen
Jetzt richten wir die LLM-Anweisungen mit dem Prompt, dem Kontext und der Frage ein. Hier ist die Funktion, die sich wie ein LLM verhält und dann die Ausgabe vom Client in einem String-Format zurückgibt.
def dorag(question: str, context: str):
prompt = (
f"You are a helpful assistant. The user has a question. Answer the user question based only on the context: {context}. \n"
f"The user question is {question}"
)
results = llm_client.generate(
max_tokens=1024,
prompt=prompt,
)
res = ""
for result in results:
res += result
return res
Ein RAG-Beispiel
Jetzt sind wir bereit, eine Frage zu stellen. Diese Funktion nimmt einfach eine Frage und führt dann RAG aus, um den relevanten Kontext aus den Hintergrundinformationen zu generieren. Dann übergeben wir den Kontext und die Frage an dorag() und erhalten das Ergebnis.
question = "What state is Cambridge in?"
def ask_a_question(question):
embeddings = get_embeddings([question])
res = milvus_client.search(
collection_name=COLLECTION_NAME,
data=embeddings, # search for the one (1) embedding returned as a list of lists
anns_field="embedding", # Search across embeddings
limit=5, # get me the top 5 results
output_fields=["sentence"], # get the sentence/chunk and city
)
sentences = []
for hits in res:
for hit in hits:
print(hit)
sentences.append(hit["entity"]["sentence"])
context = ". ".join(sentences)
return context
context = ask_a_question(question=question)
print(context)
RAG implementieren
print(dorag(question=question, context=context))
Für die Beispielfrage, in der gefragt wird, in welchem Zustand sich Cambridge befindet, können wir die gesamte Antwort aus BentoML drucken. Wenn wir uns jedoch die Zeit nehmen, sie zu parsen, sieht es einfach schöner aus und sollte uns sagen, dass Cambridge in Massachusetts liegt.