🚀 免費嘗試 Zilliz Cloud,完全托管的 Milvus,體驗速度提升 10 倍!立即嘗試

milvus-logo
LFAI
主頁
  • 整合

使用 PyTorch 和 Milvus 進行圖片搜尋

本指南介紹一個整合 PyTorch 與 Milvus 的範例,以使用嵌入式執行圖像搜尋。PyTorch 是一個強大的開源深度學習框架,廣泛用於建立和部署機器學習模型。在本範例中,我們將利用其 Torchvision 函式庫和預先訓練好的 ResNet50 模型來產生代表圖像內容的特徵向量 (embeddings)。這些嵌入向量將儲存在高效能向量資料庫 Milvus 中,以便進行有效率的相似性搜尋。使用的資料集是來自Kaggle 的 Impressionist-Classifier Dataset。透過結合 PyTorch 的深度學習能力與 Milvus 的可擴充搜尋功能,本範例展示了如何建立一個強大且有效率的圖像檢索系統。

讓我們開始吧

安裝需求

在本範例中,我們將使用pymilvus 連線使用 Milvus,torch 用於執行嵌入模型,torchvision 用於實際模型和預處理,gdown 用於下載範例資料集,tqdm 用於載入欄位。

pip install pymilvus torch gdown torchvision tqdm

擷取資料

我們要使用gdown 從 Google Drive 抓取壓縮檔,然後用內建的zipfile 函式庫來解壓縮。

import gdown
import zipfile

url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)

with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
    zip_ref.extractall("./paintings")

資料集的大小為 2.35 GB,下載所花的時間取決於您的網路狀況。

全局參數

這些是我們將會使用的一些主要全局參數,以便於追蹤和更新。

# Milvus Setup Arguments
COLLECTION_NAME = 'image_search'  # Collection name
DIMENSION = 2048  # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"

# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3

設定 Milvus

此時,我們要開始設定 Milvus。步驟如下:

  1. 使用提供的 URI 連線到 Milvus 實例。

    from pymilvus import connections
    
    # Connect to the instance
    connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
    
  2. 如果集合已經存在,請將它刪除。

    from pymilvus import utility
    
    # Remove any previous collections with the same name
    if utility.has_collection(COLLECTION_NAME):
        utility.drop_collection(COLLECTION_NAME)
    
  3. 建立收藏集,收藏 ID、圖片的檔案路徑及其嵌入。

    from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
    
    # Create collection which includes the id, filepath of the image, and image embedding
    fields = [
        FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200),  # VARCHARS need a maximum length, so for this example they are set to 200 characters
        FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
    ]
    schema = CollectionSchema(fields=fields)
    collection = Collection(name=COLLECTION_NAME, schema=schema)
    
  4. 在新建立的集合上建立索引,並將其載入記憶體。

    # Create an AutoIndex index for collection
    index_params = {
    'metric_type':'L2',
    'index_type':"IVF_FLAT",
    'params':{'nlist': 16384}
    }
    collection.create_index(field_name="image_embedding", index_params=index_params)
    collection.load()
    

完成這些步驟後,就可以插入集合並進行搜尋。任何新增的資料都會自動建立索引,並立即可供搜尋。如果資料非常新,搜尋速度可能會較慢,因為會對仍在編制索引的資料使用暴力搜尋。

插入資料

在本範例中,我們將使用torch 及其 model hub 所提供的 ResNet50 模型。為了取得嵌入式資料,我們會去掉最後的分類層,結果模型會提供 2048 個維度的嵌入式資料。在torch 上找到的所有視覺模型都使用相同的預處理,我們在這裡也包含了相同的預處理。

在接下來的幾個步驟中,我們將會

  1. 載入資料。

    import glob
    
    # Get the filepaths of the images
    paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True)
    len(paths)
    
  2. 將資料分批預先處理。

    import torch
    
    # Load the embedding model with the last layer removed
    model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
    model = torch.nn.Sequential(*(list(model.children())[:-1]))
    model.eval()
    
  3. 嵌入資料。

    from torchvision import transforms
    
    # Preprocessing for images
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
  4. 插入資料。

    from PIL import Image
    from tqdm import tqdm
    
    # Embed function that embeds the batch and inserts it
    def embed(data):
        with torch.no_grad():
            output = model(torch.stack(data[0])).squeeze()
            collection.insert([data[1], output.tolist()])
    
    data_batch = [[],[]]
    
    # Read the images into batches for embedding and insertion
    for path in tqdm(paths):
        im = Image.open(path).convert('RGB')
        data_batch[0].append(preprocess(im))
        data_batch[1].append(path)
        if len(data_batch[0]) % BATCH_SIZE == 0:
            embed(data_batch)
            data_batch = [[],[]]
    
    # Embed and insert the remainder
    if len(data_batch[0]) != 0:
        embed(data_batch)
    
    # Call a flush to index any unsealed segments.
    collection.flush()
    
    • 這個步驟相對耗時,因為嵌入需要時間。喝一口咖啡,放鬆一下。
    • PyTorch 可能無法在 Python 3.9 及更早的版本中順利運作。請考慮使用 Python 3.10 或更高版本。

將所有資料插入 Milvus 之後,我們就可以開始執行搜尋了。在這個範例中,我們要搜尋兩個範例圖片。由於我們進行的是批次搜尋,因此搜尋時間是由批次中的影像共同分擔的。

import glob

# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
len(search_paths)
import time
from matplotlib import pyplot as plt

# Embed the search images
def embed(data):
    with torch.no_grad():
        ret = model(torch.stack(data))
        # If more than one image, use squeeze
        if len(ret) > 1:
            return ret.squeeze().tolist()
        # Squeeze would remove batch for single image, so using flatten
        else:
            return torch.flatten(ret, start_dim=1).tolist()

data_batch = [[],[]]

for path in search_paths:
    im = Image.open(path).convert('RGB')
    data_batch[0].append(preprocess(im))
    data_batch[1].append(path)

embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()
# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)

for hits_i, hits in enumerate(res):
    axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
    axarr[hits_i][0].set_axis_off()
    axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
    for hit_i, hit in enumerate(hits):
        axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
        axarr[hits_i][hit_i + 1].set_axis_off()
        axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))

# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')

搜尋結果的影像應該類似於下圖:

Image search output 影像搜尋輸出

免費嘗試托管的 Milvus

Zilliz Cloud 無縫接入,由 Milvus 提供動力,速度提升 10 倍。

開始使用
反饋

這個頁面有幫助嗎?