milvus-logo
LFAI
Home
  • Intégrations

Recherche d'images avec Milvus

Sur cette page, nous allons étudier un exemple simple de recherche d'images à l'aide de Milvus. L'ensemble de données que nous recherchons est l'ensemble de données Impressionist-Classifier trouvé sur Kaggle. Pour cet exemple, nous avons réhébergé les données dans un Google Drive public.

Pour cet exemple, nous utilisons simplement le modèle Resnet50 pré-entraîné de Torchvision pour les embeddings. Commençons à travailler !

Installer les prérequis

Pour cet exemple, nous allons utiliser pymilvus pour nous connecter à Milvus, torch pour exécuter le modèle d'intégration, torchvision pour le modèle proprement dit et le prétraitement, gdown pour télécharger l'ensemble de données d'exemple et tqdm pour charger les barres.

pip install pymilvus torch gdown torchvision tqdm

Récupérer les données

Nous allons utiliser gdown pour récupérer le fichier zip de Google Drive et le décompresser avec la bibliothèque intégrée zipfile.

import gdown
import zipfile

url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)

with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
    zip_ref.extractall("./paintings")

La taille du jeu de données est de 2,35 Go, et le temps passé à le télécharger dépend de l'état de votre réseau.

Arguments globaux

Voici quelques-uns des principaux arguments globaux que nous utiliserons pour faciliter le suivi et la mise à jour.

# Milvus Setup Arguments
COLLECTION_NAME = 'image_search'  # Collection name
DIMENSION = 2048  # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"

# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3

Mise en place de Milvus

À ce stade, nous allons commencer à configurer Milvus. Les étapes sont les suivantes :

  1. Connectez-vous à l'instance Milvus à l'aide de l'URI fourni.

    from pymilvus import connections
    
    # Connect to the instance
    connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
    
  2. Si la collection existe déjà, la supprimer.

    from pymilvus import utility
    
    # Remove any previous collections with the same name
    if utility.has_collection(COLLECTION_NAME):
        utility.drop_collection(COLLECTION_NAME)
    
  3. Créez la collection qui contient l'ID, le chemin de fichier de l'image et son intégration.

    from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
    
    # Create collection which includes the id, filepath of the image, and image embedding
    fields = [
        FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200),  # VARCHARS need a maximum length, so for this example they are set to 200 characters
        FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
    ]
    schema = CollectionSchema(fields=fields)
    collection = Collection(name=COLLECTION_NAME, schema=schema)
    
  4. Créez un index sur la collection nouvellement créée et chargez-la en mémoire.

    # Create an AutoIndex index for collection
    index_params = {
    'metric_type':'L2',
    'index_type':"IVF_FLAT",
    'params':{'nlist': 16384}
    }
    collection.create_index(field_name="image_embedding", index_params=index_params)
    collection.load()
    

Une fois ces étapes effectuées, la collection est prête à être insérée et à faire l'objet de recherches. Toutes les données ajoutées seront indexées automatiquement et pourront être recherchées immédiatement. Si les données sont très récentes, la recherche peut être plus lente car une recherche par force brute sera utilisée sur les données qui sont encore en cours d'indexation.

Insérer les données

Pour cet exemple, nous allons utiliser le modèle ResNet50 fourni par torch et son hub de modèles. Pour obtenir les embeddings, nous enlevons la dernière couche de classification, ce qui fait que le modèle nous donne des embeddings de 2048 dimensions. Tous les modèles de vision trouvés sur torch utilisent le même prétraitement que celui que nous avons inclus ici.

Au cours des prochaines étapes, nous allons

  1. Charger les données.

    import glob
    
    # Get the filepaths of the images
    paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True)
    len(paths)
    
  2. Prétraitement des données en lots.

    import torch
    
    # Load the embedding model with the last layer removed
    model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
    model = torch.nn.Sequential(*(list(model.children())[:-1]))
    model.eval()
    
  3. Intégrer les données.

    from torchvision import transforms
    
    # Preprocessing for images
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
  4. Insérer les données.

    from PIL import Image
    from tqdm import tqdm
    
    # Embed function that embeds the batch and inserts it
    def embed(data):
        with torch.no_grad():
            output = model(torch.stack(data[0])).squeeze()
            collection.insert([data[1], output.tolist()])
    
    data_batch = [[],[]]
    
    # Read the images into batches for embedding and insertion
    for path in tqdm(paths):
        im = Image.open(path).convert('RGB')
        data_batch[0].append(preprocess(im))
        data_batch[1].append(path)
        if len(data_batch[0]) % BATCH_SIZE == 0:
            embed(data_batch)
            data_batch = [[],[]]
    
    # Embed and insert the remainder
    if len(data_batch[0]) != 0:
        embed(data_batch)
    
    # Call a flush to index any unsealed segments.
    collection.flush()
    
    • Cette étape est relativement longue car l'incorporation prend du temps. Prenez une gorgée de café et détendez-vous.
    • PyTorch peut ne pas fonctionner correctement avec Python 3.9 et les versions antérieures. Envisagez plutôt d'utiliser Python 3.10 et les versions ultérieures.

Une fois toutes les données insérées dans Milvus, nous pouvons commencer à effectuer nos recherches. Dans cet exemple, nous allons rechercher deux images d'exemple. Comme nous effectuons une recherche par lot, le temps de recherche est partagé entre les images du lot.

import glob

# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
len(search_paths)
import time
from matplotlib import pyplot as plt

# Embed the search images
def embed(data):
    with torch.no_grad():
        ret = model(torch.stack(data))
        # If more than one image, use squeeze
        if len(ret) > 1:
            return ret.squeeze().tolist()
        # Squeeze would remove batch for single image, so using flatten
        else:
            return torch.flatten(ret, start_dim=1).tolist()

data_batch = [[],[]]

for path in search_paths:
    im = Image.open(path).convert('RGB')
    data_batch[0].append(preprocess(im))
    data_batch[1].append(path)

embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()
# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)

for hits_i, hits in enumerate(res):
    axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
    axarr[hits_i][0].set_axis_off()
    axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
    for hit_i, hit in enumerate(hits):
        axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
        axarr[hits_i][hit_i + 1].set_axis_off()
        axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))

# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')

Le résultat de la recherche devrait ressembler à l'image suivante :

Image search output Résultat de la recherche d'images

Traduit parDeepLogo

Feedback

Cette page a-t - elle été utile ?