1use std::marker::PhantomData;
2
3use serde::Serialize;
4use validator::Validate;
5
6use super::super::{chat_base_request::*, tools::*, traits::*};
7use crate::client::http::HttpClient;
8
9pub struct AsyncChatCompletion<N, M, S = StreamOff>
10where
11 N: ModelName + AsyncChat,
12 (N, M): Bounded,
13 ChatBody<N, M>: Serialize,
14 S: StreamState,
15{
16 pub key: String,
17 body: ChatBody<N, M>,
18 _stream: PhantomData<S>,
19}
20
21impl<N, M> AsyncChatCompletion<N, M, StreamOff>
22where
23 N: ModelName + AsyncChat,
24 (N, M): Bounded,
25 ChatBody<N, M>: Serialize,
26{
27 pub fn new(model: N, messages: M, key: String) -> Self {
28 let body = ChatBody::new(model, messages);
29 Self {
30 body,
31 key,
32 _stream: PhantomData,
33 }
34 }
35
36 pub fn body_mut(&mut self) -> &mut ChatBody<N, M> {
37 &mut self.body
38 }
39
40 pub fn add_messages(mut self, messages: M) -> Self {
43 self.body = self.body.add_messages(messages);
44 self
45 }
46 pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
47 self.body = self.body.with_request_id(request_id);
48 self
49 }
50 pub fn with_do_sample(mut self, do_sample: bool) -> Self {
51 self.body = self.body.with_do_sample(do_sample);
52 self
53 }
54 #[deprecated(note = "Use enable_stream()/disable_stream() for compile-time guarantees")]
55 pub fn with_stream(mut self, stream: bool) -> Self {
56 self.body = self.body.with_stream(stream);
57 self
58 }
59 pub fn with_tool_stream(mut self, tool_stream: bool) -> Self
60 where
61 N: ToolStreamEnable,
62 {
63 self.body = self.body.with_tool_stream(tool_stream);
64 self
65 }
66
67 pub fn with_temperature(mut self, temperature: f32) -> Self {
68 self.body = self.body.with_temperature(temperature);
69 self
70 }
71 pub fn with_top_p(mut self, top_p: f32) -> Self {
72 self.body = self.body.with_top_p(top_p);
73 self
74 }
75 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
76 self.body = self.body.with_max_tokens(max_tokens);
77 self
78 }
79 pub fn add_tool(mut self, tool: Tools) -> Self {
80 self.body = self.body.add_tools(tool);
81 self
82 }
83 pub fn add_tools(mut self, tools: Vec<Tools>) -> Self {
84 self.body = self.body.extend_tools(tools);
85 self
86 }
87 pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
88 self.body = self.body.with_user_id(user_id);
89 self
90 }
91 pub fn with_stop(mut self, stop: String) -> Self {
92 self.body = self.body.with_stop(stop);
93 self
94 }
95
96 pub fn with_thinking(mut self, thinking: ThinkingType) -> Self
98 where
99 N: ThinkEnable,
100 {
101 self.body = self.body.with_thinking(thinking);
102 self
103 }
104
105 pub fn enable_stream(mut self) -> AsyncChatCompletion<N, M, StreamOn> {
107 self.body.stream = Some(true);
108 AsyncChatCompletion {
109 key: self.key,
110 body: self.body,
111 _stream: PhantomData,
112 }
113 }
114
115 pub fn validate(&self) -> crate::ZaiResult<()> {
117 self.body
118 .validate()
119 .map_err(crate::client::error::ZaiError::from)?;
120 if matches!(self.body.stream, Some(true)) {
121 return Err(crate::client::error::ZaiError::ApiError {
122 code: 1200,
123 message: "stream=true detected; use enable_stream() and streaming APIs instead"
124 .to_string(),
125 });
126 }
127
128 Ok(())
129 }
130
131 pub async fn send(
132 &self,
133 ) -> crate::ZaiResult<crate::model::chat_base_response::ChatCompletionResponse>
134 where
135 N: serde::Serialize,
136 M: serde::Serialize,
137 {
138 self.validate()?;
139
140 let resp: reqwest::Response = self.post().await?;
141
142 let parsed = resp
143 .json::<crate::model::chat_base_response::ChatCompletionResponse>()
144 .await?;
145 Ok(parsed)
146 }
147}
148
149impl<N, M> AsyncChatCompletion<N, M, StreamOn>
150where
151 N: ModelName + AsyncChat,
152 (N, M): Bounded,
153 ChatBody<N, M>: Serialize,
154{
155 pub fn with_tool_stream(mut self, tool_stream: bool) -> Self
156 where
157 N: ToolStreamEnable,
158 {
159 self.body = self.body.with_tool_stream(tool_stream);
160 self
161 }
162
163 pub fn disable_stream(mut self) -> AsyncChatCompletion<N, M, StreamOff> {
164 self.body.stream = Some(false);
165 self.body.tool_stream = None;
168 AsyncChatCompletion {
169 key: self.key,
170 body: self.body,
171 _stream: PhantomData,
172 }
173 }
174}
175
176impl<N, M, S> HttpClient for AsyncChatCompletion<N, M, S>
177where
178 N: ModelName + Serialize + AsyncChat,
179 M: Serialize,
180 (N, M): Bounded,
181 S: StreamState,
182{
183 type Body = ChatBody<N, M>;
184 type ApiUrl = &'static str;
185 type ApiKey = String;
186
187 fn api_url(&self) -> &Self::ApiUrl {
188 &"https://open.bigmodel.cn/api/paas/v4/async/chat/completions"
189 }
190 fn api_key(&self) -> &Self::ApiKey {
191 &self.key
192 }
193 fn body(&self) -> &Self::Body {
194 &self.body
195 }
196}
197
198impl<N, M> crate::model::traits::SseStreamable for AsyncChatCompletion<N, M, StreamOn>
199where
200 N: ModelName + Serialize + AsyncChat,
201 M: Serialize,
202 (N, M): Bounded,
203{
204}
205
206impl<N, M> crate::model::stream_ext::StreamChatLikeExt for AsyncChatCompletion<N, M, StreamOn>
209where
210 N: ModelName + Serialize + AsyncChat,
211 M: Serialize,
212 (N, M): Bounded,
213{
214}