Skip to main content

zai_rs/model/async_chat/
data.rs

1use std::marker::PhantomData;
2
3use serde::Serialize;
4use validator::Validate;
5
6use super::super::{chat_base_request::*, tools::*, traits::*};
7use crate::client::http::HttpClient;
8
9pub struct AsyncChatCompletion<N, M, S = StreamOff>
10where
11    N: ModelName + AsyncChat,
12    (N, M): Bounded,
13    ChatBody<N, M>: Serialize,
14    S: StreamState,
15{
16    pub key: String,
17    body: ChatBody<N, M>,
18    _stream: PhantomData<S>,
19}
20
21impl<N, M> AsyncChatCompletion<N, M, StreamOff>
22where
23    N: ModelName + AsyncChat,
24    (N, M): Bounded,
25    ChatBody<N, M>: Serialize,
26{
27    pub fn new(model: N, messages: M, key: String) -> Self {
28        let body = ChatBody::new(model, messages);
29        Self {
30            body,
31            key,
32            _stream: PhantomData,
33        }
34    }
35
36    pub fn body_mut(&mut self) -> &mut ChatBody<N, M> {
37        &mut self.body
38    }
39
40    // Fluent, builder-style forwarding methods to mutate inner ChatBody and return
41    // Self
42    pub fn add_messages(mut self, messages: M) -> Self {
43        self.body = self.body.add_messages(messages);
44        self
45    }
46    pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
47        self.body = self.body.with_request_id(request_id);
48        self
49    }
50    pub fn with_do_sample(mut self, do_sample: bool) -> Self {
51        self.body = self.body.with_do_sample(do_sample);
52        self
53    }
54    #[deprecated(note = "Use enable_stream()/disable_stream() for compile-time guarantees")]
55    pub fn with_stream(mut self, stream: bool) -> Self {
56        self.body = self.body.with_stream(stream);
57        self
58    }
59    pub fn with_tool_stream(mut self, tool_stream: bool) -> Self
60    where
61        N: ToolStreamEnable,
62    {
63        self.body = self.body.with_tool_stream(tool_stream);
64        self
65    }
66
67    pub fn with_temperature(mut self, temperature: f32) -> Self {
68        self.body = self.body.with_temperature(temperature);
69        self
70    }
71    pub fn with_top_p(mut self, top_p: f32) -> Self {
72        self.body = self.body.with_top_p(top_p);
73        self
74    }
75    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
76        self.body = self.body.with_max_tokens(max_tokens);
77        self
78    }
79    pub fn add_tool(mut self, tool: Tools) -> Self {
80        self.body = self.body.add_tools(tool);
81        self
82    }
83    pub fn add_tools(mut self, tools: Vec<Tools>) -> Self {
84        self.body = self.body.extend_tools(tools);
85        self
86    }
87    pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
88        self.body = self.body.with_user_id(user_id);
89        self
90    }
91    pub fn with_stop(mut self, stop: String) -> Self {
92        self.body = self.body.with_stop(stop);
93        self
94    }
95
96    // Optional: only available when model supports thinking
97    pub fn with_thinking(mut self, thinking: ThinkingType) -> Self
98    where
99        N: ThinkEnable,
100    {
101        self.body = self.body.with_thinking(thinking);
102        self
103    }
104
105    // Type-state toggles
106    pub fn enable_stream(mut self) -> AsyncChatCompletion<N, M, StreamOn> {
107        self.body.stream = Some(true);
108        AsyncChatCompletion {
109            key: self.key,
110            body: self.body,
111            _stream: PhantomData,
112        }
113    }
114
115    /// Validate request parameters for non-stream async chat (StreamOff)
116    pub fn validate(&self) -> crate::ZaiResult<()> {
117        self.body
118            .validate()
119            .map_err(crate::client::error::ZaiError::from)?;
120        if matches!(self.body.stream, Some(true)) {
121            return Err(crate::client::error::ZaiError::ApiError {
122                code: 1200,
123                message: "stream=true detected; use enable_stream() and streaming APIs instead"
124                    .to_string(),
125            });
126        }
127
128        Ok(())
129    }
130
131    pub async fn send(
132        &self,
133    ) -> crate::ZaiResult<crate::model::chat_base_response::ChatCompletionResponse>
134    where
135        N: serde::Serialize,
136        M: serde::Serialize,
137    {
138        self.validate()?;
139
140        let resp: reqwest::Response = self.post().await?;
141
142        let parsed = resp
143            .json::<crate::model::chat_base_response::ChatCompletionResponse>()
144            .await?;
145        Ok(parsed)
146    }
147}
148
149impl<N, M> AsyncChatCompletion<N, M, StreamOn>
150where
151    N: ModelName + AsyncChat,
152    (N, M): Bounded,
153    ChatBody<N, M>: Serialize,
154{
155    pub fn with_tool_stream(mut self, tool_stream: bool) -> Self
156    where
157        N: ToolStreamEnable,
158    {
159        self.body = self.body.with_tool_stream(tool_stream);
160        self
161    }
162
163    pub fn disable_stream(mut self) -> AsyncChatCompletion<N, M, StreamOff> {
164        self.body.stream = Some(false);
165        // Reset tool_stream when disabling streaming since tool_stream depends on
166        // stream
167        self.body.tool_stream = None;
168        AsyncChatCompletion {
169            key: self.key,
170            body: self.body,
171            _stream: PhantomData,
172        }
173    }
174}
175
176impl<N, M, S> HttpClient for AsyncChatCompletion<N, M, S>
177where
178    N: ModelName + Serialize + AsyncChat,
179    M: Serialize,
180    (N, M): Bounded,
181    S: StreamState,
182{
183    type Body = ChatBody<N, M>;
184    type ApiUrl = &'static str;
185    type ApiKey = String;
186
187    fn api_url(&self) -> &Self::ApiUrl {
188        &"https://open.bigmodel.cn/api/paas/v4/async/chat/completions"
189    }
190    fn api_key(&self) -> &Self::ApiKey {
191        &self.key
192    }
193    fn body(&self) -> &Self::Body {
194        &self.body
195    }
196}
197
198impl<N, M> crate::model::traits::SseStreamable for AsyncChatCompletion<N, M, StreamOn>
199where
200    N: ModelName + Serialize + AsyncChat,
201    M: Serialize,
202    (N, M): Bounded,
203{
204}
205
206// Enable typed streaming extension methods for AsyncChatCompletion<...,
207// StreamOn>
208impl<N, M> crate::model::stream_ext::StreamChatLikeExt for AsyncChatCompletion<N, M, StreamOn>
209where
210    N: ModelName + Serialize + AsyncChat,
211    M: Serialize,
212    (N, M): Bounded,
213{
214}