更新时间:2024-08-29 GMT+08:00
分享

自定义模型

如果使用的模型不是盘古或者兼容OpenAI-API的开源模型,如,闭源模型或者裸机部署的自定义推理服务,可以通过继承AbstractLLM自定义一个模型,示例代码如下:
@Slf4j
public class CustomLLM extends AbstractLLM<LLMResp> {
    /**
     * 初始化
     *
     * @param llmConfig llm参数配置
     */
    public CustomLLM(LLMConfig llmConfig) {
        super(llmConfig);
    }
    @Override
    protected LLMResp getLLMResponse(List<ConversationMessage> chatMessages, LLMParamConfig llmParamConfig) {
        // 构造请求体
        Map<String, Object> request = new HashMap<>();
        request.put("temperature", 0.3);
        request.put("data", chatMessages.stream().map(ConversationMessage::getContent).collect(Collectors.toList()));
        final String requestBody = JSON.toJSONString(request);
        log.info("request body : \n{}", JSON.toJSONString(JSON.parseObject(requestBody), true));
        // 从配置项读取url,构造post消息
        String url = ConfigLoadUtil.getStringConf(null, "llm.custom.api.url");
        if (StringUtils.isEmpty(url)) {
            throw new PanguDevSDKException("the llm.custom.api.url is not config");
        }
        HttpPost httpPost = new HttpPost(url);
        httpPost.setEntity(new StringEntity(requestBody, ContentType.APPLICATION_JSON));
        // 发送消息并处理响应
        String responseStr;
        if (llmConfig.getLlmParamConfig().isStream()) {
            // 处理流式请求
            httpPost.setHeader(new BasicHeader("Inference-Type", "stream"));
            final CloseableHttpAsyncClient httpclient = HttpUtil.getHttpAsyncClient(false);
            try {
                httpclient.start();
                final String callBackId = SecurityUtil.getUUID();
                final List<PanguChatChunk> panguChatChunks = new ArrayList<>();
                Future<HttpResponse> future = httpclient.execute(HttpAsyncMethods.create(httpPost),
                    StreamHelper.getAsyncConsumer(streamCallBack, callBackId, panguChatChunks),
                    StreamHelper.getCallBack(streamCallBack, callBackId, httpPost));
                future.get(Optional.ofNullable(llmConfig.getHttpConfig().getAsyncHttpWaitSeconds()).orElse(300),
                    TimeUnit.SECONDS);
                final PanguChatResp allRespFromChunk = StreamHelper.getAllRespFromChunk(panguChatChunks);
                return LLMResp.builder().answer(allRespFromChunk.getChoices().get(0).getMessage().getContent()).build();
            } catch (Exception e) {
                throw new PanguDevSDKException(e);
            }
        } else {
            // 处理非流式请求
            final CloseableHttpClient httpClient = HttpUtil.getHttpClient(false);
            try {
                final CloseableHttpResponse response = httpClient.execute(httpPost);
                responseStr = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);
                log.info("response: \n{}", JSON.toJSONString(JSON.parseObject(responseStr), true));
                // 解析结果
                final JSONObject jsonObject = JSON.parseObject(responseStr);
                JSONObject result = jsonObject.getJSONObject("result");
                if (result == null) {
                    result = jsonObject;
                }
                final String content = ((JSONObject) result.getJSONArray("answers").get(0)).getString("content");
                return LLMResp.builder().answer(content).build();
            } catch (IOException e) {
                throw new PanguDevSDKException(e);
            }
        }
    }
    @Override
    protected LLMResp getLLMResponseFromCache(String cache) {
        return LLMResp.builder().answer(cache).isFromCache(true).build();
    }
}

相关文档