更新时间:2024-09-14 GMT+08:00

向量检索的客户端代码示例(Python)

Elasticsearch提供了标准的REST接口,以及Java、Python等语言编写的客户端。

本节提供一份创建向量索引、导入向量数据和查询向量数据的Python代码示例,介绍如何使用客户端实现向量检索。

前提条件

客户端已经安装python依赖包。如果未安装可以执行如下命令安装:

# 根据集群实际版本填写,此处以7.6举例
pip install elasticsearch==7.6

代码示例

from elasticsearch import Elasticsearch
from elasticsearch import helpers

# 创建Elasticsearch客户端
def get_client(hosts: list, user: str = None, password: str = None):
    if user and password:
        return Elasticsearch(hosts, http_auth=(user, password), verify_certs=False, ssl_show_warn=False)
    else:
        return Elasticsearch(hosts)

# 创建索引表
def create(client: Elasticsearch, index: str):
    # 索引mapping信息
    index_mapping = {
        "settings": {
            "index": {
                "vector": "true",  # 开启向量特性
                "number_of_shards": 1,  # 索引分片数,根据实际需求设置
                "number_of_replicas": 0,  # 索引副本数,根据实际需求设置
            }
        },
        "mappings": {
            "properties": {
                "my_vector": {
                    "type": "vector",
                    "dimension": 2,
                    "indexing": True,
                    "algorithm": "GRAPH",
                    "metric": "euclidean"
                }
                # 可根据需求添加其他字段
            }
        }
    }
    res = client.indices.create(index=index, body=index_mapping)
    print("create index result: ", res)

# 写入数据
def write(client: Elasticsearch, index: str, vecs: list, bulk_size=500):
    for i in range(0, len(vecs), bulk_size):
        actions = [
            {
                "_index": index,
                "my_vector": vec,
                # 可根据需求添加其他字段
            }
            for vec in vecs[i: i+bulk_size]
        ]
        success, errors = helpers.bulk(client, actions, request_timeout=3600)
        if errors:
            print("write bulk failed with errors: ", errors)  # 根据需求进行错误处理
        else:
            print("write bulk {} docs success".format(success))
    client.indices.refresh(index=index, request_timeout=3600)

# 查询向量索引
def search(client: Elasticsearch, index: str, query: list, size: int):
    # 查询语句,可根据需求选择合适的查询方式
    query_body = {
        "size": size,
        "query": {
            "vector": {
                "my_vector": {
                    "vector": query,
                    "topk": size
                }
            }
        }
    }
    res = client.search(index=index, body=query_body)
    print("search index result: ", res)

# 删除索引
def delete(client: Elasticsearch, index: str):
    res = client.indices.delete(index=index)
    print("delete index result: ", res)

if __name__ == '__main__':
    # 对于非安全集群,使用:
    es_client = get_client(hosts=['http://x.x.x.x:9200'])

    # 对于开启了https的安全集群,使用:
    # es_client = get_client(hosts=['https://x.x.x.x:9200', 'https://x.x.x.x:9200'], user='xxxxx', password='xxxxx')

    # 对于未开启https的安全集群,使用:
    # es_client = get_client(hosts=['http://x.x.x.x:9200', 'http://x.x.x.x:9200'], user='xxxxx', password='xxxxx')

    # 测试索引名称
    index_name = "my_index"

    # 创建索引
    create(es_client, index=index_name)

    # 写入数据
    data = [[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]]
    write(es_client, index=index_name, vecs=data)

    # 查询索引
    query_vector = [1.0, 1.0]
    search(es_client, index=index_name, query=query_vector, size=3)

    # 删除索引
    delete(es_client, index=index_name)