1use serde::{Deserialize, Serialize};
4
5use super::response::Response;
6
7#[derive(Debug, Clone, Deserialize)]
9pub struct Batch {
10 pub id: String,
12 pub name: String,
14 pub status: BatchStatus,
16 #[serde(default)]
18 pub request_count: u32,
19 #[serde(default)]
21 pub completed_count: u32,
22 #[serde(default)]
24 pub failed_count: u32,
25 #[serde(default)]
27 pub created_at: Option<i64>,
28 #[serde(default)]
30 pub completed_at: Option<i64>,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
35#[serde(rename_all = "snake_case")]
36pub enum BatchStatus {
37 Creating,
39 Queued,
41 Processing,
43 Completed,
45 Cancelled,
47 Failed,
49}
50
51#[derive(Debug, Clone, Deserialize)]
53pub struct BatchListResponse {
54 pub data: Vec<Batch>,
56 #[serde(default)]
58 pub next_token: Option<String>,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct BatchRequest {
64 pub custom_id: String,
66 pub model: String,
68 pub input: Vec<crate::models::message::Message>,
70 #[serde(skip_serializing_if = "Option::is_none")]
72 pub temperature: Option<f32>,
73 #[serde(skip_serializing_if = "Option::is_none")]
75 pub max_tokens: Option<u32>,
76}
77
78impl BatchRequest {
79 pub fn new(custom_id: impl Into<String>, model: impl Into<String>) -> Self {
81 Self {
82 custom_id: custom_id.into(),
83 model: model.into(),
84 input: Vec::new(),
85 temperature: None,
86 max_tokens: None,
87 }
88 }
89
90 pub fn message(mut self, message: crate::models::message::Message) -> Self {
92 self.input.push(message);
93 self
94 }
95
96 pub fn messages(mut self, messages: Vec<crate::models::message::Message>) -> Self {
98 self.input.extend(messages);
99 self
100 }
101
102 pub fn temperature(mut self, temp: f32) -> Self {
104 self.temperature = Some(temp);
105 self
106 }
107
108 pub fn max_tokens(mut self, max: u32) -> Self {
110 self.max_tokens = Some(max);
111 self
112 }
113}
114
115#[derive(Debug, Clone, Deserialize)]
117pub struct BatchRequestListResponse {
118 pub data: Vec<BatchRequestInfo>,
120 #[serde(default)]
122 pub next_token: Option<String>,
123}
124
125#[derive(Debug, Clone, Deserialize)]
127pub struct BatchRequestInfo {
128 pub id: String,
130 pub custom_id: String,
132 pub status: BatchRequestStatus,
134}
135
136#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
138#[serde(rename_all = "snake_case")]
139pub enum BatchRequestStatus {
140 Pending,
142 Processing,
144 Completed,
146 Failed,
148}
149
150#[derive(Debug, Clone, Deserialize)]
152pub struct BatchResultListResponse {
153 pub data: Vec<BatchResult>,
155 #[serde(default)]
157 pub next_token: Option<String>,
158}
159
160#[derive(Debug, Clone, Deserialize)]
162pub struct BatchResult {
163 pub batch_request_id: String,
165 pub custom_id: String,
167 #[serde(default)]
169 pub error_code: i32,
170 #[serde(default)]
172 pub error_message: Option<String>,
173 #[serde(default)]
175 pub response: Option<Response>,
176}
177
178impl BatchResult {
179 pub fn is_success(&self) -> bool {
181 self.error_code == 0
182 }
183
184 pub fn has_error(&self) -> bool {
186 self.error_code != 0
187 }
188
189 pub fn text(&self) -> Option<String> {
191 self.response.as_ref().and_then(|r| r.output_text())
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use crate::chat::{assistant, user};
199 use crate::models::response::Response;
200 use serde_json::json;
201
202 #[test]
203 fn batch_request_builder_collects_messages_and_options() {
204 let request = BatchRequest::new("req-1", "grok-4")
205 .message(user("hello"))
206 .messages(vec![assistant("ready"), user("status")])
207 .temperature(0.3)
208 .max_tokens(42);
209
210 assert_eq!(request.custom_id, "req-1");
211 assert_eq!(request.model, "grok-4");
212 assert_eq!(request.input.len(), 3);
213 assert_eq!(request.temperature.unwrap(), 0.3);
214 assert_eq!(request.max_tokens.unwrap(), 42);
215 }
216
217 #[test]
218 fn batch_result_helpers_reflect_error_state_and_text() {
219 let pending = BatchResult {
220 batch_request_id: "br-1".to_string(),
221 custom_id: "custom-1".to_string(),
222 error_code: 0,
223 error_message: None,
224 response: None,
225 };
226 assert!(pending.is_success());
227 assert!(!pending.has_error());
228 assert!(pending.text().is_none());
229
230 let failure = BatchResult {
231 batch_request_id: "br-2".to_string(),
232 custom_id: "custom-2".to_string(),
233 error_code: 7,
234 error_message: Some("throttled".to_string()),
235 response: None,
236 };
237 assert!(!failure.is_success());
238 assert!(failure.has_error());
239 assert_eq!(failure.error_message.as_deref(), Some("throttled"));
240
241 let response: Response = serde_json::from_value(json!({
242 "id": "resp",
243 "model": "grok-4",
244 "output": [{
245 "type": "message",
246 "role": "assistant",
247 "content": [{"type": "text", "text": "ok"}]
248 }]
249 }))
250 .unwrap();
251
252 let success = BatchResult {
253 batch_request_id: "br-3".to_string(),
254 custom_id: "custom-3".to_string(),
255 error_code: 0,
256 error_message: None,
257 response: Some(response),
258 };
259 assert_eq!(success.text().as_deref(), Some("ok"));
260 }
261}