1use serde::{Deserialize, Serialize};
4
5use crate::message::{ContentBlock, Message, ToolDefinition};
6
7#[derive(Clone, Debug, Serialize, Deserialize)]
9pub struct ThinkingConfig {
10 pub enabled: bool,
12 #[serde(skip_serializing_if = "Option::is_none")]
14 pub budget_tokens: Option<u32>,
15}
16
17#[derive(Clone, Debug, Serialize)]
19pub struct CompletionRequest {
20 pub model: String,
22 pub messages: Vec<Message>,
24 pub max_tokens: u32,
26 #[serde(skip_serializing_if = "Option::is_none")]
28 pub system: Option<String>,
29 #[serde(skip_serializing_if = "Vec::is_empty")]
31 pub tools: Vec<ToolDefinition>,
32 #[serde(skip_serializing_if = "std::ops::Not::not")]
34 pub stream: bool,
35 #[serde(skip_serializing_if = "Option::is_none")]
37 pub temperature: Option<f32>,
38 #[serde(skip_serializing_if = "Vec::is_empty")]
40 pub stop_sequences: Vec<String>,
41 #[serde(skip_serializing_if = "Option::is_none")]
43 pub thinking: Option<ThinkingConfig>,
44}
45
46impl CompletionRequest {
47 pub fn new(model: impl Into<String>, messages: Vec<Message>, max_tokens: u32) -> Self {
49 Self {
50 model: model.into(),
51 messages,
52 max_tokens,
53 system: None,
54 tools: Vec::new(),
55 stream: false,
56 temperature: None,
57 stop_sequences: Vec::new(),
58 thinking: None,
59 }
60 }
61
62 #[must_use]
64 pub fn system(mut self, system: impl Into<String>) -> Self {
65 self.system = Some(system.into());
66 self
67 }
68
69 #[must_use]
71 pub fn stream(mut self, stream: bool) -> Self {
72 self.stream = stream;
73 self
74 }
75
76 #[must_use]
78 pub fn temperature(mut self, temp: f32) -> Self {
79 self.temperature = Some(temp);
80 self
81 }
82
83 #[must_use]
85 pub fn tools(mut self, tools: Vec<ToolDefinition>) -> Self {
86 self.tools = tools;
87 self
88 }
89
90 #[must_use]
92 pub fn thinking(mut self, config: ThinkingConfig) -> Self {
93 self.thinking = Some(config);
94 self
95 }
96}
97
98#[derive(Clone, Debug, Deserialize)]
100pub struct CompletionResponse {
101 pub id: String,
103 pub content: Vec<ContentBlock>,
105 pub model: String,
107 pub stop_reason: Option<StopReason>,
109 pub usage: Usage,
111}
112
113#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
115#[serde(rename_all = "snake_case")]
116pub enum StopReason {
117 EndTurn,
119 MaxTokens,
121 StopSequence,
123 ToolUse,
125}
126
127#[derive(Clone, Debug, Default, Serialize, Deserialize)]
129pub struct Usage {
130 #[serde(default)]
132 pub input_tokens: u32,
133 #[serde(default)]
135 pub output_tokens: u32,
136 #[serde(default)]
138 pub cache_read_tokens: u32,
139 #[serde(default)]
141 pub cache_write_tokens: u32,
142}
143
144impl Usage {
145 pub fn total(&self) -> u32 {
147 self.input_tokens + self.output_tokens + self.cache_read_tokens + self.cache_write_tokens
148 }
149}
150
151#[derive(Clone, Debug)]
153pub enum StreamEvent {
154 MessageStart {
156 id: String,
158 model: String,
160 usage: Usage,
162 },
163 ContentBlockStart {
165 index: u32,
167 content_block: ContentBlock,
169 },
170 ContentBlockDelta {
172 index: u32,
174 delta: ContentDelta,
176 },
177 ContentBlockStop {
179 index: u32,
181 },
182 MessageDelta {
184 stop_reason: Option<StopReason>,
186 usage: Usage,
188 },
189 MessageStop,
191 Ping,
193 Error {
195 message: String,
197 },
198}
199
200#[derive(Clone, Debug, Serialize, Deserialize)]
202#[serde(tag = "type", rename_all = "snake_case")]
203pub enum ContentDelta {
204 TextDelta {
206 text: String,
208 },
209 InputJsonDelta {
211 partial_json: String,
213 },
214 ThinkingDelta {
216 text: String,
218 },
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use crate::message::Message;
225
226 #[test]
227 fn request_builder() {
228 let req = CompletionRequest::new(
229 "claude-sonnet-4-5-20250929",
230 vec![Message::user("hi")],
231 1024,
232 )
233 .system("You are helpful")
234 .temperature(0.7)
235 .stream(true);
236 assert_eq!(req.model, "claude-sonnet-4-5-20250929");
237 assert_eq!(req.max_tokens, 1024);
238 assert!(req.stream);
239 assert_eq!(req.temperature, Some(0.7));
240 assert_eq!(req.system, Some("You are helpful".into()));
241 }
242
243 #[test]
244 fn request_serialization() {
245 let req = CompletionRequest::new(
246 "claude-sonnet-4-5-20250929",
247 vec![Message::user("hi")],
248 1024,
249 );
250 let json = serde_json::to_string(&req);
251 assert!(json.is_ok());
252 let json_str = json.as_deref().unwrap_or("");
253 assert!(json_str.contains("claude-sonnet-4-5-20250929"));
254 assert!(json_str.contains("1024"));
255 assert!(!json_str.contains("stream"));
257 }
258
259 #[test]
260 fn response_parsing() {
261 let json = r#"{
262 "id": "msg_123",
263 "content": [{"type": "text", "text": "Hello!"}],
264 "model": "claude-sonnet-4-5-20250929",
265 "stop_reason": "end_turn",
266 "usage": {"input_tokens": 10, "output_tokens": 5}
267 }"#;
268 let resp: std::result::Result<CompletionResponse, _> = serde_json::from_str(json);
269 assert!(resp.is_ok());
270 if let Ok(resp) = resp {
271 assert_eq!(resp.id, "msg_123");
272 assert_eq!(resp.usage.total(), 15);
274 assert_eq!(resp.usage.cache_read_tokens, 0);
275 assert_eq!(resp.usage.cache_write_tokens, 0);
276 }
277 }
278
279 #[test]
280 fn stop_reason_parsing() {
281 let json = r#""end_turn""#;
282 let reason: Result<StopReason, _> = serde_json::from_str(json);
283 assert_eq!(reason.ok(), Some(StopReason::EndTurn));
284
285 let json = r#""tool_use""#;
286 let reason: Result<StopReason, _> = serde_json::from_str(json);
287 assert_eq!(reason.ok(), Some(StopReason::ToolUse));
288 }
289
290 #[test]
291 fn usage_total() {
292 let u = Usage {
293 input_tokens: 100,
294 output_tokens: 50,
295 cache_read_tokens: 0,
296 cache_write_tokens: 0,
297 };
298 assert_eq!(u.total(), 150);
299 }
300
301 #[test]
302 fn usage_total_with_cache_tokens() {
303 let u = Usage {
304 input_tokens: 100,
305 output_tokens: 50,
306 cache_read_tokens: 20,
307 cache_write_tokens: 10,
308 };
309 assert_eq!(u.total(), 180);
310 }
311
312 #[test]
313 fn content_delta_serialization() {
314 let delta = ContentDelta::TextDelta {
315 text: "hello".into(),
316 };
317 let json = serde_json::to_string(&delta);
318 assert!(json.is_ok());
319 assert!(json.as_deref().unwrap_or("").contains("text_delta"));
320 }
321
322 #[test]
323 fn thinking_config_serialization() {
324 let config = ThinkingConfig {
325 enabled: true,
326 budget_tokens: Some(10_000),
327 };
328 let json = serde_json::to_string(&config);
329 assert!(json.is_ok());
330 let json_str = json.as_deref().unwrap_or("");
331 assert!(json_str.contains("true"));
332 assert!(json_str.contains("10000"));
333 }
334
335 #[test]
336 fn thinking_config_without_budget() {
337 let config = ThinkingConfig {
338 enabled: true,
339 budget_tokens: None,
340 };
341 let json = serde_json::to_string(&config);
342 assert!(json.is_ok());
343 let json_str = json.as_deref().unwrap_or("");
344 assert!(json_str.contains("true"));
345 assert!(!json_str.contains("budget_tokens"));
346 }
347
348 #[test]
349 fn thinking_config_roundtrip() {
350 let config = ThinkingConfig {
351 enabled: true,
352 budget_tokens: Some(5000),
353 };
354 let json = serde_json::to_string(&config).unwrap_or_default();
355 let parsed: std::result::Result<ThinkingConfig, _> = serde_json::from_str(&json);
356 assert!(parsed.is_ok());
357 if let Ok(c) = parsed {
358 assert!(c.enabled);
359 assert_eq!(c.budget_tokens, Some(5000));
360 }
361 }
362
363 #[test]
364 fn thinking_delta_variant() {
365 let delta = ContentDelta::ThinkingDelta {
366 text: "Let me think...".into(),
367 };
368 let json = serde_json::to_string(&delta);
369 assert!(json.is_ok());
370 assert!(json.as_deref().unwrap_or("").contains("thinking_delta"));
371 }
372
373 #[test]
374 fn usage_with_cache_tokens_deserialization() {
375 let json = r#"{"input_tokens": 100, "output_tokens": 50, "cache_read_tokens": 20, "cache_write_tokens": 10}"#;
376 let usage: std::result::Result<Usage, _> = serde_json::from_str(json);
377 assert!(usage.is_ok());
378 if let Ok(u) = usage {
379 assert_eq!(u.input_tokens, 100);
380 assert_eq!(u.output_tokens, 50);
381 assert_eq!(u.cache_read_tokens, 20);
382 assert_eq!(u.cache_write_tokens, 10);
383 assert_eq!(u.total(), 180);
384 }
385 }
386
387 #[test]
388 fn usage_without_cache_tokens_deserialization() {
389 let json = r#"{"input_tokens": 100, "output_tokens": 50}"#;
390 let usage: std::result::Result<Usage, _> = serde_json::from_str(json);
391 assert!(usage.is_ok());
392 if let Ok(u) = usage {
393 assert_eq!(u.cache_read_tokens, 0);
394 assert_eq!(u.cache_write_tokens, 0);
395 assert_eq!(u.total(), 150);
396 }
397 }
398
399 #[test]
400 fn request_with_thinking() {
401 let req = CompletionRequest::new("claude-opus-4", vec![Message::user("hi")], 16384)
402 .thinking(ThinkingConfig {
403 enabled: true,
404 budget_tokens: Some(10_000),
405 });
406 assert!(req.thinking.is_some());
407 if let Some(tc) = &req.thinking {
408 assert!(tc.enabled);
409 assert_eq!(tc.budget_tokens, Some(10_000));
410 }
411 }
412}