Skip to main content

zai_rs/model/
chat_base_request.rs

1//! # Chat Request Body
2//!
3//! Provides the core [`ChatBody`] structure shared by all chat-completion
4//! endpoints. The generic type parameters enforce compile-time compatibility
5//! between model and message types.
6//!
7//! # Type Parameters
8//!
9//! - `N` — Model identifier type (must implement
10//!   [`ModelName`](super::traits::ModelName))
11//! - `M` — Message type (must form a [`Bounded`](super::traits::Bounded) pair
12//!   with `N`)
13
14use serde::Serialize;
15use validator::*;
16
17use super::{tools::*, traits::*};
18
19/// Main request body structure for chat API calls.
20///
21/// This structure represents a complete chat request with all possible
22/// configuration options. It uses generic types to support different model
23/// names and message types while maintaining type safety through trait bounds.
24///
25/// # Type Parameters
26///
27/// * `N` - The model name type, must implement [`ModelName`]
28/// * `M` - The message type, must form a [`Bounded`] pair with `N`
29///
30/// # Examples
31///
32/// ```rust,ignore
33/// use crate::model::base::{ChatBody, TextMessage};
34///
35/// // Create a basic chat request
36/// let chat_body = ChatBody {
37///     model: "gpt-4".to_string(),
38///     messages: vec![
39///         TextMessage::user("Hello, how are you?"),
40///         TextMessage::assistant("I'm doing well, thank you!")
41///     ],
42///     temperature: Some(0.7),
43///     max_tokens: Some(1000),
44///     ..Default::default()
45/// };
46/// ```
47#[derive(Debug, Clone, Validate, Serialize)]
48pub struct ChatBody<N, M>
49where
50    N: ModelName,
51    (N, M): Bounded,
52{
53    /// The model to use for the chat completion.
54    pub model: N,
55
56    /// A list of messages comprising the conversation so far.
57    pub messages: Vec<M>,
58
59    /// A unique identifier for the request. Optional field that will be omitted
60    /// from serialization if not provided.
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub request_id: Option<String>,
63
64    /// Optional thinking prompt or reasoning text that can guide the model's
65    /// response. Only available for models that support thinking
66    /// capabilities.
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub thinking: Option<ThinkingType>,
69
70    /// Whether to use sampling during generation. When `true`, the model will
71    /// use probabilistic sampling; when `false`, it will use deterministic
72    /// generation.
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub do_sample: Option<bool>,
75
76    /// Whether to stream back partial message deltas as they are generated.
77    /// When `true`, responses will be sent as server-sent events.
78    #[serde(skip_serializing_if = "Option::is_none")]
79    pub stream: Option<bool>,
80
81    /// Whether to enable streaming of tool calls (streaming function call
82    /// parameters). Supported by GLM-5.1, GLM-5, GLM-5-Turbo, GLM-4.7, and
83    /// GLM-4.6 models. Defaults to false when omitted.
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub tool_stream: Option<bool>,
86
87    /// Controls randomness in the output. Higher values (closer to 1.0) make
88    /// the output more random, while lower values (closer to 0.0) make it
89    /// more deterministic. Must be between 0.0 and 1.0.
90    #[serde(skip_serializing_if = "Option::is_none")]
91    #[validate(range(min = 0.0, max = 1.0))]
92    pub temperature: Option<f32>,
93
94    /// Controls diversity via nucleus sampling. Only tokens with cumulative
95    /// probability up to `top_p` are considered. Must be between 0.0 and
96    /// 1.0.
97    #[serde(skip_serializing_if = "Option::is_none")]
98    #[validate(range(min = 0.0, max = 1.0))]
99    pub top_p: Option<f32>,
100
101    /// The maximum number of tokens to generate in the completion.
102    /// Must be between 1 and 98304.
103    #[serde(skip_serializing_if = "Option::is_none")]
104    #[validate(range(min = 1, max = 98304))]
105    pub max_tokens: Option<u32>,
106
107    /// A list of tools the model may call. Currently supports function calling,
108    /// web search, and retrieval tools.
109    /// Note: server expects an array; we model this as a vector of tool items.
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub tools: Option<Vec<Tools>>,
112
113    // tool_choice: enum<string>, but we don't need it for now
114    /// A unique identifier representing your end-user, which can help monitor
115    /// and detect abuse. Must be between 6 and 128 characters long.
116    #[serde(skip_serializing_if = "Option::is_none")]
117    #[validate(length(min = 6, max = 128))]
118    pub user_id: Option<String>,
119
120    /// Up to 1 sequence where the API will stop generating further tokens.
121    #[serde(skip_serializing_if = "Option::is_none")]
122    #[validate(length(min = 1, max = 1))]
123    pub stop: Option<Vec<String>>,
124
125    /// An object specifying the format that the model must output.
126    /// Can be either text or JSON object format.
127    #[serde(skip_serializing_if = "Option::is_none")]
128    pub response_format: Option<ResponseFormat>,
129}
130
131impl<N, M> ChatBody<N, M>
132where
133    N: ModelName,
134    (N, M): Bounded,
135{
136    pub fn new(model: N, messages: M) -> Self {
137        Self {
138            model,
139            messages: vec![messages],
140            request_id: None,
141            thinking: None,
142            do_sample: None,
143            stream: None,
144            tool_stream: None,
145            temperature: None,
146            top_p: None,
147            max_tokens: None,
148            tools: None,
149            user_id: None,
150            stop: None,
151            response_format: None,
152        }
153    }
154
155    pub fn add_messages(mut self, messages: M) -> Self {
156        self.messages.push(messages);
157        self
158    }
159    /// Add a single message to the conversation (preferred over add_messages
160    /// for clarity when adding one message).
161    pub fn add_message(mut self, message: M) -> Self {
162        self.messages.push(message);
163        self
164    }
165    /// Add multiple messages to the conversation at once.
166    pub fn extend_messages(mut self, messages: impl IntoIterator<Item = M>) -> Self {
167        self.messages.extend(messages);
168        self
169    }
170    pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
171        self.request_id = Some(request_id.into());
172        self
173    }
174    pub fn with_do_sample(mut self, do_sample: bool) -> Self {
175        self.do_sample = Some(do_sample);
176        self
177    }
178    pub fn with_stream(mut self, stream: bool) -> Self {
179        self.stream = Some(stream);
180        self
181    }
182    pub fn with_temperature(mut self, temperature: f32) -> Self {
183        self.temperature = Some(temperature);
184        self
185    }
186    pub fn with_top_p(mut self, top_p: f32) -> Self {
187        self.top_p = Some(top_p);
188        self
189    }
190    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
191        self.max_tokens = Some(max_tokens);
192        self
193    }
194    /// Deprecated: use `add_tools` (single) or `extend_tools` (Vec) on
195    /// ChatBody, or prefer ChatCompletion::add_tool / add_tools at the
196    /// client layer.
197    #[deprecated(note = "with_tools is deprecated; use add_tool/add_tools instead")]
198    pub fn with_tools(mut self, tools: impl Into<Vec<Tools>>) -> Self {
199        self.tools = Some(tools.into());
200        self
201    }
202    pub fn add_tools(mut self, tools: Tools) -> Self {
203        self.tools.get_or_insert(Vec::new()).push(tools);
204        self
205    }
206    pub fn extend_tools(mut self, tools: Vec<Tools>) -> Self {
207        self.tools.get_or_insert(Vec::new()).extend(tools);
208        self
209    }
210    pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
211        self.user_id = Some(user_id.into());
212        self
213    }
214    pub fn with_stop(mut self, stop: String) -> Self {
215        self.stop.get_or_insert_with(Vec::new).push(stop);
216        self
217    }
218}
219
220impl<N, M> ChatBody<N, M>
221where
222    N: ModelName + ThinkEnable,
223    (N, M): Bounded,
224{
225    /// Adds thinking text to the chat body for models that support thinking
226    /// capabilities.
227    ///
228    /// This method is only available for models that implement the
229    /// [`ThinkEnable`] trait, ensuring type safety for thinking-enabled
230    /// models.
231    ///
232    /// # Arguments
233    ///
234    /// * `thinking` - The thinking prompt or reasoning text to add
235    ///
236    /// # Returns
237    ///
238    /// Returns `self` with the thinking field set, allowing for method
239    /// chaining.
240    ///
241    /// # Examples
242    ///
243    /// ```rust,ignore
244    /// let chat_body = ChatBody::new(model, messages)
245    ///     .with_thinking("Let me think step by step about this problem...");
246    /// ```
247    pub fn with_thinking(mut self, thinking: ThinkingType) -> Self {
248        self.thinking = Some(thinking);
249        self
250    }
251}
252
253// Only available when the model supports streaming tool calls (GLM-4.6)
254impl<N, M> ChatBody<N, M>
255where
256    N: ModelName + ToolStreamEnable,
257    (N, M): Bounded,
258{
259    /// Enables streaming tool calls. Supported by GLM-5.1, GLM-5, GLM-5-Turbo,
260    /// GLM-4.7, and GLM-4.6 models. Default is false when omitted.
261    pub fn with_tool_stream(mut self, tool_stream: bool) -> Self {
262        self.tool_stream = Some(tool_stream);
263        if tool_stream {
264            // Enabling tool_stream implies stream=true
265            self.stream = Some(true);
266        }
267        self
268    }
269}
270
271// 为方便使用,实现从单个Tools到Vec<Tools>的转换
272impl From<Tools> for Vec<Tools> {
273    fn from(tool: Tools) -> Self {
274        vec![tool]
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use crate::model::{chat_message_types::TextMessage, chat_models::GLM4_6};
282
283    #[test]
284    fn test_with_tool_stream_sets_both_fields() {
285        let body: ChatBody<GLM4_6, TextMessage> =
286            ChatBody::new(GLM4_6 {}, TextMessage::user("test"));
287        let body = body.with_tool_stream(true);
288        assert_eq!(body.tool_stream, Some(true));
289        assert_eq!(body.stream, Some(true));
290    }
291
292    #[test]
293    fn test_with_tool_stream_false_does_not_force_stream() {
294        let body: ChatBody<GLM4_6, TextMessage> =
295            ChatBody::new(GLM4_6 {}, TextMessage::user("test"));
296        let body = body.with_tool_stream(false);
297        assert_eq!(body.tool_stream, Some(false));
298        // stream should not be forced to true when tool_stream is false
299        assert_ne!(body.stream, Some(true));
300    }
301
302    #[test]
303    fn test_add_tools_accumulates() {
304        let body: ChatBody<GLM4_6, TextMessage> =
305            ChatBody::new(GLM4_6 {}, TextMessage::user("test"));
306        let tool = crate::model::tools::Function::new(
307            "test_fn",
308            "A test function",
309            serde_json::json!({"type": "object"}),
310        );
311        let body = body.add_tools(crate::model::tools::Tools::Function { function: tool });
312        assert!(body.tools.is_some());
313        assert_eq!(body.tools.unwrap().len(), 1);
314    }
315
316    #[test]
317    fn test_extend_messages() {
318        let body: ChatBody<GLM4_6, TextMessage> =
319            ChatBody::new(GLM4_6 {}, TextMessage::user("first"));
320        let body = body.extend_messages(vec![
321            TextMessage::assistant("second"),
322            TextMessage::user("third"),
323        ]);
324        assert_eq!(body.messages.len(), 3);
325        match &body.messages[0] {
326            TextMessage::User { content } => assert_eq!(content, "first"),
327            _ => panic!("Expected User message"),
328        }
329    }
330
331    #[test]
332    fn test_add_message() {
333        let body: ChatBody<GLM4_6, TextMessage> =
334            ChatBody::new(GLM4_6 {}, TextMessage::user("first"));
335        let body = body.add_message(TextMessage::assistant("second"));
336        assert_eq!(body.messages.len(), 2);
337    }
338}