Pencarian Gambar dengan Milvus
Dalam buku catatan ini, kami akan menunjukkan kepada Anda cara menggunakan Milvus untuk mencari gambar yang mirip dalam sebuah dataset. Kami akan menggunakan subset dari dataset ImageNet, kemudian mencari gambar anjing Afghan untuk mendemonstrasikan hal ini.
Persiapan Dataset
Pertama, kita perlu memuat dataset dan mengekstraknya untuk diproses lebih lanjut.
!wget https://github.com/milvus-io/pymilvus-assets/releases/download/imagedata/reverse_image_search.zip
!unzip -q -o reverse_image_search.zip
Prasyarat
Untuk menjalankan notebook ini, Anda perlu menginstal dependensi berikut ini:
- pymilvus>=2.4.2
- timm
- torch
- numpy
- sklearn
- bantal
Untuk menjalankan Colab, kami menyediakan perintah praktis untuk menginstal dependensi yang diperlukan.
$ pip install pymilvus --upgrade
$ pip install timm
Jika Anda menggunakan Google Colab, untuk mengaktifkan dependensi yang baru saja diinstal, Anda mungkin perlu memulai ulang runtime. (Klik menu "Runtime" di bagian atas layar, dan pilih "Restart session" dari menu tarik-turun).
Tentukan Pengekstrak Fitur
Kemudian, kita perlu mendefinisikan ekstraktor fitur yang mengekstrak penyematan dari sebuah gambar menggunakan model ResNet-34 dari timm.
import torch
from PIL import Image
import timm
from sklearn.preprocessing import normalize
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
class FeatureExtractor:
def __init__(self, modelname):
# Load the pre-trained model
self.model = timm.create_model(
modelname, pretrained=True, num_classes=0, global_pool="avg"
)
self.model.eval()
# Get the input size required by the model
self.input_size = self.model.default_cfg["input_size"]
config = resolve_data_config({}, model=modelname)
# Get the preprocessing function provided by TIMM for the model
self.preprocess = create_transform(**config)
def __call__(self, imagepath):
# Preprocess the input image
input_image = Image.open(imagepath).convert("RGB") # Convert to RGB if needed
input_image = self.preprocess(input_image)
# Convert the image to a PyTorch tensor and add a batch dimension
input_tensor = input_image.unsqueeze(0)
# Perform inference
with torch.no_grad():
output = self.model(input_tensor)
# Extract the feature vector
feature_vector = output.squeeze().numpy()
return normalize(feature_vector.reshape(1, -1), norm="l2").flatten()
Membuat Koleksi Milvus
Kemudian kita perlu membuat koleksi Milvus untuk menyimpan embedding gambar
from pymilvus import MilvusClient
# Set up a Milvus client
client = MilvusClient(uri="example.db")
# Create a collection in quick setup mode
if client.has_collection(collection_name="image_embeddings"):
client.drop_collection(collection_name="image_embeddings")
client.create_collection(
collection_name="image_embeddings",
vector_field_name="vector",
dimension=512,
auto_id=True,
enable_dynamic_field=True,
metric_type="COSINE",
)
Adapun argumen dari MilvusClient
:
- Menetapkan
uri
sebagai file lokal, misalnya./milvus.db
, adalah metode yang paling mudah, karena secara otomatis menggunakan Milvus Lite untuk menyimpan semua data dalam file ini. - Jika Anda memiliki data dalam skala besar, Anda dapat mengatur server Milvus yang lebih berkinerja pada docker atau kubernetes. Dalam pengaturan ini, silakan gunakan uri server, misalnya
http://localhost:19530
, sebagaiuri
. - Jika Anda ingin menggunakan Zilliz Cloud, layanan cloud yang dikelola sepenuhnya untuk Milvus, sesuaikan
uri
dantoken
, yang sesuai dengan kunci Public Endpoint dan Api di Zilliz Cloud.
Memasukkan Embeddings ke Milvus
Kami akan mengekstrak embeddings dari setiap gambar menggunakan model ResNet34 dan memasukkan gambar dari set pelatihan ke Milvus.
import os
extractor = FeatureExtractor("resnet34")
root = "./train"
insert = True
if insert is True:
for dirpath, foldername, filenames in os.walk(root):
for filename in filenames:
if filename.endswith(".JPEG"):
filepath = dirpath + "/" + filename
image_embedding = extractor(filepath)
client.insert(
"image_embeddings",
{"vector": image_embedding, "filename": filepath},
)
from IPython.display import display
query_image = "./test/Afghan_hound/n02088094_4261.JPEG"
results = client.search(
"image_embeddings",
data=[extractor(query_image)],
output_fields=["filename"],
search_params={"metric_type": "COSINE"},
)
images = []
for result in results:
for hit in result[:10]:
filename = hit["entity"]["filename"]
img = Image.open(filename)
img = img.resize((150, 150))
images.append(img)
width = 150 * 5
height = 150 * 2
concatenated_image = Image.new("RGB", (width, height))
for idx, img in enumerate(images):
x = idx % 5
y = idx // 5
concatenated_image.paste(img, (x * 150, y * 150))
display("query")
display(Image.open(query_image).resize((150, 150)))
display("results")
display(concatenated_image)
'query'
png
'results'
Hasil
Kita dapat melihat bahwa sebagian besar gambar berasal dari kategori yang sama dengan gambar pencarian, yaitu anjing Afghan. Ini berarti kami menemukan gambar yang mirip dengan gambar pencarian.
Penyebaran Cepat
Untuk mempelajari tentang cara memulai demo online dengan tutorial ini, silakan lihat contoh aplikasi.