milvus-logo
LFAI
Home
  • Integraciones

Búsqueda de imágenes con Milvus

En esta página, vamos a repasar un ejemplo sencillo de búsqueda de imágenes utilizando Milvus. El conjunto de datos que estamos buscando es el Impressionist-Classifier Dataset que se encuentra en Kaggle. Para este ejemplo, hemos vuelto a alojar los datos en una unidad de Google pública.

Para este ejemplo, sólo estamos utilizando el modelo preentrenado Resnet50 de Torchvision para las incrustaciones. ¡Vamos a empezar!

Instalación de los requisitos

Para este ejemplo, vamos a utilizar pymilvus para conectarnos y utilizar Milvus, torch para ejecutar el modelo de incrustación, torchvision para el modelo real y el preprocesamiento, gdown para descargar el conjunto de datos de ejemplo y tqdm para cargar las barras.

pip install pymilvus torch gdown torchvision tqdm

Obtener los datos

Vamos a utilizar gdown para descargar el archivo zip de Google Drive y descomprimirlo con la biblioteca integrada zipfile.

import gdown
import zipfile

url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)

with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
    zip_ref.extractall("./paintings")

El tamaño del conjunto de datos es de 2,35 GB, y el tiempo de descarga depende del estado de la red.

Argumentos globales

Estos son algunos de los principales argumentos globales que utilizaremos para facilitar el seguimiento y la actualización.

# Milvus Setup Arguments
COLLECTION_NAME = 'image_search'  # Collection name
DIMENSION = 2048  # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"

# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3

Configuración de Milvus

Llegados a este punto, vamos a empezar a configurar Milvus. Los pasos son los siguientes:

  1. Conéctese a la instancia de Milvus utilizando el URI proporcionado.

    from pymilvus import connections
    
    # Connect to the instance
    connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
    
  2. Si la colección ya existe, elimínela.

    from pymilvus import utility
    
    # Remove any previous collections with the same name
    if utility.has_collection(COLLECTION_NAME):
        utility.drop_collection(COLLECTION_NAME)
    
  3. Cree la colección que contiene el ID, la ruta del archivo de la imagen y su incrustación.

    from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
    
    # Create collection which includes the id, filepath of the image, and image embedding
    fields = [
        FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200),  # VARCHARS need a maximum length, so for this example they are set to 200 characters
        FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
    ]
    schema = CollectionSchema(fields=fields)
    collection = Collection(name=COLLECTION_NAME, schema=schema)
    
  4. Cree un índice en la colección recién creada y cárguela en memoria.

    # Create an AutoIndex index for collection
    index_params = {
    'metric_type':'L2',
    'index_type':"IVF_FLAT",
    'params':{'nlist': 16384}
    }
    collection.create_index(field_name="image_embedding", index_params=index_params)
    collection.load()
    

Una vez realizados estos pasos, la colección estará lista para ser insertada y buscada. Cualquier dato añadido se indexará automáticamente y estará disponible para la búsqueda de forma inmediata. Si los datos son muy recientes, la búsqueda puede ser más lenta, ya que se utilizará la búsqueda de fuerza bruta en los datos que aún están en proceso de indexación.

Insertar los datos

Para este ejemplo, vamos a utilizar el modelo ResNet50 proporcionado por torch y su hub de modelos. Para obtener las incrustaciones, vamos a quitar la última capa de clasificación, lo que resulta en que el modelo nos da incrustaciones de 2048 dimensiones. Todos los modelos de visión que se encuentran en torch utilizan el mismo preprocesamiento que hemos incluido aquí.

En estos próximos pasos vamos a:

  1. Cargar los datos.

    import glob
    
    # Get the filepaths of the images
    paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True)
    len(paths)
    
  2. Preprocesar los datos en lotes.

    import torch
    
    # Load the embedding model with the last layer removed
    model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
    model = torch.nn.Sequential(*(list(model.children())[:-1]))
    model.eval()
    
  3. Incrustar los datos.

    from torchvision import transforms
    
    # Preprocessing for images
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
  4. Insertar los datos.

    from PIL import Image
    from tqdm import tqdm
    
    # Embed function that embeds the batch and inserts it
    def embed(data):
        with torch.no_grad():
            output = model(torch.stack(data[0])).squeeze()
            collection.insert([data[1], output.tolist()])
    
    data_batch = [[],[]]
    
    # Read the images into batches for embedding and insertion
    for path in tqdm(paths):
        im = Image.open(path).convert('RGB')
        data_batch[0].append(preprocess(im))
        data_batch[1].append(path)
        if len(data_batch[0]) % BATCH_SIZE == 0:
            embed(data_batch)
            data_batch = [[],[]]
    
    # Embed and insert the remainder
    if len(data_batch[0]) != 0:
        embed(data_batch)
    
    # Call a flush to index any unsealed segments.
    collection.flush()
    
    • Este paso es relativamente largo porque la incrustación lleva tiempo. Tome un sorbo de café y relájese.
    • PyTorch puede no funcionar bien con Python 3.9 y versiones anteriores. Considere usar Python 3.10 y versiones posteriores en su lugar.

Con todos los datos insertados en Milvus, podemos empezar a realizar nuestras búsquedas. En este ejemplo, vamos a buscar dos imágenes de ejemplo. Como estamos realizando una búsqueda por lotes, el tiempo de búsqueda se comparte entre todas las imágenes del lote.

import glob

# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
len(search_paths)
import time
from matplotlib import pyplot as plt

# Embed the search images
def embed(data):
    with torch.no_grad():
        ret = model(torch.stack(data))
        # If more than one image, use squeeze
        if len(ret) > 1:
            return ret.squeeze().tolist()
        # Squeeze would remove batch for single image, so using flatten
        else:
            return torch.flatten(ret, start_dim=1).tolist()

data_batch = [[],[]]

for path in search_paths:
    im = Image.open(path).convert('RGB')
    data_batch[0].append(preprocess(im))
    data_batch[1].append(path)

embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()
# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)

for hits_i, hits in enumerate(res):
    axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
    axarr[hits_i][0].set_axis_off()
    axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
    for hit_i, hit in enumerate(hits):
        axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
        axarr[hits_i][hit_i + 1].set_axis_off()
        axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))

# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')

La imagen resultante de la búsqueda debería ser similar a la siguiente:

Image search output Resultado de la búsqueda de imágenes

Traducido porDeepLogo

Feedback

¿Fue útil esta página?