使用 OKhttp3 实现 ChatGLM HTTP 调用(SSE、异步、同步)

发布时间 2023-11-06 16:07:39作者: Cheyaoyao

为了熟悉下 OKhttp 和 ChatGLM 接口,写几个 demo 试试

1. 准备工作

ChatGLM接口文档可知,每次 HTTP 调用都需要带上一个鉴权 token,而组装这个 token,我们需要先获取一个 API Key,这个可从智谱AI开放平台 API Keys 页面获得,API Key 包含 “用户标识 id” 和 “签名密钥 secret”,即格式为 {id}.{secret}

获取 token 和接口请求参数的代码在最后的附录中

2. SSE 调用

SSE(Sever-Sent Event),就是浏览器向服务器发送一个HTTP请求,保持长连接,服务器不断单向地向浏览器推送“信息”(message),这么做是为了节约网络资源,不用一直发请求,建立新连接。

// 设置请求参数
RequestParam requestParam = new RequestParam();
List<RequestParam.Prompt> prompts = new ArrayList<>();
prompts.add(RequestParam.Prompt.builder()
        .role(Role.user.getCode())
        .content("你好,我想问你一些 Java 相关的问题")
        .build());
requestParam.setPrompt(prompts);

// 创建请求体
MediaType json = MediaType.parse("application/json; charset=utf-8");
RequestBody requestBody = RequestBody.create(json, requestParam.toString());

// 创建请求对象
Request request = new Request.Builder()
        .url("https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_turbo/sse-invoke")
        .post(requestBody) // 请求体
        .addHeader("Authorization", "Bearer " + token)
        .addHeader("Accept", "text/event-stream")
        .build();

// 开启 Http 客户端
OkHttpClient okHttpClient = new OkHttpClient.Builder()
        .connectTimeout(10, TimeUnit.SECONDS)   // 建立连接的超时时间
        .readTimeout(10, TimeUnit.MINUTES)  // 建立连接后读取数据的超时时间
        .build();

// 创建一个 CountDownLatch 对象,其初始计数为1,表示需要等待一个事件发生后才能继续执行。
CountDownLatch eventLatch = new CountDownLatch(1);

// 实例化EventSource,注册EventSource监听器 -- 创建一个用于处理服务器发送事件的实例,并定义处理事件的回调逻辑
RealEventSource realEventSource = new RealEventSource(request, new EventSourceListener() {

    @Override
    public void onEvent(EventSource eventSource, String id, String type, String data) {
        System.out.println(data);   // 请求到的数据
        if ("finish".equals(type)) {    // 消息类型,add 增量,finish 结束,error 错误,interrupted 中断
            eventLatch.countDown();
        }
    }

});

// 与服务器建立连接
realEventSource.connect(okHttpClient); 

// await() 方法被调用来阻塞当前线程,直到 CountDownLatch 的计数变为0。
eventLatch.await();

3. 异步调用

根据文档描述,首先得通过异步 POST 请求获得 task_id ,再根据 task_id 发送 GET 请求获得最终结果

// TODO 设置请求参数,同 SSE 调用

// 开启 Http 客户端
OkHttpClient okHttpClient = new OkHttpClient();

// 创建请求体
MediaType json = MediaType.parse("application/json; charset=utf-8");
RequestBody requestBody = RequestBody.create(json, requestParam.toString());

// 第一步:发送异步请求(POST)获取 task_id,并存放到 taskIdFuture 中
CompletableFuture<String> taskIdFuture = new CompletableFuture<>();

Request requestForTaskId = new Request.Builder()
        .url("https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_turbo/async-invoke")
        .post(requestBody)
        .addHeader("Authorization", "Bearer " + token)
        .build();

// 创建一个新的异步 HTTP 请求,并指定请求的回调函数
okHttpClient.newCall(requestForTaskId).enqueue(new Callback() {
    // 在请求成功并返回响应时被调用
    @Override
    public void onResponse(Call call, Response response) throws IOException {
        if (response.isSuccessful()) {
            String responseBody = response.body().string();
            System.out.println("requestForTaskId: " + responseBody);
            // 解析 JSON 响应获取 task_id
            JSONObject jsonObject = JSON.parseObject(responseBody);
            String taskId = jsonObject.getJSONObject("data").getString("task_id");
            // 将结果设置到 CompletableFuture
            taskIdFuture.complete(taskId);
        } else {
            taskIdFuture.completeExceptionally(new Exception("Request for task_id failed"));
        }
    }

    // 在请求失败时被调用
    @Override
    public void onFailure(Call call, IOException e) {
        taskIdFuture.completeExceptionally(e);
    }
});

// 阻塞主线程,等待 CompletableFuture 的结果,设置了最大等待时间
String taskId = taskIdFuture.get(10, TimeUnit.SECONDS);
System.out.println("Task ID: " + taskId);

// TODO 第二步,使用 task_id 发送同步请求(GET)获取最终响应结果(和第四节基本一样)

4. 同步调用

// TODO 设置请求参数,同 SSE 调用

// 开启 Http 客户端
OkHttpClient client = new OkHttpClient();

// 创建请求体
MediaType json = MediaType.parse("application/json; charset=utf-8");
RequestBody requestBody = RequestBody.create(json, requestParam.toString());

// 创建请求对象
Request request = new Request.Builder()
        .url("https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_turbo/invoke")
        .post(requestBody) 
        .addHeader("Authorization", "Bearer " + token)
        .build();

// 发送请求
Response response = client.newCall(request).execute();

// 处理响应
if (response.isSuccessful()) {
    String responseBody = response.body().string();
    System.out.println("Response: " + responseBody);
} else {
    System.out.println("Request failed: " + response.code() + " " + response.message());
}

5. 附录

5.1 组装鉴权 token

// 这里的 secret 是 API Key 中的 {secret} 部分
Algorithm algorithm = Algorithm.HMAC256(secret.getBytes(StandardCharsets.UTF_8));
Map<String, Object> payload = new HashMap<>();
// 这里的 id 是 API Key 中的 {id} 部分
payload.put("api_key", id);
payload.put("exp", System.currentTimeMillis() + 30 * 60 * 1000L);  // 过期时间, 30分钟
payload.put("timestamp", Calendar.getInstance().getTimeInMillis()); // 时间戳
Map<String, Object> headerClaims = new HashMap<>();
headerClaims.put("alg", "HS256");
headerClaims.put("sign_type", "SIGN");
String token = JWT.create().withPayload(payload).withHeader(headerClaims).sign(algorithm);

5.2 接口请求参数

@Data
@JsonInclude(JsonInclude.Include.NON_NULL)  
@Builder    
@NoArgsConstructor  
@AllArgsConstructor     
public class RequestParam {

    @JsonProperty("request_id")
    private String requestId = String.format("gpt-%d", System.currentTimeMillis());

    private float temperature = 0.9f;

    @JsonProperty("top_p")
    private float topP = 0.7f;

    /**
     * 输入给模型的会话信息
     * 用户输入的内容;role=user
     * 挟带历史的内容;role=assistant
     */
    private List<RequestParam.Prompt> prompt;

    private boolean incremental = true;

    private String sseFormat = "data";

    @Data
    @Builder    
    @NoArgsConstructor  
    @AllArgsConstructor     
    public static class Prompt {
        private String role;
        private String content;
    }

    @Override
    public String toString() {
        Map<String, Object> paramsMap = new HashMap<>();
        paramsMap.put("request_id", requestId);
        paramsMap.put("prompt", prompt);
        paramsMap.put("incremental", incremental);
        paramsMap.put("temperature", temperature);
        paramsMap.put("top_p", topP);
        paramsMap.put("sseFormat", sseFormat);
        try {
            return new ObjectMapper().writeValueAsString(paramsMap);
        } catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
    }
}