更新时间:2024-10-12 GMT+08:00
向量检索的客户端代码示例(Java)
OpenSearch提供了标准的REST接口,以及Java、Python等语言编写的客户端。
本节提供一份创建向量索引、导入向量数据和查询向量数据的Java代码示例,介绍如何使用客户端实现向量检索。
前提条件
根据集群实际版本添加如下Maven依赖,此处以OpenSearch 1.3.6版本举例。
<dependency> <groupId>org.opensearch.client</groupId> <artifactId>opensearch-rest-high-level-client</artifactId> <version>1.3.6</version> </dependency>
代码示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
package org.example; import org.apache.http.HttpEntity; import org.apache.http.HttpHost; import org.apache.http.HttpStatus; import org.apache.http.auth.AuthScope; import org.apache.http.auth.UsernamePasswordCredentials; import org.apache.http.client.CredentialsProvider; import org.apache.http.conn.ssl.NoopHostnameVerifier; import org.apache.http.impl.client.BasicCredentialsProvider; import org.apache.http.nio.conn.ssl.SSLIOSessionStrategy; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.admin.indices.refresh.RefreshRequest; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.master.AcknowledgedResponse; import org.opensearch.client.Request; import org.opensearch.client.RequestOptions; import org.opensearch.client.Response; import org.opensearch.client.RestClient; import org.opensearch.client.RestClientBuilder; import org.opensearch.client.RestHighLevelClient; import org.opensearch.client.indices.CreateIndexRequest; import org.opensearch.client.indices.CreateIndexResponse; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.DeprecationHandler; import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.common.xcontent.XContentParser; import org.opensearch.common.xcontent.XContentType; import javax.net.ssl.SSLContext; import javax.net.ssl.TrustManager; import javax.net.ssl.X509TrustManager; import java.io.IOException; import java.security.KeyManagementException; import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; import java.security.cert.X509Certificate; import java.util.Arrays; import java.util.List; public class ClientExampleOS { private final RestHighLevelClient client; public ClientExampleOS(RestHighLevelClient client) { this.client = client; } // 创建非安全集群访问客户端 public static RestHighLevelClient getClient(List<String> hosts, int port, String scheme) { HttpHost[] httpHosts = hosts.stream().map(p -> new HttpHost(p, port, scheme)).toArray(HttpHost[]::new); return new RestHighLevelClient(RestClient.builder(httpHosts)); } // 创建安全集群访问客户端 public static RestHighLevelClient getClient(List<String> hosts, int port, String scheme, String user, String password) { final CredentialsProvider credentialsProvider = new BasicCredentialsProvider(); credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(user, password)); SSLContext sc = null; try { sc = SSLContext.getInstance("SSL"); sc.init(null, trustAllCerts, new SecureRandom()); } catch (KeyManagementException | NoSuchAlgorithmException e) { e.printStackTrace(); } HttpHost[] httpHosts = hosts.stream().map(p -> new HttpHost(p, port, scheme)).toArray(HttpHost[]::new); final SSLIOSessionStrategy sessionStrategy = new SSLIOSessionStrategy(sc, NoopHostnameVerifier.INSTANCE); RestClientBuilder builder = RestClient.builder(httpHosts).setHttpClientConfigCallback(httpClientBuilder -> { httpClientBuilder.disableAuthCaching(); httpClientBuilder.setSSLStrategy(sessionStrategy); return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider); }); return new RestHighLevelClient(builder); } public static TrustManager[] trustAllCerts = new TrustManager[] { new X509TrustManager() { @Override public void checkClientTrusted(X509Certificate[] chain, String authType) { } @Override public void checkServerTrusted(X509Certificate[] chain, String authType) { } @Override public X509Certificate[] getAcceptedIssuers() { return null; } } }; // 创建索引 public void create(String index) throws IOException { CreateIndexRequest request = new CreateIndexRequest(index); request.settings(Settings.builder() .put("index.vector", true) // 开启向量特性 .put("index.number_of_shards", 1) // 索引分片数, 根据实际需求设置 .put("index.number_of_replicas", 0) // 索引副本数,根据实际需求设置 ); String mapping = "{" + " \"properties\": {" + " \"my_vector\": {" + " \"type\": \"vector\"," + // 设置该字段为向量类型 " \"indexing\": \"true\"," + // 开启索引加速 " \"dimension\": \"2\"," + // 向量索引 " \"metric\": \"euclidean\"," + // 相似度度量 " \"algorithm\": \"GRAPH\"" + // 索引算法 " }" + " }" + "}"; request.mapping(mapping, XContentType.JSON); CreateIndexResponse response = client.indices().create(request, RequestOptions.DEFAULT); if (response.isAcknowledged()) { System.out.println("create " + index + " success"); } } // 写入数据,一批数据的大小建议控制在500条向量以内 public void write(String index, List<float[]> vectors) throws IOException { BulkRequest request = new BulkRequest(); for (float[] vec : vectors) { request.add(new IndexRequest(index).source("my_vector", vec)); } BulkResponse response = client.bulk(request, RequestOptions.DEFAULT); if (response.hasFailures()) { System.out.println(response.buildFailureMessage()); } else { System.out.println("write bulk to " + index + " success"); } // 可选,ES会默认刷新 client.indices().refresh(new RefreshRequest(index), RequestOptions.DEFAULT); } // 查询向量 public void search(String index, float[] query, int size) throws IOException { String queryFormat = "{\n" + " \"size\":%d,\n" + " \"query\": {\n" + " \"vector\": {\n" + " \"my_vector\": {\n" + // 查询向量字段名称 " \"vector\": %s,\n" + " \"topk\":%d\n" + " }\n" + " }\n" + " }\n" + "}"; String body = String.format(queryFormat, size, Arrays.toString(query), size); Request request = new Request("POST", index + "/_search"); request.setJsonEntity(body); Response response = client.getLowLevelClient().performRequest(request); if (response.getStatusLine().getStatusCode() != HttpStatus.SC_OK) { System.out.println(response.getEntity()); // 根据业务需求处理查询错误 return; } // 根据业务需求处理正常返回结果 HttpEntity entity = response.getEntity(); XContentType xContentType = XContentType.fromMediaTypeOrFormat("application/json"); XContentParser parser = xContentType.xContent().createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, entity.getContent()); SearchResponse searchResponse = SearchResponse.fromXContent(parser); System.out.println(searchResponse); } // 删除索引 public void delete(String index) throws IOException { DeleteIndexRequest request = new DeleteIndexRequest(index); AcknowledgedResponse response = client.indices().delete(request, RequestOptions.DEFAULT); if (response.isAcknowledged()) { System.out.println("delete " + index + " success"); } } public void close() throws IOException { client.close(); } public static void main(String[] args) throws IOException { // 对于非安全集群,使用: RestHighLevelClient client = getClient(Arrays.asList("x.x.x.x"), 9200, "http"); /* * 对于开启了https的安全集群,使用: * RestHighLevelClient client = getClient(Arrays.asList("x.x.x.x", "x.x.x.x"), 9200, "https", "user_name", "password"); * 对于未开启https的安全集群,使用: * RestHighLevelClient client = getClient(Arrays.asList("x.x.x.x", "x.x.x.x"), 9200, "http", "user_name", "password"); */ ClientExampleOS example = new ClientExampleOS(client); String indexName = "my_index"; // 创建索引 example.create(indexName); // 写入数据 List<float[]> data = Arrays.asList(new float[]{1.0f, 1.0f}, new float[]{2.0f, 2.0f}, new float[]{3.0f, 3.0f}); example.write(indexName, data); // 查询索引 float[] queryVector = new float[]{1.0f, 1.0f}; example.search(indexName, queryVector, 3); // 删除索引 example.delete(indexName); // 关闭客户端 example.close(); } } |
父主题: 配置OpenSearch集群向量检索