Skip to main content

saorsa_ai/
types.rs

1//! Request, response, and streaming types for LLM APIs.
2
3use serde::{Deserialize, Serialize};
4
5use crate::message::{ContentBlock, Message, ToolDefinition};
6
7/// Configuration for extended thinking/reasoning.
8#[derive(Clone, Debug, Serialize, Deserialize)]
9pub struct ThinkingConfig {
10    /// Whether thinking is enabled.
11    pub enabled: bool,
12    /// Maximum tokens for thinking budget.
13    #[serde(skip_serializing_if = "Option::is_none")]
14    pub budget_tokens: Option<u32>,
15}
16
17/// A completion request to send to an LLM provider.
18#[derive(Clone, Debug, Serialize)]
19pub struct CompletionRequest {
20    /// The model identifier.
21    pub model: String,
22    /// The conversation messages.
23    pub messages: Vec<Message>,
24    /// Maximum tokens to generate.
25    pub max_tokens: u32,
26    /// Optional system prompt.
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub system: Option<String>,
29    /// Available tools.
30    #[serde(skip_serializing_if = "Vec::is_empty")]
31    pub tools: Vec<ToolDefinition>,
32    /// Whether to stream the response.
33    #[serde(skip_serializing_if = "std::ops::Not::not")]
34    pub stream: bool,
35    /// Sampling temperature (0.0-1.0).
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub temperature: Option<f32>,
38    /// Stop sequences.
39    #[serde(skip_serializing_if = "Vec::is_empty")]
40    pub stop_sequences: Vec<String>,
41    /// Extended thinking configuration.
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub thinking: Option<ThinkingConfig>,
44}
45
46impl CompletionRequest {
47    /// Create a new request with required fields.
48    pub fn new(model: impl Into<String>, messages: Vec<Message>, max_tokens: u32) -> Self {
49        Self {
50            model: model.into(),
51            messages,
52            max_tokens,
53            system: None,
54            tools: Vec::new(),
55            stream: false,
56            temperature: None,
57            stop_sequences: Vec::new(),
58            thinking: None,
59        }
60    }
61
62    /// Set the system prompt.
63    #[must_use]
64    pub fn system(mut self, system: impl Into<String>) -> Self {
65        self.system = Some(system.into());
66        self
67    }
68
69    /// Set streaming mode.
70    #[must_use]
71    pub fn stream(mut self, stream: bool) -> Self {
72        self.stream = stream;
73        self
74    }
75
76    /// Set the temperature.
77    #[must_use]
78    pub fn temperature(mut self, temp: f32) -> Self {
79        self.temperature = Some(temp);
80        self
81    }
82
83    /// Add tools.
84    #[must_use]
85    pub fn tools(mut self, tools: Vec<ToolDefinition>) -> Self {
86        self.tools = tools;
87        self
88    }
89
90    /// Set extended thinking configuration.
91    #[must_use]
92    pub fn thinking(mut self, config: ThinkingConfig) -> Self {
93        self.thinking = Some(config);
94        self
95    }
96}
97
98/// A completion response from an LLM provider.
99#[derive(Clone, Debug, Deserialize)]
100pub struct CompletionResponse {
101    /// Unique response ID.
102    pub id: String,
103    /// The response content blocks.
104    pub content: Vec<ContentBlock>,
105    /// The model that generated this response.
106    pub model: String,
107    /// Why the model stopped generating.
108    pub stop_reason: Option<StopReason>,
109    /// Token usage information.
110    pub usage: Usage,
111}
112
113/// Why the model stopped generating.
114#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
115#[serde(rename_all = "snake_case")]
116pub enum StopReason {
117    /// The model finished its response naturally.
118    EndTurn,
119    /// The response hit the max_tokens limit.
120    MaxTokens,
121    /// A stop sequence was encountered.
122    StopSequence,
123    /// The model wants to use a tool.
124    ToolUse,
125}
126
127/// Token usage information.
128#[derive(Clone, Debug, Default, Serialize, Deserialize)]
129pub struct Usage {
130    /// Number of input tokens.
131    #[serde(default)]
132    pub input_tokens: u32,
133    /// Number of output tokens.
134    #[serde(default)]
135    pub output_tokens: u32,
136    /// Number of input tokens read from cache.
137    #[serde(default)]
138    pub cache_read_tokens: u32,
139    /// Number of input tokens written to cache.
140    #[serde(default)]
141    pub cache_write_tokens: u32,
142}
143
144impl Usage {
145    /// Total tokens (input + output + cache).
146    pub fn total(&self) -> u32 {
147        self.input_tokens + self.output_tokens + self.cache_read_tokens + self.cache_write_tokens
148    }
149}
150
151/// A streaming event from the LLM provider.
152#[derive(Clone, Debug)]
153pub enum StreamEvent {
154    /// The message has started.
155    MessageStart {
156        /// Response ID.
157        id: String,
158        /// Model name.
159        model: String,
160        /// Usage so far.
161        usage: Usage,
162    },
163    /// A content block has started.
164    ContentBlockStart {
165        /// Index of the content block.
166        index: u32,
167        /// The initial content block (may be partial).
168        content_block: ContentBlock,
169    },
170    /// A delta (incremental update) to a content block.
171    ContentBlockDelta {
172        /// Index of the content block.
173        index: u32,
174        /// The delta content.
175        delta: ContentDelta,
176    },
177    /// A content block has finished.
178    ContentBlockStop {
179        /// Index of the content block.
180        index: u32,
181    },
182    /// Final message metadata.
183    MessageDelta {
184        /// Why the model stopped.
185        stop_reason: Option<StopReason>,
186        /// Final usage info.
187        usage: Usage,
188    },
189    /// The message is complete.
190    MessageStop,
191    /// Keepalive ping.
192    Ping,
193    /// An error occurred.
194    Error {
195        /// Error message.
196        message: String,
197    },
198}
199
200/// Delta content for streaming updates.
201#[derive(Clone, Debug, Serialize, Deserialize)]
202#[serde(tag = "type", rename_all = "snake_case")]
203pub enum ContentDelta {
204    /// A text delta.
205    TextDelta {
206        /// The incremental text.
207        text: String,
208    },
209    /// A tool input delta (partial JSON).
210    InputJsonDelta {
211        /// Partial JSON string.
212        partial_json: String,
213    },
214    /// A thinking/reasoning delta.
215    ThinkingDelta {
216        /// The incremental thinking text.
217        text: String,
218    },
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use crate::message::Message;
225
226    #[test]
227    fn request_builder() {
228        let req = CompletionRequest::new(
229            "claude-sonnet-4-5-20250929",
230            vec![Message::user("hi")],
231            1024,
232        )
233        .system("You are helpful")
234        .temperature(0.7)
235        .stream(true);
236        assert_eq!(req.model, "claude-sonnet-4-5-20250929");
237        assert_eq!(req.max_tokens, 1024);
238        assert!(req.stream);
239        assert_eq!(req.temperature, Some(0.7));
240        assert_eq!(req.system, Some("You are helpful".into()));
241    }
242
243    #[test]
244    fn request_serialization() {
245        let req = CompletionRequest::new(
246            "claude-sonnet-4-5-20250929",
247            vec![Message::user("hi")],
248            1024,
249        );
250        let json = serde_json::to_string(&req);
251        assert!(json.is_ok());
252        let json_str = json.as_deref().unwrap_or("");
253        assert!(json_str.contains("claude-sonnet-4-5-20250929"));
254        assert!(json_str.contains("1024"));
255        // stream=false should not be serialized
256        assert!(!json_str.contains("stream"));
257    }
258
259    #[test]
260    fn response_parsing() {
261        let json = r#"{
262            "id": "msg_123",
263            "content": [{"type": "text", "text": "Hello!"}],
264            "model": "claude-sonnet-4-5-20250929",
265            "stop_reason": "end_turn",
266            "usage": {"input_tokens": 10, "output_tokens": 5}
267        }"#;
268        let resp: std::result::Result<CompletionResponse, _> = serde_json::from_str(json);
269        assert!(resp.is_ok());
270        if let Ok(resp) = resp {
271            assert_eq!(resp.id, "msg_123");
272            // total = 10 + 5 + 0 (cache_read) + 0 (cache_write)
273            assert_eq!(resp.usage.total(), 15);
274            assert_eq!(resp.usage.cache_read_tokens, 0);
275            assert_eq!(resp.usage.cache_write_tokens, 0);
276        }
277    }
278
279    #[test]
280    fn stop_reason_parsing() {
281        let json = r#""end_turn""#;
282        let reason: Result<StopReason, _> = serde_json::from_str(json);
283        assert_eq!(reason.ok(), Some(StopReason::EndTurn));
284
285        let json = r#""tool_use""#;
286        let reason: Result<StopReason, _> = serde_json::from_str(json);
287        assert_eq!(reason.ok(), Some(StopReason::ToolUse));
288    }
289
290    #[test]
291    fn usage_total() {
292        let u = Usage {
293            input_tokens: 100,
294            output_tokens: 50,
295            cache_read_tokens: 0,
296            cache_write_tokens: 0,
297        };
298        assert_eq!(u.total(), 150);
299    }
300
301    #[test]
302    fn usage_total_with_cache_tokens() {
303        let u = Usage {
304            input_tokens: 100,
305            output_tokens: 50,
306            cache_read_tokens: 20,
307            cache_write_tokens: 10,
308        };
309        assert_eq!(u.total(), 180);
310    }
311
312    #[test]
313    fn content_delta_serialization() {
314        let delta = ContentDelta::TextDelta {
315            text: "hello".into(),
316        };
317        let json = serde_json::to_string(&delta);
318        assert!(json.is_ok());
319        assert!(json.as_deref().unwrap_or("").contains("text_delta"));
320    }
321
322    #[test]
323    fn thinking_config_serialization() {
324        let config = ThinkingConfig {
325            enabled: true,
326            budget_tokens: Some(10_000),
327        };
328        let json = serde_json::to_string(&config);
329        assert!(json.is_ok());
330        let json_str = json.as_deref().unwrap_or("");
331        assert!(json_str.contains("true"));
332        assert!(json_str.contains("10000"));
333    }
334
335    #[test]
336    fn thinking_config_without_budget() {
337        let config = ThinkingConfig {
338            enabled: true,
339            budget_tokens: None,
340        };
341        let json = serde_json::to_string(&config);
342        assert!(json.is_ok());
343        let json_str = json.as_deref().unwrap_or("");
344        assert!(json_str.contains("true"));
345        assert!(!json_str.contains("budget_tokens"));
346    }
347
348    #[test]
349    fn thinking_config_roundtrip() {
350        let config = ThinkingConfig {
351            enabled: true,
352            budget_tokens: Some(5000),
353        };
354        let json = serde_json::to_string(&config).unwrap_or_default();
355        let parsed: std::result::Result<ThinkingConfig, _> = serde_json::from_str(&json);
356        assert!(parsed.is_ok());
357        if let Ok(c) = parsed {
358            assert!(c.enabled);
359            assert_eq!(c.budget_tokens, Some(5000));
360        }
361    }
362
363    #[test]
364    fn thinking_delta_variant() {
365        let delta = ContentDelta::ThinkingDelta {
366            text: "Let me think...".into(),
367        };
368        let json = serde_json::to_string(&delta);
369        assert!(json.is_ok());
370        assert!(json.as_deref().unwrap_or("").contains("thinking_delta"));
371    }
372
373    #[test]
374    fn usage_with_cache_tokens_deserialization() {
375        let json = r#"{"input_tokens": 100, "output_tokens": 50, "cache_read_tokens": 20, "cache_write_tokens": 10}"#;
376        let usage: std::result::Result<Usage, _> = serde_json::from_str(json);
377        assert!(usage.is_ok());
378        if let Ok(u) = usage {
379            assert_eq!(u.input_tokens, 100);
380            assert_eq!(u.output_tokens, 50);
381            assert_eq!(u.cache_read_tokens, 20);
382            assert_eq!(u.cache_write_tokens, 10);
383            assert_eq!(u.total(), 180);
384        }
385    }
386
387    #[test]
388    fn usage_without_cache_tokens_deserialization() {
389        let json = r#"{"input_tokens": 100, "output_tokens": 50}"#;
390        let usage: std::result::Result<Usage, _> = serde_json::from_str(json);
391        assert!(usage.is_ok());
392        if let Ok(u) = usage {
393            assert_eq!(u.cache_read_tokens, 0);
394            assert_eq!(u.cache_write_tokens, 0);
395            assert_eq!(u.total(), 150);
396        }
397    }
398
399    #[test]
400    fn request_with_thinking() {
401        let req = CompletionRequest::new("claude-opus-4", vec![Message::user("hi")], 16384)
402            .thinking(ThinkingConfig {
403                enabled: true,
404                budget_tokens: Some(10_000),
405            });
406        assert!(req.thinking.is_some());
407        if let Some(tc) = &req.thinking {
408            assert!(tc.enabled);
409            assert_eq!(tc.budget_tokens, Some(10_000));
410        }
411    }
412}