Recherche d'images avec Milvus
Sur cette page, nous allons étudier un exemple simple de recherche d'images à l'aide de Milvus. L'ensemble de données que nous recherchons est l'ensemble de données Impressionist-Classifier trouvé sur Kaggle. Pour cet exemple, nous avons réhébergé les données dans un Google Drive public.
Pour cet exemple, nous utilisons simplement le modèle Resnet50 pré-entraîné de Torchvision pour les embeddings. Commençons à travailler !
Installer les prérequis
Pour cet exemple, nous allons utiliser pymilvus
pour nous connecter à Milvus, torch
pour exécuter le modèle d'intégration, torchvision
pour le modèle proprement dit et le prétraitement, gdown
pour télécharger l'ensemble de données d'exemple et tqdm
pour charger les barres.
pip install pymilvus torch gdown torchvision tqdm
Récupérer les données
Nous allons utiliser gdown
pour récupérer le fichier zip de Google Drive et le décompresser avec la bibliothèque intégrée zipfile
.
import gdown
import zipfile
url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)
with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
zip_ref.extractall("./paintings")
La taille du jeu de données est de 2,35 Go, et le temps passé à le télécharger dépend de l'état de votre réseau.
Arguments globaux
Voici quelques-uns des principaux arguments globaux que nous utiliserons pour faciliter le suivi et la mise à jour.
# Milvus Setup Arguments
COLLECTION_NAME = 'image_search' # Collection name
DIMENSION = 2048 # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"
# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3
Mise en place de Milvus
À ce stade, nous allons commencer à configurer Milvus. Les étapes sont les suivantes :
Connectez-vous à l'instance Milvus à l'aide de l'URI fourni.
from pymilvus import connections # Connect to the instance connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
Si la collection existe déjà, la supprimer.
from pymilvus import utility # Remove any previous collections with the same name if utility.has_collection(COLLECTION_NAME): utility.drop_collection(COLLECTION_NAME)
Créez la collection qui contient l'ID, le chemin de fichier de l'image et son intégration.
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection # Create collection which includes the id, filepath of the image, and image embedding fields = [ FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200), # VARCHARS need a maximum length, so for this example they are set to 200 characters FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION) ] schema = CollectionSchema(fields=fields) collection = Collection(name=COLLECTION_NAME, schema=schema)
Créez un index sur la collection nouvellement créée et chargez-la en mémoire.
# Create an AutoIndex index for collection index_params = { 'metric_type':'L2', 'index_type':"IVF_FLAT", 'params':{'nlist': 16384} } collection.create_index(field_name="image_embedding", index_params=index_params) collection.load()
Une fois ces étapes effectuées, la collection est prête à être insérée et à faire l'objet de recherches. Toutes les données ajoutées seront indexées automatiquement et pourront être recherchées immédiatement. Si les données sont très récentes, la recherche peut être plus lente car une recherche par force brute sera utilisée sur les données qui sont encore en cours d'indexation.
Insérer les données
Pour cet exemple, nous allons utiliser le modèle ResNet50 fourni par torch
et son hub de modèles. Pour obtenir les embeddings, nous enlevons la dernière couche de classification, ce qui fait que le modèle nous donne des embeddings de 2048 dimensions. Tous les modèles de vision trouvés sur torch
utilisent le même prétraitement que celui que nous avons inclus ici.
Au cours des prochaines étapes, nous allons
Charger les données.
import glob # Get the filepaths of the images paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True) len(paths)
Prétraitement des données en lots.
import torch # Load the embedding model with the last layer removed model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True) model = torch.nn.Sequential(*(list(model.children())[:-1])) model.eval()
Intégrer les données.
from torchvision import transforms # Preprocessing for images preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])
Insérer les données.
from PIL import Image from tqdm import tqdm # Embed function that embeds the batch and inserts it def embed(data): with torch.no_grad(): output = model(torch.stack(data[0])).squeeze() collection.insert([data[1], output.tolist()]) data_batch = [[],[]] # Read the images into batches for embedding and insertion for path in tqdm(paths): im = Image.open(path).convert('RGB') data_batch[0].append(preprocess(im)) data_batch[1].append(path) if len(data_batch[0]) % BATCH_SIZE == 0: embed(data_batch) data_batch = [[],[]] # Embed and insert the remainder if len(data_batch[0]) != 0: embed(data_batch) # Call a flush to index any unsealed segments. collection.flush()
- Cette étape est relativement longue car l'incorporation prend du temps. Prenez une gorgée de café et détendez-vous.
- PyTorch peut ne pas fonctionner correctement avec Python 3.9 et les versions antérieures. Envisagez plutôt d'utiliser Python 3.10 et les versions ultérieures.
Effectuer la recherche
Une fois toutes les données insérées dans Milvus, nous pouvons commencer à effectuer nos recherches. Dans cet exemple, nous allons rechercher deux images d'exemple. Comme nous effectuons une recherche par lot, le temps de recherche est partagé entre les images du lot.
import glob
# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
len(search_paths)
import time
from matplotlib import pyplot as plt
# Embed the search images
def embed(data):
with torch.no_grad():
ret = model(torch.stack(data))
# If more than one image, use squeeze
if len(ret) > 1:
return ret.squeeze().tolist()
# Squeeze would remove batch for single image, so using flatten
else:
return torch.flatten(ret, start_dim=1).tolist()
data_batch = [[],[]]
for path in search_paths:
im = Image.open(path).convert('RGB')
data_batch[0].append(preprocess(im))
data_batch[1].append(path)
embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()
# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)
for hits_i, hits in enumerate(res):
axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
axarr[hits_i][0].set_axis_off()
axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
for hit_i, hit in enumerate(hits):
axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
axarr[hits_i][hit_i + 1].set_axis_off()
axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))
# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')
Le résultat de la recherche devrait ressembler à l'image suivante :
Résultat de la recherche d'images