milvus-logo
LFAI
Home
  • Integraciones

Generación mejorada por recuperación (RAG) con Milvus y BentoML

Open In Colab

Introducción

Esta guía demuestra cómo utilizar un modelo de incrustación de código abierto y un modelo de lenguaje grande en BentoCloud con la base de datos vectorial Milvus para construir una aplicación RAG (Retrieval Augmented Generation). BentoCloud es una plataforma de inferencia de IA para equipos de IA de rápido movimiento, que ofrece una infraestructura totalmente gestionada y adaptada para la inferencia de modelos. Funciona conjuntamente con BentoML, un marco de trabajo de código abierto para el servicio de modelos, para facilitar la creación y el despliegue de servicios de modelos de alto rendimiento. En esta demostración, utilizamos Milvus Lite como base de datos vectorial, que es la versión ligera de Milvus que puede incrustarse en su aplicación Python.

Antes de empezar

Milvus Lite está disponible en PyPI. Puede instalarlo a través de pip para Python 3.8+:

$ pip install -U pymilvus bentoml

Si estás utilizando Google Colab, para habilitar las dependencias que acabas de instalar, puede que necesites reiniciar el runtime (Haz clic en el menú "Runtime" en la parte superior de la pantalla, y selecciona "Restart session" en el menú desplegable).

Después de iniciar sesión en BentoCloud, podemos interactuar con los Servicios BentoCloud desplegados en Deployments, y el END_POINT y API correspondientes se encuentran en Playground -> Python. Puede descargar los datos de la ciudad aquí.

Sirviendo Embeddings con BentoML/BentoCloud

Para utilizar este endpoint, importe bentoml y configure un cliente HTTP utilizando SyncHTTPClient especificando el endpoint y opcionalmente el token (si activa Endpoint Authorization en BentoCloud). Alternativamente, puede utilizar el mismo modelo servido a través de BentoML utilizando su repositorio Sentence Transformers Embeddings.

import bentoml

BENTO_EMBEDDING_MODEL_END_POINT = "BENTO_EMBEDDING_MODEL_END_POINT"
BENTO_API_TOKEN = "BENTO_API_TOKEN"

embedding_client = bentoml.SyncHTTPClient(
    BENTO_EMBEDDING_MODEL_END_POINT, token=BENTO_API_TOKEN
)

Una vez que nos conectamos al embedding_client, necesitamos procesar nuestros datos. Proporcionamos varias funciones para realizar la división e incrustación de datos.

Leer archivos y preprocesar el texto en una lista de cadenas.

# naively chunk on newlines
def chunk_text(filename: str) -> list:
    with open(filename, "r") as f:
        text = f.read()
    sentences = text.split("\n")
    return sentences

Primero tenemos que descargar los datos de la ciudad.

import os
import requests
import urllib.request

# set up the data source
repo = "ytang07/bento_octo_milvus_RAG"
directory = "data"
save_dir = "./city_data"
api_url = f"https://api.github.com/repos/{repo}/contents/{directory}"


response = requests.get(api_url)
data = response.json()

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

for item in data:
    if item["type"] == "file":
        file_url = item["download_url"]
        file_path = os.path.join(save_dir, item["name"])
        urllib.request.urlretrieve(file_url, file_path)

A continuación, procesamos cada uno de los archivos que tenemos.

# please upload your data directory under this file's folder
cities = os.listdir("city_data")
# store chunked text for each of the cities in a list of dicts
city_chunks = []
for city in cities:
    chunked = chunk_text(f"city_data/{city}")
    cleaned = []
    for chunk in chunked:
        if len(chunk) > 7:
            cleaned.append(chunk)
    mapped = {"city_name": city.split(".")[0], "chunks": cleaned}
    city_chunks.append(mapped)

Divide una lista de cadenas en una lista de incrustaciones, cada una agrupa 25 cadenas de texto.

def get_embeddings(texts: list) -> list:
    if len(texts) > 25:
        splits = [texts[x : x + 25] for x in range(0, len(texts), 25)]
        embeddings = []
        for split in splits:
            embedding_split = embedding_client.encode(sentences=split)
            embeddings += embedding_split
        return embeddings
    return embedding_client.encode(
        sentences=texts,
    )

Ahora, tenemos que emparejar los embeddings y los trozos de texto. Como la lista de incrustaciones y la lista de frases deben coincidir por índice, podemos enumerate a través de cualquiera de las listas para emparejarlas.

entries = []
for city_dict in city_chunks:
    # No need for the embeddings list if get_embeddings already returns a list of lists
    embedding_list = get_embeddings(city_dict["chunks"])  # returns a list of lists
    # Now match texts with embeddings and city name
    for i, embedding in enumerate(embedding_list):
        entry = {
            "embedding": embedding,
            "sentence": city_dict["chunks"][
                i
            ],  # Assume "chunks" has the corresponding texts for the embeddings
            "city": city_dict["city_name"],
        }
        entries.append(entry)
    print(entries)

Inserción de datos en una base de datos vectorial para su recuperación

Con nuestras incrustaciones y datos preparados, podemos insertar los vectores junto con los metadatos en Milvus Lite para la búsqueda de vectores más adelante. El primer paso en esta sección es iniciar un cliente conectándose a Milvus Lite. Simplemente importamos el módulo MilvusClient e inicializamos un cliente Milvus Lite que se conecta a su base de datos de vectores Milvus Lite. El tamaño de la dimensión proviene del tamaño del modelo de incrustación, por ejemplo, el modelo del transformador de frases all-MiniLM-L6-v2 produce vectores de 384 dimensiones.

from pymilvus import MilvusClient

COLLECTION_NAME = "Bento_Milvus_RAG"  # random name for your collection
DIMENSION = 384

# Initialize a Milvus Lite client
milvus_client = MilvusClient("milvus_demo.db")

En cuanto al argumento de MilvusClient:

  • Establecer el uri como un archivo local, por ejemplo./milvus.db, es el método más conveniente, ya que utiliza automáticamente Milvus Lite para almacenar todos los datos en este archivo.
  • Si tiene una gran escala de datos, puede configurar un servidor Milvus más eficiente en docker o kubernetes. En esta configuración, por favor utilice la uri del servidor, por ejemplohttp://localhost:19530, como su uri.
  • Si desea utilizar Zilliz Cloud, el servicio en la nube totalmente gestionado para Milvus, ajuste el uri y token, que corresponden al punto final público y la clave Api en Zilliz Cloud.

O con la antigua API connections.connect (no recomendado):

from pymilvus import connections

connections.connect(uri="milvus_demo.db")

Creación de su colección Milvus Lite

Crear una colección usando Milvus Lite implica dos pasos: primero, definir el esquema, y segundo, definir el índice. Para esta sección, necesitamos un módulo: DataType nos dice qué tipo de datos habrá en un campo. También necesitamos utilizar dos funciones para crear el esquema y añadir campos. create_schema(): crea el esquema de una colección, add_field(): añade un campo al esquema de una colección.

from pymilvus import MilvusClient, DataType, Collection

# Create schema
schema = MilvusClient.create_schema(
    auto_id=True,
    enable_dynamic_field=True,
)

# 3.2. Add fields to schema
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=DIMENSION)

Ahora que hemos creado nuestro esquema y definido correctamente el campo de datos, necesitamos definir el índice. En términos de búsqueda, un "índice" define cómo vamos a mapear nuestros datos para su recuperación. En este proyecto utilizaremos la opción AUTOINDEX por defecto para indexar nuestros datos.

A continuación, creamos la colección con el nombre, esquema e índice dados anteriormente. Finalmente, insertamos los datos previamente procesados.

# prepare index parameters
index_params = milvus_client.prepare_index_params()

# add index
index_params.add_index(
    field_name="embedding",
    index_type="AUTOINDEX",  # use autoindex instead of other complex indexing method
    metric_type="COSINE",  # L2, COSINE, or IP
)

# create collection
if milvus_client.has_collection(collection_name=COLLECTION_NAME):
    milvus_client.drop_collection(collection_name=COLLECTION_NAME)
milvus_client.create_collection(
    collection_name=COLLECTION_NAME, schema=schema, index_params=index_params
)

# Outside the loop, now you upsert all the entries at once
milvus_client.insert(collection_name=COLLECTION_NAME, data=entries)

Configura tu LLM para RAG

Para construir una aplicación RAG, necesitamos desplegar un LLM en BentoCloud. Vamos a utilizar el último LLM Llama3. Una vez que esté funcionando, simplemente copie el endpoint y el token de este servicio modelo y configure un cliente para él.

BENTO_LLM_END_POINT = "BENTO_LLM_END_POINT"

llm_client = bentoml.SyncHTTPClient(BENTO_LLM_END_POINT, token=BENTO_API_TOKEN)

Instrucciones LLM

Ahora, configuramos las instrucciones LLM con el prompt, el contexto y la pregunta. Aquí está la función que se comporta como un LLM y luego devuelve la salida del cliente en un formato de cadena.

def dorag(question: str, context: str):

    prompt = (
        f"You are a helpful assistant. The user has a question. Answer the user question based only on the context: {context}. \n"
        f"The user question is {question}"
    )

    results = llm_client.generate(
        max_tokens=1024,
        prompt=prompt,
    )

    res = ""
    for result in results:
        res += result

    return res

Un ejemplo RAG

Ahora estamos listos para hacer una pregunta. Esta función simplemente toma una pregunta y luego hace RAG para generar el contexto relevante a partir de la información de fondo. A continuación, pasamos el contexto y la pregunta a dorag() y obtenemos el resultado.

question = "What state is Cambridge in?"


def ask_a_question(question):
    embeddings = get_embeddings([question])
    res = milvus_client.search(
        collection_name=COLLECTION_NAME,
        data=embeddings,  # search for the one (1) embedding returned as a list of lists
        anns_field="embedding",  # Search across embeddings
        limit=5,  # get me the top 5 results
        output_fields=["sentence"],  # get the sentence/chunk and city
    )

    sentences = []
    for hits in res:
        for hit in hits:
            print(hit)
            sentences.append(hit["entity"]["sentence"])
    context = ". ".join(sentences)
    return context


context = ask_a_question(question=question)
print(context)

Implementación de RAG

print(dorag(question=question, context=context))

Para la pregunta de ejemplo que pregunta en qué estado se encuentra Cambridge, podemos imprimir toda la respuesta desde BentoML. Sin embargo, si nos tomamos la molestia de analizarla, tendrá un aspecto más agradable y nos dirá que Cambridge se encuentra en Massachusetts.