Поиск изображений с помощью PyTorch и Milvus
В этом руководстве представлен пример интеграции PyTorch и Milvus для выполнения поиска по изображениям с использованием вкраплений. PyTorch - это мощный фреймворк глубокого обучения с открытым исходным кодом, широко используемый для построения и развертывания моделей машинного обучения. В этом примере мы используем библиотеку Torchvision и предварительно обученную модель ResNet50 для создания векторов признаков (вкраплений), которые представляют содержимое изображений. Эти вкрапления будут храниться в Milvus, высокопроизводительной базе данных векторов, чтобы обеспечить эффективный поиск по сходству. В качестве базы данных используется набор данных Impressionist-Classifier Dataset от Kaggle. Сочетая возможности глубокого обучения PyTorch с масштабируемой поисковой функциональностью Milvus, этот пример демонстрирует, как создать надежную и эффективную систему поиска изображений.
Давайте приступим!
Установка требований
В этом примере мы будем использовать pymilvus
для подключения к Milvus, torch
для запуска модели встраивания, torchvision
для собственно модели и препроцессинга, gdown
для загрузки набора данных примера и tqdm
для загрузки баров.
pip install pymilvus torch gdown torchvision tqdm
Захват данных
Мы используем gdown
для захвата zip-архива с Google Drive и последующей его распаковки с помощью встроенной библиотеки zipfile
.
import gdown
import zipfile
url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)
with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
zip_ref.extractall("./paintings")
Размер набора данных составляет 2,35 ГБ, а время, затраченное на его загрузку, зависит от состояния вашей сети.
Глобальные аргументы
Вот некоторые из основных глобальных аргументов, которые мы будем использовать для более удобного отслеживания и обновления.
# Milvus Setup Arguments
COLLECTION_NAME = 'image_search' # Collection name
DIMENSION = 2048 # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"
# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3
Настройка Milvus
На данном этапе мы приступим к настройке Milvus. Для этого необходимо выполнить следующие шаги:
Подключитесь к экземпляру Milvus, используя предоставленный URI.
from pymilvus import connections # Connect to the instance connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
Если коллекция уже существует, удалите ее.
from pymilvus import utility # Remove any previous collections with the same name if utility.has_collection(COLLECTION_NAME): utility.drop_collection(COLLECTION_NAME)
Создайте коллекцию, содержащую идентификатор, путь к файлу изображения и его вставку.
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection # Create collection which includes the id, filepath of the image, and image embedding fields = [ FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200), # VARCHARS need a maximum length, so for this example they are set to 200 characters FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION) ] schema = CollectionSchema(fields=fields) collection = Collection(name=COLLECTION_NAME, schema=schema)
Создайте индекс на только что созданной коллекции и загрузите ее в память.
# Create an AutoIndex index for collection index_params = { 'metric_type':'L2', 'index_type':"IVF_FLAT", 'params':{'nlist': 16384} } collection.create_index(field_name="image_embedding", index_params=index_params) collection.load()
После выполнения этих действий коллекция будет готова к вставке и поиску. Все добавленные данные будут автоматически проиндексированы и сразу же станут доступны для поиска. Если данные очень свежие, поиск может быть медленнее, так как для данных, которые еще находятся в процессе индексирования, будет использоваться грубая сила.
Вставка данных
Для этого примера мы будем использовать модель ResNet50, предоставленную сайтом torch
и его центром моделей. Чтобы получить вкрапления, мы убираем последний слой классификации, в результате чего модель дает нам вкрапления 2048 измерений. Все модели зрения, найденные на сайте torch
, используют ту же самую предварительную обработку, которую мы включили сюда.
На следующих шагах мы выполним следующие действия:
Загрузим данные.
import glob # Get the filepaths of the images paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True) len(paths)
Предварительная обработка данных.
import torch # Load the embedding model with the last layer removed model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True) model = torch.nn.Sequential(*(list(model.children())[:-1])) model.eval()
Встраивание данных.
from torchvision import transforms # Preprocessing for images preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])
Вставка данных.
from PIL import Image from tqdm import tqdm # Embed function that embeds the batch and inserts it def embed(data): with torch.no_grad(): output = model(torch.stack(data[0])).squeeze() collection.insert([data[1], output.tolist()]) data_batch = [[],[]] # Read the images into batches for embedding and insertion for path in tqdm(paths): im = Image.open(path).convert('RGB') data_batch[0].append(preprocess(im)) data_batch[1].append(path) if len(data_batch[0]) % BATCH_SIZE == 0: embed(data_batch) data_batch = [[],[]] # Embed and insert the remainder if len(data_batch[0]) != 0: embed(data_batch) # Call a flush to index any unsealed segments. collection.flush()
- Этот шаг занимает относительно много времени, потому что вставка требует времени. Сделайте глоток кофе и расслабьтесь.
- PyTorch может плохо работать с Python 3.9 и более ранними версиями. Вместо него лучше использовать Python 3.10 и более поздние версии.
Выполнение поиска
Когда все данные вставлены в Milvus, мы можем приступить к выполнению поиска. В этом примере мы будем искать два примера изображений. Поскольку мы выполняем пакетный поиск, время поиска распределяется между всеми изображениями пакета.
import glob
# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
len(search_paths)
import time
from matplotlib import pyplot as plt
# Embed the search images
def embed(data):
with torch.no_grad():
ret = model(torch.stack(data))
# If more than one image, use squeeze
if len(ret) > 1:
return ret.squeeze().tolist()
# Squeeze would remove batch for single image, so using flatten
else:
return torch.flatten(ret, start_dim=1).tolist()
data_batch = [[],[]]
for path in search_paths:
im = Image.open(path).convert('RGB')
data_batch[0].append(preprocess(im))
data_batch[1].append(path)
embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()
# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)
for hits_i, hits in enumerate(res):
axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
axarr[hits_i][0].set_axis_off()
axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
for hit_i, hit in enumerate(hits):
axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
axarr[hits_i][hit_i + 1].set_axis_off()
axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))
# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')
Результат поиска должен быть похож на следующий:
Результат поиска изображений