milvus-logo
LFAI
Home
  • Anleitungen

Bildsuche mit Milvus

Open In Colab

In diesem Notizbuch zeigen wir Ihnen, wie Sie Milvus verwenden können, um nach ähnlichen Bildern in einem Datensatz zu suchen. Wir werden eine Teilmenge des ImageNet-Datensatzes verwenden und dann nach einem Bild eines afghanischen Hundes suchen, um dies zu demonstrieren.

Vorbereitung des Datensatzes

Zunächst müssen wir den Datensatz laden und für die weitere Verarbeitung extrahieren.

!wget https://github.com/milvus-io/pymilvus-assets/releases/download/imagedata/reverse_image_search.zip
!unzip -q -o reverse_image_search.zip

Vorraussetzungen

Um dieses Notizbuch auszuführen, müssen Sie die folgenden Abhängigkeiten installiert haben:

  • pymilvus>=2.4.2
  • timm
  • torch
  • numpy
  • sklearn
  • pillow

Um Colab auszuführen, stellen wir die praktischen Befehle zur Installation der erforderlichen Abhängigkeiten zur Verfügung.

$ pip install pymilvus --upgrade
$ pip install timm

Wenn Sie Google Colab verwenden, müssen Sie möglicherweise die Runtime neu starten, um die gerade installierten Abhängigkeiten zu aktivieren. (Klicken Sie auf das Menü "Runtime" am oberen Rand des Bildschirms und wählen Sie "Restart session" aus dem Dropdown-Menü).

Definieren Sie den Feature Extractor

Dann müssen wir einen Feature-Extraktor definieren, der die Einbettung aus einem Bild mit Hilfe des ResNet-34-Modells von timm extrahiert.

import torch
from PIL import Image
import timm
from sklearn.preprocessing import normalize
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform


class FeatureExtractor:
    def __init__(self, modelname):
        # Load the pre-trained model
        self.model = timm.create_model(
            modelname, pretrained=True, num_classes=0, global_pool="avg"
        )
        self.model.eval()

        # Get the input size required by the model
        self.input_size = self.model.default_cfg["input_size"]

        config = resolve_data_config({}, model=modelname)
        # Get the preprocessing function provided by TIMM for the model
        self.preprocess = create_transform(**config)

    def __call__(self, imagepath):
        # Preprocess the input image
        input_image = Image.open(imagepath).convert("RGB")  # Convert to RGB if needed
        input_image = self.preprocess(input_image)

        # Convert the image to a PyTorch tensor and add a batch dimension
        input_tensor = input_image.unsqueeze(0)

        # Perform inference
        with torch.no_grad():
            output = self.model(input_tensor)

        # Extract the feature vector
        feature_vector = output.squeeze().numpy()

        return normalize(feature_vector.reshape(1, -1), norm="l2").flatten()

Erstellen einer Milvus-Sammlung

Dann müssen wir eine Milvus-Sammlung erstellen, um die Bildeinbettungen zu speichern

from pymilvus import MilvusClient

# Set up a Milvus client
client = MilvusClient(uri="example.db")
# Create a collection in quick setup mode
if client.has_collection(collection_name="image_embeddings"):
    client.drop_collection(collection_name="image_embeddings")
client.create_collection(
    collection_name="image_embeddings",
    vector_field_name="vector",
    dimension=512,
    auto_id=True,
    enable_dynamic_field=True,
    metric_type="COSINE",
)

Für das Argument von MilvusClient:

  • Die Einstellung von uri als lokale Datei, z. B../milvus.db, ist die bequemste Methode, da sie automatisch Milvus Lite verwendet, um alle Daten in dieser Datei zu speichern.
  • Wenn Sie große Datenmengen haben, können Sie einen leistungsfähigeren Milvus-Server auf Docker oder Kubernetes einrichten. Bei dieser Einrichtung verwenden Sie bitte die Server-Uri, z. B.http://localhost:19530, als uri.
  • Wenn Sie Zilliz Cloud, den vollständig verwalteten Cloud-Service für Milvus, verwenden möchten, passen Sie uri und token an, die dem öffentlichen Endpunkt und dem Api-Schlüssel in Zilliz Cloud entsprechen.

Einfügen der Einbettungen in Milvus

Wir extrahieren die Einbettungen jedes Bildes mit Hilfe des ResNet34-Modells und fügen die Bilder aus dem Trainingsset in Milvus ein.

import os

extractor = FeatureExtractor("resnet34")

root = "./train"
insert = True
if insert is True:
    for dirpath, foldername, filenames in os.walk(root):
        for filename in filenames:
            if filename.endswith(".JPEG"):
                filepath = dirpath + "/" + filename
                image_embedding = extractor(filepath)
                client.insert(
                    "image_embeddings",
                    {"vector": image_embedding, "filename": filepath},
                )
from IPython.display import display

query_image = "./test/Afghan_hound/n02088094_4261.JPEG"

results = client.search(
    "image_embeddings",
    data=[extractor(query_image)],
    output_fields=["filename"],
    search_params={"metric_type": "COSINE"},
)
images = []
for result in results:
    for hit in result[:10]:
        filename = hit["entity"]["filename"]
        img = Image.open(filename)
        img = img.resize((150, 150))
        images.append(img)

width = 150 * 5
height = 150 * 2
concatenated_image = Image.new("RGB", (width, height))

for idx, img in enumerate(images):
    x = idx % 5
    y = idx // 5
    concatenated_image.paste(img, (x * 150, y * 150))
display("query")
display(Image.open(query_image).resize((150, 150)))
display("results")
display(concatenated_image)
'query'

png png

'results'

Results Ergebnisse

Wir können sehen, dass die meisten Bilder aus der gleichen Kategorie wie das Suchbild stammen, nämlich der afghanische Jagdhund. Dies bedeutet, dass wir ähnliche Bilder wie das Suchbild gefunden haben.

Schnelles Einsetzen

Wie Sie eine Online-Demo mit diesem Tutorial starten können, erfahren Sie in der Beispielanwendung.