milvus-logo
LFAI
Home
  • Integrationen

Bildsuche mit Milvus

Auf dieser Seite werden wir ein einfaches Beispiel für eine Bildsuche mit Milvus durchgehen. Der Datensatz, den wir durchsuchen, ist der Impressionist-Classifier-Datensatz, der auf Kaggle zu finden ist. Für dieses Beispiel haben wir die Daten in einem öffentlichen Google Drive gehostet.

Für dieses Beispiel verwenden wir nur das von Torchvision trainierte Resnet50-Modell für Einbettungen. Los geht's!

Installieren der Voraussetzungen

Für dieses Beispiel verwenden wir pymilvus, um eine Verbindung zu Milvus herzustellen, torch für die Ausführung des Einbettungsmodells, torchvision für das eigentliche Modell und die Vorverarbeitung, gdown zum Herunterladen des Beispieldatensatzes und tqdm zum Laden der Balken.

pip install pymilvus torch gdown torchvision tqdm

Erfassen der Daten

Wir werden gdown verwenden, um den Zip-Datensatz von Google Drive zu holen und ihn dann mit der integrierten Bibliothek zipfile zu dekomprimieren.

import gdown
import zipfile

url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)

with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
    zip_ref.extractall("./paintings")

Die Größe des Datensatzes beträgt 2,35 GB, und die Zeit, die für das Herunterladen benötigt wird, hängt von Ihren Netzwerkbedingungen ab.

Globale Argumente

Dies sind einige der wichtigsten globalen Argumente, die wir zur einfacheren Verfolgung und Aktualisierung verwenden werden.

# Milvus Setup Arguments
COLLECTION_NAME = 'image_search'  # Collection name
DIMENSION = 2048  # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"

# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3

Einrichten von Milvus

An dieser Stelle werden wir mit der Einrichtung von Milvus beginnen. Die Schritte sind wie folgt:

  1. Verbinden Sie sich mit der Milvus-Instanz unter Verwendung der angegebenen URI.

    from pymilvus import connections
    
    # Connect to the instance
    connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
    
  2. Wenn die Sammlung bereits existiert, löschen Sie sie.

    from pymilvus import utility
    
    # Remove any previous collections with the same name
    if utility.has_collection(COLLECTION_NAME):
        utility.drop_collection(COLLECTION_NAME)
    
  3. Erstellen Sie die Sammlung, die die ID, den Dateipfad des Bildes und seine Einbettung enthält.

    from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
    
    # Create collection which includes the id, filepath of the image, and image embedding
    fields = [
        FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200),  # VARCHARS need a maximum length, so for this example they are set to 200 characters
        FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
    ]
    schema = CollectionSchema(fields=fields)
    collection = Collection(name=COLLECTION_NAME, schema=schema)
    
  4. Erstellen Sie einen Index für die neu erstellte Sammlung und laden Sie sie in den Speicher.

    # Create an AutoIndex index for collection
    index_params = {
    'metric_type':'L2',
    'index_type':"IVF_FLAT",
    'params':{'nlist': 16384}
    }
    collection.create_index(field_name="image_embedding", index_params=index_params)
    collection.load()
    

Sobald diese Schritte abgeschlossen sind, kann die Sammlung eingefügt und durchsucht werden. Alle hinzugefügten Daten werden automatisch indiziert und sind sofort für die Suche verfügbar. Wenn die Daten sehr frisch sind, kann die Suche langsamer sein, da eine Brute-Force-Suche auf Daten angewendet wird, die noch indiziert werden müssen.

Einfügen der Daten

Für dieses Beispiel verwenden wir das ResNet50-Modell, das von torch und seinem Modell-Hub bereitgestellt wird. Um die Einbettungen zu erhalten, wird die letzte Klassifizierungsschicht entfernt, was dazu führt, dass das Modell Einbettungen mit 2048 Dimensionen liefert. Alle Bildverarbeitungsmodelle, die auf torch zu finden sind, verwenden die gleiche Vorverarbeitung, die wir hier mit einbezogen haben.

In den nächsten Schritten werden wir:

  1. Laden der Daten.

    import glob
    
    # Get the filepaths of the images
    paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True)
    len(paths)
    
  2. Vorverarbeitung der Daten in Stapeln.

    import torch
    
    # Load the embedding model with the last layer removed
    model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
    model = torch.nn.Sequential(*(list(model.children())[:-1]))
    model.eval()
    
  3. Einbetten der Daten.

    from torchvision import transforms
    
    # Preprocessing for images
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
  4. Einfügen der Daten.

    from PIL import Image
    from tqdm import tqdm
    
    # Embed function that embeds the batch and inserts it
    def embed(data):
        with torch.no_grad():
            output = model(torch.stack(data[0])).squeeze()
            collection.insert([data[1], output.tolist()])
    
    data_batch = [[],[]]
    
    # Read the images into batches for embedding and insertion
    for path in tqdm(paths):
        im = Image.open(path).convert('RGB')
        data_batch[0].append(preprocess(im))
        data_batch[1].append(path)
        if len(data_batch[0]) % BATCH_SIZE == 0:
            embed(data_batch)
            data_batch = [[],[]]
    
    # Embed and insert the remainder
    if len(data_batch[0]) != 0:
        embed(data_batch)
    
    # Call a flush to index any unsealed segments.
    collection.flush()
    
    • Dieser Schritt ist relativ zeitaufwändig, weil das Einbetten Zeit braucht. Nehmen Sie einen Schluck Kaffee und entspannen Sie sich.
    • PyTorch funktioniert möglicherweise nicht gut mit Python 3.9 und früheren Versionen. Ziehen Sie stattdessen die Verwendung von Python 3.10 und späteren Versionen in Betracht.

Nachdem alle Daten in Milvus eingegeben wurden, können wir mit der Suche beginnen. In diesem Beispiel werden wir nach zwei Beispielbildern suchen. Da wir eine Stapelsuche durchführen, wird die Suchzeit auf die Bilder des Stapels aufgeteilt.

import glob

# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
len(search_paths)
import time
from matplotlib import pyplot as plt

# Embed the search images
def embed(data):
    with torch.no_grad():
        ret = model(torch.stack(data))
        # If more than one image, use squeeze
        if len(ret) > 1:
            return ret.squeeze().tolist()
        # Squeeze would remove batch for single image, so using flatten
        else:
            return torch.flatten(ret, start_dim=1).tolist()

data_batch = [[],[]]

for path in search_paths:
    im = Image.open(path).convert('RGB')
    data_batch[0].append(preprocess(im))
    data_batch[1].append(path)

embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()
# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)

for hits_i, hits in enumerate(res):
    axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
    axarr[hits_i][0].set_axis_off()
    axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
    for hit_i, hit in enumerate(hits):
        axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
        axarr[hits_i][hit_i + 1].set_axis_off()
        axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))

# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')

Das Suchergebnis sollte in etwa so aussehen wie das folgende:

Image search output Ausgabe der Bildsuche

Übersetzt vonDeepLogo

Feedback

War diese Seite hilfreich?