Ricerca di immagini con PyTorch e Milvus
Questa guida presenta un esempio di integrazione di PyTorch e Milvus per eseguire ricerche di immagini utilizzando gli embeddings. PyTorch è un potente framework open-source per l'apprendimento profondo ampiamente utilizzato per costruire e distribuire modelli di apprendimento automatico. In questo esempio, sfrutteremo la sua libreria Torchvision e un modello ResNet50 pre-addestrato per generare vettori di caratteristiche (embeddings) che rappresentano il contenuto delle immagini. Questi embeddings saranno memorizzati in Milvus, un database vettoriale ad alte prestazioni, per consentire un'efficiente ricerca di similarità. Il dataset utilizzato è l'Impressionist-Classifier Dataset di Kaggle. Combinando le capacità di deep learning di PyTorch con le funzionalità di ricerca scalabili di Milvus, questo esempio dimostra come costruire un sistema di recupero delle immagini robusto ed efficiente.
Iniziamo!
Installazione dei requisiti
Per questo esempio, utilizzeremo pymilvus per connetterci a Milvus, torch per eseguire il modello di embedding, torchvision per il modello vero e proprio e la preelaborazione, gdown per scaricare il set di dati di esempio e tqdm per caricare le barre.
pip install pymilvus torch gdown torchvision tqdm
Acquisizione dei dati
Utilizzeremo gdown per prelevare lo zip da Google Drive e poi decomprimerlo con la libreria integrata zipfile.
import gdown
import zipfile
url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)
with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
zip_ref.extractall("./paintings")
La dimensione del set di dati è di 2,35 GB e il tempo necessario per scaricarlo dipende dalle condizioni della rete.
Argomenti globali
Questi sono alcuni dei principali argomenti globali che utilizzeremo per facilitare il monitoraggio e l'aggiornamento.
# Milvus Setup Arguments
COLLECTION_NAME = 'image_search' # Collection name
DIMENSION = 2048 # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"
# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3
Impostazione di Milvus
A questo punto, iniziamo a configurare Milvus. I passaggi sono i seguenti:
Collegarsi all'istanza di Milvus utilizzando l'URI fornito.
from pymilvus import connections # Connect to the instance connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)Se la collezione esiste già, eliminarla.
from pymilvus import utility # Remove any previous collections with the same name if utility.has_collection(COLLECTION_NAME): utility.drop_collection(COLLECTION_NAME)Creare la collezione che contiene l'ID, il percorso del file dell'immagine e il suo incorporamento.
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection # Create collection which includes the id, filepath of the image, and image embedding fields = [ FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200), # VARCHARS need a maximum length, so for this example they are set to 200 characters FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION) ] schema = CollectionSchema(fields=fields) collection = Collection(name=COLLECTION_NAME, schema=schema)Creare un indice sulla raccolta appena creata e caricarla in memoria.
# Create an AutoIndex index for collection index_params = { 'metric_type':'L2', 'index_type':"IVF_FLAT", 'params':{'nlist': 16384} } collection.create_index(field_name="image_embedding", index_params=index_params) collection.load()
Una volta eseguiti questi passaggi, la raccolta è pronta per essere inserita e ricercata. Tutti i dati aggiunti verranno indicizzati automaticamente e saranno immediatamente disponibili per la ricerca. Se i dati sono molto recenti, la ricerca potrebbe essere più lenta, in quanto la ricerca brute force verrà utilizzata sui dati ancora in fase di indicizzazione.
Inserimento dei dati
Per questo esempio, utilizzeremo il modello ResNet50 fornito da torch e il suo hub di modelli. Per ottenere le incorporazioni, togliamo il livello di classificazione finale, in modo che il modello ci fornisca incorporazioni di 2048 dimensioni. Tutti i modelli di visione presenti su torch utilizzano la stessa pre-elaborazione che abbiamo incluso qui.
Nei prossimi passi verranno eseguite le seguenti operazioni:
Caricare i dati.
import glob # Get the filepaths of the images paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True) len(paths)Preelaborazione dei dati in batch.
import torch # Load the embedding model with the last layer removed model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True) model = torch.nn.Sequential(*(list(model.children())[:-1])) model.eval()Incorporare i dati.
from torchvision import transforms # Preprocessing for images preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])Inserimento dei dati.
from PIL import Image from tqdm import tqdm # Embed function that embeds the batch and inserts it def embed(data): with torch.no_grad(): output = model(torch.stack(data[0])).squeeze() collection.insert([data[1], output.tolist()]) data_batch = [[],[]] # Read the images into batches for embedding and insertion for path in tqdm(paths): im = Image.open(path).convert('RGB') data_batch[0].append(preprocess(im)) data_batch[1].append(path) if len(data_batch[0]) % BATCH_SIZE == 0: embed(data_batch) data_batch = [[],[]] # Embed and insert the remainder if len(data_batch[0]) != 0: embed(data_batch) # Call a flush to index any unsealed segments. collection.flush()- Questa fase è relativamente lunga perché l'incorporazione richiede tempo. Prendete un sorso di caffè e rilassatevi.
- PyTorch potrebbe non funzionare bene con Python 3.9 e versioni precedenti. Si consiglia di utilizzare Python 3.10 e versioni successive.
Esecuzione della ricerca
Una volta inseriti tutti i dati in Milvus, possiamo iniziare a eseguire le nostre ricerche. In questo esempio, cercheremo due immagini di esempio. Poiché stiamo eseguendo una ricerca in batch, il tempo di ricerca è condiviso tra le immagini del batch.
import glob
# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
len(search_paths)
import time
from matplotlib import pyplot as plt
# Embed the search images
def embed(data):
with torch.no_grad():
ret = model(torch.stack(data))
# If more than one image, use squeeze
if len(ret) > 1:
return ret.squeeze().tolist()
# Squeeze would remove batch for single image, so using flatten
else:
return torch.flatten(ret, start_dim=1).tolist()
data_batch = [[],[]]
for path in search_paths:
im = Image.open(path).convert('RGB')
data_batch[0].append(preprocess(im))
data_batch[1].append(path)
embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()
# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)
for hits_i, hits in enumerate(res):
axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
axarr[hits_i][0].set_axis_off()
axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
for hit_i, hit in enumerate(hits):
axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
axarr[hits_i][hit_i + 1].set_axis_off()
axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))
# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')
Il risultato della ricerca dovrebbe essere simile al seguente:
Risultato della ricerca di immagini