Recherche d'images avec Milvus
Dans ce carnet, nous allons vous montrer comment utiliser Milvus pour rechercher des images similaires dans un ensemble de données. Nous utiliserons un sous-ensemble de l'ensemble de données ImageNet, puis nous rechercherons une image d'un chien de chasse afghan.
Préparation du jeu de données
Tout d'abord, nous devons charger l'ensemble de données et le désextraire en vue d'un traitement ultérieur.
!wget https://github.com/milvus-io/pymilvus-assets/releases/download/imagedata/reverse_image_search.zip
!unzip -q -o reverse_image_search.zip
Conditions préalables
Pour exécuter ce bloc-notes, les dépendances suivantes doivent être installées :
- pymilvus>=2.4.2
- timm
- torch
- numpy
- sklearn
- oreiller
Pour exécuter Colab, nous fournissons les commandes pratiques pour installer les dépendances nécessaires.
$ pip install pymilvus --upgrade
$ pip install timm
Si vous utilisez Google Colab, pour activer les dépendances qui viennent d'être installées, vous devrez peut-être redémarrer le moteur d'exécution. (Cliquez sur le menu "Runtime" en haut de l'écran, et sélectionnez "Restart session" dans le menu déroulant).
Définir l'extracteur de caractéristiques
Nous devons ensuite définir un extracteur de caractéristiques qui extrait l'intégration d'une image à l'aide du modèle ResNet-34 de Timm.
import torch
from PIL import Image
import timm
from sklearn.preprocessing import normalize
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
class FeatureExtractor:
def __init__(self, modelname):
# Load the pre-trained model
self.model = timm.create_model(
modelname, pretrained=True, num_classes=0, global_pool="avg"
)
self.model.eval()
# Get the input size required by the model
self.input_size = self.model.default_cfg["input_size"]
config = resolve_data_config({}, model=modelname)
# Get the preprocessing function provided by TIMM for the model
self.preprocess = create_transform(**config)
def __call__(self, imagepath):
# Preprocess the input image
input_image = Image.open(imagepath).convert("RGB") # Convert to RGB if needed
input_image = self.preprocess(input_image)
# Convert the image to a PyTorch tensor and add a batch dimension
input_tensor = input_image.unsqueeze(0)
# Perform inference
with torch.no_grad():
output = self.model(input_tensor)
# Extract the feature vector
feature_vector = output.squeeze().numpy()
return normalize(feature_vector.reshape(1, -1), norm="l2").flatten()
Créer une collection Milvus
Ensuite, nous devons créer une collection Milvus pour stocker les embeddings de l'image
from pymilvus import MilvusClient
# Set up a Milvus client
client = MilvusClient(uri="example.db")
# Create a collection in quick setup mode
if client.has_collection(collection_name="image_embeddings"):
client.drop_collection(collection_name="image_embeddings")
client.create_collection(
collection_name="image_embeddings",
vector_field_name="vector",
dimension=512,
auto_id=True,
enable_dynamic_field=True,
metric_type="COSINE",
)
Comme pour l'argument de MilvusClient
:
- Définir
uri
comme un fichier local, par exemple./milvus.db
, est la méthode la plus pratique, car elle utilise automatiquement Milvus Lite pour stocker toutes les données dans ce fichier. - Si vous avez des données à grande échelle, vous pouvez configurer un serveur Milvus plus performant sur docker ou kubernetes. Dans cette configuration, veuillez utiliser l'uri du serveur, par exemple
http://localhost:19530
, comme votreuri
. - Si vous souhaitez utiliser Zilliz Cloud, le service cloud entièrement géré pour Milvus, ajustez les adresses
uri
ettoken
, qui correspondent au point de terminaison public et à la clé Api dans Zilliz Cloud.
Insérer les embeddings dans Milvus
Nous allons extraire les embeddings de chaque image à l'aide du modèle ResNet34 et insérer les images de l'ensemble d'entraînement dans Milvus.
import os
extractor = FeatureExtractor("resnet34")
root = "./train"
insert = True
if insert is True:
for dirpath, foldername, filenames in os.walk(root):
for filename in filenames:
if filename.endswith(".JPEG"):
filepath = dirpath + "/" + filename
image_embedding = extractor(filepath)
client.insert(
"image_embeddings",
{"vector": image_embedding, "filename": filepath},
)
from IPython.display import display
query_image = "./test/Afghan_hound/n02088094_4261.JPEG"
results = client.search(
"image_embeddings",
data=[extractor(query_image)],
output_fields=["filename"],
search_params={"metric_type": "COSINE"},
)
images = []
for result in results:
for hit in result[:10]:
filename = hit["entity"]["filename"]
img = Image.open(filename)
img = img.resize((150, 150))
images.append(img)
width = 150 * 5
height = 150 * 2
concatenated_image = Image.new("RGB", (width, height))
for idx, img in enumerate(images):
x = idx % 5
y = idx // 5
concatenated_image.paste(img, (x * 150, y * 150))
display("query")
display(Image.open(query_image).resize((150, 150)))
display("results")
display(concatenated_image)
'query'
png
'results'
Résultats
Nous pouvons constater que la plupart des images appartiennent à la même catégorie que l'image recherchée, à savoir le chien de chasse afghan. Cela signifie que nous avons trouvé des images similaires à l'image recherchée.
Déploiement rapide
Pour savoir comment démarrer une démo en ligne avec ce tutoriel, veuillez vous référer à l 'exemple d'application.