Búsqueda de imágenes con Milvus
En esta página, vamos a repasar un ejemplo sencillo de búsqueda de imágenes utilizando Milvus. El conjunto de datos que estamos buscando es el Impressionist-Classifier Dataset que se encuentra en Kaggle. Para este ejemplo, hemos vuelto a alojar los datos en una unidad de Google pública.
Para este ejemplo, sólo estamos utilizando el modelo preentrenado Resnet50 de Torchvision para las incrustaciones. ¡Vamos a empezar!
Instalación de los requisitos
Para este ejemplo, vamos a utilizar pymilvus
para conectarnos y utilizar Milvus, torch
para ejecutar el modelo de incrustación, torchvision
para el modelo real y el preprocesamiento, gdown
para descargar el conjunto de datos de ejemplo y tqdm
para cargar las barras.
pip install pymilvus torch gdown torchvision tqdm
Obtener los datos
Vamos a utilizar gdown
para descargar el archivo zip de Google Drive y descomprimirlo con la biblioteca integrada zipfile
.
import gdown
import zipfile
url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)
with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
zip_ref.extractall("./paintings")
El tamaño del conjunto de datos es de 2,35 GB, y el tiempo de descarga dependerá del estado de la red.
Argumentos globales
Estos son algunos de los principales argumentos globales que utilizaremos para facilitar el seguimiento y la actualización.
# Milvus Setup Arguments
COLLECTION_NAME = 'image_search' # Collection name
DIMENSION = 2048 # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"
# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3
Configuración de Milvus
Llegados a este punto, vamos a empezar a configurar Milvus. Los pasos son los siguientes:
Conéctese a la instancia de Milvus utilizando el URI proporcionado.
from pymilvus import connections # Connect to the instance connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
Si la colección ya existe, elimínela.
from pymilvus import utility # Remove any previous collections with the same name if utility.has_collection(COLLECTION_NAME): utility.drop_collection(COLLECTION_NAME)
Cree la colección que contiene el ID, la ruta del archivo de la imagen y su incrustación.
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection # Create collection which includes the id, filepath of the image, and image embedding fields = [ FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200), # VARCHARS need a maximum length, so for this example they are set to 200 characters FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION) ] schema = CollectionSchema(fields=fields) collection = Collection(name=COLLECTION_NAME, schema=schema)
Cree un índice en la colección recién creada y cárguela en memoria.
# Create an AutoIndex index for collection index_params = { 'metric_type':'L2', 'index_type':"IVF_FLAT", 'params':{'nlist': 16384} } collection.create_index(field_name="image_embedding", index_params=index_params) collection.load()
Una vez realizados estos pasos, la colección estará lista para ser insertada y buscada. Cualquier dato añadido se indexará automáticamente y estará disponible para la búsqueda de forma inmediata. Si los datos son muy recientes, la búsqueda puede ser más lenta, ya que se utilizará la búsqueda de fuerza bruta en los datos que aún están en proceso de indexación.
Insertar los datos
Para este ejemplo, vamos a utilizar el modelo ResNet50 proporcionado por torch
y su hub de modelos. Para obtener las incrustaciones, vamos a quitar la última capa de clasificación, lo que resulta en que el modelo nos da incrustaciones de 2048 dimensiones. Todos los modelos de visión que se encuentran en torch
utilizan el mismo preprocesamiento que hemos incluido aquí.
En estos próximos pasos vamos a:
Cargar los datos.
import glob # Get the filepaths of the images paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True) len(paths)
Preprocesar los datos en lotes.
import torch # Load the embedding model with the last layer removed model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True) model = torch.nn.Sequential(*(list(model.children())[:-1])) model.eval()
Incrustar los datos.
from torchvision import transforms # Preprocessing for images preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])
Insertar los datos.
from PIL import Image from tqdm import tqdm # Embed function that embeds the batch and inserts it def embed(data): with torch.no_grad(): output = model(torch.stack(data[0])).squeeze() collection.insert([data[1], output.tolist()]) data_batch = [[],[]] # Read the images into batches for embedding and insertion for path in tqdm(paths): im = Image.open(path).convert('RGB') data_batch[0].append(preprocess(im)) data_batch[1].append(path) if len(data_batch[0]) % BATCH_SIZE == 0: embed(data_batch) data_batch = [[],[]] # Embed and insert the remainder if len(data_batch[0]) != 0: embed(data_batch) # Call a flush to index any unsealed segments. collection.flush()
- Este paso es relativamente largo porque la incrustación lleva tiempo. Tome un sorbo de café y relájese.
- PyTorch puede no funcionar bien con Python 3.9 y versiones anteriores. Considere usar Python 3.10 y versiones posteriores en su lugar.
Realizar la búsqueda
Con todos los datos insertados en Milvus, podemos empezar a realizar nuestras búsquedas. En este ejemplo, vamos a buscar dos imágenes de ejemplo. Como estamos realizando una búsqueda por lotes, el tiempo de búsqueda se comparte entre todas las imágenes del lote.
import glob
# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
len(search_paths)
import time
from matplotlib import pyplot as plt
# Embed the search images
def embed(data):
with torch.no_grad():
ret = model(torch.stack(data))
# If more than one image, use squeeze
if len(ret) > 1:
return ret.squeeze().tolist()
# Squeeze would remove batch for single image, so using flatten
else:
return torch.flatten(ret, start_dim=1).tolist()
data_batch = [[],[]]
for path in search_paths:
im = Image.open(path).convert('RGB')
data_batch[0].append(preprocess(im))
data_batch[1].append(path)
embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()
# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)
for hits_i, hits in enumerate(res):
axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
axarr[hits_i][0].set_axis_off()
axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
for hit_i, hit in enumerate(hits):
axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
axarr[hits_i][hit_i + 1].set_axis_off()
axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))
# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')
La imagen resultante de la búsqueda debería ser similar a la siguiente:
Resultado de la búsqueda de imágenes