文档首页/
云搜索服务 CSS/
用户指南/
使用Elasticsearch搜索数据/
增强Elasticsearch集群搜索能力/
配置Elasticsearch集群向量检索/
向量检索的客户端代码示例(Java)
更新时间:2024-08-15 GMT+08:00
向量检索的客户端代码示例(Java)
Elasticsearch提供了标准的REST接口,以及Java、Python等语言编写的客户端。
本节提供一份创建向量索引、导入向量数据和查询向量数据的Java代码示例,介绍如何使用客户端实现向量检索。
前提条件
根据集群实际版本添加如下Maven依赖,此处以7.10.2举例。
<dependency> <groupId>org.elasticsearch.client</groupId> <artifactId>elasticsearch-rest-high-level-client</artifactId> <version>7.10.2</version> </dependency>
代码示例
package org.example; import org.apache.http.HttpEntity; import org.apache.http.HttpHost; import org.apache.http.HttpStatus; import org.apache.http.auth.AuthScope; import org.apache.http.auth.UsernamePasswordCredentials; import org.apache.http.client.CredentialsProvider; import org.apache.http.conn.ssl.NoopHostnameVerifier; import org.apache.http.impl.client.BasicCredentialsProvider; import org.apache.http.nio.conn.ssl.SSLIOSessionStrategy; import org.elasticsearch.action.admin.indices.delete.DeleteIndexRequest; import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.action.bulk.BulkRequest; import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.Request; import org.elasticsearch.client.RequestOptions; import org.elasticsearch.client.Response; import org.elasticsearch.client.RestClient; import org.elasticsearch.client.RestClientBuilder; import org.elasticsearch.client.RestHighLevelClient; import org.elasticsearch.client.indices.CreateIndexRequest; import org.elasticsearch.client.indices.CreateIndexResponse; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.DeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; import javax.net.ssl.SSLContext; import javax.net.ssl.TrustManager; import javax.net.ssl.X509TrustManager; import java.io.IOException; import java.security.KeyManagementException; import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; import java.security.cert.X509Certificate; import java.util.Arrays; import java.util.List; public class ClientExample { private final RestHighLevelClient client; public ClientExample(RestHighLevelClient client) { this.client = client; } // 创建非安全集群访问客户端 public static RestHighLevelClient getClient(List<String> hosts, int port, String scheme) { HttpHost[] httpHosts = hosts.stream().map(p -> new HttpHost(p, port, scheme)).toArray(HttpHost[]::new); return new RestHighLevelClient(RestClient.builder(httpHosts)); } // 创建安全集群访问客户端 public static RestHighLevelClient getClient(List<String> hosts, int port, String scheme, String user, String password) { final CredentialsProvider credentialsProvider = new BasicCredentialsProvider(); credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(user, password)); SSLContext sc = null; try { sc = SSLContext.getInstance("SSL"); sc.init(null, trustAllCerts, new SecureRandom()); } catch (KeyManagementException | NoSuchAlgorithmException e) { e.printStackTrace(); } HttpHost[] httpHosts = hosts.stream().map(p -> new HttpHost(p, port, scheme)).toArray(HttpHost[]::new); final SSLIOSessionStrategy sessionStrategy = new SSLIOSessionStrategy(sc, NoopHostnameVerifier.INSTANCE); RestClientBuilder builder = RestClient.builder(httpHosts).setHttpClientConfigCallback(httpClientBuilder -> { httpClientBuilder.disableAuthCaching(); httpClientBuilder.setSSLStrategy(sessionStrategy); return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider); }); return new RestHighLevelClient(builder); } public static TrustManager[] trustAllCerts = new TrustManager[] { new X509TrustManager() { @Override public void checkClientTrusted(X509Certificate[] chain, String authType) { } @Override public void checkServerTrusted(X509Certificate[] chain, String authType) { } @Override public X509Certificate[] getAcceptedIssuers() { return null; } } }; // 创建索引 public void create(String index) throws IOException { CreateIndexRequest request = new CreateIndexRequest(index); request.settings(Settings.builder() .put("index.vector", true) // 开启向量特性 .put("index.number_of_shards", 1) // 索引分片数, 根据实际需求设置 .put("index.number_of_replicas", 0) // 索引副本数,根据实际需求设置 ); String mapping = "{\n" + " \"properties\": {\n" + " \"my_vector\": {\n" + " \"type\": \"vector\",\n" + // 设置该字段为向量类型 " \"indexing\": \"true\",\n" + // 开启索引加速 " \"dimension\": \"2\",\n" + // 向量维度 " \"metric\": \"euclidean\",\n" + // 相似度度量 " \"algorithm\": \"GRAPH\"\n" + // 索引算法 " }\n" + " }\n" + "}"; request.mapping(mapping, XContentType.JSON); CreateIndexResponse response = client.indices().create(request, RequestOptions.DEFAULT); if (response.isAcknowledged()) { System.out.println("create " + index + " success"); } } // 写入数据,一批数据的大小建议控制在500条以内 public void write(String index, List<float[]> vectors) throws IOException { BulkRequest request = new BulkRequest(); for (float[] vec : vectors) { request.add(new IndexRequest(index).source("my_vector", vec)); } BulkResponse response = client.bulk(request, RequestOptions.DEFAULT); if (response.hasFailures()) { System.out.println(response.buildFailureMessage()); } else { System.out.println("write bulk to " + index + " success"); } // 可选,Elasticsearch会默认刷新 client.indices().refresh(new RefreshRequest(index), RequestOptions.DEFAULT); } // 查询向量 public void search(String index, float[] query, int size) throws IOException { String queryFormat = "{\n" + " \"size\":%d,\n" + " \"query\": {\n" + " \"vector\": {\n" + " \"my_vector\": {\n" + // 查询向量字段名称 " \"vector\": %s,\n" + " \"topk\":%d\n" + " }\n" + " }\n" + " }\n" + "}"; String body = String.format(queryFormat, size, Arrays.toString(query), size); Request request = new Request("POST", index + "/_search"); request.setJsonEntity(body); Response response = client.getLowLevelClient().performRequest(request); if (response.getStatusLine().getStatusCode() != HttpStatus.SC_OK) { System.out.println(response.getEntity()); // 根据业务需求处理查询错误 return; } // 根据业务需求处理正常返回结果 HttpEntity entity = response.getEntity(); XContentType xContentType = XContentType.fromMediaTypeOrFormat("application/json"); XContentParser parser = xContentType.xContent().createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, entity.getContent()); SearchResponse searchResponse = SearchResponse.fromXContent(parser); System.out.println(searchResponse); } // 删除索引 public void delete(String index) throws IOException { DeleteIndexRequest request = new DeleteIndexRequest(index); AcknowledgedResponse response = client.indices().delete(request, RequestOptions.DEFAULT); if (response.isAcknowledged()) { System.out.println("delete " + index + " success"); } } public void close() throws IOException { client.close(); } public static void main(String[] args) throws IOException { // 对于非安全集群,使用: RestHighLevelClient client = getClient(Arrays.asList("x.x.x.x"), 9200, "http"); /* * 对于开启了https的安全集群,使用: * RestHighLevelClient client = getClient(Arrays.asList("x.x.x.x", "x.x.x.x"), 9200, "https", "user_name", "password"); * 对于未开启https的安全集群,使用: * RestHighLevelClient client = getClient(Arrays.asList("x.x.x.x", "x.x.x.x"), 9200, "http", "user_name", "password"); */ ClientExample example = new ClientExample(client); String indexName = "my_index"; // 创建索引 example.create(indexName); // 写入数据 List<float[]> data = Arrays.asList(new float[]{1.0f, 1.0f}, new float[]{2.0f, 2.0f}, new float[]{3.0f, 3.0f}); example.write(indexName, data); // 查询索引 float[] queryVector = new float[]{1.0f, 1.0f}; example.search(indexName, queryVector, 3); // 删除索引 example.delete(indexName); // 关闭客户端 example.close(); } }