Bildsuche mit Milvus
Auf dieser Seite werden wir ein einfaches Beispiel für eine Bildsuche mit Milvus durchgehen. Der Datensatz, den wir durchsuchen, ist der Impressionist-Classifier-Datensatz, der auf Kaggle zu finden ist. Für dieses Beispiel haben wir die Daten in einem öffentlichen Google Drive gehostet.
Für dieses Beispiel verwenden wir nur das von Torchvision trainierte Resnet50-Modell für Einbettungen. Los geht's!
Installieren der Voraussetzungen
Für dieses Beispiel verwenden wir pymilvus
, um uns mit Milvus zu verbinden, torch
für die Ausführung des Einbettungsmodells, torchvision
für das eigentliche Modell und die Vorverarbeitung, gdown
für das Herunterladen des Beispieldatensatzes und tqdm
für das Laden der Balken.
pip install pymilvus torch gdown torchvision tqdm
Erfassen der Daten
Wir werden gdown
verwenden, um den Zip-Datensatz von Google Drive zu holen und ihn dann mit der integrierten Bibliothek zipfile
zu dekomprimieren.
import gdown
import zipfile
url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)
with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
zip_ref.extractall("./paintings")
Die Größe des Datensatzes beträgt 2,35 GB, und die Zeit, die für das Herunterladen benötigt wird, hängt von Ihren Netzwerkbedingungen ab.
Globale Argumente
Dies sind einige der wichtigsten globalen Argumente, die wir zur einfacheren Verfolgung und Aktualisierung verwenden werden.
# Milvus Setup Arguments
COLLECTION_NAME = 'image_search' # Collection name
DIMENSION = 2048 # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"
# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3
Einrichten von Milvus
An dieser Stelle werden wir mit der Einrichtung von Milvus beginnen. Die Schritte sind wie folgt:
Verbinden Sie sich mit der Milvus-Instanz unter Verwendung der angegebenen URI.
from pymilvus import connections # Connect to the instance connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
Wenn die Sammlung bereits existiert, löschen Sie sie.
from pymilvus import utility # Remove any previous collections with the same name if utility.has_collection(COLLECTION_NAME): utility.drop_collection(COLLECTION_NAME)
Erstellen Sie die Sammlung, die die ID, den Dateipfad des Bildes und seine Einbettung enthält.
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection # Create collection which includes the id, filepath of the image, and image embedding fields = [ FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200), # VARCHARS need a maximum length, so for this example they are set to 200 characters FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION) ] schema = CollectionSchema(fields=fields) collection = Collection(name=COLLECTION_NAME, schema=schema)
Erstellen Sie einen Index für die neu erstellte Sammlung und laden Sie sie in den Speicher.
# Create an AutoIndex index for collection index_params = { 'metric_type':'L2', 'index_type':"IVF_FLAT", 'params':{'nlist': 16384} } collection.create_index(field_name="image_embedding", index_params=index_params) collection.load()
Sobald diese Schritte abgeschlossen sind, kann die Sammlung eingefügt und durchsucht werden. Alle hinzugefügten Daten werden automatisch indiziert und sind sofort für die Suche verfügbar. Wenn die Daten sehr frisch sind, kann die Suche langsamer sein, da eine Brute-Force-Suche auf Daten angewendet wird, die noch indiziert werden müssen.
Einfügen der Daten
Für dieses Beispiel verwenden wir das ResNet50-Modell, das von torch
und seinem Modell-Hub bereitgestellt wird. Um die Einbettungen zu erhalten, wird die letzte Klassifizierungsschicht entfernt, was dazu führt, dass das Modell Einbettungen mit 2048 Dimensionen liefert. Alle Bildverarbeitungsmodelle, die auf torch
zu finden sind, verwenden die gleiche Vorverarbeitung, die wir hier mit einbezogen haben.
In den nächsten Schritten werden wir:
Laden der Daten.
import glob # Get the filepaths of the images paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True) len(paths)
Vorverarbeitung der Daten in Stapeln.
import torch # Load the embedding model with the last layer removed model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True) model = torch.nn.Sequential(*(list(model.children())[:-1])) model.eval()
Einbetten der Daten.
from torchvision import transforms # Preprocessing for images preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])
Einfügen der Daten.
from PIL import Image from tqdm import tqdm # Embed function that embeds the batch and inserts it def embed(data): with torch.no_grad(): output = model(torch.stack(data[0])).squeeze() collection.insert([data[1], output.tolist()]) data_batch = [[],[]] # Read the images into batches for embedding and insertion for path in tqdm(paths): im = Image.open(path).convert('RGB') data_batch[0].append(preprocess(im)) data_batch[1].append(path) if len(data_batch[0]) % BATCH_SIZE == 0: embed(data_batch) data_batch = [[],[]] # Embed and insert the remainder if len(data_batch[0]) != 0: embed(data_batch) # Call a flush to index any unsealed segments. collection.flush()
- Dieser Schritt ist relativ zeitaufwändig, weil das Einbetten Zeit braucht. Nehmen Sie einen Schluck Kaffee und entspannen Sie sich.
- PyTorch funktioniert möglicherweise nicht gut mit Python 3.9 und früheren Versionen. Ziehen Sie stattdessen die Verwendung von Python 3.10 und späteren Versionen in Betracht.
Durchführen der Suche
Nachdem alle Daten in Milvus eingegeben wurden, können wir mit der Suche beginnen. In diesem Beispiel werden wir nach zwei Beispielbildern suchen. Da wir eine Stapelsuche durchführen, wird die Suchzeit auf die Bilder des Stapels aufgeteilt.
import glob
# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
len(search_paths)
import time
from matplotlib import pyplot as plt
# Embed the search images
def embed(data):
with torch.no_grad():
ret = model(torch.stack(data))
# If more than one image, use squeeze
if len(ret) > 1:
return ret.squeeze().tolist()
# Squeeze would remove batch for single image, so using flatten
else:
return torch.flatten(ret, start_dim=1).tolist()
data_batch = [[],[]]
for path in search_paths:
im = Image.open(path).convert('RGB')
data_batch[0].append(preprocess(im))
data_batch[1].append(path)
embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()
# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)
for hits_i, hits in enumerate(res):
axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
axarr[hits_i][0].set_axis_off()
axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
for hit_i, hit in enumerate(hits):
axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
axarr[hits_i][hit_i + 1].set_axis_off()
axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))
# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')
Das Suchergebnis sollte in etwa so aussehen wie das folgende:
Ausgabe der Bildsuche