milvus-logo
LFAI
Casa
  • Integrazioni

Ricerca di immagini con Milvus

In questa pagina esamineremo un semplice esempio di ricerca di immagini con Milvus. Il set di dati che stiamo cercando è l'Impressionist-Classifier Dataset trovato su Kaggle. Per questo esempio, abbiamo rehosted i dati in un google drive pubblico.

Per questo esempio, utilizziamo solo il modello Resnet50 pre-addestrato da Torchvision per le incorporazioni. Iniziamo!

Installazione dei requisiti

Per questo esempio, utilizzeremo pymilvus per connetterci a Milvus, torch per eseguire il modello di embedding, torchvision per il modello vero e proprio e la preelaborazione, gdown per scaricare il dataset di esempio e tqdm per caricare le barre.

pip install pymilvus torch gdown torchvision tqdm

Acquisizione dei dati

Utilizzeremo gdown per prelevare lo zip da Google Drive e poi decomprimerlo con la libreria integrata zipfile.

import gdown
import zipfile

url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)

with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
    zip_ref.extractall("./paintings")

La dimensione del set di dati è di 2,35 GB e il tempo necessario per scaricarlo dipende dalle condizioni della rete.

Argomenti globali

Questi sono alcuni dei principali argomenti globali che utilizzeremo per facilitare il monitoraggio e l'aggiornamento.

# Milvus Setup Arguments
COLLECTION_NAME = 'image_search'  # Collection name
DIMENSION = 2048  # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"

# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3

Impostazione di Milvus

A questo punto, iniziamo a configurare Milvus. I passaggi sono i seguenti:

  1. Collegarsi all'istanza di Milvus utilizzando l'URI fornito.

    from pymilvus import connections
    
    # Connect to the instance
    connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
    
  2. Se la collezione esiste già, eliminarla.

    from pymilvus import utility
    
    # Remove any previous collections with the same name
    if utility.has_collection(COLLECTION_NAME):
        utility.drop_collection(COLLECTION_NAME)
    
  3. Creare la collezione che contiene l'ID, il percorso del file dell'immagine e il suo incorporamento.

    from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
    
    # Create collection which includes the id, filepath of the image, and image embedding
    fields = [
        FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200),  # VARCHARS need a maximum length, so for this example they are set to 200 characters
        FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
    ]
    schema = CollectionSchema(fields=fields)
    collection = Collection(name=COLLECTION_NAME, schema=schema)
    
  4. Creare un indice sulla raccolta appena creata e caricarla in memoria.

    # Create an AutoIndex index for collection
    index_params = {
    'metric_type':'L2',
    'index_type':"IVF_FLAT",
    'params':{'nlist': 16384}
    }
    collection.create_index(field_name="image_embedding", index_params=index_params)
    collection.load()
    

Una volta eseguiti questi passaggi, la raccolta è pronta per essere inserita e ricercata. Tutti i dati aggiunti verranno indicizzati automaticamente e saranno immediatamente disponibili per la ricerca. Se i dati sono molto recenti, la ricerca potrebbe essere più lenta, in quanto la ricerca brute force verrà utilizzata sui dati ancora in fase di indicizzazione.

Inserimento dei dati

Per questo esempio, utilizzeremo il modello ResNet50 fornito da torch e il suo hub di modelli. Per ottenere le incorporazioni, togliamo il livello di classificazione finale, in modo che il modello ci fornisca incorporazioni di 2048 dimensioni. Tutti i modelli di visione presenti su torch utilizzano la stessa pre-elaborazione che abbiamo incluso qui.

Nei prossimi passaggi verranno eseguiti i seguenti passaggi:

  1. Caricare i dati.

    import glob
    
    # Get the filepaths of the images
    paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True)
    len(paths)
    
  2. Preelaborazione dei dati in batch.

    import torch
    
    # Load the embedding model with the last layer removed
    model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
    model = torch.nn.Sequential(*(list(model.children())[:-1]))
    model.eval()
    
  3. Incorporare i dati.

    from torchvision import transforms
    
    # Preprocessing for images
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
  4. Inserimento dei dati.

    from PIL import Image
    from tqdm import tqdm
    
    # Embed function that embeds the batch and inserts it
    def embed(data):
        with torch.no_grad():
            output = model(torch.stack(data[0])).squeeze()
            collection.insert([data[1], output.tolist()])
    
    data_batch = [[],[]]
    
    # Read the images into batches for embedding and insertion
    for path in tqdm(paths):
        im = Image.open(path).convert('RGB')
        data_batch[0].append(preprocess(im))
        data_batch[1].append(path)
        if len(data_batch[0]) % BATCH_SIZE == 0:
            embed(data_batch)
            data_batch = [[],[]]
    
    # Embed and insert the remainder
    if len(data_batch[0]) != 0:
        embed(data_batch)
    
    # Call a flush to index any unsealed segments.
    collection.flush()
    
    • Questa fase è relativamente lunga perché l'incorporazione richiede tempo. Prendete un sorso di caffè e rilassatevi.
    • PyTorch potrebbe non funzionare bene con Python 3.9 e versioni precedenti. Si consiglia di utilizzare Python 3.10 e versioni successive.

Una volta inseriti tutti i dati in Milvus, possiamo iniziare a eseguire le nostre ricerche. In questo esempio, cercheremo due immagini di esempio. Poiché stiamo eseguendo una ricerca in batch, il tempo di ricerca è condiviso tra le immagini del batch.

import glob

# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
len(search_paths)
import time
from matplotlib import pyplot as plt

# Embed the search images
def embed(data):
    with torch.no_grad():
        ret = model(torch.stack(data))
        # If more than one image, use squeeze
        if len(ret) > 1:
            return ret.squeeze().tolist()
        # Squeeze would remove batch for single image, so using flatten
        else:
            return torch.flatten(ret, start_dim=1).tolist()

data_batch = [[],[]]

for path in search_paths:
    im = Image.open(path).convert('RGB')
    data_batch[0].append(preprocess(im))
    data_batch[1].append(path)

embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()
# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)

for hits_i, hits in enumerate(res):
    axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
    axarr[hits_i][0].set_axis_off()
    axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
    for hit_i, hit in enumerate(hits):
        axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
        axarr[hits_i][hit_i + 1].set_axis_off()
        axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))

# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')

Il risultato della ricerca dovrebbe essere simile al seguente:

Image search output Risultato della ricerca di immagini

Tradotto daDeepLogo

Feedback

Questa pagina è stata utile?