更新时间:2024-10-16 GMT+08:00
自定义模型
如果使用的模型不是盘古或者兼容OpenAI-API的开源模型,如,闭源模型或者裸机部署的自定义推理服务,可以通过继承AbstractLLM自定义一个模型,示例代码如下:
@Slf4j public class CustomLLM extends AbstractLLM<LLMResp> { /** * 初始化 * * @param llmConfig llm参数配置 */ public CustomLLM(LLMConfig llmConfig) { super(llmConfig); } @Override protected LLMResp getLLMResponse(List<ConversationMessage> chatMessages, LLMParamConfig llmParamConfig) { // 构造请求体 Map<String, Object> request = new HashMap<>(); request.put("temperature", 0.3); request.put("data", chatMessages.stream().map(ConversationMessage::getContent).collect(Collectors.toList())); final String requestBody = JSON.toJSONString(request); log.info("request body : \n{}", JSON.toJSONString(JSON.parseObject(requestBody), true)); // 从配置项读取url,构造post消息 String url = ConfigLoadUtil.getStringConf(null, "llm.custom.api.url"); if (StringUtils.isEmpty(url)) { throw new PanguDevSDKException("the llm.custom.api.url is not config"); } HttpPost httpPost = new HttpPost(url); httpPost.setEntity(new StringEntity(requestBody, ContentType.APPLICATION_JSON)); // 发送消息并处理响应 String responseStr; if (llmConfig.getLlmParamConfig().isStream()) { // 处理流式请求 httpPost.setHeader(new BasicHeader("Inference-Type", "stream")); final CloseableHttpAsyncClient httpclient = HttpUtil.getHttpAsyncClient(false); try { httpclient.start(); final String callBackId = SecurityUtil.getUUID(); final List<PanguChatChunk> panguChatChunks = new ArrayList<>(); Future<HttpResponse> future = httpclient.execute(HttpAsyncMethods.create(httpPost), StreamHelper.getAsyncConsumer(streamCallBack, callBackId, panguChatChunks), StreamHelper.getCallBack(streamCallBack, callBackId, httpPost)); future.get(Optional.ofNullable(llmConfig.getHttpConfig().getAsyncHttpWaitSeconds()).orElse(300), TimeUnit.SECONDS); final PanguChatResp allRespFromChunk = StreamHelper.getAllRespFromChunk(panguChatChunks); return LLMResp.builder().answer(allRespFromChunk.getChoices().get(0).getMessage().getContent()).build(); } catch (Exception e) { throw new PanguDevSDKException(e); } } else { // 处理非流式请求 final CloseableHttpClient httpClient = HttpUtil.getHttpClient(false); try { final CloseableHttpResponse response = httpClient.execute(httpPost); responseStr = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8); log.info("response: \n{}", JSON.toJSONString(JSON.parseObject(responseStr), true)); // 解析结果 final JSONObject jsonObject = JSON.parseObject(responseStr); JSONObject result = jsonObject.getJSONObject("result"); if (result == null) { result = jsonObject; } final String content = ((JSONObject) result.getJSONArray("answers").get(0)).getString("content"); return LLMResp.builder().answer(content).build(); } catch (IOException e) { throw new PanguDevSDKException(e); } } } @Override protected LLMResp getLLMResponseFromCache(String cache) { return LLMResp.builder().answer(cache).isFromCache(true).build(); } }
父主题: LLMs(语言模型)