Milvus로 이미지 검색하기
이 페이지에서는 Milvus를 사용한 간단한 이미지 검색 예제를 살펴보겠습니다. 우리가 검색하는 데이터 세트는 Kaggle에 있는 인상주의-분류자 데이터 세트입니다. 이 예제에서는 공용 Google 드라이브에 데이터를 리호스팅했습니다.
이 예제에서는 임베딩을 위해 Torchvision에서 사전 학습된 Resnet50 모델을 사용하고 있습니다. 시작해 보겠습니다!
요구 사항 설치하기
이 예제에서는 Milvus 사용을 위한 연결은 pymilvus
, 임베딩 모델 실행은 torch
, 실제 모델 및 전처리는 torchvision
, 예제 데이터 세트 다운로드는 gdown
, 바 로딩은 tqdm
을 사용할 것입니다.
pip install pymilvus torch gdown torchvision tqdm
데이터 가져오기
gdown
을 사용하여 Google 드라이브에서 압축 파일을 가져온 다음 기본 제공 zipfile
라이브러리로 압축을 풀겠습니다.
import gdown
import zipfile
url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)
with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
zip_ref.extractall("./paintings")
데이터 세트의 크기는 2.35GB이며 다운로드하는 데 걸리는 시간은 네트워크 상태에 따라 다릅니다.
글로벌 인수
다음은 추적 및 업데이트를 쉽게 하기 위해 사용할 주요 글로벌 인수의 일부입니다.
# Milvus Setup Arguments
COLLECTION_NAME = 'image_search' # Collection name
DIMENSION = 2048 # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"
# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3
Milvus 설정하기
이제 Milvus 설정을 시작하겠습니다. 단계는 다음과 같습니다:
제공된 URI를 사용하여 Milvus 인스턴스에 연결합니다.
from pymilvus import connections # Connect to the instance connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
컬렉션이 이미 존재하는 경우 삭제합니다.
from pymilvus import utility # Remove any previous collections with the same name if utility.has_collection(COLLECTION_NAME): utility.drop_collection(COLLECTION_NAME)
ID, 이미지의 파일 경로, 임베딩이 포함된 컬렉션을 생성합니다.
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection # Create collection which includes the id, filepath of the image, and image embedding fields = [ FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200), # VARCHARS need a maximum length, so for this example they are set to 200 characters FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION) ] schema = CollectionSchema(fields=fields) collection = Collection(name=COLLECTION_NAME, schema=schema)
새로 생성된 컬렉션에 인덱스를 생성하고 메모리에 로드합니다.
# Create an AutoIndex index for collection index_params = { 'metric_type':'L2', 'index_type':"IVF_FLAT", 'params':{'nlist': 16384} } collection.create_index(field_name="image_embedding", index_params=index_params) collection.load()
이 단계가 완료되면 컬렉션을 삽입하고 검색할 준비가 된 것입니다. 추가된 모든 데이터는 자동으로 색인화되어 즉시 검색할 수 있습니다. 데이터가 매우 새 데이터인 경우, 아직 인덱싱이 진행 중인 데이터에 대해 무차별 대입 검색이 사용되므로 검색 속도가 느려질 수 있습니다.
데이터 삽입하기
이 예에서는 torch
에서 제공하는 ResNet50 모델과 해당 모델 허브를 사용하겠습니다. 임베딩을 얻기 위해 최종 분류 계층을 제거하여 2048개의 차원을 임베딩하는 모델을 만들 것입니다. torch
에 있는 모든 비전 모델은 여기에 포함된 것과 동일한 전처리를 사용합니다.
다음 몇 단계는 다음과 같습니다:
데이터 로드하기.
import glob # Get the filepaths of the images paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True) len(paths)
데이터를 일괄 처리로 전처리하기.
import torch # Load the embedding model with the last layer removed model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True) model = torch.nn.Sequential(*(list(model.children())[:-1])) model.eval()
데이터 임베딩.
from torchvision import transforms # Preprocessing for images preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])
데이터 삽입.
from PIL import Image from tqdm import tqdm # Embed function that embeds the batch and inserts it def embed(data): with torch.no_grad(): output = model(torch.stack(data[0])).squeeze() collection.insert([data[1], output.tolist()]) data_batch = [[],[]] # Read the images into batches for embedding and insertion for path in tqdm(paths): im = Image.open(path).convert('RGB') data_batch[0].append(preprocess(im)) data_batch[1].append(path) if len(data_batch[0]) % BATCH_SIZE == 0: embed(data_batch) data_batch = [[],[]] # Embed and insert the remainder if len(data_batch[0]) != 0: embed(data_batch) # Call a flush to index any unsealed segments. collection.flush()
- 이 단계는 임베딩에 시간이 걸리기 때문에 상대적으로 시간이 오래 걸립니다. 커피 한 모금 마시고 긴장을 푸세요.
- Python 3.9 및 이전 버전에서는 PyTorch가 제대로 작동하지 않을 수 있습니다. 대신 Python 3.10 이상 버전을 사용하는 것이 좋습니다.
검색 수행하기
Milvus에 모든 데이터를 삽입했으면 검색을 시작할 수 있습니다. 이 예제에서는 두 개의 예제 이미지를 검색하겠습니다. 일괄 검색을 수행하므로 검색 시간은 일괄 이미지 전체에 걸쳐 공유됩니다.
import glob
# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
len(search_paths)
import time
from matplotlib import pyplot as plt
# Embed the search images
def embed(data):
with torch.no_grad():
ret = model(torch.stack(data))
# If more than one image, use squeeze
if len(ret) > 1:
return ret.squeeze().tolist()
# Squeeze would remove batch for single image, so using flatten
else:
return torch.flatten(ret, start_dim=1).tolist()
data_batch = [[],[]]
for path in search_paths:
im = Image.open(path).convert('RGB')
data_batch[0].append(preprocess(im))
data_batch[1].append(path)
embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()
# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)
for hits_i, hits in enumerate(res):
axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
axarr[hits_i][0].set_axis_off()
axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
for hit_i, hit in enumerate(hits):
axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
axarr[hits_i][hit_i + 1].set_axis_off()
axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))
# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')
검색 결과 이미지는 다음과 비슷해야 합니다:
이미지 검색 출력