更新时间:2024-08-15 GMT+08:00
分享

向量检索的客户端代码示例(Java)

Elasticsearch提供了标准的REST接口,以及Java、Python等语言编写的客户端。

本节提供一份创建向量索引、导入向量数据和查询向量数据的Java代码示例,介绍如何使用客户端实现向量检索。

前提条件

根据集群实际版本添加如下Maven依赖,此处以7.10.2举例。

<dependency>
    <groupId>org.elasticsearch.client</groupId>
    <artifactId>elasticsearch-rest-high-level-client</artifactId>
    <version>7.10.2</version>
</dependency>

代码示例

package org.example;

import org.apache.http.HttpEntity;
import org.apache.http.HttpHost;
import org.apache.http.HttpStatus;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.conn.ssl.NoopHostnameVerifier;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.http.nio.conn.ssl.SSLIOSessionStrategy;
import org.elasticsearch.action.admin.indices.delete.DeleteIndexRequest;
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestClientBuilder;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.client.indices.CreateIndexRequest;
import org.elasticsearch.client.indices.CreateIndexResponse;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;

import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import java.io.IOException;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.List;

public class ClientExample {
    private final RestHighLevelClient client;

    public ClientExample(RestHighLevelClient client) {
        this.client = client;
    }

    // 创建非安全集群访问客户端
    public static RestHighLevelClient getClient(List<String> hosts, int port, String scheme) {
        HttpHost[] httpHosts = hosts.stream().map(p -> new HttpHost(p, port, scheme)).toArray(HttpHost[]::new);
        return new RestHighLevelClient(RestClient.builder(httpHosts));
    }

    // 创建安全集群访问客户端
    public static RestHighLevelClient getClient(List<String> hosts, int port, String scheme, String user, String password) {
        final CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
        credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(user, password));
        SSLContext sc = null;
        try {
            sc = SSLContext.getInstance("SSL");
            sc.init(null, trustAllCerts, new SecureRandom());
        } catch (KeyManagementException | NoSuchAlgorithmException e) {
            e.printStackTrace();
        }
        HttpHost[] httpHosts = hosts.stream().map(p -> new HttpHost(p, port, scheme)).toArray(HttpHost[]::new);
        final SSLIOSessionStrategy sessionStrategy = new SSLIOSessionStrategy(sc, NoopHostnameVerifier.INSTANCE);
        RestClientBuilder builder = RestClient.builder(httpHosts).setHttpClientConfigCallback(httpClientBuilder -> {
            httpClientBuilder.disableAuthCaching();
            httpClientBuilder.setSSLStrategy(sessionStrategy);
            return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider);
        });
        return new RestHighLevelClient(builder);
    }

    public static TrustManager[] trustAllCerts = new TrustManager[] {
            new X509TrustManager() {
                @Override
                public void checkClientTrusted(X509Certificate[] chain, String authType) {
                }
                @Override
                public void checkServerTrusted(X509Certificate[] chain, String authType) {
                }
                @Override
                public X509Certificate[] getAcceptedIssuers() {
                    return null;
                }
            }
    };

    // 创建索引
    public void create(String index) throws IOException {
        CreateIndexRequest request = new CreateIndexRequest(index);
        request.settings(Settings.builder()
                .put("index.vector", true) // 开启向量特性
                .put("index.number_of_shards", 1) // 索引分片数, 根据实际需求设置
                .put("index.number_of_replicas", 0) // 索引副本数,根据实际需求设置
        );
        String mapping =
                "{\n" +
                "  \"properties\": {\n" +
                "    \"my_vector\": {\n" +
                "      \"type\": \"vector\",\n" +      // 设置该字段为向量类型
                "      \"indexing\": \"true\",\n" +    // 开启索引加速
                "      \"dimension\": \"2\",\n" +      // 向量维度
                "      \"metric\": \"euclidean\",\n" + // 相似度度量
                "      \"algorithm\": \"GRAPH\"\n" +   // 索引算法
                "    }\n" +
                "  }\n" +
                "}";
        request.mapping(mapping, XContentType.JSON);
        CreateIndexResponse response = client.indices().create(request, RequestOptions.DEFAULT);
        if (response.isAcknowledged()) {
            System.out.println("create " + index + " success");
        }
    }

    // 写入数据,一批数据的大小建议控制在500条以内
    public void write(String index, List<float[]> vectors) throws IOException {
        BulkRequest request = new BulkRequest();
        for (float[] vec : vectors) {
            request.add(new IndexRequest(index).source("my_vector", vec));
        }

        BulkResponse response = client.bulk(request, RequestOptions.DEFAULT);
        if (response.hasFailures()) {
            System.out.println(response.buildFailureMessage());
        } else {
            System.out.println("write bulk to " + index + " success");
        }

        // 可选,Elasticsearch会默认刷新
        client.indices().refresh(new RefreshRequest(index), RequestOptions.DEFAULT);
    }

    // 查询向量
    public void search(String index, float[] query, int size) throws IOException {
        String queryFormat =
            "{\n" +
            "  \"size\":%d,\n" +
            "  \"query\": {\n" +
            "    \"vector\": {\n" +
            "      \"my_vector\": {\n" +  // 查询向量字段名称
            "        \"vector\": %s,\n" +
            "        \"topk\":%d\n" +
            "      }\n" +
            "    }\n" +
            "  }\n" +
            "}";
        String body = String.format(queryFormat, size, Arrays.toString(query), size);
        Request request = new Request("POST", index + "/_search");
        request.setJsonEntity(body);
        Response response = client.getLowLevelClient().performRequest(request);
        if (response.getStatusLine().getStatusCode() != HttpStatus.SC_OK) {
            System.out.println(response.getEntity()); // 根据业务需求处理查询错误
            return;
        }
        // 根据业务需求处理正常返回结果
        HttpEntity entity = response.getEntity();
        XContentType xContentType = XContentType.fromMediaTypeOrFormat("application/json");
        XContentParser parser = xContentType.xContent().createParser(NamedXContentRegistry.EMPTY,
                DeprecationHandler.IGNORE_DEPRECATIONS, entity.getContent());
        SearchResponse searchResponse = SearchResponse.fromXContent(parser);
        System.out.println(searchResponse);
    }

    // 删除索引
    public void delete(String index) throws IOException {
        DeleteIndexRequest request = new DeleteIndexRequest(index);
        AcknowledgedResponse response = client.indices().delete(request, RequestOptions.DEFAULT);
        if (response.isAcknowledged()) {
            System.out.println("delete " + index + " success");
        }
    }

    public void close() throws IOException {
        client.close();
    }

    public static void main(String[] args) throws IOException {
         // 对于非安全集群,使用:
         RestHighLevelClient client = getClient(Arrays.asList("x.x.x.x"), 9200, "http");

        /*
         *  对于开启了https的安全集群,使用:
         *  RestHighLevelClient client = getClient(Arrays.asList("x.x.x.x", "x.x.x.x"), 9200, "https", "user_name", "password");
         *  对于未开启https的安全集群,使用:
         *  RestHighLevelClient client = getClient(Arrays.asList("x.x.x.x", "x.x.x.x"), 9200, "http", "user_name", "password");
         */

        ClientExample example = new ClientExample(client);
        String indexName = "my_index";

        // 创建索引
        example.create(indexName);

        // 写入数据
        List<float[]> data = Arrays.asList(new float[]{1.0f, 1.0f}, new float[]{2.0f, 2.0f}, new float[]{3.0f, 3.0f});
        example.write(indexName, data);

        // 查询索引
        float[] queryVector = new float[]{1.0f, 1.0f};
        example.search(indexName, queryVector, 3);

        // 删除索引
        example.delete(indexName);

        // 关闭客户端
        example.close();
    }
}

相关文档