Búsqueda de imágenes con PyTorch y Milvus
Esta guía presenta un ejemplo de integración de PyTorch y Milvus para realizar búsquedas de imágenes utilizando incrustaciones. PyTorch es un potente marco de aprendizaje profundo de código abierto ampliamente utilizado para construir y desplegar modelos de aprendizaje automático. En este ejemplo, aprovecharemos su biblioteca Torchvision y un modelo ResNet50 preentrenado para generar vectores de características (incrustaciones) que representen el contenido de la imagen. Estas incrustaciones se almacenarán en Milvus, una base de datos vectorial de alto rendimiento, para permitir una búsqueda eficiente de similitudes. El conjunto de datos utilizado es el Impressionist-Classifier Dataset de Kaggle. Combinando las capacidades de aprendizaje profundo de PyTorch con la funcionalidad de búsqueda escalable de Milvus, este ejemplo demuestra cómo construir un sistema de recuperación de imágenes robusto y eficiente.
¡Vamos a empezar!
Instalación de los requisitos
Para este ejemplo, vamos a utilizar pymilvus
para conectarnos y utilizar Milvus, torch
para ejecutar el modelo de incrustación, torchvision
para el modelo real y el preprocesamiento, gdown
para descargar el conjunto de datos de ejemplo y tqdm
para cargar las barras.
pip install pymilvus torch gdown torchvision tqdm
Obtener los datos
Vamos a utilizar gdown
para descargar el archivo zip de Google Drive y descomprimirlo con la biblioteca integrada zipfile
.
import gdown
import zipfile
url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)
with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
zip_ref.extractall("./paintings")
El tamaño del conjunto de datos es de 2,35 GB, y el tiempo de descarga dependerá del estado de la red.
Argumentos globales
Estos son algunos de los principales argumentos globales que utilizaremos para facilitar el seguimiento y la actualización.
# Milvus Setup Arguments
COLLECTION_NAME = 'image_search' # Collection name
DIMENSION = 2048 # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"
# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3
Configuración de Milvus
Llegados a este punto, vamos a empezar a configurar Milvus. Los pasos son los siguientes:
Conéctese a la instancia de Milvus utilizando el URI proporcionado.
from pymilvus import connections # Connect to the instance connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
Si la colección ya existe, elimínela.
from pymilvus import utility # Remove any previous collections with the same name if utility.has_collection(COLLECTION_NAME): utility.drop_collection(COLLECTION_NAME)
Cree la colección que contiene el ID, la ruta del archivo de la imagen y su incrustación.
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection # Create collection which includes the id, filepath of the image, and image embedding fields = [ FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200), # VARCHARS need a maximum length, so for this example they are set to 200 characters FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION) ] schema = CollectionSchema(fields=fields) collection = Collection(name=COLLECTION_NAME, schema=schema)
Cree un índice en la colección recién creada y cárguela en memoria.
# Create an AutoIndex index for collection index_params = { 'metric_type':'L2', 'index_type':"IVF_FLAT", 'params':{'nlist': 16384} } collection.create_index(field_name="image_embedding", index_params=index_params) collection.load()
Una vez realizados estos pasos, la colección estará lista para ser insertada y buscada. Cualquier dato añadido se indexará automáticamente y estará disponible para la búsqueda de forma inmediata. Si los datos son muy recientes, la búsqueda puede ser más lenta, ya que se utilizará la búsqueda de fuerza bruta en los datos que aún están en proceso de indexación.
Insertar los datos
Para este ejemplo, vamos a utilizar el modelo ResNet50 proporcionado por torch
y su hub de modelos. Para obtener las incrustaciones, vamos a quitar la última capa de clasificación, lo que resulta en que el modelo nos da incrustaciones de 2048 dimensiones. Todos los modelos de visión que se encuentran en torch
utilizan el mismo preprocesamiento que hemos incluido aquí.
En estos próximos pasos vamos a:
Cargar los datos.
import glob # Get the filepaths of the images paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True) len(paths)
Preprocesar los datos en lotes.
import torch # Load the embedding model with the last layer removed model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True) model = torch.nn.Sequential(*(list(model.children())[:-1])) model.eval()
Incrustar los datos.
from torchvision import transforms # Preprocessing for images preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])
Insertar los datos.
from PIL import Image from tqdm import tqdm # Embed function that embeds the batch and inserts it def embed(data): with torch.no_grad(): output = model(torch.stack(data[0])).squeeze() collection.insert([data[1], output.tolist()]) data_batch = [[],[]] # Read the images into batches for embedding and insertion for path in tqdm(paths): im = Image.open(path).convert('RGB') data_batch[0].append(preprocess(im)) data_batch[1].append(path) if len(data_batch[0]) % BATCH_SIZE == 0: embed(data_batch) data_batch = [[],[]] # Embed and insert the remainder if len(data_batch[0]) != 0: embed(data_batch) # Call a flush to index any unsealed segments. collection.flush()
- Este paso es relativamente largo porque la incrustación lleva tiempo. Tome un sorbo de café y relájese.
- PyTorch puede no funcionar bien con Python 3.9 y versiones anteriores. Considere usar Python 3.10 y versiones posteriores en su lugar.
Realizar la búsqueda
Con todos los datos insertados en Milvus, podemos empezar a realizar nuestras búsquedas. En este ejemplo, vamos a buscar dos imágenes de ejemplo. Como estamos realizando una búsqueda por lotes, el tiempo de búsqueda se comparte entre todas las imágenes del lote.
import glob
# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
len(search_paths)
import time
from matplotlib import pyplot as plt
# Embed the search images
def embed(data):
with torch.no_grad():
ret = model(torch.stack(data))
# If more than one image, use squeeze
if len(ret) > 1:
return ret.squeeze().tolist()
# Squeeze would remove batch for single image, so using flatten
else:
return torch.flatten(ret, start_dim=1).tolist()
data_batch = [[],[]]
for path in search_paths:
im = Image.open(path).convert('RGB')
data_batch[0].append(preprocess(im))
data_batch[1].append(path)
embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()
# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)
for hits_i, hits in enumerate(res):
axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
axarr[hits_i][0].set_axis_off()
axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
for hit_i, hit in enumerate(hits):
axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
axarr[hits_i][hit_i + 1].set_axis_off()
axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))
# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')
La imagen resultante de la búsqueda debería ser similar a la siguiente:
Resultado de la búsqueda de imágenes