向量检索的Python代码示例
Elasticsearch提供了标准的REST接口,以及Java、Python等语言编写的客户端。
本节提供一份创建向量索引、导入向量数据和查询向量数据的Python代码示例,介绍如何使用客户端实现向量检索。
前提条件
客户端已经安装python依赖包。如果未安装可以执行如下命令安装:
# 根据集群实际版本填写,此处以7.6举例 pip install elasticsearch==7.6
代码示例
from elasticsearch import Elasticsearch
from elasticsearch import helpers
# 创建Elasticsearch客户端
def get_client(hosts: list, user: str = None, password: str = None):
if user and password:
return Elasticsearch(hosts, http_auth=(user, password), verify_certs=False, ssl_show_warn=False)
else:
return Elasticsearch(hosts)
# 创建索引表
def create(client: Elasticsearch, index: str):
# 索引mapping信息
index_mapping = {
"settings": {
"index": {
"vector": "true", # 开启向量特性
"number_of_shards": 1, # 索引分片数,根据实际需求设置
"number_of_replicas": 0 # 索引副本数,根据实际需求设置
}
},
"mappings": {
"properties": {
"my_vector": {
"type": "vector",
"dimension": 2,
"indexing": True,
"algorithm": "GRAPH",
"metric": "euclidean"
}
# 可根据需求添加其他字段
}
}
}
res = client.indices.create(index=index, body=index_mapping)
print("create index result: ", res)
# 写入数据
def write(client: Elasticsearch, index: str, vecs: list, bulk_size=500):
for i in range(0, len(vecs), bulk_size):
actions = [
{
"_index": index,
"my_vector": vec,
# 可根据需求添加其他字段
}
for vec in vecs[i: i+bulk_size]
]
success, errors = helpers.bulk(client, actions, request_timeout=3600)
if errors:
print("write bulk failed with errors: ", errors) # 根据需求进行错误处理
else:
print("write bulk {} docs success".format(success))
client.indices.refresh(index=index, request_timeout=3600)
# 查询向量索引
def search(client: Elasticsearch, index: str, query: list, size: int):
# 查询语句,可根据需求选择合适的查询方式
query_body = {
"size": size,
"query": {
"vector": {
"my_vector": {
"vector": query,
"topk": size
}
}
}
}
res = client.search(index=index, body=query_body)
print("search index result: ", res)
# 删除索引
def delete(client: Elasticsearch, index: str):
res = client.indices.delete(index=index)
print("delete index result: ", res)
if __name__ == '__main__':
# 对于非安全集群,使用:
es_client = get_client(hosts=['http://xx.xx.xx.xx:9200'])
# 对于开启了https的安全集群,使用:
# es_client = get_client(hosts=['https://xx.xx.xx.xx:9200', 'https://xx.xx.xx.xx:9200'], user='xxxxx', password='xxxxx')
# 对于未开启https的安全集群,使用:
# es_client = get_client(hosts=['http://xx.xx.xx.xx:9200', 'http://xx.xx.xx.xx:9200'], user='xxxxx', password='xxxxx')
# 测试索引名称
index_name = "my_index"
# 创建索引
create(es_client, index=index_name)
# 写入数据
data = [[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]]
write(es_client, index=index_name, vecs=data)
# 查询索引
query_vector = [1.0, 1.0]
search(es_client, index=index_name, query=query_vector, size=3)
# 删除索引
delete(es_client, index=index_name)