Construire RAG avec Milvus, vLLM, et Llama 3.1
L'Université de Californie - Berkeley a fait don de vLLM, une bibliothèque rapide et facile à utiliser pour l'inférence et le service LLM, à la Fondation LF AI & Data en tant que projet en phase d'incubation en juillet 2024. En tant que projet membre, nous souhaitons la bienvenue à vLLM qui rejoint la famille LF AI & Data ! 🎉
Les grands modèles de langage(LLM) et les bases de données vectorielles sont généralement associés pour construire Retrieval Augmented Generation(RAG), une architecture d'application d'IA populaire pour répondre aux hallucinations de l'IA. Ce blog vous montrera comment construire et exécuter un RAG avec Milvus, vLLM, et Llama 3.1. Plus précisément, je vous montrerai comment intégrer et stocker des informations textuelles sous forme d'embeddings vectoriels dans Milvus et utiliser ce stockage vectoriel comme base de connaissances pour récupérer efficacement des morceaux de texte pertinents pour les questions de l'utilisateur. Enfin, nous utiliserons vLLM pour servir le modèle Llama 3.1-8B de Meta afin de générer des réponses enrichies par le texte récupéré. Plongeons dans l'aventure !
Introduction à Milvus, vLLM et Meta's Llama 3.1
Base de données vectorielles Milvus
Milvus est une base de données vectorielles distribuée à code source ouvert, conçue spécialement pour le stockage, l'indexation et la recherche de vecteurs pour les charges de travail de l'IA générative (GenAI). Sa capacité à effectuer une recherche hybride, un filtrage des métadonnées, un reclassement et à gérer efficacement des trillions de vecteurs fait de Milvus un choix de premier ordre pour les charges de travail d'IA et d'apprentissage automatique. Milvus peut être exécuté localement, sur un cluster ou hébergé dans le Zilliz Cloud entièrement géré.
vLLM
vLLM est un projet open-source lancé au SkyLab de l'Université de Berkeley et axé sur l'optimisation des performances de service LLM. Il utilise une gestion efficace de la mémoire avec PagedAttention, une mise en lot continue et des noyaux CUDA optimisés. Par rapport aux méthodes traditionnelles, vLLM améliore les performances de service jusqu'à 24 fois tout en réduisant de moitié l'utilisation de la mémoire du GPU.
Selon l'article "Efficient Memory Management for Large Language Model Serving with PagedAttention", le cache KV utilise environ 30 % de la mémoire du GPU, ce qui peut entraîner des problèmes de mémoire. Le cache KV est stocké dans une mémoire contiguë, mais une modification de sa taille peut entraîner une fragmentation de la mémoire, ce qui est inefficace pour les calculs.
Image 1. Gestion de la mémoire cache KV dans les systèmes existants (2023 Paged Attention paper)
En utilisant la mémoire virtuelle pour le cache KV, vLLM n'alloue la mémoire physique du GPU qu'en cas de besoin, ce qui élimine la fragmentation de la mémoire et évite la pré-allocation. Lors des tests, vLLM a surpassé HuggingFace Transformers (HF) et Text Generation Inference (TGI), atteignant un débit jusqu'à 24 fois plus élevé que HF et 3,5 fois plus élevé que TGI sur les GPU NVIDIA A10G et A100.
Image 2. Débit de service lorsque chaque requête demande trois sorties parallèles. vLLM atteint un débit 8,5x-15x plus élevé que HF et 3,3x-3,5x plus élevé que TGI (2023 vLLM blog).
Le lama de Meta 3.1
Meta's Llama 3.1 a été annoncé le 23 juillet 2024. Le modèle 405B offre des performances de pointe sur plusieurs benchmarks publics et dispose d'une fenêtre contextuelle de 128 000 jetons d'entrée avec diverses utilisations commerciales autorisées. Parallèlement au modèle à 405 milliards de paramètres, Meta a publié une version mise à jour du Llama3 70B (70 milliards de paramètres) et 8B (8 milliards de paramètres). Les poids des modèles peuvent être téléchargés sur le site web de Meta.
L'une des principales conclusions est qu'un réglage fin des données générées peut améliorer les performances, mais que des exemples de mauvaise qualité peuvent les dégrader. L'équipe Llama a beaucoup travaillé pour identifier et supprimer ces mauvais exemples en utilisant le modèle lui-même, des modèles auxiliaires et d'autres outils.
Construire et exécuter le RAG-Retrieval avec Milvus
Préparez votre ensemble de données.
Pour cette démonstration, j'ai utilisé la documentation officielle de Milvus, que j'ai téléchargée et sauvegardée localement.
from langchain.document_loaders import DirectoryLoader
# Load HTML files already saved in a local directory
path = "../../RAG/rtdocs_new/"
global_pattern = '*.html'
loader = DirectoryLoader(path=path, glob=global_pattern)
docs = loader.load()
# Print num documents and a preview.
print(f"loaded {len(docs)} documents")
print(docs[0].page_content)
pprint.pprint(docs[0].metadata)
loaded 22 documents
Why Milvus Docs Tutorials Tools Blog Community Stars0 Try Managed Milvus FREE Search Home v2.4.x About ...
{'source': 'https://milvus.io/docs/quickstart.md'}
Téléchargez un modèle d'intégration.
Ensuite, téléchargez un modèle d'intégration gratuit et open-source à partir de HuggingFace.
import torch
from sentence_transformers import SentenceTransformer
# Initialize torch settings for device-agnostic code.
N_GPU = torch.cuda.device_count()
DEVICE = torch.device('cuda:N_GPU' if torch.cuda.is_available() else 'cpu')
# Download the model from huggingface model hub.
model_name = "BAAI/bge-large-en-v1.5"
encoder = SentenceTransformer(model_name, device=DEVICE)
# Get the model parameters and save for later.
EMBEDDING_DIM = encoder.get_sentence_embedding_dimension()
MAX_SEQ_LENGTH_IN_TOKENS = encoder.get_max_seq_length()
# Inspect model parameters.
print(f"model_name: {model_name}")
print(f"EMBEDDING_DIM: {EMBEDDING_DIM}")
print(f"MAX_SEQ_LENGTH: {MAX_SEQ_LENGTH}")
model_name: BAAI/bge-large-en-v1.5
EMBEDDING_DIM: 1024
MAX_SEQ_LENGTH: 512
Décomposez et encodez vos données personnalisées sous forme de vecteurs.
J'utiliserai une longueur fixe de 512 caractères avec un chevauchement de 10 %.
from langchain.text_splitter import RecursiveCharacterTextSplitter
CHUNK_SIZE = 512
chunk_overlap = np.round(CHUNK_SIZE * 0.10, 0)
print(f"chunk_size: {CHUNK_SIZE}, chunk_overlap: {chunk_overlap}")
# Define the splitter.
child_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=chunk_overlap)
# Chunk the docs.
chunks = child_splitter.split_documents(docs)
print(f"{len(docs)} docs split into {len(chunks)} child documents.")
# Encoder input is doc.page_content as strings.
list_of_strings = [doc.page_content for doc in chunks if hasattr(doc, 'page_content')]
# Embedding inference using HuggingFace encoder.
embeddings = torch.tensor(encoder.encode(list_of_strings))
# Normalize the embeddings.
embeddings = np.array(embeddings / np.linalg.norm(embeddings))
# Milvus expects a list of `numpy.ndarray` of `numpy.float32` numbers.
converted_values = list(map(np.float32, embeddings))
# Create dict_list for Milvus insertion.
dict_list = []
for chunk, vector in zip(chunks, converted_values):
# Assemble embedding vector, original text chunk, metadata.
chunk_dict = {
'chunk': chunk.page_content,
'source': chunk.metadata.get('source', ""),
'vector': vector,
}
dict_list.append(chunk_dict)
chunk_size: 512, chunk_overlap: 51.0
22 docs split into 355 child documents.
Enregistrez les vecteurs dans Milvus.
Intégrez les vecteurs encodés dans la base de données vectorielle de Milvus.
# Connect a client to the Milvus Lite server.
from pymilvus import MilvusClient
mc = MilvusClient("milvus_demo.db")
# Create a collection with flexible schema and AUTOINDEX.
COLLECTION_NAME = "MilvusDocs"
mc.create_collection(COLLECTION_NAME,
EMBEDDING_DIM,
consistency_level="Eventually",
auto_id=True,
overwrite=True)
# Insert data into the Milvus collection.
print("Start inserting entities")
start_time = time.time()
mc.insert(
COLLECTION_NAME,
data=dict_list,
progress_bar=True)
end_time = time.time()
print(f"Milvus insert time for {len(dict_list)} vectors: ", end="")
print(f"{round(end_time - start_time, 2)} seconds")
Start inserting entities
Milvus insert time for 355 vectors: 0.2 seconds
Effectuer une recherche de vecteurs.
Posez une question et recherchez les morceaux les plus proches de votre base de connaissances dans Milvus.
SAMPLE_QUESTION = "What do the parameters for HNSW mean?"
# Embed the question using the same encoder.
query_embeddings = torch.tensor(encoder.encode(SAMPLE_QUESTION))
# Normalize embeddings to unit length.
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
# Convert the embeddings to list of list of np.float32.
query_embeddings = list(map(np.float32, query_embeddings))
# Define metadata fields you can filter on.
OUTPUT_FIELDS = list(dict_list[0].keys())
OUTPUT_FIELDS.remove('vector')
# Define how many top-k results you want to retrieve.
TOP_K = 2
# Run semantic vector search using your query and the vector database.
results = mc.search(
COLLECTION_NAME,
data=query_embeddings,
output_fields=OUTPUT_FIELDS,
limit=TOP_K,
consistency_level="Eventually")
Le résultat de la recherche est illustré ci-dessous.
Retrieved result #1
distance = 0.7001987099647522
('Chunk text: layer, finds the node closest to the target in this layer, and'
...
'outgoing')
source: https://milvus.io/docs/index.md
Retrieved result #2
distance = 0.6953287124633789
('Chunk text: this value can improve recall rate at the cost of increased'
...
'to the target')
source: https://milvus.io/docs/index.md
Construire et réaliser la génération RAG avec vLLM et Llama 3.1-8B
Installer vLLM et les modèles de HuggingFace
Par défaut, vLLM télécharge de grands modèles linguistiques à partir de HuggingFace. En général, chaque fois que vous voulez utiliser un nouveau modèle sur HuggingFace, vous devriez faire un pip install --upgrade ou -U. De plus, vous aurez besoin d'un GPU pour exécuter l'inférence des modèles Llama 3.1 de Meta avec vLLM.
Pour une liste complète de tous les modèles supportés par vLLM, voir cette page de documentation.
# (Recommended) Create a new conda environment.
conda create -n myenv python=3.11 -y
conda activate myenv
# Install vLLM with CUDA 12.1.
pip install -U vllm transformers torch
import vllm, torch
from vllm import LLM, SamplingParams
# Clear the GPU memory cache.
torch.cuda.empty_cache()
# Check the GPU.
!nvidia-smi
Pour en savoir plus sur l'installation de vLLM, voir sa page d'installation.
Obtenir un jeton HuggingFace.
Certains modèles sur HuggingFace, comme Meta Llama 3.1, requièrent que l'utilisateur accepte leur licence avant de pouvoir télécharger les poids. Par conséquent, vous devez créer un compte HuggingFace, accepter la licence du modèle et générer un jeton.
En visitant cette page Llama3.1 sur HuggingFace, vous recevrez un message vous demandant d'accepter les termes. Cliquez sur "Accepter la licence" pour accepter les termes de Meta avant de télécharger les poids du modèle. L'approbation prend généralement moins d'une journée.
Après avoir reçu l'approbation, vous devez générer un nouveau jeton HuggingFace. Vos anciens jetons ne fonctionneront pas avec les nouvelles autorisations.
Avant d'installer vLLM, connectez-vous à HuggingFace avec votre nouveau jeton. Ci-dessous, j'ai utilisé Colab secrets pour stocker le jeton.
# Login to HuggingFace using your new token.
from huggingface_hub import login
from google.colab import userdata
hf_token = userdata.get('HF_TOKEN')
login(token = hf_token, add_to_git_credential=True)
Exécuter la génération RAG
Dans la démo, nous exécutons le modèle Llama-3.1-8B
, qui nécessite un GPU et une mémoire importante pour tourner. L'exemple suivant a été exécuté sur Google Colab Pro (10 $/mois) avec un GPU A100. Pour en savoir plus sur l'exécution de vLLM, vous pouvez consulter la documentation Quickstart.
# 1. Choose a model
MODELTORUN = "meta-llama/Meta-Llama-3.1-8B-Instruct"
# 2. Clear the GPU memory cache, you're going to need it all!
torch.cuda.empty_cache()
# 3. Instantiate a vLLM model instance.
llm = LLM(model=MODELTORUN,
enforce_eager=True,
dtype=torch.bfloat16,
gpu_memory_utilization=0.5,
max_model_len=1000,
seed=415,
max_num_batched_tokens=3000)
Rédigez une invite en utilisant les contextes et les sources récupérés dans Milvus.
# Separate all the context together by space.
contexts_combined = ' '.join(contexts)
# Lance Martin, LangChain, says put the best contexts at the end.
contexts_combined = ' '.join(reversed(contexts))
# Separate all the unique sources together by comma.
source_combined = ' '.join(reversed(list(dict.fromkeys(sources))))
SYSTEM_PROMPT = f"""First, check if the provided Context is relevant to
the user's question. Second, only if the provided Context is strongly relevant, answer the question using the Context. Otherwise, if the Context is not strongly relevant, answer the question without using the Context.
Be clear, concise, relevant. Answer clearly, in fewer than 2 sentences.
Grounding sources: {source_combined}
Context: {contexts_combined}
User's question: {SAMPLE_QUESTION}
"""
prompts = [SYSTEM_PROMPT]
Maintenant, générez une réponse en utilisant les morceaux récupérés et la question originale insérée dans l'invite.
# Sampling parameters
sampling_params = SamplingParams(temperature=0.2, top_p=0.95)
# Invoke the vLLM model.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
# !r calls repr(), which prints a string inside quotes.
print()
print(f"Question: {SAMPLE_QUESTION!r}")
pprint.pprint(f"Generated text: {generated_text!r}")
Question: 'What do the parameters for HNSW MEAN!?'
Generated text: 'Answer: The parameters for HNSW (Hiera(rchical Navigable Small World Graph) are: '
'* M: The maximum degree of nodes on each layer oof the graph, which can improve '
'recall rate at the cost of increased search time. * efConstruction and ef: '
'These parameters specify a search range when building or searching an index.'
La réponse ci-dessus me semble parfaite !
Si cette démo vous intéresse, n'hésitez pas à l'essayer vous-même et à nous faire part de vos impressions. Vous êtes également invités à rejoindre notre communauté Milvus sur Discord pour discuter directement avec tous les développeurs GenAI.
Références
2023 présentation de vLLM au Ray Summit
Blog vLLM : vLLM : Easy, Fast, and Cheap LLM Serving with PagedAttention (en anglais)
Blog utile sur le fonctionnement du serveur vLLM : Déploiement de vLLM : un guide étape par étape