milvus-logo
LFAI
Home
  • Tutoriales

Búsqueda de imágenes con Milvus

Open In Colab

En este cuaderno, le mostraremos cómo utilizar Milvus para buscar imágenes similares en un conjunto de datos. Utilizaremos un subconjunto del conjunto de datos ImageNet y buscaremos una imagen de un sabueso afgano para demostrarlo.

Preparación del conjunto de datos

En primer lugar, tenemos que cargar el conjunto de datos y extraerlo para su posterior procesamiento.

!wget https://github.com/milvus-io/pymilvus-assets/releases/download/imagedata/reverse_image_search.zip
!unzip -q -o reverse_image_search.zip

Requisitos previos

Para ejecutar este cuaderno, es necesario tener instaladas las siguientes dependencias:

  • pymilvus>=2.4.2
  • timm
  • torch
  • numpy
  • sklearn
  • almohada

Para ejecutar Colab, proporcionamos los prácticos comandos para instalar las dependencias necesarias.

$ pip install pymilvus --upgrade
$ pip install timm

Si utilizas Google Colab, para activar las dependencias que acabas de instalar, es posible que tengas que reiniciar el tiempo de ejecución. (Haga clic en el menú "Tiempo de ejecución" en la parte superior de la pantalla y seleccione "Reiniciar sesión" en el menú desplegable).

Definir el extractor de características

A continuación, necesitamos definir un extractor de características que extraiga la incrustación de una imagen utilizando el modelo ResNet-34 de timm.

import torch
from PIL import Image
import timm
from sklearn.preprocessing import normalize
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform


class FeatureExtractor:
    def __init__(self, modelname):
        # Load the pre-trained model
        self.model = timm.create_model(
            modelname, pretrained=True, num_classes=0, global_pool="avg"
        )
        self.model.eval()

        # Get the input size required by the model
        self.input_size = self.model.default_cfg["input_size"]

        config = resolve_data_config({}, model=modelname)
        # Get the preprocessing function provided by TIMM for the model
        self.preprocess = create_transform(**config)

    def __call__(self, imagepath):
        # Preprocess the input image
        input_image = Image.open(imagepath).convert("RGB")  # Convert to RGB if needed
        input_image = self.preprocess(input_image)

        # Convert the image to a PyTorch tensor and add a batch dimension
        input_tensor = input_image.unsqueeze(0)

        # Perform inference
        with torch.no_grad():
            output = self.model(input_tensor)

        # Extract the feature vector
        feature_vector = output.squeeze().numpy()

        return normalize(feature_vector.reshape(1, -1), norm="l2").flatten()

Crear una colección Milvus

A continuación, tenemos que crear una colección Milvus para almacenar las incrustaciones de la imagen

from pymilvus import MilvusClient

# Set up a Milvus client
client = MilvusClient(uri="example.db")
# Create a collection in quick setup mode
if client.has_collection(collection_name="image_embeddings"):
    client.drop_collection(collection_name="image_embeddings")
client.create_collection(
    collection_name="image_embeddings",
    vector_field_name="vector",
    dimension=512,
    auto_id=True,
    enable_dynamic_field=True,
    metric_type="COSINE",
)

En cuanto al argumento de MilvusClient:

  • Establecer el uri como un archivo local, por ejemplo./milvus.db, es el método más conveniente, ya que utiliza automáticamente Milvus Lite para almacenar todos los datos en este archivo.
  • Si tiene una gran escala de datos, puede configurar un servidor Milvus más eficiente en docker o kubernetes. En esta configuración, por favor utilice la uri del servidor, por ejemplohttp://localhost:19530, como su uri.
  • Si desea utilizar Zilliz Cloud, el servicio en la nube totalmente gestionado para Milvus, ajuste los uri y token, que corresponden al Public Endpoint y a la clave Api en Zilliz Cloud.

Insertar los Embeddings en Milvus

Extraeremos los embeddings de cada imagen utilizando el modelo ResNet34 e insertaremos las imágenes del conjunto de entrenamiento en Milvus.

import os

extractor = FeatureExtractor("resnet34")

root = "./train"
insert = True
if insert is True:
    for dirpath, foldername, filenames in os.walk(root):
        for filename in filenames:
            if filename.endswith(".JPEG"):
                filepath = dirpath + "/" + filename
                image_embedding = extractor(filepath)
                client.insert(
                    "image_embeddings",
                    {"vector": image_embedding, "filename": filepath},
                )
from IPython.display import display

query_image = "./test/Afghan_hound/n02088094_4261.JPEG"

results = client.search(
    "image_embeddings",
    data=[extractor(query_image)],
    output_fields=["filename"],
    search_params={"metric_type": "COSINE"},
)
images = []
for result in results:
    for hit in result[:10]:
        filename = hit["entity"]["filename"]
        img = Image.open(filename)
        img = img.resize((150, 150))
        images.append(img)

width = 150 * 5
height = 150 * 2
concatenated_image = Image.new("RGB", (width, height))

for idx, img in enumerate(images):
    x = idx % 5
    y = idx // 5
    concatenated_image.paste(img, (x * 150, y * 150))
display("query")
display(Image.open(query_image).resize((150, 150)))
display("results")
display(concatenated_image)
'query'

png png

'results'

Results Resultados

Podemos ver que la mayoría de las imágenes son de la misma categoría que la imagen de búsqueda, que es el sabueso afgano. Esto significa que hemos encontrado imágenes similares a la imagen de búsqueda.

Despliegue rápido

Para saber cómo iniciar una demostración en línea con este tutorial, consulte la aplicación de ejemplo.