Bildsuche mit Milvus
In diesem Notizbuch zeigen wir Ihnen, wie Sie Milvus verwenden können, um nach ähnlichen Bildern in einem Datensatz zu suchen. Wir werden eine Teilmenge des ImageNet-Datensatzes verwenden und dann nach einem Bild eines afghanischen Hundes suchen, um dies zu demonstrieren.
Vorbereitung des Datensatzes
Zunächst müssen wir den Datensatz laden und für die weitere Verarbeitung extrahieren.
!wget https://github.com/milvus-io/pymilvus-assets/releases/download/imagedata/reverse_image_search.zip
!unzip -q -o reverse_image_search.zip
Vorraussetzungen
Um dieses Notizbuch auszuführen, müssen Sie die folgenden Abhängigkeiten installiert haben:
- pymilvus>=2.4.2
- timm
- torch
- numpy
- sklearn
- pillow
Um Colab auszuführen, stellen wir die praktischen Befehle zur Installation der erforderlichen Abhängigkeiten zur Verfügung.
$ pip install pymilvus --upgrade
$ pip install timm
Wenn Sie Google Colab verwenden, müssen Sie möglicherweise die Runtime neu starten, um die gerade installierten Abhängigkeiten zu aktivieren. (Klicken Sie auf das Menü "Runtime" am oberen Rand des Bildschirms und wählen Sie "Restart session" aus dem Dropdown-Menü).
Definieren Sie den Feature Extractor
Dann müssen wir einen Feature-Extraktor definieren, der die Einbettung aus einem Bild mit Hilfe des ResNet-34-Modells von timm extrahiert.
import torch
from PIL import Image
import timm
from sklearn.preprocessing import normalize
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
class FeatureExtractor:
def __init__(self, modelname):
# Load the pre-trained model
self.model = timm.create_model(
modelname, pretrained=True, num_classes=0, global_pool="avg"
)
self.model.eval()
# Get the input size required by the model
self.input_size = self.model.default_cfg["input_size"]
config = resolve_data_config({}, model=modelname)
# Get the preprocessing function provided by TIMM for the model
self.preprocess = create_transform(**config)
def __call__(self, imagepath):
# Preprocess the input image
input_image = Image.open(imagepath).convert("RGB") # Convert to RGB if needed
input_image = self.preprocess(input_image)
# Convert the image to a PyTorch tensor and add a batch dimension
input_tensor = input_image.unsqueeze(0)
# Perform inference
with torch.no_grad():
output = self.model(input_tensor)
# Extract the feature vector
feature_vector = output.squeeze().numpy()
return normalize(feature_vector.reshape(1, -1), norm="l2").flatten()
Erstellen einer Milvus-Sammlung
Dann müssen wir eine Milvus-Sammlung erstellen, um die Bildeinbettungen zu speichern
from pymilvus import MilvusClient
# Set up a Milvus client
client = MilvusClient(uri="example.db")
# Create a collection in quick setup mode
if client.has_collection(collection_name="image_embeddings"):
client.drop_collection(collection_name="image_embeddings")
client.create_collection(
collection_name="image_embeddings",
vector_field_name="vector",
dimension=512,
auto_id=True,
enable_dynamic_field=True,
metric_type="COSINE",
)
Für das Argument von MilvusClient
:
- Die Einstellung von
uri
als lokale Datei, z. B../milvus.db
, ist die bequemste Methode, da sie automatisch Milvus Lite verwendet, um alle Daten in dieser Datei zu speichern. - Wenn Sie große Datenmengen haben, können Sie einen leistungsfähigeren Milvus-Server auf Docker oder Kubernetes einrichten. Bei dieser Einrichtung verwenden Sie bitte die Server-Uri, z. B.
http://localhost:19530
, alsuri
. - Wenn Sie Zilliz Cloud, den vollständig verwalteten Cloud-Service für Milvus, verwenden möchten, passen Sie
uri
undtoken
an, die dem öffentlichen Endpunkt und dem Api-Schlüssel in Zilliz Cloud entsprechen.
Einfügen der Einbettungen in Milvus
Wir extrahieren die Einbettungen jedes Bildes mit Hilfe des ResNet34-Modells und fügen die Bilder aus dem Trainingsset in Milvus ein.
import os
extractor = FeatureExtractor("resnet34")
root = "./train"
insert = True
if insert is True:
for dirpath, foldername, filenames in os.walk(root):
for filename in filenames:
if filename.endswith(".JPEG"):
filepath = dirpath + "/" + filename
image_embedding = extractor(filepath)
client.insert(
"image_embeddings",
{"vector": image_embedding, "filename": filepath},
)
from IPython.display import display
query_image = "./test/Afghan_hound/n02088094_4261.JPEG"
results = client.search(
"image_embeddings",
data=[extractor(query_image)],
output_fields=["filename"],
search_params={"metric_type": "COSINE"},
)
images = []
for result in results:
for hit in result[:10]:
filename = hit["entity"]["filename"]
img = Image.open(filename)
img = img.resize((150, 150))
images.append(img)
width = 150 * 5
height = 150 * 2
concatenated_image = Image.new("RGB", (width, height))
for idx, img in enumerate(images):
x = idx % 5
y = idx // 5
concatenated_image.paste(img, (x * 150, y * 150))
display("query")
display(Image.open(query_image).resize((150, 150)))
display("results")
display(concatenated_image)
'query'
png
'results'
Ergebnisse
Wir können sehen, dass die meisten Bilder aus der gleichen Kategorie wie das Suchbild stammen, nämlich der afghanische Jagdhund. Dies bedeutet, dass wir ähnliche Bilder wie das Suchbild gefunden haben.
Schnelles Einsetzen
Wie Sie eine Online-Demo mit diesem Tutorial starten können, erfahren Sie in der Beispielanwendung.