milvus-logo
LFAI
Home
  • Integraciones

Búsqueda de películas usando Milvus y SentenceTransformers

En este ejemplo, vamos a realizar una búsqueda de artículos en Wikipedia utilizando Milvus y la biblioteca SentenceTransformers. El conjunto de datos que estamos buscando es el Wikipedia-Movie-Plots Dataset que se encuentra en Kaggle. Para este ejemplo, hemos vuelto a alojar los datos en una unidad de Google pública.

Vamos a empezar.

Requisitos de instalación

Para este ejemplo, vamos a utilizar pymilvus para conectarnos y utilizar Milvus, sentencetransformers para generar incrustaciones vectoriales y gdown para descargar el conjunto de datos de ejemplo.

pip install pymilvus sentence-transformers gdown

Obtención de los datos

Vamos a utilizar gdown para obtener el zip de Google Drive y luego descomprimirlo con la biblioteca incorporada zipfile.

import gdown
url = 'https://drive.google.com/uc?id=11ISS45aO2ubNCGaC3Lvd3D7NT8Y7MeO8'
output = './movies.zip'
gdown.download(url, output)

import zipfile

with zipfile.ZipFile("./movies.zip","r") as zip_ref:
    zip_ref.extractall("./movies")

Parámetros globales

Aquí podemos encontrar los principales argumentos que necesitan ser modificados para ejecutar con sus propias cuentas. Al lado de cada uno hay una descripción de lo que es.

# Milvus Setup Arguments
COLLECTION_NAME = 'movies_db'  # Collection name
DIMENSION = 384  # Embeddings size
COUNT = 1000  # Number of vectors to insert
MILVUS_HOST = 'localhost'
MILVUS_PORT = '19530'

# Inference Arguments
BATCH_SIZE = 128

# Search Arguments
TOP_K = 3

Configuración de Milvus

Llegados a este punto, vamos a empezar a configurar Milvus. Los pasos son los siguientes:

  1. Conéctese a la instancia de Milvus utilizando el URI proporcionado.

    from pymilvus import connections
    
    # Connect to Milvus Database
    connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
    
  2. Si la colección ya existe, elimínela.

    from pymilvus import utility
    
    # Remove any previous collections with the same name
    if utility.has_collection(COLLECTION_NAME):
        utility.drop_collection(COLLECTION_NAME)
    
  3. Cree la colección que contiene el id, el título de la película y las incrustaciones del texto de la trama.

    from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
    
    
    # Create collection which includes the id, title, and embedding.
    fields = [
        FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name='title', dtype=DataType.VARCHAR, max_length=200),  # VARCHARS need a maximum length, so for this example they are set to 200 characters
        FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
    ]
    schema = CollectionSchema(fields=fields)
    collection = Collection(name=COLLECTION_NAME, schema=schema)
    
  4. Cree un índice en la colección recién creada y cárguela en memoria.

    # Create an IVF_FLAT index for collection.
    index_params = {
        'metric_type':'L2',
        'index_type':"IVF_FLAT",
        'params':{'nlist': 1536}
    }
    collection.create_index(field_name="embedding", index_params=index_params)
    collection.load()
    

Una vez realizados estos pasos, la colección estará lista para ser insertada y buscada. Cualquier dato que se añada se indexará automáticamente y estará disponible para realizar búsquedas inmediatamente. Si los datos son muy recientes, la búsqueda puede ser más lenta, ya que se utilizará la búsqueda de fuerza bruta en los datos que aún están en proceso de indexación.

Insertar los datos

Para este ejemplo, vamos a utilizar el modelo miniLM de SentenceTransformers para crear incrustaciones del texto del gráfico. Este modelo devuelve incrustaciones de 384-dim.

En los siguientes pasos vamos a:

  1. Cargar los datos.
  2. Incrustar los datos del texto del gráfico utilizando SentenceTransformers.
  3. Insertar los datos en Milvus.
import csv
from sentence_transformers import SentenceTransformer

transformer = SentenceTransformer('all-MiniLM-L6-v2')

# Extract the book titles
def csv_load(file):
    with open(file, newline='') as f:
        reader = csv.reader(f, delimiter=',')
        for row in reader:
            if '' in (row[1], row[7]):
                continue
            yield (row[1], row[7])


# Extract embedding from text using OpenAI
def embed_insert(data):
    embeds = transformer.encode(data[1]) 
    ins = [
            data[0],
            [x for x in embeds]
    ]
    collection.insert(ins)

import time

data_batch = [[],[]]

count = 0

for title, plot in csv_load('./movies/plots.csv'):
    if count <= COUNT:
        data_batch[0].append(title)
        data_batch[1].append(plot)
        if len(data_batch[0]) % BATCH_SIZE == 0:
            embed_insert(data_batch)
            data_batch = [[],[]]
        count += 1
    else:
        break

# Embed and insert the remainder
if len(data_batch[0]) != 0:
    embed_insert(data_batch)

# Call a flush to index any unsealed segments.
collection.flush()

La operación anterior es relativamente lenta porque la incrustación lleva tiempo. Para mantener el tiempo consumido a un nivel aceptable, intente ajustar COUNT en los parámetros globales a un valor apropiado. Tómese un descanso y disfrute de una taza de café.

Con todos los datos insertados en Milvus, podemos empezar a realizar nuestras búsquedas. En este ejemplo, vamos a buscar películas basándonos en la trama. Como estamos realizando una búsqueda por lotes, el tiempo de búsqueda se reparte entre todas las búsquedas de películas.

# Search for titles that closest match these phrases.
search_terms = ['A movie about cars', 'A movie about monsters']

# Search the database based on input text
def embed_search(data):
    embeds = transformer.encode(data) 
    return [x for x in embeds]

search_data = embed_search(search_terms)

start = time.time()
res = collection.search(
    data=search_data,  # Embeded search value
    anns_field="embedding",  # Search across embeddings
    param={},
    limit = TOP_K,  # Limit to top_k results per search
    output_fields=['title']  # Include title field in result
)
end = time.time()

for hits_i, hits in enumerate(res):
    print('Title:', search_terms[hits_i])
    print('Search Time:', end-start)
    print('Results:')
    for hit in hits:
        print( hit.entity.get('title'), '----', hit.distance)
    print()

El resultado debería ser similar al siguiente:

Title: A movie about cars
Search Time: 0.08636689186096191
Results:
Youth's Endearing Charm ---- 1.0954499244689941
From Leadville to Aspen: A Hold-Up in the Rockies ---- 1.1019384860992432
Gentlemen of Nerve ---- 1.1331942081451416

Title: A movie about monsters
Search Time: 0.08636689186096191
Results:
The Suburbanite ---- 1.0666425228118896
Youth's Endearing Charm ---- 1.1072258949279785
The Godless Girl ---- 1.1511223316192627