Ricerca di immagini con Milvus
In questa pagina esamineremo un semplice esempio di ricerca di immagini con Milvus. Il set di dati che stiamo cercando è l'Impressionist-Classifier Dataset trovato su Kaggle. Per questo esempio, abbiamo rehosted i dati in un google drive pubblico.
Per questo esempio, utilizziamo solo il modello Resnet50 pre-addestrato da Torchvision per le incorporazioni. Iniziamo!
Installazione dei requisiti
Per questo esempio, utilizzeremo pymilvus
per connetterci a Milvus, torch
per eseguire il modello di embedding, torchvision
per il modello vero e proprio e la preelaborazione, gdown
per scaricare il dataset di esempio e tqdm
per caricare le barre.
pip install pymilvus torch gdown torchvision tqdm
Acquisizione dei dati
Utilizzeremo gdown
per prelevare lo zip da Google Drive e poi decomprimerlo con la libreria integrata zipfile
.
import gdown
import zipfile
url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)
with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
zip_ref.extractall("./paintings")
La dimensione del set di dati è di 2,35 GB e il tempo necessario per scaricarlo dipende dalle condizioni della rete.
Argomenti globali
Questi sono alcuni dei principali argomenti globali che utilizzeremo per facilitare il monitoraggio e l'aggiornamento.
# Milvus Setup Arguments
COLLECTION_NAME = 'image_search' # Collection name
DIMENSION = 2048 # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"
# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3
Impostazione di Milvus
A questo punto, iniziamo a configurare Milvus. I passaggi sono i seguenti:
Collegarsi all'istanza di Milvus utilizzando l'URI fornito.
from pymilvus import connections # Connect to the instance connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
Se la collezione esiste già, eliminarla.
from pymilvus import utility # Remove any previous collections with the same name if utility.has_collection(COLLECTION_NAME): utility.drop_collection(COLLECTION_NAME)
Creare la collezione che contiene l'ID, il percorso del file dell'immagine e il suo incorporamento.
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection # Create collection which includes the id, filepath of the image, and image embedding fields = [ FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200), # VARCHARS need a maximum length, so for this example they are set to 200 characters FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION) ] schema = CollectionSchema(fields=fields) collection = Collection(name=COLLECTION_NAME, schema=schema)
Creare un indice sulla raccolta appena creata e caricarla in memoria.
# Create an AutoIndex index for collection index_params = { 'metric_type':'L2', 'index_type':"IVF_FLAT", 'params':{'nlist': 16384} } collection.create_index(field_name="image_embedding", index_params=index_params) collection.load()
Una volta eseguiti questi passaggi, la raccolta è pronta per essere inserita e ricercata. Tutti i dati aggiunti verranno indicizzati automaticamente e saranno immediatamente disponibili per la ricerca. Se i dati sono molto recenti, la ricerca potrebbe essere più lenta, in quanto la ricerca brute force verrà utilizzata sui dati ancora in fase di indicizzazione.
Inserimento dei dati
Per questo esempio, utilizzeremo il modello ResNet50 fornito da torch
e il suo hub di modelli. Per ottenere le incorporazioni, togliamo il livello di classificazione finale, in modo che il modello ci fornisca incorporazioni di 2048 dimensioni. Tutti i modelli di visione presenti su torch
utilizzano la stessa pre-elaborazione che abbiamo incluso qui.
Nei prossimi passaggi verranno eseguiti i seguenti passaggi:
Caricare i dati.
import glob # Get the filepaths of the images paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True) len(paths)
Preelaborazione dei dati in batch.
import torch # Load the embedding model with the last layer removed model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True) model = torch.nn.Sequential(*(list(model.children())[:-1])) model.eval()
Incorporare i dati.
from torchvision import transforms # Preprocessing for images preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])
Inserimento dei dati.
from PIL import Image from tqdm import tqdm # Embed function that embeds the batch and inserts it def embed(data): with torch.no_grad(): output = model(torch.stack(data[0])).squeeze() collection.insert([data[1], output.tolist()]) data_batch = [[],[]] # Read the images into batches for embedding and insertion for path in tqdm(paths): im = Image.open(path).convert('RGB') data_batch[0].append(preprocess(im)) data_batch[1].append(path) if len(data_batch[0]) % BATCH_SIZE == 0: embed(data_batch) data_batch = [[],[]] # Embed and insert the remainder if len(data_batch[0]) != 0: embed(data_batch) # Call a flush to index any unsealed segments. collection.flush()
- Questa fase è relativamente lunga perché l'incorporazione richiede tempo. Prendete un sorso di caffè e rilassatevi.
- PyTorch potrebbe non funzionare bene con Python 3.9 e versioni precedenti. Si consiglia di utilizzare Python 3.10 e versioni successive.
Esecuzione della ricerca
Una volta inseriti tutti i dati in Milvus, possiamo iniziare a eseguire le nostre ricerche. In questo esempio, cercheremo due immagini di esempio. Poiché stiamo eseguendo una ricerca in batch, il tempo di ricerca è condiviso tra le immagini del batch.
import glob
# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
len(search_paths)
import time
from matplotlib import pyplot as plt
# Embed the search images
def embed(data):
with torch.no_grad():
ret = model(torch.stack(data))
# If more than one image, use squeeze
if len(ret) > 1:
return ret.squeeze().tolist()
# Squeeze would remove batch for single image, so using flatten
else:
return torch.flatten(ret, start_dim=1).tolist()
data_batch = [[],[]]
for path in search_paths:
im = Image.open(path).convert('RGB')
data_batch[0].append(preprocess(im))
data_batch[1].append(path)
embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()
# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)
for hits_i, hits in enumerate(res):
axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
axarr[hits_i][0].set_axis_off()
axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
for hit_i, hit in enumerate(hits):
axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
axarr[hits_i][hit_i + 1].set_axis_off()
axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))
# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')
Il risultato della ricerca dovrebbe essere simile al seguente:
Risultato della ricerca di immagini