Génération améliorée par la recherche (RAG) avec Milvus et BentoML
Introduction
Ce guide explique comment utiliser un modèle d'intégration open-source et un modèle de grande langue sur BentoCloud avec la base de données vectorielle Milvus pour construire une application RAG (Retrieval Augmented Generation). BentoCloud est une plateforme d'inférence d'IA destinée aux équipes d'IA qui évoluent rapidement, offrant une infrastructure entièrement gérée et adaptée à l'inférence de modèles. Il fonctionne en conjonction avec BentoML, un cadre de service de modèle open-source, pour faciliter la création et le déploiement de services de modèle de haute performance. Dans cette démo, nous utilisons Milvus Lite comme base de données vectorielle, qui est la version allégée de Milvus pouvant être intégrée dans votre application Python.
Avant de commencer
Milvus Lite est disponible sur PyPI. Vous pouvez l'installer via pip pour Python 3.8+ :
$ pip install -U pymilvus bentoml
Si vous utilisez Google Colab, pour activer les dépendances qui viennent d'être installées, vous devrez peut-être redémarrer le runtime (cliquez sur le menu "Runtime" en haut de l'écran, et sélectionnez "Restart session" dans le menu déroulant).
Après s'être connecté à BentoCloud, nous pouvons interagir avec les services BentoCloud déployés dans Deployments, et les END_POINT et API correspondants sont situés dans Playground -> Python. Vous pouvez télécharger les données de la ville ici.
Servir les Embeddings avec BentoML/BentoCloud
Pour utiliser ce point d'accès, importez bentoml
et configurez un client HTTP utilisant SyncHTTPClient
en spécifiant le point d'accès et éventuellement le jeton (si vous activez Endpoint Authorization
sur BentoCloud). Vous pouvez également utiliser le même modèle servi par BentoML à l'aide de son référentiel Sentence Transformers Embeddings.
import bentoml
BENTO_EMBEDDING_MODEL_END_POINT = "BENTO_EMBEDDING_MODEL_END_POINT"
BENTO_API_TOKEN = "BENTO_API_TOKEN"
embedding_client = bentoml.SyncHTTPClient(
BENTO_EMBEDDING_MODEL_END_POINT, token=BENTO_API_TOKEN
)
Une fois que nous nous connectons au client d'intégration, nous devons traiter nos données. Nous avons fourni plusieurs fonctions pour effectuer le découpage et l'intégration des données.
Lire les fichiers et prétraiter le texte en une liste de chaînes de caractères.
# naively chunk on newlines
def chunk_text(filename: str) -> list:
with open(filename, "r") as f:
text = f.read()
sentences = text.split("\n")
return sentences
Nous devons d'abord télécharger les données de la ville.
import os
import requests
import urllib.request
# set up the data source
repo = "ytang07/bento_octo_milvus_RAG"
directory = "data"
save_dir = "./city_data"
api_url = f"https://api.github.com/repos/{repo}/contents/{directory}"
response = requests.get(api_url)
data = response.json()
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for item in data:
if item["type"] == "file":
file_url = item["download_url"]
file_path = os.path.join(save_dir, item["name"])
urllib.request.urlretrieve(file_url, file_path)
Ensuite, nous traitons chacun des fichiers que nous avons.
# please upload your data directory under this file's folder
cities = os.listdir("city_data")
# store chunked text for each of the cities in a list of dicts
city_chunks = []
for city in cities:
chunked = chunk_text(f"city_data/{city}")
cleaned = []
for chunk in chunked:
if len(chunk) > 7:
cleaned.append(chunk)
mapped = {"city_name": city.split(".")[0], "chunks": cleaned}
city_chunks.append(mapped)
Fractionne une liste de chaînes de caractères en une liste d'embeddings, chacun regroupant 25 chaînes de caractères.
def get_embeddings(texts: list) -> list:
if len(texts) > 25:
splits = [texts[x : x + 25] for x in range(0, len(texts), 25)]
embeddings = []
for split in splits:
embedding_split = embedding_client.encode(sentences=split)
embeddings += embedding_split
return embeddings
return embedding_client.encode(
sentences=texts,
)
Maintenant, nous devons faire correspondre les embeddings et les morceaux de texte. Étant donné que la liste des enchâssements et la liste des phrases doivent correspondre par index, nous pouvons enumerate
parcourir l'une ou l'autre des listes pour les faire correspondre.
entries = []
for city_dict in city_chunks:
# No need for the embeddings list if get_embeddings already returns a list of lists
embedding_list = get_embeddings(city_dict["chunks"]) # returns a list of lists
# Now match texts with embeddings and city name
for i, embedding in enumerate(embedding_list):
entry = {
"embedding": embedding,
"sentence": city_dict["chunks"][
i
], # Assume "chunks" has the corresponding texts for the embeddings
"city": city_dict["city_name"],
}
entries.append(entry)
print(entries)
Insertion des données dans une base de données vectorielle pour l'extraction
Une fois nos embeddings et nos données préparés, nous pouvons insérer les vecteurs avec les métadonnées dans Milvus Lite pour une recherche vectorielle ultérieure. La première étape de cette section consiste à démarrer un client en se connectant à Milvus Lite. Il suffit d'importer le module MilvusClient
et d'initialiser un client Milvus Lite qui se connecte à votre base de données vectorielles Milvus Lite. La taille de la dimension provient de la taille du modèle d'intégration, par exemple le modèle Sentence Transformer all-MiniLM-L6-v2
produit des vecteurs de 384 dimensions.
from pymilvus import MilvusClient
COLLECTION_NAME = "Bento_Milvus_RAG" # random name for your collection
DIMENSION = 384
# Initialize a Milvus Lite client
milvus_client = MilvusClient("milvus_demo.db")
Comme pour l'argument de MilvusClient
:
- La définition de
uri
comme fichier local, par exemple./milvus.db
, est la méthode la plus pratique, car elle utilise automatiquement Milvus Lite pour stocker toutes les données dans ce fichier. - Si vous avez des données à grande échelle, vous pouvez configurer un serveur Milvus plus performant sur docker ou kubernetes. Dans cette configuration, veuillez utiliser l'uri du serveur, par exemple
http://localhost:19530
, comme votreuri
. - Si vous souhaitez utiliser Zilliz Cloud, le service cloud entièrement géré pour Milvus, ajustez les adresses
uri
ettoken
, qui correspondent au point de terminaison public et à la clé Api dans Zilliz Cloud.
Ou avec l'ancienne API connections.connect (non recommandé) :
from pymilvus import connections
connections.connect(uri="milvus_demo.db")
Création de votre collection Milvus Lite
La création d'une collection à l'aide de Milvus Lite implique deux étapes : premièrement, la définition du schéma et deuxièmement, la définition de l'index. Pour cette section, nous avons besoin d'un module : DataType nous indique quel type de données sera contenu dans un champ. Nous devons également utiliser deux fonctions pour créer un schéma et ajouter des champs. create_schema() : crée un schéma de collection, add_field() : ajoute un champ au schéma d'une collection.
from pymilvus import MilvusClient, DataType, Collection
# Create schema
schema = MilvusClient.create_schema(
auto_id=True,
enable_dynamic_field=True,
)
# 3.2. Add fields to schema
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=DIMENSION)
Maintenant que nous avons créé notre schéma et défini avec succès un champ de données, nous devons définir l'index. En termes de recherche, un "index" définit la manière dont nous allons cartographier nos données pour les retrouver. Nous utilisons le choix par défaut AUTOINDEX pour indexer nos données dans le cadre de ce projet.
Ensuite, nous créons la collection avec le nom, le schéma et l'index donnés précédemment. Enfin, nous insérons les données précédemment traitées.
# prepare index parameters
index_params = milvus_client.prepare_index_params()
# add index
index_params.add_index(
field_name="embedding",
index_type="AUTOINDEX", # use autoindex instead of other complex indexing method
metric_type="COSINE", # L2, COSINE, or IP
)
# create collection
if milvus_client.has_collection(collection_name=COLLECTION_NAME):
milvus_client.drop_collection(collection_name=COLLECTION_NAME)
milvus_client.create_collection(
collection_name=COLLECTION_NAME, schema=schema, index_params=index_params
)
# Outside the loop, now you upsert all the entries at once
milvus_client.insert(collection_name=COLLECTION_NAME, data=entries)
Configurer votre LLM pour RAG
Pour construire une application RAG, nous devons déployer un LLM sur BentoCloud. Utilisons le dernier LLM Llama3. Une fois qu'il est opérationnel, il suffit de copier le point de terminaison et le jeton de ce service modèle et de configurer un client pour celui-ci.
BENTO_LLM_END_POINT = "BENTO_LLM_END_POINT"
llm_client = bentoml.SyncHTTPClient(BENTO_LLM_END_POINT, token=BENTO_API_TOKEN)
Instructions LLM
Maintenant, nous configurons les instructions LLM avec l'invite, le contexte et la question. Voici la fonction qui se comporte comme un LLM et qui renvoie la sortie du client dans un format de chaîne.
def dorag(question: str, context: str):
prompt = (
f"You are a helpful assistant. The user has a question. Answer the user question based only on the context: {context}. \n"
f"The user question is {question}"
)
results = llm_client.generate(
max_tokens=1024,
prompt=prompt,
)
res = ""
for result in results:
res += result
return res
Un exemple de RAG
Nous sommes maintenant prêts à poser une question. Cette fonction prend simplement une question et effectue un RAG pour générer le contexte pertinent à partir des informations d'arrière-plan. Ensuite, nous passons le contexte et la question à dorag() et nous obtenons le résultat.
question = "What state is Cambridge in?"
def ask_a_question(question):
embeddings = get_embeddings([question])
res = milvus_client.search(
collection_name=COLLECTION_NAME,
data=embeddings, # search for the one (1) embedding returned as a list of lists
anns_field="embedding", # Search across embeddings
limit=5, # get me the top 5 results
output_fields=["sentence"], # get the sentence/chunk and city
)
sentences = []
for hits in res:
for hit in hits:
print(hit)
sentences.append(hit["entity"]["sentence"])
context = ". ".join(sentences)
return context
context = ask_a_question(question=question)
print(context)
Mise en œuvre de RAG
print(dorag(question=question, context=context))
Pour l'exemple de la question demandant dans quel état se trouve Cambridge, nous pouvons imprimer l'intégralité de la réponse à partir de BentoML. Cependant, si nous prenons le temps de l'analyser, elle est plus jolie et devrait nous indiquer que Cambridge est situé dans le Massachusetts.