zai_rs/model/async_chat/
data.rs

1use super::super::chat_base_request::*;
2use super::super::tools::*;
3use super::super::traits::*;
4use crate::client::http::HttpClient;
5use serde::Serialize;
6use std::marker::PhantomData;
7use validator::Validate;
8
9pub struct AsyncChatCompletion<N, M, S = StreamOff>
10where
11    N: ModelName + AsyncChat,
12    (N, M): Bounded,
13    ChatBody<N, M>: Serialize,
14    S: StreamState,
15{
16    pub key: String,
17    body: ChatBody<N, M>,
18    _stream: PhantomData<S>,
19}
20
21impl<N, M> AsyncChatCompletion<N, M, StreamOff>
22where
23    N: ModelName + AsyncChat,
24    (N, M): Bounded,
25    ChatBody<N, M>: Serialize,
26{
27    pub fn new(model: N, messages: M, key: String) -> Self {
28        let body = ChatBody::new(model, messages);
29        Self {
30            body,
31            key,
32            _stream: PhantomData,
33        }
34    }
35
36    pub fn body_mut(&mut self) -> &mut ChatBody<N, M> {
37        &mut self.body
38    }
39
40    // Fluent, builder-style forwarding methods to mutate inner ChatBody and return Self
41    pub fn add_messages(mut self, messages: M) -> Self {
42        self.body = self.body.add_messages(messages);
43        self
44    }
45    pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
46        self.body = self.body.with_request_id(request_id);
47        self
48    }
49    pub fn with_do_sample(mut self, do_sample: bool) -> Self {
50        self.body = self.body.with_do_sample(do_sample);
51        self
52    }
53    #[deprecated(note = "Use enable_stream()/disable_stream() for compile-time guarantees")]
54    pub fn with_stream(mut self, stream: bool) -> Self {
55        self.body = self.body.with_stream(stream);
56        self
57    }
58    pub fn with_tool_stream(mut self, tool_stream: bool) -> Self
59    where
60        N: ToolStreamEnable,
61    {
62        self.body = self.body.with_tool_stream(tool_stream);
63        self
64    }
65
66    pub fn with_temperature(mut self, temperature: f32) -> Self {
67        self.body = self.body.with_temperature(temperature);
68        self
69    }
70    pub fn with_top_p(mut self, top_p: f32) -> Self {
71        self.body = self.body.with_top_p(top_p);
72        self
73    }
74    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
75        self.body = self.body.with_max_tokens(max_tokens);
76        self
77    }
78    pub fn add_tool(mut self, tool: Tools) -> Self {
79        self.body = self.body.add_tools(tool);
80        self
81    }
82    pub fn add_tools(mut self, tools: Vec<Tools>) -> Self {
83        self.body = self.body.extend_tools(tools);
84        self
85    }
86    pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
87        self.body = self.body.with_user_id(user_id);
88        self
89    }
90    pub fn with_stop(mut self, stop: String) -> Self {
91        self.body = self.body.with_stop(stop);
92        self
93    }
94
95    // Optional: only available when model supports thinking
96    pub fn with_thinking(mut self, thinking: ThinkingType) -> Self
97    where
98        N: ThinkEnable,
99    {
100        self.body = self.body.with_thinking(thinking);
101        self
102    }
103
104    // Type-state toggles
105    pub fn enable_stream(mut self) -> AsyncChatCompletion<N, M, StreamOn> {
106        self.body.stream = Some(true);
107        AsyncChatCompletion {
108            key: self.key,
109            body: self.body,
110            _stream: PhantomData,
111        }
112    }
113    pub fn disable_stream(mut self) -> AsyncChatCompletion<N, M, StreamOff> {
114        self.body.stream = Some(false);
115        AsyncChatCompletion {
116            key: self.key,
117            body: self.body,
118            _stream: PhantomData,
119        }
120    }
121    /// Validate request parameters for non-stream async chat (StreamOff)
122    pub fn validate(&self) -> anyhow::Result<()> {
123        self.body.validate().map_err(|e| anyhow::anyhow!(e))?;
124        if matches!(self.body.stream, Some(true)) {
125            return Err(anyhow::anyhow!(
126                "stream=true detected; use enable_stream() and streaming APIs instead"
127            ));
128        }
129        Ok(())
130    }
131
132    /// Send the request and parse typed response.
133    /// Automatically runs `validate()` before sending.
134    pub async fn send(
135        &self,
136    ) -> anyhow::Result<crate::model::chat_base_response::ChatCompletionResponse>
137    where
138        N: serde::Serialize,
139        M: serde::Serialize,
140    {
141        self.validate()?;
142        let resp: reqwest::Response = self.post().await?;
143        let parsed = resp
144            .json::<crate::model::chat_base_response::ChatCompletionResponse>()
145            .await?;
146        Ok(parsed)
147    }
148}
149
150impl<N, M> AsyncChatCompletion<N, M, StreamOn>
151where
152    N: ModelName + AsyncChat,
153    (N, M): Bounded,
154    ChatBody<N, M>: Serialize,
155{
156    pub fn with_tool_stream(mut self, tool_stream: bool) -> Self
157    where
158        N: ToolStreamEnable,
159    {
160        self.body = self.body.with_tool_stream(tool_stream);
161        self
162    }
163}
164
165impl<N, M, S> HttpClient for AsyncChatCompletion<N, M, S>
166where
167    N: ModelName + Serialize + AsyncChat,
168    M: Serialize,
169    (N, M): Bounded,
170    S: StreamState,
171{
172    type Body = ChatBody<N, M>;
173    type ApiUrl = &'static str;
174    type ApiKey = String;
175
176    fn api_url(&self) -> &Self::ApiUrl {
177        &"https://open.bigmodel.cn/api/paas/v4/async/chat/completions"
178    }
179    fn api_key(&self) -> &Self::ApiKey {
180        &self.key
181    }
182    fn body(&self) -> &Self::Body {
183        &self.body
184    }
185}
186
187impl<N, M> crate::model::traits::SseStreamable for AsyncChatCompletion<N, M, StreamOn>
188where
189    N: ModelName + Serialize + AsyncChat,
190    M: Serialize,
191    (N, M): Bounded,
192{
193}
194
195// Enable typed streaming extension methods for AsyncChatCompletion<..., StreamOn>
196impl<N, M> crate::model::stream_ext::StreamChatLikeExt for AsyncChatCompletion<N, M, StreamOn>
197where
198    N: ModelName + Serialize + AsyncChat,
199    M: Serialize,
200    (N, M): Bounded,
201{
202}