Skip to main content

xai_rust/models/
batch.rs

1//! Batch API types.
2
3use serde::{Deserialize, Serialize};
4
5use super::response::Response;
6
7/// A batch for processing multiple requests.
8#[derive(Debug, Clone, Deserialize)]
9pub struct Batch {
10    /// Unique identifier for the batch.
11    pub id: String,
12    /// Name of the batch.
13    pub name: String,
14    /// Current status of the batch.
15    pub status: BatchStatus,
16    /// Number of requests in the batch.
17    #[serde(default)]
18    pub request_count: u32,
19    /// Number of completed requests.
20    #[serde(default)]
21    pub completed_count: u32,
22    /// Number of failed requests.
23    #[serde(default)]
24    pub failed_count: u32,
25    /// Creation timestamp.
26    #[serde(default)]
27    pub created_at: Option<i64>,
28    /// Completion timestamp.
29    #[serde(default)]
30    pub completed_at: Option<i64>,
31}
32
33/// Status of a batch.
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
35#[serde(rename_all = "snake_case")]
36pub enum BatchStatus {
37    /// Batch is being created.
38    Creating,
39    /// Batch is queued for processing.
40    Queued,
41    /// Batch is currently processing.
42    Processing,
43    /// Batch completed successfully.
44    Completed,
45    /// Batch was cancelled.
46    Cancelled,
47    /// Batch failed.
48    Failed,
49}
50
51/// Response from listing batches.
52#[derive(Debug, Clone, Deserialize)]
53pub struct BatchListResponse {
54    /// List of batches.
55    pub data: Vec<Batch>,
56    /// Pagination token for next page.
57    #[serde(default)]
58    pub next_token: Option<String>,
59}
60
61/// A request within a batch.
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct BatchRequest {
64    /// Custom ID for tracking this request.
65    pub custom_id: String,
66    /// The model to use.
67    pub model: String,
68    /// Input messages.
69    pub input: Vec<crate::models::message::Message>,
70    /// Optional temperature parameter.
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub temperature: Option<f32>,
73    /// Optional maximum tokens to generate.
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub max_tokens: Option<u32>,
76}
77
78impl BatchRequest {
79    /// Create a new batch request.
80    pub fn new(custom_id: impl Into<String>, model: impl Into<String>) -> Self {
81        Self {
82            custom_id: custom_id.into(),
83            model: model.into(),
84            input: Vec::new(),
85            temperature: None,
86            max_tokens: None,
87        }
88    }
89
90    /// Add a message.
91    pub fn message(mut self, message: crate::models::message::Message) -> Self {
92        self.input.push(message);
93        self
94    }
95
96    /// Add multiple messages.
97    pub fn messages(mut self, messages: Vec<crate::models::message::Message>) -> Self {
98        self.input.extend(messages);
99        self
100    }
101
102    /// Set temperature.
103    pub fn temperature(mut self, temp: f32) -> Self {
104        self.temperature = Some(temp);
105        self
106    }
107
108    /// Set max tokens.
109    pub fn max_tokens(mut self, max: u32) -> Self {
110        self.max_tokens = Some(max);
111        self
112    }
113}
114
115/// Response from listing batch requests.
116#[derive(Debug, Clone, Deserialize)]
117pub struct BatchRequestListResponse {
118    /// List of batch requests.
119    pub data: Vec<BatchRequestInfo>,
120    /// Pagination token for next page.
121    #[serde(default)]
122    pub next_token: Option<String>,
123}
124
125/// Information about a batch request.
126#[derive(Debug, Clone, Deserialize)]
127pub struct BatchRequestInfo {
128    /// The batch request ID.
129    pub id: String,
130    /// Custom ID provided when creating the request.
131    pub custom_id: String,
132    /// Status of this request.
133    pub status: BatchRequestStatus,
134}
135
136/// Status of a batch request.
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
138#[serde(rename_all = "snake_case")]
139pub enum BatchRequestStatus {
140    /// Request is pending.
141    Pending,
142    /// Request is processing.
143    Processing,
144    /// Request completed.
145    Completed,
146    /// Request failed.
147    Failed,
148}
149
150/// Response from listing batch results.
151#[derive(Debug, Clone, Deserialize)]
152pub struct BatchResultListResponse {
153    /// List of results.
154    pub data: Vec<BatchResult>,
155    /// Pagination token for next page.
156    #[serde(default)]
157    pub next_token: Option<String>,
158}
159
160/// Result of a batch request.
161#[derive(Debug, Clone, Deserialize)]
162pub struct BatchResult {
163    /// The batch request ID.
164    pub batch_request_id: String,
165    /// Custom ID provided when creating the request.
166    pub custom_id: String,
167    /// Error code (0 = success).
168    #[serde(default)]
169    pub error_code: i32,
170    /// Error message if failed.
171    #[serde(default)]
172    pub error_message: Option<String>,
173    /// The response if successful.
174    #[serde(default)]
175    pub response: Option<Response>,
176}
177
178impl BatchResult {
179    /// Check if this result is successful.
180    pub fn is_success(&self) -> bool {
181        self.error_code == 0
182    }
183
184    /// Check if this result has an error.
185    pub fn has_error(&self) -> bool {
186        self.error_code != 0
187    }
188
189    /// Get the response text if successful.
190    pub fn text(&self) -> Option<String> {
191        self.response.as_ref().and_then(|r| r.output_text())
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use crate::chat::{assistant, user};
199    use crate::models::response::Response;
200    use serde_json::json;
201
202    #[test]
203    fn batch_request_builder_collects_messages_and_options() {
204        let request = BatchRequest::new("req-1", "grok-4")
205            .message(user("hello"))
206            .messages(vec![assistant("ready"), user("status")])
207            .temperature(0.3)
208            .max_tokens(42);
209
210        assert_eq!(request.custom_id, "req-1");
211        assert_eq!(request.model, "grok-4");
212        assert_eq!(request.input.len(), 3);
213        assert_eq!(request.temperature.unwrap(), 0.3);
214        assert_eq!(request.max_tokens.unwrap(), 42);
215    }
216
217    #[test]
218    fn batch_result_helpers_reflect_error_state_and_text() {
219        let pending = BatchResult {
220            batch_request_id: "br-1".to_string(),
221            custom_id: "custom-1".to_string(),
222            error_code: 0,
223            error_message: None,
224            response: None,
225        };
226        assert!(pending.is_success());
227        assert!(!pending.has_error());
228        assert!(pending.text().is_none());
229
230        let failure = BatchResult {
231            batch_request_id: "br-2".to_string(),
232            custom_id: "custom-2".to_string(),
233            error_code: 7,
234            error_message: Some("throttled".to_string()),
235            response: None,
236        };
237        assert!(!failure.is_success());
238        assert!(failure.has_error());
239        assert_eq!(failure.error_message.as_deref(), Some("throttled"));
240
241        let response: Response = serde_json::from_value(json!({
242            "id": "resp",
243            "model": "grok-4",
244            "output": [{
245                "type": "message",
246                "role": "assistant",
247                "content": [{"type": "text", "text": "ok"}]
248            }]
249        }))
250        .unwrap();
251
252        let success = BatchResult {
253            batch_request_id: "br-3".to_string(),
254            custom_id: "custom-3".to_string(),
255            error_code: 0,
256            error_message: None,
257            response: Some(response),
258        };
259        assert_eq!(success.text().as_deref(), Some("ok"));
260    }
261}