milvus-logo
LFAI
Casa
  • Tutorial

Ricerca di immagini con Milvus

Open In Colab

In questo quaderno mostreremo come utilizzare Milvus per cercare immagini simili in un set di dati. Per dimostrarlo, utilizzeremo un sottoinsieme del dataset ImageNet e cercheremo un'immagine di un cane afgano.

Preparazione del set di dati

Per prima cosa, è necessario caricare il set di dati e disestrarlo per un'ulteriore elaborazione.

!wget https://github.com/milvus-io/pymilvus-assets/releases/download/imagedata/reverse_image_search.zip
!unzip -q -o reverse_image_search.zip

Requisiti preliminari

Per eseguire questo notebook, è necessario che siano installate le seguenti dipendenze:

  • pymilvus>=2.4.2
  • timm
  • torch
  • numpy
  • sklearn
  • cuscino

Per eseguire Colab, forniamo i comandi pratici per installare le dipendenze necessarie.

$ pip install pymilvus --upgrade
$ pip install timm

Se si utilizza Google Colab, per abilitare le dipendenze appena installate potrebbe essere necessario riavviare il runtime. (Fare clic sul menu "Runtime" nella parte superiore dello schermo e selezionare "Restart session" dal menu a discesa).

Definire l'estrattore di funzioni

Occorre quindi definire un estrattore di caratteristiche che estragga l'embedding da un'immagine utilizzando il modello ResNet-34 di Timm.

import torch
from PIL import Image
import timm
from sklearn.preprocessing import normalize
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform


class FeatureExtractor:
    def __init__(self, modelname):
        # Load the pre-trained model
        self.model = timm.create_model(
            modelname, pretrained=True, num_classes=0, global_pool="avg"
        )
        self.model.eval()

        # Get the input size required by the model
        self.input_size = self.model.default_cfg["input_size"]

        config = resolve_data_config({}, model=modelname)
        # Get the preprocessing function provided by TIMM for the model
        self.preprocess = create_transform(**config)

    def __call__(self, imagepath):
        # Preprocess the input image
        input_image = Image.open(imagepath).convert("RGB")  # Convert to RGB if needed
        input_image = self.preprocess(input_image)

        # Convert the image to a PyTorch tensor and add a batch dimension
        input_tensor = input_image.unsqueeze(0)

        # Perform inference
        with torch.no_grad():
            output = self.model(input_tensor)

        # Extract the feature vector
        feature_vector = output.squeeze().numpy()

        return normalize(feature_vector.reshape(1, -1), norm="l2").flatten()

Creare una raccolta Milvus

Occorre poi creare una collezione Milvus per memorizzare gli embedding delle immagini.

from pymilvus import MilvusClient

# Set up a Milvus client
client = MilvusClient(uri="example.db")
# Create a collection in quick setup mode
if client.has_collection(collection_name="image_embeddings"):
    client.drop_collection(collection_name="image_embeddings")
client.create_collection(
    collection_name="image_embeddings",
    vector_field_name="vector",
    dimension=512,
    auto_id=True,
    enable_dynamic_field=True,
    metric_type="COSINE",
)

Per quanto riguarda l'argomento di MilvusClient:

  • L'impostazione di uri come file locale, ad esempio./milvus.db, è il metodo più conveniente, poiché utilizza automaticamente Milvus Lite per memorizzare tutti i dati in questo file.
  • Se si dispone di una grande quantità di dati, è possibile configurare un server Milvus più performante su docker o kubernetes. In questa configurazione, utilizzare l'uri del server, ad esempiohttp://localhost:19530, come uri.
  • Se si desidera utilizzare Zilliz Cloud, il servizio cloud completamente gestito per Milvus, regolare uri e token, che corrispondono all'endpoint pubblico e alla chiave Api di Zilliz Cloud.

Inserire gli embeddings in Milvus

Estraiamo gli embeddings di ogni immagine utilizzando il modello ResNet34 e inseriamo le immagini del set di addestramento in Milvus.

import os

extractor = FeatureExtractor("resnet34")

root = "./train"
insert = True
if insert is True:
    for dirpath, foldername, filenames in os.walk(root):
        for filename in filenames:
            if filename.endswith(".JPEG"):
                filepath = dirpath + "/" + filename
                image_embedding = extractor(filepath)
                client.insert(
                    "image_embeddings",
                    {"vector": image_embedding, "filename": filepath},
                )
from IPython.display import display

query_image = "./test/Afghan_hound/n02088094_4261.JPEG"

results = client.search(
    "image_embeddings",
    data=[extractor(query_image)],
    output_fields=["filename"],
    search_params={"metric_type": "COSINE"},
)
images = []
for result in results:
    for hit in result[:10]:
        filename = hit["entity"]["filename"]
        img = Image.open(filename)
        img = img.resize((150, 150))
        images.append(img)

width = 150 * 5
height = 150 * 2
concatenated_image = Image.new("RGB", (width, height))

for idx, img in enumerate(images):
    x = idx % 5
    y = idx // 5
    concatenated_image.paste(img, (x * 150, y * 150))
display("query")
display(Image.open(query_image).resize((150, 150)))
display("results")
display(concatenated_image)
'query'

png png

'results'

Results Risultati

Possiamo notare che la maggior parte delle immagini appartiene alla stessa categoria dell'immagine ricercata, ovvero il mastino afgano. Ciò significa che abbiamo trovato immagini simili all'immagine di ricerca.

Distribuzione rapida

Per sapere come avviare una demo online con questa esercitazione, consultare l 'applicazione di esempio.