Pencarian Gambar dengan PyTorch dan Milvus
Panduan ini memperkenalkan contoh pengintegrasian PyTorch dan Milvus untuk melakukan pencarian gambar menggunakan penyematan. PyTorch adalah kerangka kerja pembelajaran mendalam sumber terbuka yang kuat dan banyak digunakan untuk membangun dan menerapkan model pembelajaran mesin. Dalam contoh ini, kita akan memanfaatkan pustaka Torchvision dan model ResNet50 yang telah dilatih sebelumnya untuk menghasilkan vektor fitur (embedding) yang merepresentasikan konten gambar. Embeddings ini akan disimpan di Milvus, database vektor berkinerja tinggi, untuk memungkinkan pencarian kemiripan yang efisien. Dataset yang digunakan adalah Impressionist-Classifier Dataset dari Kaggle. Dengan menggabungkan kemampuan pembelajaran mendalam PyTorch dengan fungsionalitas pencarian yang dapat diskalakan dari Milvus, contoh ini mendemonstrasikan cara membangun sistem pengambilan gambar yang kuat dan efisien.
Mari kita mulai!
Menginstal persyaratan
Untuk contoh ini, kita akan menggunakan pymilvus
untuk terhubung menggunakan Milvus, torch
untuk menjalankan model penyematan, torchvision
untuk model aktual dan prapemrosesan, gdown
untuk mengunduh dataset contoh dan tqdm
untuk memuat bilah.
pip install pymilvus torch gdown torchvision tqdm
Mengambil data
Kita akan menggunakan gdown
untuk mengambil zip dari Google Drive dan kemudian mendekompresnya dengan pustaka zipfile
bawaan.
import gdown
import zipfile
url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)
with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
zip_ref.extractall("./paintings")
Ukuran dataset adalah 2,35 GB, dan waktu yang dihabiskan untuk mengunduhnya tergantung pada kondisi jaringan Anda.
Argumen Global
Berikut ini adalah beberapa argumen global utama yang akan kita gunakan untuk memudahkan pelacakan dan pembaruan.
# Milvus Setup Arguments
COLLECTION_NAME = 'image_search' # Collection name
DIMENSION = 2048 # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"
# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3
Menyiapkan Milvus
Pada tahap ini, kita akan mulai menyiapkan Milvus. Langkah-langkahnya adalah sebagai berikut:
Hubungkan ke instans Milvus menggunakan URI yang disediakan.
from pymilvus import connections # Connect to the instance connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
Jika koleksinya sudah ada, hapus saja.
from pymilvus import utility # Remove any previous collections with the same name if utility.has_collection(COLLECTION_NAME): utility.drop_collection(COLLECTION_NAME)
Buat koleksi yang menyimpan ID, jalur berkas gambar, dan penyematannya.
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection # Create collection which includes the id, filepath of the image, and image embedding fields = [ FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200), # VARCHARS need a maximum length, so for this example they are set to 200 characters FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION) ] schema = CollectionSchema(fields=fields) collection = Collection(name=COLLECTION_NAME, schema=schema)
Buat indeks pada koleksi yang baru dibuat dan muat ke dalam memori.
# Create an AutoIndex index for collection index_params = { 'metric_type':'L2', 'index_type':"IVF_FLAT", 'params':{'nlist': 16384} } collection.create_index(field_name="image_embedding", index_params=index_params) collection.load()
Setelah langkah-langkah ini selesai, koleksi siap untuk dimasukkan ke dalam dan dicari. Setiap data yang ditambahkan akan diindeks secara otomatis dan tersedia untuk segera dicari. Jika data masih sangat baru, pencarian mungkin akan lebih lambat karena pencarian brute force akan digunakan pada data yang masih dalam proses pengindeksan.
Memasukkan data
Untuk contoh ini, kita akan menggunakan model ResNet50 yang disediakan oleh torch
dan hub modelnya. Untuk mendapatkan penyematan, kita akan membuang lapisan klasifikasi akhir, yang menghasilkan model yang memberikan penyematan 2048 dimensi. Semua model visi yang ditemukan di torch
menggunakan preprocessing yang sama dengan yang kami sertakan di sini.
Dalam beberapa langkah berikutnya kita akan melakukannya:
Memuat data.
import glob # Get the filepaths of the images paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True) len(paths)
Memproses data menjadi beberapa kelompok.
import torch # Load the embedding model with the last layer removed model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True) model = torch.nn.Sequential(*(list(model.children())[:-1])) model.eval()
Menanamkan data.
from torchvision import transforms # Preprocessing for images preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])
Memasukkan data.
from PIL import Image from tqdm import tqdm # Embed function that embeds the batch and inserts it def embed(data): with torch.no_grad(): output = model(torch.stack(data[0])).squeeze() collection.insert([data[1], output.tolist()]) data_batch = [[],[]] # Read the images into batches for embedding and insertion for path in tqdm(paths): im = Image.open(path).convert('RGB') data_batch[0].append(preprocess(im)) data_batch[1].append(path) if len(data_batch[0]) % BATCH_SIZE == 0: embed(data_batch) data_batch = [[],[]] # Embed and insert the remainder if len(data_batch[0]) != 0: embed(data_batch) # Call a flush to index any unsealed segments. collection.flush()
- Langkah ini relatif memakan waktu karena penyematan membutuhkan waktu. Minumlah seteguk kopi dan bersantailah.
- PyTorch mungkin tidak bekerja dengan baik dengan Python 3.9 dan versi sebelumnya. Pertimbangkan untuk menggunakan Python 3.10 dan versi yang lebih baru.
Melakukan pencarian
Dengan semua data yang telah dimasukkan ke dalam Milvus, kita dapat mulai melakukan pencarian. Pada contoh ini, kita akan mencari dua contoh gambar. Karena kita melakukan pencarian batch, waktu pencarian dibagi ke seluruh gambar dalam batch tersebut.
import glob
# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
len(search_paths)
import time
from matplotlib import pyplot as plt
# Embed the search images
def embed(data):
with torch.no_grad():
ret = model(torch.stack(data))
# If more than one image, use squeeze
if len(ret) > 1:
return ret.squeeze().tolist()
# Squeeze would remove batch for single image, so using flatten
else:
return torch.flatten(ret, start_dim=1).tolist()
data_batch = [[],[]]
for path in search_paths:
im = Image.open(path).convert('RGB')
data_batch[0].append(preprocess(im))
data_batch[1].append(path)
embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()
# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)
for hits_i, hits in enumerate(res):
axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
axarr[hits_i][0].set_axis_off()
axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
for hit_i, hit in enumerate(hits):
axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
axarr[hits_i][hit_i + 1].set_axis_off()
axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))
# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')
Gambar hasil pencarian akan terlihat seperti berikut ini:
Output pencarian gambar