🚀 Prueba Zilliz Cloud, el Milvus completamente gestionado, gratis—¡experimenta un rendimiento 10 veces más rápido! Prueba Ahora>>

Milvus
Zilliz
Home
  • Integraciones
  • Home
  • Docs
  • Integraciones

  • Modelos de incrustación

  • PyTorch

Búsqueda de imágenes con PyTorch y Milvus

Esta guía presenta un ejemplo de integración de PyTorch y Milvus para realizar búsquedas de imágenes utilizando incrustaciones. PyTorch es un potente marco de aprendizaje profundo de código abierto ampliamente utilizado para construir y desplegar modelos de aprendizaje automático. En este ejemplo, aprovecharemos su biblioteca Torchvision y un modelo ResNet50 preentrenado para generar vectores de características (incrustaciones) que representen el contenido de la imagen. Estas incrustaciones se almacenarán en Milvus, una base de datos vectorial de alto rendimiento, para permitir una búsqueda eficiente de similitudes. El conjunto de datos utilizado es el Impressionist-Classifier Dataset de Kaggle. Combinando las capacidades de aprendizaje profundo de PyTorch con la funcionalidad de búsqueda escalable de Milvus, este ejemplo demuestra cómo construir un sistema de recuperación de imágenes robusto y eficiente.

¡Vamos a empezar!

Instalación de los requisitos

Para este ejemplo, vamos a utilizar pymilvus para conectarnos y utilizar Milvus, torch para ejecutar el modelo de incrustación, torchvision para el modelo real y el preprocesamiento, gdown para descargar el conjunto de datos de ejemplo y tqdm para cargar las barras.

pip install pymilvus torch gdown torchvision tqdm

Obtener los datos

Vamos a utilizar gdown para descargar el archivo zip de Google Drive y descomprimirlo con la biblioteca integrada zipfile.

import gdown
import zipfile

url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)

with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
    zip_ref.extractall("./paintings")

El tamaño del conjunto de datos es de 2,35 GB, y el tiempo de descarga dependerá del estado de la red.

Argumentos globales

Estos son algunos de los principales argumentos globales que utilizaremos para facilitar el seguimiento y la actualización.

# Milvus Setup Arguments
COLLECTION_NAME = 'image_search'  # Collection name
DIMENSION = 2048  # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"

# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3

Configuración de Milvus

Llegados a este punto, vamos a empezar a configurar Milvus. Los pasos son los siguientes:

  1. Conéctese a la instancia de Milvus utilizando el URI proporcionado.

    from pymilvus import connections
    
    # Connect to the instance
    connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
    
  2. Si la colección ya existe, elimínela.

    from pymilvus import utility
    
    # Remove any previous collections with the same name
    if utility.has_collection(COLLECTION_NAME):
        utility.drop_collection(COLLECTION_NAME)
    
  3. Cree la colección que contiene el ID, la ruta del archivo de la imagen y su incrustación.

    from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
    
    # Create collection which includes the id, filepath of the image, and image embedding
    fields = [
        FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200),  # VARCHARS need a maximum length, so for this example they are set to 200 characters
        FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
    ]
    schema = CollectionSchema(fields=fields)
    collection = Collection(name=COLLECTION_NAME, schema=schema)
    
  4. Cree un índice en la colección recién creada y cárguela en memoria.

    # Create an AutoIndex index for collection
    index_params = {
    'metric_type':'L2',
    'index_type':"IVF_FLAT",
    'params':{'nlist': 16384}
    }
    collection.create_index(field_name="image_embedding", index_params=index_params)
    collection.load()
    

Una vez realizados estos pasos, la colección estará lista para ser insertada y buscada. Cualquier dato añadido se indexará automáticamente y estará disponible para la búsqueda de forma inmediata. Si los datos son muy recientes, la búsqueda puede ser más lenta, ya que se utilizará la búsqueda de fuerza bruta en los datos que aún están en proceso de indexación.

Insertar los datos

Para este ejemplo, vamos a utilizar el modelo ResNet50 proporcionado por torch y su hub de modelos. Para obtener las incrustaciones, vamos a quitar la última capa de clasificación, lo que resulta en que el modelo nos da incrustaciones de 2048 dimensiones. Todos los modelos de visión que se encuentran en torch utilizan el mismo preprocesamiento que hemos incluido aquí.

En estos próximos pasos vamos a:

  1. Cargar los datos.

    import glob
    
    # Get the filepaths of the images
    paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True)
    len(paths)
    
  2. Preprocesar los datos en lotes.

    import torch
    
    # Load the embedding model with the last layer removed
    model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
    model = torch.nn.Sequential(*(list(model.children())[:-1]))
    model.eval()
    
  3. Incrustar los datos.

    from torchvision import transforms
    
    # Preprocessing for images
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
  4. Insertar los datos.

    from PIL import Image
    from tqdm import tqdm
    
    # Embed function that embeds the batch and inserts it
    def embed(data):
        with torch.no_grad():
            output = model(torch.stack(data[0])).squeeze()
            collection.insert([data[1], output.tolist()])
    
    data_batch = [[],[]]
    
    # Read the images into batches for embedding and insertion
    for path in tqdm(paths):
        im = Image.open(path).convert('RGB')
        data_batch[0].append(preprocess(im))
        data_batch[1].append(path)
        if len(data_batch[0]) % BATCH_SIZE == 0:
            embed(data_batch)
            data_batch = [[],[]]
    
    # Embed and insert the remainder
    if len(data_batch[0]) != 0:
        embed(data_batch)
    
    # Call a flush to index any unsealed segments.
    collection.flush()
    
    • Este paso es relativamente largo porque la incrustación lleva tiempo. Tome un sorbo de café y relájese.
    • PyTorch puede no funcionar bien con Python 3.9 y versiones anteriores. Considere usar Python 3.10 y versiones posteriores en su lugar.

Con todos los datos insertados en Milvus, podemos empezar a realizar nuestras búsquedas. En este ejemplo, vamos a buscar dos imágenes de ejemplo. Como estamos realizando una búsqueda por lotes, el tiempo de búsqueda se comparte entre todas las imágenes del lote.

import glob

# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
len(search_paths)
import time
from matplotlib import pyplot as plt

# Embed the search images
def embed(data):
    with torch.no_grad():
        ret = model(torch.stack(data))
        # If more than one image, use squeeze
        if len(ret) > 1:
            return ret.squeeze().tolist()
        # Squeeze would remove batch for single image, so using flatten
        else:
            return torch.flatten(ret, start_dim=1).tolist()

data_batch = [[],[]]

for path in search_paths:
    im = Image.open(path).convert('RGB')
    data_batch[0].append(preprocess(im))
    data_batch[1].append(path)

embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()
# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)

for hits_i, hits in enumerate(res):
    axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
    axarr[hits_i][0].set_axis_off()
    axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
    for hit_i, hit in enumerate(hits):
        axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
        axarr[hits_i][hit_i + 1].set_axis_off()
        axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))

# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')

La imagen resultante de la búsqueda debería ser similar a la siguiente:

Image search output Resultado de la búsqueda de imágenes

Try Managed Milvus for Free

Zilliz Cloud is hassle-free, powered by Milvus and 10x faster.

Get Started
Feedback

¿Fue útil esta página?