🚀 Coba Zilliz Cloud, Milvus yang sepenuhnya terkelola, secara gratis—rasakan performa 10x lebih cepat! Coba Sekarang>>

milvus-logo
LFAI
Beranda
  • Integrasi
  • Home
  • Docs
  • Integrasi

  • Model Penyematan

  • PyTorch

Pencarian Gambar dengan PyTorch dan Milvus

Panduan ini memperkenalkan contoh pengintegrasian PyTorch dan Milvus untuk melakukan pencarian gambar menggunakan penyematan. PyTorch adalah kerangka kerja pembelajaran mendalam sumber terbuka yang kuat dan banyak digunakan untuk membangun dan menerapkan model pembelajaran mesin. Dalam contoh ini, kita akan memanfaatkan pustaka Torchvision dan model ResNet50 yang telah dilatih sebelumnya untuk menghasilkan vektor fitur (embedding) yang merepresentasikan konten gambar. Embeddings ini akan disimpan di Milvus, database vektor berkinerja tinggi, untuk memungkinkan pencarian kemiripan yang efisien. Dataset yang digunakan adalah Impressionist-Classifier Dataset dari Kaggle. Dengan menggabungkan kemampuan pembelajaran mendalam PyTorch dengan fungsionalitas pencarian yang dapat diskalakan dari Milvus, contoh ini mendemonstrasikan cara membangun sistem pengambilan gambar yang kuat dan efisien.

Mari kita mulai!

Menginstal persyaratan

Untuk contoh ini, kita akan menggunakan pymilvus untuk terhubung menggunakan Milvus, torch untuk menjalankan model penyematan, torchvision untuk model aktual dan prapemrosesan, gdown untuk mengunduh dataset contoh dan tqdm untuk memuat bilah.

pip install pymilvus torch gdown torchvision tqdm

Mengambil data

Kita akan menggunakan gdown untuk mengambil zip dari Google Drive dan kemudian mendekompresnya dengan pustaka zipfile bawaan.

import gdown
import zipfile

url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)

with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
    zip_ref.extractall("./paintings")

Ukuran dataset adalah 2,35 GB, dan waktu yang dihabiskan untuk mengunduhnya tergantung pada kondisi jaringan Anda.

Argumen Global

Berikut ini adalah beberapa argumen global utama yang akan kita gunakan untuk memudahkan pelacakan dan pembaruan.

# Milvus Setup Arguments
COLLECTION_NAME = 'image_search'  # Collection name
DIMENSION = 2048  # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"

# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3

Menyiapkan Milvus

Pada tahap ini, kita akan mulai menyiapkan Milvus. Langkah-langkahnya adalah sebagai berikut:

  1. Hubungkan ke instans Milvus menggunakan URI yang disediakan.

    from pymilvus import connections
    
    # Connect to the instance
    connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
    
  2. Jika koleksinya sudah ada, hapus saja.

    from pymilvus import utility
    
    # Remove any previous collections with the same name
    if utility.has_collection(COLLECTION_NAME):
        utility.drop_collection(COLLECTION_NAME)
    
  3. Buat koleksi yang menyimpan ID, jalur berkas gambar, dan penyematannya.

    from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
    
    # Create collection which includes the id, filepath of the image, and image embedding
    fields = [
        FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200),  # VARCHARS need a maximum length, so for this example they are set to 200 characters
        FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
    ]
    schema = CollectionSchema(fields=fields)
    collection = Collection(name=COLLECTION_NAME, schema=schema)
    
  4. Buat indeks pada koleksi yang baru dibuat dan muat ke dalam memori.

    # Create an AutoIndex index for collection
    index_params = {
    'metric_type':'L2',
    'index_type':"IVF_FLAT",
    'params':{'nlist': 16384}
    }
    collection.create_index(field_name="image_embedding", index_params=index_params)
    collection.load()
    

Setelah langkah-langkah ini selesai, koleksi siap untuk dimasukkan ke dalam dan dicari. Setiap data yang ditambahkan akan diindeks secara otomatis dan tersedia untuk segera dicari. Jika data masih sangat baru, pencarian mungkin akan lebih lambat karena pencarian brute force akan digunakan pada data yang masih dalam proses pengindeksan.

Memasukkan data

Untuk contoh ini, kita akan menggunakan model ResNet50 yang disediakan oleh torch dan hub modelnya. Untuk mendapatkan penyematan, kita akan membuang lapisan klasifikasi akhir, yang menghasilkan model yang memberikan penyematan 2048 dimensi. Semua model visi yang ditemukan di torch menggunakan preprocessing yang sama dengan yang kami sertakan di sini.

Dalam beberapa langkah berikutnya kita akan melakukannya:

  1. Memuat data.

    import glob
    
    # Get the filepaths of the images
    paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True)
    len(paths)
    
  2. Memproses data menjadi beberapa kelompok.

    import torch
    
    # Load the embedding model with the last layer removed
    model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
    model = torch.nn.Sequential(*(list(model.children())[:-1]))
    model.eval()
    
  3. Menanamkan data.

    from torchvision import transforms
    
    # Preprocessing for images
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
  4. Memasukkan data.

    from PIL import Image
    from tqdm import tqdm
    
    # Embed function that embeds the batch and inserts it
    def embed(data):
        with torch.no_grad():
            output = model(torch.stack(data[0])).squeeze()
            collection.insert([data[1], output.tolist()])
    
    data_batch = [[],[]]
    
    # Read the images into batches for embedding and insertion
    for path in tqdm(paths):
        im = Image.open(path).convert('RGB')
        data_batch[0].append(preprocess(im))
        data_batch[1].append(path)
        if len(data_batch[0]) % BATCH_SIZE == 0:
            embed(data_batch)
            data_batch = [[],[]]
    
    # Embed and insert the remainder
    if len(data_batch[0]) != 0:
        embed(data_batch)
    
    # Call a flush to index any unsealed segments.
    collection.flush()
    
    • Langkah ini relatif memakan waktu karena penyematan membutuhkan waktu. Minumlah seteguk kopi dan bersantailah.
    • PyTorch mungkin tidak bekerja dengan baik dengan Python 3.9 dan versi sebelumnya. Pertimbangkan untuk menggunakan Python 3.10 dan versi yang lebih baru.

Dengan semua data yang telah dimasukkan ke dalam Milvus, kita dapat mulai melakukan pencarian. Pada contoh ini, kita akan mencari dua contoh gambar. Karena kita melakukan pencarian batch, waktu pencarian dibagi ke seluruh gambar dalam batch tersebut.

import glob

# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
len(search_paths)
import time
from matplotlib import pyplot as plt

# Embed the search images
def embed(data):
    with torch.no_grad():
        ret = model(torch.stack(data))
        # If more than one image, use squeeze
        if len(ret) > 1:
            return ret.squeeze().tolist()
        # Squeeze would remove batch for single image, so using flatten
        else:
            return torch.flatten(ret, start_dim=1).tolist()

data_batch = [[],[]]

for path in search_paths:
    im = Image.open(path).convert('RGB')
    data_batch[0].append(preprocess(im))
    data_batch[1].append(path)

embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()
# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)

for hits_i, hits in enumerate(res):
    axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
    axarr[hits_i][0].set_axis_off()
    axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
    for hit_i, hit in enumerate(hits):
        axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
        axarr[hits_i][hit_i + 1].set_axis_off()
        axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))

# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')

Gambar hasil pencarian akan terlihat seperti berikut ini:

Image search output Output pencarian gambar

Coba Milvus yang Dikelola secara Gratis

Zilliz Cloud bebas masalah, didukung oleh Milvus dan 10x lebih cepat.

Mulai
Umpan balik

Apakah halaman ini bermanfaat?