向量检索的Java代码示例
Elasticsearch提供了标准的REST接口,以及Java、Python等语言编写的客户端。
本节提供一份创建向量索引、导入向量数据和查询向量数据的Java代码示例,介绍如何使用客户端实现向量检索。
前提条件
根据集群实际版本添加如下Maven依赖,此处以7.10.2举例。
<dependency>
<groupId>org.elasticsearch.client</groupId>
<artifactId>elasticsearch-rest-high-level-client</artifactId>
<version>7.10.2</version>
</dependency>
代码示例
package org.example;
import org.apache.http.HttpEntity;
import org.apache.http.HttpHost;
import org.apache.http.HttpStatus;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.conn.ssl.NoopHostnameVerifier;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.http.nio.conn.ssl.SSLIOSessionStrategy;
import org.elasticsearch.action.admin.indices.delete.DeleteIndexRequest;
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestClientBuilder;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.client.indices.CreateIndexRequest;
import org.elasticsearch.client.indices.CreateIndexResponse;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import java.io.IOException;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.List;
public class ClientExample {
private final RestHighLevelClient client;
public ClientExample(RestHighLevelClient client) {
this.client = client;
}
// 创建非安全集群访问客户端
public static RestHighLevelClient getClient(List<String> hosts, int port, String scheme) {
HttpHost[] httpHosts = hosts.stream().map(p -> new HttpHost(p, port, scheme)).toArray(HttpHost[]::new);
return new RestHighLevelClient(RestClient.builder(httpHosts));
}
// 创建安全集群访问客户端
public static RestHighLevelClient getClient(List<String> hosts, int port, String scheme, String user, String password) {
final CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(user, password));
SSLContext sc = null;
try {
sc = SSLContext.getInstance("SSL");
sc.init(null, trustAllCerts, new SecureRandom());
} catch (KeyManagementException | NoSuchAlgorithmException e) {
e.printStackTrace();
}
HttpHost[] httpHosts = hosts.stream().map(p -> new HttpHost(p, port, scheme)).toArray(HttpHost[]::new);
final SSLIOSessionStrategy sessionStrategy = new SSLIOSessionStrategy(sc, NoopHostnameVerifier.INSTANCE);
RestClientBuilder builder = RestClient.builder(httpHosts).setHttpClientConfigCallback(httpClientBuilder -> {
httpClientBuilder.disableAuthCaching();
httpClientBuilder.setSSLStrategy(sessionStrategy);
return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider);
});
return new RestHighLevelClient(builder);
}
public static TrustManager[] trustAllCerts = new TrustManager[] {
new X509TrustManager() {
@Override
public void checkClientTrusted(X509Certificate[] chain, String authType) {
}
@Override
public void checkServerTrusted(X509Certificate[] chain, String authType) {
}
@Override
public X509Certificate[] getAcceptedIssuers() {
return null;
}
}
};
// 创建索引
public void create(String index) throws IOException {
CreateIndexRequest request = new CreateIndexRequest(index);
request.settings(Settings.builder()
.put("index.vector", true) // 开启向量特性
.put("index.number_of_shards", 1) // 索引分片数, 根据实际需求设置
.put("index.number_of_replicas", 0) // 索引副本数,根据实际需求设置
);
String mapping =
"{\n" +
" \"properties\": {\n" +
" \"my_vector\": {\n" +
" \"type\": \"vector\",\n" + // 设置该字段为向量类型
" \"indexing\": \"true\",\n" + // 开启索引加速
" \"dimension\": \"2\",\n" + // 向量维度
" \"metric\": \"euclidean\",\n" + // 相似度度量
" \"algorithm\": \"GRAPH\"\n" + // 索引算法
" }\n" +
" }\n" +
"}";
request.mapping(mapping, XContentType.JSON);
CreateIndexResponse response = client.indices().create(request, RequestOptions.DEFAULT);
if (response.isAcknowledged()) {
System.out.println("create " + index + " success");
}
}
// 写入数据,一批数据的大小建议控制在500条以内
public void write(String index, List<float[]> vectors) throws IOException {
BulkRequest request = new BulkRequest();
for (float[] vec : vectors) {
request.add(new IndexRequest(index).source("my_vector", vec));
}
BulkResponse response = client.bulk(request, RequestOptions.DEFAULT);
if (response.hasFailures()) {
System.out.println(response.buildFailureMessage());
} else {
System.out.println("write bulk to " + index + " success");
}
// 可选,Elasticsearch会默认刷新
client.indices().refresh(new RefreshRequest(index), RequestOptions.DEFAULT);
}
// 查询向量
public void search(String index, float[] query, int size) throws IOException {
String queryFormat =
"{\n" +
" \"size\":%d,\n" +
" \"query\": {\n" +
" \"vector\": {\n" +
" \"my_vector\": {\n" + // 查询向量字段名称
" \"vector\": %s,\n" +
" \"topk\":%d\n" +
" }\n" +
" }\n" +
" }\n" +
"}";
String body = String.format(queryFormat, size, Arrays.toString(query), size);
Request request = new Request("POST", index + "/_search");
request.setJsonEntity(body);
Response response = client.getLowLevelClient().performRequest(request);
if (response.getStatusLine().getStatusCode() != HttpStatus.SC_OK) {
System.out.println(response.getEntity()); // 根据业务需求处理查询错误
return;
}
// 根据业务需求处理正常返回结果
HttpEntity entity = response.getEntity();
XContentType xContentType = XContentType.fromMediaTypeOrFormat("application/json");
XContentParser parser = xContentType.xContent().createParser(NamedXContentRegistry.EMPTY,
DeprecationHandler.IGNORE_DEPRECATIONS, entity.getContent());
SearchResponse searchResponse = SearchResponse.fromXContent(parser);
System.out.println(searchResponse);
}
// 删除索引
public void delete(String index) throws IOException {
DeleteIndexRequest request = new DeleteIndexRequest(index);
AcknowledgedResponse response = client.indices().delete(request, RequestOptions.DEFAULT);
if (response.isAcknowledged()) {
System.out.println("delete " + index + " success");
}
}
public void close() throws IOException {
client.close();
}
public static void main(String[] args) throws IOException {
// 对于非安全集群,使用:
RestHighLevelClient client = getClient(Arrays.asList("xx.xx.xx.xx"), 9200, "http");
/*
* 对于开启了https的安全集群,使用:
* RestHighLevelClient client = getClient(Arrays.asList("xx.xx.xx.xx", "xx.xx.xx.xx"), 9200, "https", "user_name", "password");
* 对于未开启https的安全集群,使用:
* RestHighLevelClient client = getClient(Arrays.asList("xx.xx.xx.xx", "xx.xx.xx.xx"), 9200, "http", "user_name", "password");
*/
ClientExample example = new ClientExample(client);
String indexName = "my_index";
// 创建索引
example.create(indexName);
// 写入数据
List<float[]> data = Arrays.asList(new float[]{1.0f, 1.0f}, new float[]{2.0f, 2.0f}, new float[]{3.0f, 3.0f});
example.write(indexName, data);
// 查询索引
float[] queryVector = new float[]{1.0f, 1.0f};
example.search(indexName, queryVector, 3);
// 删除索引
example.delete(indexName);
// 关闭客户端
example.close();
}
}