Skip to main content

zai_rs/model/chat/
data.rs

1//! # Chat Completion Data Models
2//!
3//! This module defines the core data structures for chat completion requests,
4//! implementing type-safe chat interactions with the Zhipu AI API.
5//!
6//! ## Type-State Pattern
7//!
8//! The implementation uses Rust's type system to enforce compile-time
9//! guarantees about streaming capabilities through phantom types
10//! (`StreamOn`/`StreamOff`).
11//!
12//! ## Features
13//!
14//! - **Type-safe model binding** - Compile-time verification of model-message
15//!   compatibility
16//! - **Builder pattern** - Fluent API for request construction
17//! - **Streaming support** - Type-state based streaming capability enforcement
18//! - **Tool integration** - Support for function calling and tool usage
19//! - **Parameter control** - Temperature, top-p, max tokens, and other
20//!   generation parameters
21
22use std::marker::PhantomData;
23
24use serde::Serialize;
25use validator::Validate;
26
27use super::super::{chat_base_request::*, tools::*, traits::*};
28use crate::client::http::HttpClient;
29
30// Type-state is defined in model::traits::{StreamState, StreamOn, StreamOff}
31
32/// Type-safe chat completion request structure.
33///
34/// This struct represents a chat completion request with compile-time
35/// guarantees for model compatibility and streaming capabilities.
36///
37/// ## Type Parameters
38///
39/// - `N` - The AI model type (must implement `ModelName + Chat`)
40/// - `M` - The message type (must form a valid bound with the model)
41/// - `S` - Stream state (`StreamOn` or `StreamOff`, defaults to `StreamOff`)
42///
43/// ## Examples
44///
45/// ```rust,ignore
46/// let model = GLM4_5_flash {};
47/// let messages = TextMessage::user("Hello, how are you?");
48/// let request = ChatCompletion::new(model, messages, api_key);
49/// ```
50pub struct ChatCompletion<N, M, S = StreamOff>
51where
52    N: ModelName + Chat,
53    (N, M): Bounded,
54    ChatBody<N, M>: Serialize,
55    S: StreamState,
56{
57    /// API key for authentication with the Zhipu AI service.
58    pub key: String,
59
60    /// API endpoint URL for chat completions.
61    /// Defaults to "https://open.bigmodel.cn/api/paas/v4/chat/completions"
62    /// but can be customized using the `with_url()` method.
63    pub url: String,
64
65    /// The request body containing model, messages, and parameters.
66    body: ChatBody<N, M>,
67
68    /// Phantom data to track streaming capability at compile time.
69    _stream: PhantomData<S>,
70}
71
72impl<N, M> ChatCompletion<N, M, StreamOff>
73where
74    N: ModelName + Chat,
75    (N, M): Bounded,
76    ChatBody<N, M>: Serialize,
77{
78    /// Creates a new non-streaming chat completion request.
79    ///
80    /// ## Arguments
81    ///
82    /// * `model` - The AI model to use for completion
83    /// * `messages` - The conversation messages
84    /// * `key` - API key for authentication
85    ///
86    /// ## Returns
87    ///
88    /// A new `ChatCompletion` instance configured for non-streaming requests.
89    pub fn new(model: N, messages: M, key: String) -> ChatCompletion<N, M, StreamOff> {
90        let body = ChatBody::new(model, messages);
91        ChatCompletion {
92            body,
93            key,
94            url: "https://open.bigmodel.cn/api/paas/v4/chat/completions".to_string(),
95            _stream: PhantomData,
96        }
97    }
98
99    /// Gets mutable access to the request body for further customization.
100    ///
101    /// This method allows modification of request parameters after initial
102    /// creation.
103    pub fn body_mut(&mut self) -> &mut ChatBody<N, M> {
104        &mut self.body
105    }
106
107    /// Adds additional messages to the conversation.
108    ///
109    /// This method provides a fluent interface for building conversation
110    /// context.
111    ///
112    /// ## Arguments
113    ///
114    /// * `messages` - Additional messages to append to the conversation
115    ///
116    /// ## Returns
117    ///
118    /// Self with the updated message collection, enabling method chaining.
119    pub fn add_messages(mut self, messages: M) -> Self {
120        self.body = self.body.add_messages(messages);
121        self
122    }
123    pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
124        self.body = self.body.with_request_id(request_id);
125        self
126    }
127    pub fn with_do_sample(mut self, do_sample: bool) -> Self {
128        self.body = self.body.with_do_sample(do_sample);
129        self
130    }
131
132    pub fn with_temperature(mut self, temperature: f32) -> Self {
133        self.body = self.body.with_temperature(temperature);
134        self
135    }
136    pub fn with_top_p(mut self, top_p: f32) -> Self {
137        self.body = self.body.with_top_p(top_p);
138        self
139    }
140    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
141        self.body = self.body.with_max_tokens(max_tokens);
142        self
143    }
144    pub fn add_tool(mut self, tool: Tools) -> Self {
145        self.body = self.body.add_tools(tool);
146        self
147    }
148    pub fn add_tools(mut self, tools: Vec<Tools>) -> Self {
149        self.body = self.body.extend_tools(tools);
150        self
151    }
152    pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
153        self.body = self.body.with_user_id(user_id);
154        self
155    }
156    pub fn with_stop(mut self, stop: String) -> Self {
157        self.body = self.body.with_stop(stop);
158        self
159    }
160
161    /// Sets a custom API endpoint URL for this chat completion request.
162    ///
163    /// This method allows overriding the default API endpoint with a custom
164    /// URL, enabling support for different deployment environments or proxy
165    /// configurations.
166    ///
167    /// ## Arguments
168    ///
169    /// * `url` - The custom API endpoint URL
170    ///
171    /// ## Returns
172    ///
173    /// Self with the updated URL, enabling method chaining.
174    ///
175    /// ## Examples
176    ///
177    /// ```rust,ignore
178    /// let request = ChatCompletion::new(model, messages, api_key)
179    ///     .with_url("https://custom-api.example.com/v1/chat/completions");
180    /// ```
181    pub fn with_url(mut self, url: impl Into<String>) -> Self {
182        self.url = url.into();
183        self
184    }
185
186    /// Sets the URL to the coding plan endpoint.
187    ///
188    /// This method configures the chat completion request to use the
189    /// coding-specific API endpoint "https://open.bigmodel.cn/api/coding/paas/v4/chat/completions".
190    ///
191    /// ## Returns
192    ///
193    /// Self with the coding plan URL, enabling method chaining.
194    ///
195    /// ## Examples
196    ///
197    /// ```rust,ignore
198    /// let request = ChatCompletion::new(model, messages, api_key)
199    ///     .with_coding_plan();
200    /// ```
201    pub fn with_coding_plan(mut self) -> Self {
202        self.url = "https://open.bigmodel.cn/api/coding/paas/v4/chat/completions".to_string();
203        self
204    }
205
206    // Optional: only available when model supports thinking
207    pub fn with_thinking(mut self, thinking: ThinkingType) -> Self
208    where
209        N: ThinkEnable,
210    {
211        self.body = self.body.with_thinking(thinking);
212        self
213    }
214
215    /// Enables streaming for this chat completion request.
216    ///
217    /// This method transitions the request to streaming mode, allowing
218    /// real-time response processing through Server-Sent Events (SSE).
219    ///
220    /// ## Returns
221    ///
222    /// A new `ChatCompletion` instance with streaming enabled (`StreamOn`).
223    pub fn enable_stream(mut self) -> ChatCompletion<N, M, StreamOn> {
224        self.body.stream = Some(true);
225        ChatCompletion {
226            key: self.key,
227            url: self.url,
228            body: self.body,
229            _stream: PhantomData,
230        }
231    }
232
233    /// Validate request parameters for non-stream chat (StreamOff)
234    pub fn validate(&self) -> crate::ZaiResult<()> {
235        // Field-level validation from ChatBody
236        // (temperature/top_p/max_tokens/user_id/stop...)
237
238        self.body
239            .validate()
240            .map_err(crate::client::error::ZaiError::from)?;
241        // Ensure not accidentally enabling stream in StreamOff state
242
243        if matches!(self.body.stream, Some(true)) {
244            return Err(crate::client::error::ZaiError::ApiError {
245                code: 1200,
246                message: "stream=true detected; use enable_stream() and streaming APIs instead"
247                    .to_string(),
248            });
249        }
250
251        Ok(())
252    }
253
254    pub async fn send(
255        &self,
256    ) -> crate::ZaiResult<crate::model::chat_base_response::ChatCompletionResponse>
257    where
258        N: serde::Serialize,
259        M: serde::Serialize,
260    {
261        self.validate()?;
262
263        let resp: reqwest::Response = self.post().await?;
264
265        let parsed = resp
266            .json::<crate::model::chat_base_response::ChatCompletionResponse>()
267            .await?;
268
269        Ok(parsed)
270    }
271}
272
273impl<N, M> ChatCompletion<N, M, StreamOn>
274where
275    N: ModelName + Chat,
276    (N, M): Bounded,
277    ChatBody<N, M>: Serialize,
278{
279    pub fn with_tool_stream(mut self, tool_stream: bool) -> Self
280    where
281        N: ToolStreamEnable,
282    {
283        self.body = self.body.with_tool_stream(tool_stream);
284        self
285    }
286
287    /// Disables streaming for this chat completion request.
288    ///
289    /// This method ensures the request will receive a complete response
290    /// rather than streaming chunks.
291    ///
292    /// ## Returns
293    ///
294    /// A new `ChatCompletion` instance with streaming disabled (`StreamOff`).
295    pub fn disable_stream(mut self) -> ChatCompletion<N, M, StreamOff> {
296        self.body.stream = Some(false);
297        // Reset tool_stream when disabling streaming since tool_stream depends on
298        // stream
299        self.body.tool_stream = None;
300        ChatCompletion {
301            key: self.key,
302            url: self.url,
303            body: self.body,
304            _stream: PhantomData,
305        }
306    }
307}
308
309impl<N, M, S> HttpClient for ChatCompletion<N, M, S>
310where
311    N: ModelName + Serialize + Chat,
312    M: Serialize,
313    (N, M): Bounded,
314    S: StreamState,
315{
316    type Body = ChatBody<N, M>;
317    type ApiUrl = String;
318    type ApiKey = String;
319
320    /// Returns the API endpoint URL for chat completions.
321    fn api_url(&self) -> &Self::ApiUrl {
322        &self.url
323    }
324    fn api_key(&self) -> &Self::ApiKey {
325        &self.key
326    }
327    fn body(&self) -> &Self::Body {
328        &self.body
329    }
330}
331
332/// Enables Server-Sent Events (SSE) streaming for streaming-enabled chat
333/// completions.
334///
335/// This implementation allows streaming chat completions to be processed
336/// incrementally as responses arrive from the API.
337impl<N, M> crate::model::traits::SseStreamable for ChatCompletion<N, M, StreamOn>
338where
339    N: ModelName + Serialize + Chat,
340    M: Serialize,
341    (N, M): Bounded,
342{
343}
344
345/// Provides streaming extension methods for streaming-enabled chat completions.
346///
347/// This implementation enables the use of streaming-specific methods
348/// for processing chat responses in real-time.
349impl<N, M> crate::model::stream_ext::StreamChatLikeExt for ChatCompletion<N, M, StreamOn>
350where
351    N: ModelName + Serialize + Chat,
352    M: Serialize,
353    (N, M): Bounded,
354{
355}