Поиск изображений с помощью PyTorch и Milvus
В этом руководстве представлен пример интеграции PyTorch и Milvus для выполнения поиска по изображениям с использованием вкраплений. PyTorch - это мощный фреймворк глубокого обучения с открытым исходным кодом, широко используемый для построения и развертывания моделей машинного обучения. В этом примере мы используем библиотеку Torchvision и предварительно обученную модель ResNet50 для создания векторов признаков (вкраплений), которые представляют содержимое изображений. Эти вкрапления будут храниться в Milvus, высокопроизводительной базе данных векторов, чтобы обеспечить эффективный поиск по сходству. В качестве базы данных используется набор данных Impressionist-Classifier Dataset от Kaggle. Сочетая возможности глубокого обучения PyTorch с масштабируемой поисковой функциональностью Milvus, этот пример демонстрирует, как создать надежную и эффективную систему поиска изображений.
Давайте приступим!
Установка требований
В этом примере мы будем использовать pymilvus для подключения к Milvus, torch для запуска модели встраивания, torchvision для собственно модели и препроцессинга, gdown для загрузки набора данных примера и tqdm для загрузки баров.
pip install pymilvus torch gdown torchvision tqdm
Захват данных
Мы используем gdown для захвата zip-архива с Google Drive и последующей его распаковки с помощью встроенной библиотеки zipfile.
import gdown
import zipfile
url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)
with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
zip_ref.extractall("./paintings")
Размер набора данных составляет 2,35 ГБ, а время, затраченное на его загрузку, зависит от состояния вашей сети.
Глобальные аргументы
Вот некоторые из основных глобальных аргументов, которые мы будем использовать для более удобного отслеживания и обновления.
# Milvus Setup Arguments
COLLECTION_NAME = 'image_search' # Collection name
DIMENSION = 2048 # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"
# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3
Настройка Milvus
На данном этапе мы приступим к настройке Milvus. Для этого необходимо выполнить следующие шаги:
Подключитесь к экземпляру Milvus, используя предоставленный URI.
from pymilvus import connections # Connect to the instance connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)Если коллекция уже существует, удалите ее.
from pymilvus import utility # Remove any previous collections with the same name if utility.has_collection(COLLECTION_NAME): utility.drop_collection(COLLECTION_NAME)Создайте коллекцию, содержащую идентификатор, путь к файлу изображения и его вставку.
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection # Create collection which includes the id, filepath of the image, and image embedding fields = [ FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200), # VARCHARS need a maximum length, so for this example they are set to 200 characters FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION) ] schema = CollectionSchema(fields=fields) collection = Collection(name=COLLECTION_NAME, schema=schema)Создайте индекс на только что созданной коллекции и загрузите ее в память.
# Create an AutoIndex index for collection index_params = { 'metric_type':'L2', 'index_type':"IVF_FLAT", 'params':{'nlist': 16384} } collection.create_index(field_name="image_embedding", index_params=index_params) collection.load()
После выполнения этих действий коллекция будет готова к вставке и поиску. Все добавленные данные будут автоматически проиндексированы и сразу же станут доступны для поиска. Если данные очень свежие, поиск может быть медленнее, так как для данных, которые еще находятся в процессе индексирования, будет использоваться грубая сила.
Вставка данных
Для этого примера мы будем использовать модель ResNet50, предоставленную сайтом torch и его центром моделей. Чтобы получить вкрапления, мы убираем последний слой классификации, в результате чего модель дает нам вкрапления 2048 измерений. Все модели зрения, найденные на сайте torch, используют ту же самую предварительную обработку, которую мы включили сюда.
На следующих шагах мы выполним следующие действия:
Загрузим данные.
import glob # Get the filepaths of the images paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True) len(paths)Предварительная обработка данных.
import torch # Load the embedding model with the last layer removed model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True) model = torch.nn.Sequential(*(list(model.children())[:-1])) model.eval()Встраивание данных.
from torchvision import transforms # Preprocessing for images preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])Вставка данных.
from PIL import Image from tqdm import tqdm # Embed function that embeds the batch and inserts it def embed(data): with torch.no_grad(): output = model(torch.stack(data[0])).squeeze() collection.insert([data[1], output.tolist()]) data_batch = [[],[]] # Read the images into batches for embedding and insertion for path in tqdm(paths): im = Image.open(path).convert('RGB') data_batch[0].append(preprocess(im)) data_batch[1].append(path) if len(data_batch[0]) % BATCH_SIZE == 0: embed(data_batch) data_batch = [[],[]] # Embed and insert the remainder if len(data_batch[0]) != 0: embed(data_batch) # Call a flush to index any unsealed segments. collection.flush()- Этот шаг занимает относительно много времени, потому что вставка требует времени. Сделайте глоток кофе и расслабьтесь.
- PyTorch может плохо работать с Python 3.9 и более ранними версиями. Вместо него лучше использовать Python 3.10 и более поздние версии.
Выполнение поиска
Когда все данные вставлены в Milvus, мы можем приступить к выполнению поиска. В этом примере мы будем искать два примера изображений. Поскольку мы выполняем пакетный поиск, время поиска распределяется между всеми изображениями пакета.
import glob
# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
len(search_paths)
import time
from matplotlib import pyplot as plt
# Embed the search images
def embed(data):
with torch.no_grad():
ret = model(torch.stack(data))
# If more than one image, use squeeze
if len(ret) > 1:
return ret.squeeze().tolist()
# Squeeze would remove batch for single image, so using flatten
else:
return torch.flatten(ret, start_dim=1).tolist()
data_batch = [[],[]]
for path in search_paths:
im = Image.open(path).convert('RGB')
data_batch[0].append(preprocess(im))
data_batch[1].append(path)
embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()
# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)
for hits_i, hits in enumerate(res):
axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
axarr[hits_i][0].set_axis_off()
axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
for hit_i, hit in enumerate(hits):
axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
axarr[hits_i][hit_i + 1].set_axis_off()
axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))
# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')
Результат поиска должен быть похож на следующий:
Результат поиска изображений