milvus-logo
LFAI
홈페이지
  • 튜토리얼

Milvus로 이미지 검색하기

Open In Colab

이 노트북에서는 Milvus를 사용해 데이터 세트에서 유사한 이미지를 검색하는 방법을 보여드리겠습니다. 이를 보여드리기 위해 ImageNet 데이터 집합의 하위 집합을 사용한 다음 아프간 사냥개 이미지를 검색해 보겠습니다.

데이터 세트 준비

먼저, 추가 처리를 위해 데이터 집합을 로드하고 압축을 풀어야 합니다.

!wget https://github.com/milvus-io/pymilvus-assets/releases/download/imagedata/reverse_image_search.zip
!unzip -q -o reverse_image_search.zip

전제 조건

이 노트북을 실행하려면 다음 종속성이 설치되어 있어야 합니다:

  • pymilvus>=2.4.2
  • timm
  • torch
  • numpy
  • sklearn
  • pillow

Colab을 실행하기 위해 필요한 종속성을 설치하는 편리한 명령어를 제공합니다.

$ pip install pymilvus --upgrade
$ pip install timm

Google Colab을 사용하는 경우 방금 설치한 종속 요소를 사용하려면 런타임을 다시 시작해야 할 수 있습니다. (화면 상단의 "런타임" 메뉴를 클릭하고 드롭다운 메뉴에서 "세션 다시 시작"을 선택합니다.)

특징 추출기 정의하기

그런 다음 Timm의 ResNet-34 모델을 사용하여 이미지에서 임베딩을 추출하는 특징 추출기를 정의해야 합니다.

import torch
from PIL import Image
import timm
from sklearn.preprocessing import normalize
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform


class FeatureExtractor:
    def __init__(self, modelname):
        # Load the pre-trained model
        self.model = timm.create_model(
            modelname, pretrained=True, num_classes=0, global_pool="avg"
        )
        self.model.eval()

        # Get the input size required by the model
        self.input_size = self.model.default_cfg["input_size"]

        config = resolve_data_config({}, model=modelname)
        # Get the preprocessing function provided by TIMM for the model
        self.preprocess = create_transform(**config)

    def __call__(self, imagepath):
        # Preprocess the input image
        input_image = Image.open(imagepath).convert("RGB")  # Convert to RGB if needed
        input_image = self.preprocess(input_image)

        # Convert the image to a PyTorch tensor and add a batch dimension
        input_tensor = input_image.unsqueeze(0)

        # Perform inference
        with torch.no_grad():
            output = self.model(input_tensor)

        # Extract the feature vector
        feature_vector = output.squeeze().numpy()

        return normalize(feature_vector.reshape(1, -1), norm="l2").flatten()

Milvus 컬렉션 만들기

그런 다음 이미지 임베딩을 저장할 Milvus 컬렉션을 생성해야 합니다.

from pymilvus import MilvusClient

# Set up a Milvus client
client = MilvusClient(uri="example.db")
# Create a collection in quick setup mode
if client.has_collection(collection_name="image_embeddings"):
    client.drop_collection(collection_name="image_embeddings")
client.create_collection(
    collection_name="image_embeddings",
    vector_field_name="vector",
    dimension=512,
    auto_id=True,
    enable_dynamic_field=True,
    metric_type="COSINE",
)

MilvusClient 의 인수를 사용합니다:

  • uri 를 로컬 파일(예:./milvus.db)로 설정하는 것이 가장 편리한 방법인데, 이 파일에 모든 데이터를 저장하기 위해 Milvus Lite를 자동으로 활용하기 때문입니다.
  • 데이터 규모가 큰 경우, 도커나 쿠버네티스에 더 고성능의 Milvus 서버를 설정할 수 있습니다. 이 설정에서는 서버 URL(예:http://localhost:19530)을 uri 으로 사용하세요.
  • 밀버스의 완전 관리형 클라우드 서비스인 질리즈 클라우드를 사용하려면, 질리즈 클라우드의 퍼블릭 엔드포인트와 API 키에 해당하는 uritoken 을 조정하세요.

밀버스에 임베딩 삽입하기

ResNet34 모델을 사용하여 각 이미지의 임베딩을 추출하고 학습 세트의 이미지를 Milvus에 삽입합니다.

import os

extractor = FeatureExtractor("resnet34")

root = "./train"
insert = True
if insert is True:
    for dirpath, foldername, filenames in os.walk(root):
        for filename in filenames:
            if filename.endswith(".JPEG"):
                filepath = dirpath + "/" + filename
                image_embedding = extractor(filepath)
                client.insert(
                    "image_embeddings",
                    {"vector": image_embedding, "filename": filepath},
                )
from IPython.display import display

query_image = "./test/Afghan_hound/n02088094_4261.JPEG"

results = client.search(
    "image_embeddings",
    data=[extractor(query_image)],
    output_fields=["filename"],
    search_params={"metric_type": "COSINE"},
)
images = []
for result in results:
    for hit in result[:10]:
        filename = hit["entity"]["filename"]
        img = Image.open(filename)
        img = img.resize((150, 150))
        images.append(img)

width = 150 * 5
height = 150 * 2
concatenated_image = Image.new("RGB", (width, height))

for idx, img in enumerate(images):
    x = idx % 5
    y = idx // 5
    concatenated_image.paste(img, (x * 150, y * 150))
display("query")
display(Image.open(query_image).resize((150, 150)))
display("results")
display(concatenated_image)
'query'

png png

'results'

Results 결과

대부분의 이미지가 검색 이미지인 아프간 하운드와 같은 카테고리에 속하는 것을 볼 수 있습니다. 이는 검색 이미지와 유사한 이미지를 찾았음을 의미합니다.

빠른 배포

이 튜토리얼을 통해 온라인 데모를 시작하는 방법에 대해 알아보려면 예제 애플리케이션을 참조하세요.

번역DeepLogo

피드백

이 페이지가 도움이 되었나요?