1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7 Gemini,
8 OpenAI,
9 Anthropic,
10 DeepSeek,
11 OpenRouter,
12 Ollama,
13 XAI,
14 ZAI,
15 Moonshot,
16 HuggingFace,
17 Minimax,
18}
19
20#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
21pub struct Usage {
22 pub prompt_tokens: u32,
23 pub completion_tokens: u32,
24 pub total_tokens: u32,
25 pub cached_prompt_tokens: Option<u32>,
26 pub cache_creation_tokens: Option<u32>,
27 pub cache_read_tokens: Option<u32>,
28}
29
30impl Usage {
31 #[inline]
32 pub fn cache_hit_rate(&self) -> Option<f64> {
33 let read = self.cache_read_tokens? as f64;
34 let creation = self.cache_creation_tokens? as f64;
35 let total = read + creation;
36 if total > 0.0 {
37 Some((read / total) * 100.0)
38 } else {
39 None
40 }
41 }
42
43 #[inline]
44 pub fn is_cache_hit(&self) -> Option<bool> {
45 Some(self.cache_read_tokens? > 0)
46 }
47
48 #[inline]
49 pub fn is_cache_miss(&self) -> Option<bool> {
50 Some(self.cache_creation_tokens? > 0 && self.cache_read_tokens? == 0)
51 }
52
53 #[inline]
54 pub fn total_cache_tokens(&self) -> u32 {
55 let read = self.cache_read_tokens.unwrap_or(0);
56 let creation = self.cache_creation_tokens.unwrap_or(0);
57 read + creation
58 }
59
60 #[inline]
61 pub fn cache_savings_ratio(&self) -> Option<f64> {
62 let read = self.cache_read_tokens? as f64;
63 let prompt = self.prompt_tokens as f64;
64 if prompt > 0.0 {
65 Some(read / prompt)
66 } else {
67 None
68 }
69 }
70}
71
72#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
73pub enum FinishReason {
74 #[default]
75 Stop,
76 Length,
77 ToolCalls,
78 ContentFilter,
79 Pause,
80 Refusal,
81 Error(String),
82}
83
84#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
86pub struct ToolCall {
87 pub id: String,
89
90 #[serde(rename = "type")]
92 pub call_type: String,
93
94 #[serde(skip_serializing_if = "Option::is_none")]
96 pub function: Option<FunctionCall>,
97
98 #[serde(skip_serializing_if = "Option::is_none")]
100 pub text: Option<String>,
101
102 #[serde(skip_serializing_if = "Option::is_none")]
104 pub thought_signature: Option<String>,
105}
106
107#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
109pub struct FunctionCall {
110 pub name: String,
112
113 pub arguments: String,
115}
116
117impl ToolCall {
118 pub fn function(id: String, name: String, arguments: String) -> Self {
120 Self {
121 id,
122 call_type: "function".to_owned(),
123 function: Some(FunctionCall { name, arguments }),
124 text: None,
125 thought_signature: None,
126 }
127 }
128
129 pub fn custom(id: String, name: String, text: String) -> Self {
131 Self {
132 id,
133 call_type: "custom".to_owned(),
134 function: Some(FunctionCall {
135 name,
136 arguments: text.clone(),
137 }),
138 text: Some(text),
139 thought_signature: None,
140 }
141 }
142
143 pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
145 if let Some(ref func) = self.function {
146 parse_tool_arguments(&func.arguments)
147 } else {
148 serde_json::from_str("")
150 }
151 }
152
153 pub fn validate(&self) -> Result<(), String> {
155 if self.id.is_empty() {
156 return Err("Tool call ID cannot be empty".to_owned());
157 }
158
159 match self.call_type.as_str() {
160 "function" => {
161 if let Some(func) = &self.function {
162 if func.name.is_empty() {
163 return Err("Function name cannot be empty".to_owned());
164 }
165 if let Err(e) = self.parsed_arguments() {
167 return Err(format!("Invalid JSON in function arguments: {}", e));
168 }
169 } else {
170 return Err("Function tool call missing function details".to_owned());
171 }
172 }
173 "custom" => {
174 if let Some(func) = &self.function {
176 if func.name.is_empty() {
177 return Err("Custom tool name cannot be empty".to_owned());
178 }
179 } else {
180 return Err("Custom tool call missing function details".to_owned());
181 }
182 }
183 _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
184 }
185
186 Ok(())
187 }
188}
189
190fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
191 let trimmed = raw_arguments.trim();
192 match serde_json::from_str(trimmed) {
193 Ok(parsed) => Ok(parsed),
194 Err(primary_error) => {
195 if let Some(candidate) = extract_balanced_json(trimmed)
196 && let Ok(parsed) = serde_json::from_str(candidate)
197 {
198 return Ok(parsed);
199 }
200 Err(primary_error)
201 }
202 }
203}
204
205fn extract_balanced_json(input: &str) -> Option<&str> {
206 let start = input.find(['{', '['])?;
207 let opening = input.as_bytes().get(start).copied()?;
208 let closing = match opening {
209 b'{' => b'}',
210 b'[' => b']',
211 _ => return None,
212 };
213
214 let mut depth = 0usize;
215 let mut in_string = false;
216 let mut escaped = false;
217
218 for (offset, ch) in input[start..].char_indices() {
219 if in_string {
220 if escaped {
221 escaped = false;
222 continue;
223 }
224 if ch == '\\' {
225 escaped = true;
226 continue;
227 }
228 if ch == '"' {
229 in_string = false;
230 }
231 continue;
232 }
233
234 match ch {
235 '"' => in_string = true,
236 _ if ch as u32 == opening as u32 => depth += 1,
237 _ if ch as u32 == closing as u32 => {
238 depth = depth.saturating_sub(1);
239 if depth == 0 {
240 let end = start + offset + ch.len_utf8();
241 return input.get(start..end);
242 }
243 }
244 _ => {}
245 }
246 }
247
248 None
249}
250
251#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
253pub struct LLMResponse {
254 pub content: Option<String>,
256
257 pub tool_calls: Option<Vec<ToolCall>>,
259
260 pub model: String,
262
263 pub usage: Option<Usage>,
265
266 pub finish_reason: FinishReason,
268
269 pub reasoning: Option<String>,
271
272 pub reasoning_details: Option<Vec<String>>,
274
275 pub tool_references: Vec<String>,
277
278 pub request_id: Option<String>,
280
281 pub organization_id: Option<String>,
283}
284
285impl LLMResponse {
286 pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
288 Self {
289 content: Some(content.into()),
290 tool_calls: None,
291 model: model.into(),
292 usage: None,
293 finish_reason: FinishReason::Stop,
294 reasoning: None,
295 reasoning_details: None,
296 tool_references: Vec::new(),
297 request_id: None,
298 organization_id: None,
299 }
300 }
301
302 pub fn content_text(&self) -> &str {
304 self.content.as_deref().unwrap_or("")
305 }
306
307 pub fn content_string(&self) -> String {
309 self.content.clone().unwrap_or_default()
310 }
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
314pub struct LLMErrorMetadata {
315 pub provider: Option<String>,
316 pub status: Option<u16>,
317 pub code: Option<String>,
318 pub request_id: Option<String>,
319 pub organization_id: Option<String>,
320 pub retry_after: Option<String>,
321 pub message: Option<String>,
322}
323
324impl LLMErrorMetadata {
325 pub fn new(
326 provider: impl Into<String>,
327 status: Option<u16>,
328 code: Option<String>,
329 request_id: Option<String>,
330 organization_id: Option<String>,
331 retry_after: Option<String>,
332 message: Option<String>,
333 ) -> Box<Self> {
334 Box::new(Self {
335 provider: Some(provider.into()),
336 status,
337 code,
338 request_id,
339 organization_id,
340 retry_after,
341 message,
342 })
343 }
344}
345
346#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
348#[serde(tag = "type", rename_all = "snake_case")]
349pub enum LLMError {
350 #[error("Authentication failed: {message}")]
351 Authentication {
352 message: String,
353 metadata: Option<Box<LLMErrorMetadata>>,
354 },
355 #[error("Rate limit exceeded")]
356 RateLimit {
357 metadata: Option<Box<LLMErrorMetadata>>,
358 },
359 #[error("Invalid request: {message}")]
360 InvalidRequest {
361 message: String,
362 metadata: Option<Box<LLMErrorMetadata>>,
363 },
364 #[error("Network error: {message}")]
365 Network {
366 message: String,
367 metadata: Option<Box<LLMErrorMetadata>>,
368 },
369 #[error("Provider error: {message}")]
370 Provider {
371 message: String,
372 metadata: Option<Box<LLMErrorMetadata>>,
373 },
374}
375
376#[cfg(test)]
377mod tests {
378 use super::ToolCall;
379 use serde_json::json;
380
381 #[test]
382 fn parsed_arguments_accepts_trailing_characters() {
383 let call = ToolCall::function(
384 "call_read".to_string(),
385 "read_file".to_string(),
386 r#"{"path":"src/main.rs"} trailing text"#.to_string(),
387 );
388
389 let parsed = call
390 .parsed_arguments()
391 .expect("arguments with trailing text should recover");
392 assert_eq!(parsed, json!({"path":"src/main.rs"}));
393 }
394
395 #[test]
396 fn parsed_arguments_accepts_code_fenced_json() {
397 let call = ToolCall::function(
398 "call_read".to_string(),
399 "read_file".to_string(),
400 "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
401 );
402
403 let parsed = call
404 .parsed_arguments()
405 .expect("code-fenced arguments should recover");
406 assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
407 }
408
409 #[test]
410 fn parsed_arguments_rejects_incomplete_json() {
411 let call = ToolCall::function(
412 "call_read".to_string(),
413 "read_file".to_string(),
414 r#"{"path":"src/main.rs""#.to_string(),
415 );
416
417 assert!(call.parsed_arguments().is_err());
418 }
419}