Búsqueda de imágenes con Milvus
En este cuaderno, le mostraremos cómo utilizar Milvus para buscar imágenes similares en un conjunto de datos. Utilizaremos un subconjunto del conjunto de datos ImageNet y buscaremos una imagen de un sabueso afgano para demostrarlo.
Preparación del conjunto de datos
En primer lugar, tenemos que cargar el conjunto de datos y extraerlo para su posterior procesamiento.
!wget https://github.com/milvus-io/pymilvus-assets/releases/download/imagedata/reverse_image_search.zip
!unzip -q -o reverse_image_search.zip
Requisitos previos
Para ejecutar este cuaderno, es necesario tener instaladas las siguientes dependencias:
- pymilvus>=2.4.2
- timm
- torch
- numpy
- sklearn
- almohada
Para ejecutar Colab, proporcionamos los prácticos comandos para instalar las dependencias necesarias.
$ pip install pymilvus --upgrade
$ pip install timm
Si utilizas Google Colab, para activar las dependencias recién instaladas, es posible que tengas que reiniciar el tiempo de ejecución. (Haga clic en el menú "Tiempo de ejecución" en la parte superior de la pantalla y seleccione "Reiniciar sesión" en el menú desplegable).
Definir el extractor de características
A continuación, necesitamos definir un extractor de características que extraiga la incrustación de una imagen utilizando el modelo ResNet-34 de timm.
import torch
from PIL import Image
import timm
from sklearn.preprocessing import normalize
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
class FeatureExtractor:
def __init__(self, modelname):
# Load the pre-trained model
self.model = timm.create_model(
modelname, pretrained=True, num_classes=0, global_pool="avg"
)
self.model.eval()
# Get the input size required by the model
self.input_size = self.model.default_cfg["input_size"]
config = resolve_data_config({}, model=modelname)
# Get the preprocessing function provided by TIMM for the model
self.preprocess = create_transform(**config)
def __call__(self, imagepath):
# Preprocess the input image
input_image = Image.open(imagepath).convert("RGB") # Convert to RGB if needed
input_image = self.preprocess(input_image)
# Convert the image to a PyTorch tensor and add a batch dimension
input_tensor = input_image.unsqueeze(0)
# Perform inference
with torch.no_grad():
output = self.model(input_tensor)
# Extract the feature vector
feature_vector = output.squeeze().numpy()
return normalize(feature_vector.reshape(1, -1), norm="l2").flatten()
Crear una colección Milvus
A continuación, tenemos que crear una colección Milvus para almacenar las incrustaciones de la imagen
from pymilvus import MilvusClient
# Set up a Milvus client
client = MilvusClient(uri="example.db")
# Create a collection in quick setup mode
if client.has_collection(collection_name="image_embeddings"):
client.drop_collection(collection_name="image_embeddings")
client.create_collection(
collection_name="image_embeddings",
vector_field_name="vector",
dimension=512,
auto_id=True,
enable_dynamic_field=True,
metric_type="COSINE",
)
En cuanto al argumento de MilvusClient
:
- Establecer el
uri
como un archivo local, por ejemplo./milvus.db
, es el método más conveniente, ya que utiliza automáticamente Milvus Lite para almacenar todos los datos en este archivo. - Si tiene una gran escala de datos, puede configurar un servidor Milvus más eficiente en docker o kubernetes. En esta configuración, por favor utilice la uri del servidor, por ejemplo
http://localhost:19530
, como suuri
. - Si desea utilizar Zilliz Cloud, el servicio en la nube totalmente gestionado para Milvus, ajuste los
uri
ytoken
, que corresponden al Public Endpoint y a la clave Api en Zilliz Cloud.
Insertar los Embeddings en Milvus
Extraeremos los embeddings de cada imagen utilizando el modelo ResNet34 e insertaremos las imágenes del conjunto de entrenamiento en Milvus.
import os
extractor = FeatureExtractor("resnet34")
root = "./train"
insert = True
if insert is True:
for dirpath, foldername, filenames in os.walk(root):
for filename in filenames:
if filename.endswith(".JPEG"):
filepath = dirpath + "/" + filename
image_embedding = extractor(filepath)
client.insert(
"image_embeddings",
{"vector": image_embedding, "filename": filepath},
)
from IPython.display import display
query_image = "./test/Afghan_hound/n02088094_4261.JPEG"
results = client.search(
"image_embeddings",
data=[extractor(query_image)],
output_fields=["filename"],
search_params={"metric_type": "COSINE"},
)
images = []
for result in results:
for hit in result[:10]:
filename = hit["entity"]["filename"]
img = Image.open(filename)
img = img.resize((150, 150))
images.append(img)
width = 150 * 5
height = 150 * 2
concatenated_image = Image.new("RGB", (width, height))
for idx, img in enumerate(images):
x = idx % 5
y = idx // 5
concatenated_image.paste(img, (x * 150, y * 150))
display("query")
display(Image.open(query_image).resize((150, 150)))
display("results")
display(concatenated_image)
'query'
png
'results'
Resultados
Podemos ver que la mayoría de las imágenes son de la misma categoría que la imagen de búsqueda, que es el sabueso afgano. Esto significa que hemos encontrado imágenes similares a la imagen de búsqueda.
Despliegue rápido
Para saber cómo iniciar una demostración en línea con este tutorial, consulte la aplicación de ejemplo.