milvus-logo
LFAI
Home
  • Tutoriels

Recherche d'images avec Milvus

Open In Colab

Dans ce carnet, nous allons vous montrer comment utiliser Milvus pour rechercher des images similaires dans un ensemble de données. Nous utiliserons un sous-ensemble de l'ensemble de données ImageNet, puis nous rechercherons une image d'un chien de chasse afghan pour en faire la démonstration.

Préparation du jeu de données

Tout d'abord, nous devons charger l'ensemble de données et le désextraire en vue d'un traitement ultérieur.

!wget https://github.com/milvus-io/pymilvus-assets/releases/download/imagedata/reverse_image_search.zip
!unzip -q -o reverse_image_search.zip

Conditions préalables

Pour exécuter ce bloc-notes, les dépendances suivantes doivent être installées :

  • pymilvus>=2.4.2
  • timm
  • torch
  • numpy
  • sklearn
  • oreiller

Pour exécuter Colab, nous fournissons les commandes pratiques pour installer les dépendances nécessaires.

$ pip install pymilvus --upgrade
$ pip install timm

Si vous utilisez Google Colab, pour activer les dépendances qui viennent d'être installées, vous devrez peut-être redémarrer le moteur d'exécution. (Cliquez sur le menu "Runtime" en haut de l'écran, et sélectionnez "Restart session" dans le menu déroulant).

Définir l'extracteur de caractéristiques

Nous devons ensuite définir un extracteur de caractéristiques qui extrait l'intégration d'une image à l'aide du modèle ResNet-34 de Timm.

import torch
from PIL import Image
import timm
from sklearn.preprocessing import normalize
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform


class FeatureExtractor:
    def __init__(self, modelname):
        # Load the pre-trained model
        self.model = timm.create_model(
            modelname, pretrained=True, num_classes=0, global_pool="avg"
        )
        self.model.eval()

        # Get the input size required by the model
        self.input_size = self.model.default_cfg["input_size"]

        config = resolve_data_config({}, model=modelname)
        # Get the preprocessing function provided by TIMM for the model
        self.preprocess = create_transform(**config)

    def __call__(self, imagepath):
        # Preprocess the input image
        input_image = Image.open(imagepath).convert("RGB")  # Convert to RGB if needed
        input_image = self.preprocess(input_image)

        # Convert the image to a PyTorch tensor and add a batch dimension
        input_tensor = input_image.unsqueeze(0)

        # Perform inference
        with torch.no_grad():
            output = self.model(input_tensor)

        # Extract the feature vector
        feature_vector = output.squeeze().numpy()

        return normalize(feature_vector.reshape(1, -1), norm="l2").flatten()

Créer une collection Milvus

Ensuite, nous devons créer une collection Milvus pour stocker les embeddings de l'image

from pymilvus import MilvusClient

# Set up a Milvus client
client = MilvusClient(uri="example.db")
# Create a collection in quick setup mode
if client.has_collection(collection_name="image_embeddings"):
    client.drop_collection(collection_name="image_embeddings")
client.create_collection(
    collection_name="image_embeddings",
    vector_field_name="vector",
    dimension=512,
    auto_id=True,
    enable_dynamic_field=True,
    metric_type="COSINE",
)

Comme pour l'argument de MilvusClient:

  • Définir uri comme un fichier local, par exemple./milvus.db, est la méthode la plus pratique, car elle utilise automatiquement Milvus Lite pour stocker toutes les données dans ce fichier.
  • Si vous avez des données à grande échelle, vous pouvez configurer un serveur Milvus plus performant sur docker ou kubernetes. Dans cette configuration, veuillez utiliser l'uri du serveur, par exemplehttp://localhost:19530, comme votre uri.
  • Si vous souhaitez utiliser Zilliz Cloud, le service cloud entièrement géré pour Milvus, ajustez les adresses uri et token, qui correspondent au point de terminaison public et à la clé Api dans Zilliz Cloud.

Insérer les embeddings dans Milvus

Nous allons extraire les embeddings de chaque image à l'aide du modèle ResNet34 et insérer les images de l'ensemble d'entraînement dans Milvus.

import os

extractor = FeatureExtractor("resnet34")

root = "./train"
insert = True
if insert is True:
    for dirpath, foldername, filenames in os.walk(root):
        for filename in filenames:
            if filename.endswith(".JPEG"):
                filepath = dirpath + "/" + filename
                image_embedding = extractor(filepath)
                client.insert(
                    "image_embeddings",
                    {"vector": image_embedding, "filename": filepath},
                )
from IPython.display import display

query_image = "./test/Afghan_hound/n02088094_4261.JPEG"

results = client.search(
    "image_embeddings",
    data=[extractor(query_image)],
    output_fields=["filename"],
    search_params={"metric_type": "COSINE"},
)
images = []
for result in results:
    for hit in result[:10]:
        filename = hit["entity"]["filename"]
        img = Image.open(filename)
        img = img.resize((150, 150))
        images.append(img)

width = 150 * 5
height = 150 * 2
concatenated_image = Image.new("RGB", (width, height))

for idx, img in enumerate(images):
    x = idx % 5
    y = idx // 5
    concatenated_image.paste(img, (x * 150, y * 150))
display("query")
display(Image.open(query_image).resize((150, 150)))
display("results")
display(concatenated_image)
'query'

png png

'results'

Results Résultats

Nous pouvons constater que la plupart des images appartiennent à la même catégorie que l'image recherchée, à savoir le chien de chasse afghan. Cela signifie que nous avons trouvé des images similaires à l'image recherchée.

Déploiement rapide

Pour savoir comment démarrer une démo en ligne avec ce tutoriel, veuillez vous référer à l 'exemple d'application.