🚀 Попробуйте Zilliz Cloud, полностью управляемый Milvus, бесплатно — ощутите 10-кратное увеличение производительности! Попробовать сейчас>

milvus-logo
LFAI
Главная
  • Интеграции

Поиск изображений с помощью PyTorch и Milvus

В этом руководстве представлен пример интеграции PyTorch и Milvus для выполнения поиска по изображениям с использованием вкраплений. PyTorch - это мощный фреймворк глубокого обучения с открытым исходным кодом, широко используемый для построения и развертывания моделей машинного обучения. В этом примере мы используем библиотеку Torchvision и предварительно обученную модель ResNet50 для создания векторов признаков (вкраплений), которые представляют содержимое изображений. Эти вкрапления будут храниться в Milvus, высокопроизводительной базе данных векторов, чтобы обеспечить эффективный поиск по сходству. В качестве базы данных используется набор данных Impressionist-Classifier Dataset от Kaggle. Сочетая возможности глубокого обучения PyTorch с масштабируемой поисковой функциональностью Milvus, этот пример демонстрирует, как создать надежную и эффективную систему поиска изображений.

Давайте приступим!

Установка требований

В этом примере мы будем использовать pymilvus для подключения к Milvus, torch для запуска модели встраивания, torchvision для собственно модели и препроцессинга, gdown для загрузки набора данных примера и tqdm для загрузки баров.

pip install pymilvus torch gdown torchvision tqdm

Захват данных

Мы используем gdown для захвата zip-архива с Google Drive и последующей его распаковки с помощью встроенной библиотеки zipfile.

import gdown
import zipfile

url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)

with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
    zip_ref.extractall("./paintings")

Размер набора данных составляет 2,35 ГБ, а время, затраченное на его загрузку, зависит от состояния вашей сети.

Глобальные аргументы

Вот некоторые из основных глобальных аргументов, которые мы будем использовать для более удобного отслеживания и обновления.

# Milvus Setup Arguments
COLLECTION_NAME = 'image_search'  # Collection name
DIMENSION = 2048  # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"

# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3

Настройка Milvus

На данном этапе мы приступим к настройке Milvus. Для этого необходимо выполнить следующие шаги:

  1. Подключитесь к экземпляру Milvus, используя предоставленный URI.

    from pymilvus import connections
    
    # Connect to the instance
    connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
    
  2. Если коллекция уже существует, удалите ее.

    from pymilvus import utility
    
    # Remove any previous collections with the same name
    if utility.has_collection(COLLECTION_NAME):
        utility.drop_collection(COLLECTION_NAME)
    
  3. Создайте коллекцию, содержащую идентификатор, путь к файлу изображения и его вставку.

    from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
    
    # Create collection which includes the id, filepath of the image, and image embedding
    fields = [
        FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200),  # VARCHARS need a maximum length, so for this example they are set to 200 characters
        FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
    ]
    schema = CollectionSchema(fields=fields)
    collection = Collection(name=COLLECTION_NAME, schema=schema)
    
  4. Создайте индекс на только что созданной коллекции и загрузите ее в память.

    # Create an AutoIndex index for collection
    index_params = {
    'metric_type':'L2',
    'index_type':"IVF_FLAT",
    'params':{'nlist': 16384}
    }
    collection.create_index(field_name="image_embedding", index_params=index_params)
    collection.load()
    

После выполнения этих действий коллекция будет готова к вставке и поиску. Все добавленные данные будут автоматически проиндексированы и сразу же станут доступны для поиска. Если данные очень свежие, поиск может быть медленнее, так как для данных, которые еще находятся в процессе индексирования, будет использоваться грубая сила.

Вставка данных

Для этого примера мы будем использовать модель ResNet50, предоставленную сайтом torch и его центром моделей. Чтобы получить вкрапления, мы убираем последний слой классификации, в результате чего модель дает нам вкрапления 2048 измерений. Все модели зрения, найденные на сайте torch, используют ту же самую предварительную обработку, которую мы включили сюда.

На следующих шагах мы выполним следующие действия:

  1. Загрузим данные.

    import glob
    
    # Get the filepaths of the images
    paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True)
    len(paths)
    
  2. Предварительная обработка данных.

    import torch
    
    # Load the embedding model with the last layer removed
    model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
    model = torch.nn.Sequential(*(list(model.children())[:-1]))
    model.eval()
    
  3. Встраивание данных.

    from torchvision import transforms
    
    # Preprocessing for images
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
  4. Вставка данных.

    from PIL import Image
    from tqdm import tqdm
    
    # Embed function that embeds the batch and inserts it
    def embed(data):
        with torch.no_grad():
            output = model(torch.stack(data[0])).squeeze()
            collection.insert([data[1], output.tolist()])
    
    data_batch = [[],[]]
    
    # Read the images into batches for embedding and insertion
    for path in tqdm(paths):
        im = Image.open(path).convert('RGB')
        data_batch[0].append(preprocess(im))
        data_batch[1].append(path)
        if len(data_batch[0]) % BATCH_SIZE == 0:
            embed(data_batch)
            data_batch = [[],[]]
    
    # Embed and insert the remainder
    if len(data_batch[0]) != 0:
        embed(data_batch)
    
    # Call a flush to index any unsealed segments.
    collection.flush()
    
    • Этот шаг занимает относительно много времени, потому что вставка требует времени. Сделайте глоток кофе и расслабьтесь.
    • PyTorch может плохо работать с Python 3.9 и более ранними версиями. Вместо него лучше использовать Python 3.10 и более поздние версии.

Когда все данные вставлены в Milvus, мы можем приступить к выполнению поиска. В этом примере мы будем искать два примера изображений. Поскольку мы выполняем пакетный поиск, время поиска распределяется между всеми изображениями пакета.

import glob

# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
len(search_paths)
import time
from matplotlib import pyplot as plt

# Embed the search images
def embed(data):
    with torch.no_grad():
        ret = model(torch.stack(data))
        # If more than one image, use squeeze
        if len(ret) > 1:
            return ret.squeeze().tolist()
        # Squeeze would remove batch for single image, so using flatten
        else:
            return torch.flatten(ret, start_dim=1).tolist()

data_batch = [[],[]]

for path in search_paths:
    im = Image.open(path).convert('RGB')
    data_batch[0].append(preprocess(im))
    data_batch[1].append(path)

embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()
# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)

for hits_i, hits in enumerate(res):
    axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
    axarr[hits_i][0].set_axis_off()
    axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
    for hit_i, hit in enumerate(hits):
        axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
        axarr[hits_i][hit_i + 1].set_axis_off()
        axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))

# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')

Результат поиска должен быть похож на следующий:

Image search output Результат поиска изображений

Попробуйте Managed Milvus бесплатно

Zilliz Cloud работает без проблем, поддерживается Milvus и в 10 раз быстрее.

Начать
Обратная связь

Была ли эта страница полезной?