更新时间:2023-06-20 GMT+08:00
向量检索的客户端代码示例
Elasticsearch提供了标准的REST接口,以及Java、Python、Go等语言编写的客户端。
基于开源数据集SIFT1M(http://corpus-texmex.irisa.fr/)和Python Elasticsearch Client,本节提供一份创建向量索引、导入向量数据和查询向量数据的代码示例,介绍如何使用客户端实现向量检索。
前提条件
客户端已经安装python依赖包。如果未安装可以执行如下命令安装:
pip install numpy pip install elasticsearch==7.6.0
代码示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import numpy as np import time import json from concurrent.futures import ThreadPoolExecutor, wait from elasticsearch import Elasticsearch from elasticsearch import helpers endpoint = 'http://xxx.xxx.xxx.xxx:9200/' # 构建es客户端对象 es = Elasticsearch(endpoint) # 索引mapping信息 index_mapping = ''' { "settings": { "index": { "vector": "true" } }, "mappings": { "properties": { "my_vector": { "type": "vector", "dimension": 128, "indexing": true, "algorithm": "GRAPH", "metric": "euclidean" } } } } ''' # 创建索引 def create_index(index_name, mapping): res = es.indices.create(index=index_name, ignore=400, body=mapping) print(res) # 删除索引 def delete_index(index_name): res = es.indices.delete(index=index_name) print(res) # 刷新索引 def refresh_index(index_name): res = es.indices.refresh(index=index_name) print(res) # 索引段合并 def merge_index(index_name, seg_cnt=1): start = time.time() es.indices.forcemerge(index=index_name, max_num_segments=seg_cnt, request_timeout=36000) print(f"在{time.time() - start}秒内完成merge") # 加载向量数据 def load_vectors(file_name): fv = np.fromfile(file_name, dtype=np.float32) dim = fv.view(np.int32)[0] vectors = fv.reshape(-1, 1 + dim)[:, 1:] return vectors # 加载ground_truth数据 def load_gts(file_name): fv = np.fromfile(file_name, dtype=np.int32) dim = fv.view(np.int32)[0] gts = fv.reshape(-1, 1 + dim)[:, 1:] return gts def partition(ls, size): return [ls[i:i + size] for i in range(0, len(ls), size)] # 写入向量数据 def write_index(index_name, vec_file): pool = ThreadPoolExecutor(max_workers=8) tasks = [] vectors = load_vectors(vec_file) bulk_size = 1000 partitions = partition(vectors, bulk_size) start = time.time() start_id = 0 for vecs in partitions: tasks.append(pool.submit(write_bulk, index_name, vecs, start_id)) start_id += len(vecs) wait(tasks) print(f"在{time.time() - start}秒内完成写入") def write_bulk(index_name, vecs, start_id): actions = [ { "_index": index_name, "my_vector": vecs[j].tolist(), "_id": str(j + start_id) } for j in range(len(vecs)) ] helpers.bulk(es, actions, request_timeout=3600) # 查询索引 def search_index(index_name, query_file, gt_file, k): print("Start query! Index name: " + index_name) queries = load_vectors(query_file) gt = load_gts(gt_file) took = 0 precision = [] for idx, query in enumerate(queries): hits = set() query_json = { "size": k, "_source": False, "query": { "vector": { "my_vector": { "vector": query.tolist(), "topk": k } } } } res = es.search(index=index_name, body=json.dumps(query_json)) for hit in res['hits']['hits']: hits.add(int(hit['_id'])) precision.append(len(hits.intersection(set(gt[idx, :k]))) / k) took += res['took'] print("precision: " + str(sum(precision) / len(precision))) print(f"在{took / 1000:.2f}秒内完成检索,平均took大小为{took / len(queries):.2f}毫秒") if __name__ == "__main__": vec_file = r"./data/sift/sift_base.fvecs" qry_file = r"./data/sift/sift_query.fvecs" gt_file = r"./data/sift/sift_groundtruth.ivecs" index = "test" create_index(index, index_mapping) write_index(index, vec_file) merge_index(index) refresh_index(index) search_index(index, qry_file, gt_file, 10) |
父主题: 向量检索