🚀 완전 관리형 Milvus인 Zilliz Cloud를 무료로 체험해보세요—10배 더 빠른 성능을 경험하세요! 지금 체험하기>>

Milvus
Zilliz
홈페이지
  • 통합

PyTorch와 Milvus로 이미지 검색하기

이 가이드에서는 임베딩을 사용해 이미지 검색을 수행하기 위해 PyTorch와 Milvus를 통합하는 예제를 소개합니다. PyTorch는 머신 러닝 모델을 구축하고 배포하는 데 널리 사용되는 강력한 오픈 소스 딥 러닝 프레임워크입니다. 이 예제에서는 Torchvision 라이브러리와 사전 학습된 ResNet50 모델을 활용하여 이미지 콘텐츠를 나타내는 특징 벡터(임베딩)를 생성합니다. 이러한 임베딩은 고성능 벡터 데이터베이스인 Milvus에 저장되어 효율적인 유사도 검색을 가능하게 합니다. 사용된 데이터 세트는 Kaggle의 인상주의-분류자 데이터 세트입니다. 이 예제는 PyTorch의 딥 러닝 기능과 Milvus의 확장 가능한 검색 기능을 결합하여 강력하고 효율적인 이미지 검색 시스템을 구축하는 방법을 보여줍니다.

시작해 보겠습니다!

요구 사항 설치하기

이 예제에서는 Milvus 사용을 위한 연결은 pymilvus, 임베딩 모델 실행은 torch, 실제 모델 및 전처리는 torchvision, 예제 데이터 세트 다운로드는 gdown, 로딩 바는 tqdm 를 사용할 것입니다.

pip install pymilvus torch gdown torchvision tqdm

데이터 가져오기

gdown 을 사용하여 Google 드라이브에서 압축 파일을 가져온 다음 기본 제공 zipfile 라이브러리로 압축을 풀겠습니다.

import gdown
import zipfile

url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)

with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
    zip_ref.extractall("./paintings")

데이터 세트의 크기는 2.35GB이며 다운로드하는 데 걸리는 시간은 네트워크 상태에 따라 다릅니다.

글로벌 인수

다음은 추적 및 업데이트를 쉽게 하기 위해 사용할 주요 글로벌 인수의 일부입니다.

# Milvus Setup Arguments
COLLECTION_NAME = 'image_search'  # Collection name
DIMENSION = 2048  # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"

# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3

Milvus 설정하기

이제 Milvus 설정을 시작하겠습니다. 단계는 다음과 같습니다:

  1. 제공된 URI를 사용하여 Milvus 인스턴스에 연결합니다.

    from pymilvus import connections
    
    # Connect to the instance
    connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
    
  2. 컬렉션이 이미 존재하는 경우 삭제합니다.

    from pymilvus import utility
    
    # Remove any previous collections with the same name
    if utility.has_collection(COLLECTION_NAME):
        utility.drop_collection(COLLECTION_NAME)
    
  3. ID, 이미지의 파일 경로, 임베딩이 포함된 컬렉션을 생성합니다.

    from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
    
    # Create collection which includes the id, filepath of the image, and image embedding
    fields = [
        FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200),  # VARCHARS need a maximum length, so for this example they are set to 200 characters
        FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
    ]
    schema = CollectionSchema(fields=fields)
    collection = Collection(name=COLLECTION_NAME, schema=schema)
    
  4. 새로 생성된 컬렉션에 인덱스를 생성하고 메모리에 로드합니다.

    # Create an AutoIndex index for collection
    index_params = {
    'metric_type':'L2',
    'index_type':"IVF_FLAT",
    'params':{'nlist': 16384}
    }
    collection.create_index(field_name="image_embedding", index_params=index_params)
    collection.load()
    

이 단계가 완료되면 컬렉션을 삽입하고 검색할 준비가 된 것입니다. 추가된 모든 데이터는 자동으로 색인화되어 즉시 검색할 수 있습니다. 데이터가 매우 새 데이터인 경우, 아직 색인 작업이 진행 중인 데이터에 무차별 대입 검색이 사용되므로 검색 속도가 느려질 수 있습니다.

데이터 삽입하기

이 예에서는 torch 에서 제공하는 ResNet50 모델과 해당 모델 허브를 사용하겠습니다. 임베딩을 얻기 위해 최종 분류 계층을 제거하여 2048개의 차원을 임베딩하는 모델을 만들 것입니다. torch 에 있는 모든 비전 모델은 여기에 포함된 것과 동일한 전처리를 사용합니다.

다음 몇 단계는 다음과 같습니다:

  1. 데이터 로드하기.

    import glob
    
    # Get the filepaths of the images
    paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True)
    len(paths)
    
  2. 데이터를 일괄 처리로 전처리하기.

    import torch
    
    # Load the embedding model with the last layer removed
    model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
    model = torch.nn.Sequential(*(list(model.children())[:-1]))
    model.eval()
    
  3. 데이터 임베딩.

    from torchvision import transforms
    
    # Preprocessing for images
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
  4. 데이터 삽입.

    from PIL import Image
    from tqdm import tqdm
    
    # Embed function that embeds the batch and inserts it
    def embed(data):
        with torch.no_grad():
            output = model(torch.stack(data[0])).squeeze()
            collection.insert([data[1], output.tolist()])
    
    data_batch = [[],[]]
    
    # Read the images into batches for embedding and insertion
    for path in tqdm(paths):
        im = Image.open(path).convert('RGB')
        data_batch[0].append(preprocess(im))
        data_batch[1].append(path)
        if len(data_batch[0]) % BATCH_SIZE == 0:
            embed(data_batch)
            data_batch = [[],[]]
    
    # Embed and insert the remainder
    if len(data_batch[0]) != 0:
        embed(data_batch)
    
    # Call a flush to index any unsealed segments.
    collection.flush()
    
    • 이 단계는 임베딩에 시간이 걸리기 때문에 상대적으로 시간이 오래 걸립니다. 커피 한 모금 마시고 긴장을 푸세요.
    • Python 3.9 및 이전 버전에서는 PyTorch가 제대로 작동하지 않을 수 있습니다. 대신 Python 3.10 이상 버전을 사용하는 것이 좋습니다.

Milvus에 모든 데이터를 삽입했으면 검색을 시작할 수 있습니다. 이 예제에서는 두 개의 예제 이미지를 검색하겠습니다. 일괄 검색을 수행하므로 검색 시간은 일괄 이미지 전체에 걸쳐 공유됩니다.

import glob

# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
len(search_paths)
import time
from matplotlib import pyplot as plt

# Embed the search images
def embed(data):
    with torch.no_grad():
        ret = model(torch.stack(data))
        # If more than one image, use squeeze
        if len(ret) > 1:
            return ret.squeeze().tolist()
        # Squeeze would remove batch for single image, so using flatten
        else:
            return torch.flatten(ret, start_dim=1).tolist()

data_batch = [[],[]]

for path in search_paths:
    im = Image.open(path).convert('RGB')
    data_batch[0].append(preprocess(im))
    data_batch[1].append(path)

embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()
# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)

for hits_i, hits in enumerate(res):
    axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
    axarr[hits_i][0].set_axis_off()
    axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
    for hit_i, hit in enumerate(hits):
        axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
        axarr[hits_i][hit_i + 1].set_axis_off()
        axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))

# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')

검색 결과 이미지는 다음과 비슷해야 합니다:

Image search output 이미지 검색 출력

Try Managed Milvus for Free

Zilliz Cloud is hassle-free, powered by Milvus and 10x faster.

Get Started
피드백

이 페이지가 도움이 되었나요?