milvus-logo
LFAI
フロントページへ
  • 統合

MilvusとCohereを用いた質問応答

このページでは、ベクトルデータベースとしてMilvusを、埋め込みシステムとしてCohereを使用して、SQuADデータセットに基づいた質問応答システムを作成する方法を説明します。

始める前に

このページのコードスニペットにはpymilvuscoherepandasnumpytqdmのインストールが必要です。これらのパッケージのうち、pymilvusはMilvusのクライアントです。システムに存在しない場合は、以下のコマンドを実行してインストールしてください:

pip install pymilvus cohere pandas numpy tqdm

次に、このガイドで使用するモジュールをロードする必要がある。

import cohere
import pandas
import numpy as np
from tqdm import tqdm
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

パラメータ

ここでは、以下のスニペットで使用されるパラメータを見つけることができます。いくつかはあなたの環境に合わせて変更する必要があります。それぞれの横には、それが何であるかの説明があります。

FILE = 'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json'  # The SQuAD dataset url
COLLECTION_NAME = 'question_answering_db'  # Collection name
DIMENSION = 1024  # Embeddings size, cohere embeddings default to 4096 with the large model
COUNT = 5000  # How many questions to embed and insert into Milvus
BATCH_SIZE = 96 # How large of batches to use for embedding and insertion
MILVUS_HOST = 'localhost'  # Milvus server URI
MILVUS_PORT = '19530'
COHERE_API_KEY = 'replace-this-with-the-cohere-api-key'  # API key obtained from Cohere

このページで使われているモデルやデータセットについての詳細は、co:hereと SQuADを参照してほしい。

データセットの準備

この例では、質問に答えるための真理源としてStanford Question Answering Dataset (SQuAD)を使用します。このデータセットはJSONファイルの形で提供されており、pandasを使って読み込みます。

# Download the dataset
dataset = pandas.read_json(FILE)

# Clean up the dataset by grabbing all the question answer pairs
simplified_records = []
for x in dataset['data']:
    for y in x['paragraphs']:
        for z in y['qas']:
            if len(z['answers']) != 0:
                simplified_records.append({'question': z['question'], 'answer': z['answers'][0]['text']})

# Grab the amount of records based on COUNT
simplified_records = pandas.DataFrame.from_records(simplified_records)
simplified_records = simplified_records.sample(n=min(COUNT, len(simplified_records)), random_state = 42)

# Check the length of the cleaned dataset matches count
print(len(simplified_records))

出力はデータセットのレコード数です。

5000

コレクションの作成

このセクションでは、Milvusとこのユースケースのためのデータベースのセットアップを扱う。Milvusの中で、コレクションをセットアップし、インデックスを作成する必要がある。

# Connect to Milvus Database
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)

# Remove collection if it already exists
if utility.has_collection(COLLECTION_NAME):
    utility.drop_collection(COLLECTION_NAME)

# Create collection which includes the id, title, and embedding.
fields = [
    FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name='original_question', dtype=DataType.VARCHAR, max_length=1000),
    FieldSchema(name='answer', dtype=DataType.VARCHAR, max_length=1000),
    FieldSchema(name='original_question_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
]
schema = CollectionSchema(fields=fields)
collection = Collection(name=COLLECTION_NAME, schema=schema)

# Create an IVF_FLAT index for collection.
index_params = {
    'metric_type':'IP',
    'index_type':"IVF_FLAT",
    'params':{"nlist": 1024}
}
collection.create_index(field_name="original_question_embedding", index_params=index_params)
collection.load()

データの挿入

コレクションをセットアップしたら、データの挿入を開始する必要がある。これは3つのステップで行われます。

  • データを読み込む、
  • オリジナルの質問を埋め込む
  • Milvusで作成したコレクションにデータを挿入します。

この例では、データには元の質問、元の質問の埋め込み、元の質問に対する回答が含まれます。

# Set up a co:here client.
cohere_client = cohere.Client(COHERE_API_KEY)

# Extract embeddings from questions using Cohere
def embed(texts, input_type):
    res = cohere_client.embed(texts, model='embed-multilingual-v3.0', input_type=input_type)
    return res.embeddings

# Insert each question, answer, and qustion embedding
total = pandas.DataFrame()
for batch in tqdm(np.array_split(simplified_records, (COUNT/BATCH_SIZE) + 1)):
    questions = batch['question'].tolist()
    embeddings = embed(questions, "search_document")
    
    data = [
        {
            'original_question': x,
            'answer': batch['answer'].tolist()[i],
            'original_question_embedding': embeddings[i]
        } for i, x in enumerate(questions)
    ]

    collection.insert(data=data)

time.sleep(10)

質問する

すべてのデータがMilvusコレクションに挿入されると、質問フレーズをCohereで埋め込み、コレクションで検索することで、システムに質問することができます。

挿入直後のデータに対して実行される検索は、インデックス付けされていないデータを総当たりで検索するため、少し遅くなるかもしれません。新しいデータが自動的にインデックス化されると、検索はスピードアップする。

# Search the cluster for an answer to a question text
def search(text, top_k = 5):

    # AUTOINDEX does not require any search params 
    search_params = {}

    results = collection.search(
        data = embed([text], "search_query"),  # Embeded the question
        anns_field='original_question_embedding',
        param=search_params,
        limit = top_k,  # Limit to top_k results per search
        output_fields=['original_question', 'answer']  # Include the original question and answer in the result
    )

    distances = results[0].distances
    entities = [ x.entity.to_dict()['entity'] for x in results[0] ]

    ret = [ {
        "answer": x[1]["answer"],
        "distance": x[0],
        "original_question": x[1]['original_question']
    } for x in zip(distances, entities)]

    return ret

# Ask these questions
search_questions = ['What kills bacteria?', 'What\'s the biggest dog?']

# Print out the results in order of [answer, similarity score, original question]

ret = [ { "question": x, "candidates": search(x) } for x in search_questions ]

出力は以下のようになるはずです:

# Output
#
# [
#     {
#         "question": "What kills bacteria?",
#         "candidates": [
#             {
#                 "answer": "farming",
#                 "distance": 0.6261022090911865,
#                 "original_question": "What makes bacteria resistant to antibiotic treatment?"
#             },
#             {
#                 "answer": "Phage therapy",
#                 "distance": 0.6093736886978149,
#                 "original_question": "What has been talked about to treat resistant bacteria?"
#             },
#             {
#                 "answer": "oral contraceptives",
#                 "distance": 0.5902313590049744,
#                 "original_question": "In therapy, what does the antibacterial interact with?"
#             },
#             {
#                 "answer": "slowing down the multiplication of bacteria or killing the bacteria",
#                 "distance": 0.5874154567718506,
#                 "original_question": "How do antibiotics work?"
#             },
#             {
#                 "answer": "in intensive farming to promote animal growth",
#                 "distance": 0.5667208433151245,
#                 "original_question": "Besides in treating human disease where else are antibiotics used?"
#             }
#         ]
#     },
#     {
#         "question": "What's the biggest dog?",
#         "candidates": [
#             {
#                 "answer": "English Mastiff",
#                 "distance": 0.7875324487686157,
#                 "original_question": "What breed was the largest dog known to have lived?"
#             },
#             {
#                 "answer": "forest elephants",
#                 "distance": 0.5886962413787842,
#                 "original_question": "What large animals reside in the national park?"
#             },
#             {
#                 "answer": "Rico",
#                 "distance": 0.5634892582893372,
#                 "original_question": "What is the name of the dog that could ID over 200 things?"
#             },
#             {
#                 "answer": "Iditarod Trail Sled Dog Race",
#                 "distance": 0.546872615814209,
#                 "original_question": "Which dog-sled race in Alaska is the most famous?"
#             },
#             {
#                 "answer": "part of the family",
#                 "distance": 0.5387814044952393,
#                 "original_question": "Most people today describe their dogs as what?"
#             }
#         ]
#     }
# ]

翻訳DeepLogo

フィードバック

このページは役に立ちましたか ?