Milvusを使った画像検索
このノートブックでは、Milvusを使ってデータセット内の類似画像を検索する方法を紹介します。ImageNetデータセットのサブセットを使用し、アフガンハウンドの画像を検索します。
データセットの準備
まず、データセットをロードし、さらなる処理のために抽出を解除する必要があります。
!wget https://github.com/milvus-io/pymilvus-assets/releases/download/imagedata/reverse_image_search.zip
!unzip -q -o reverse_image_search.zip
前提条件
このノートブックを実行するには、以下の依存関係がインストールされている必要がある:
- pymilvus>=2.4.2
- timm
- トーチ
- numpy
- sklearn
- 枕
Colabを実行するために、必要な依存関係をインストールするための便利なコマンドを提供します。
$ pip install pymilvus --upgrade
$ pip install timm
Google Colabをご利用の場合、インストールしたばかりの依存関係を有効にするには、ランタイムを再起動する必要があります。(画面上部の "Runtime "メニューをクリックし、ドロップダウンメニューから "Restart session "を選択してください)。
フィーチャー・エクストラクターの定義
次に、timmのResNet-34モデルを用いて画像から埋め込みを抽出する特徴抽出器を定義します。
import torch
from PIL import Image
import timm
from sklearn.preprocessing import normalize
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
class FeatureExtractor:
def __init__(self, modelname):
# Load the pre-trained model
self.model = timm.create_model(
modelname, pretrained=True, num_classes=0, global_pool="avg"
)
self.model.eval()
# Get the input size required by the model
self.input_size = self.model.default_cfg["input_size"]
config = resolve_data_config({}, model=modelname)
# Get the preprocessing function provided by TIMM for the model
self.preprocess = create_transform(**config)
def __call__(self, imagepath):
# Preprocess the input image
input_image = Image.open(imagepath).convert("RGB") # Convert to RGB if needed
input_image = self.preprocess(input_image)
# Convert the image to a PyTorch tensor and add a batch dimension
input_tensor = input_image.unsqueeze(0)
# Perform inference
with torch.no_grad():
output = self.model(input_tensor)
# Extract the feature vector
feature_vector = output.squeeze().numpy()
return normalize(feature_vector.reshape(1, -1), norm="l2").flatten()
Milvusコレクションの作成
埋め込み画像を格納するMilvusコレクションを作成します。
from pymilvus import MilvusClient
# Set up a Milvus client
client = MilvusClient(uri="example.db")
# Create a collection in quick setup mode
if client.has_collection(collection_name="image_embeddings"):
client.drop_collection(collection_name="image_embeddings")
client.create_collection(
collection_name="image_embeddings",
vector_field_name="vector",
dimension=512,
auto_id=True,
enable_dynamic_field=True,
metric_type="COSINE",
)
引数としてMilvusClient
を指定します:
./milvus.db
のように、uri
をローカルファイルとして設定するのが最も便利な方法である。このファイルには自動的にMilvus Liteが利用され、すべてのデータが格納される。- データ規模が大きい場合は、dockerやkubernetes上に、よりパフォーマンスの高いMilvusサーバを構築することができます。このセットアップでは、サーバの uri、例えば
http://localhost:19530
をuri
として使用してください。 - MilvusのフルマネージドクラウドサービスであるZilliz Cloudを使用する場合は、Zilliz CloudのPublic EndpointとApi keyに対応する
uri
とtoken
を調整してください。
Milvusへのエンベッディングの挿入
ResNet34モデルを用いて各画像のエンベッディングを抽出し、学習セットからMilvusに画像を挿入します。
import os
extractor = FeatureExtractor("resnet34")
root = "./train"
insert = True
if insert is True:
for dirpath, foldername, filenames in os.walk(root):
for filename in filenames:
if filename.endswith(".JPEG"):
filepath = dirpath + "/" + filename
image_embedding = extractor(filepath)
client.insert(
"image_embeddings",
{"vector": image_embedding, "filename": filepath},
)
from IPython.display import display
query_image = "./test/Afghan_hound/n02088094_4261.JPEG"
results = client.search(
"image_embeddings",
data=[extractor(query_image)],
output_fields=["filename"],
search_params={"metric_type": "COSINE"},
)
images = []
for result in results:
for hit in result[:10]:
filename = hit["entity"]["filename"]
img = Image.open(filename)
img = img.resize((150, 150))
images.append(img)
width = 150 * 5
height = 150 * 2
concatenated_image = Image.new("RGB", (width, height))
for idx, img in enumerate(images):
x = idx % 5
y = idx // 5
concatenated_image.paste(img, (x * 150, y * 150))
display("query")
display(Image.open(query_image).resize((150, 150)))
display("results")
display(concatenated_image)
'query'
png
'results'
結果
ほとんどの画像が検索画像と同じカテゴリー(アフガンハウンド)の画像であることがわかる。つまり、検索画像と類似した画像が見つかったということです。
クイックデプロイ
このチュートリアルでオンラインデモを開始する方法については、サンプルアプリケーションを参照してください。