1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7 Gemini,
8 OpenAI,
9 Anthropic,
10 DeepSeek,
11 OpenRouter,
12 Ollama,
13 ZAI,
14 Moonshot,
15 HuggingFace,
16 Minimax,
17 LiteLLM,
18}
19
20#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
21pub struct Usage {
22 pub prompt_tokens: u32,
23 pub completion_tokens: u32,
24 pub total_tokens: u32,
25 pub cached_prompt_tokens: Option<u32>,
26 pub cache_creation_tokens: Option<u32>,
27 pub cache_read_tokens: Option<u32>,
28}
29
30impl Usage {
31 #[inline]
32 fn has_cache_read_metric(&self) -> bool {
33 self.cache_read_tokens.is_some() || self.cached_prompt_tokens.is_some()
34 }
35
36 #[inline]
37 fn has_any_cache_metrics(&self) -> bool {
38 self.has_cache_read_metric() || self.cache_creation_tokens.is_some()
39 }
40
41 #[inline]
42 pub fn cache_read_tokens_or_fallback(&self) -> u32 {
43 self.cache_read_tokens
44 .or(self.cached_prompt_tokens)
45 .unwrap_or(0)
46 }
47
48 #[inline]
49 pub fn cache_creation_tokens_or_zero(&self) -> u32 {
50 self.cache_creation_tokens.unwrap_or(0)
51 }
52
53 #[inline]
54 pub fn cache_hit_rate(&self) -> Option<f64> {
55 if !self.has_any_cache_metrics() {
56 return None;
57 }
58 let read = self.cache_read_tokens_or_fallback() as f64;
59 let creation = self.cache_creation_tokens_or_zero() as f64;
60 let total = read + creation;
61 if total > 0.0 {
62 Some((read / total) * 100.0)
63 } else {
64 None
65 }
66 }
67
68 #[inline]
69 pub fn is_cache_hit(&self) -> Option<bool> {
70 self.has_any_cache_metrics()
71 .then(|| self.cache_read_tokens_or_fallback() > 0)
72 }
73
74 #[inline]
75 pub fn is_cache_miss(&self) -> Option<bool> {
76 self.has_any_cache_metrics().then(|| {
77 self.cache_creation_tokens_or_zero() > 0 && self.cache_read_tokens_or_fallback() == 0
78 })
79 }
80
81 #[inline]
82 pub fn total_cache_tokens(&self) -> u32 {
83 let read = self.cache_read_tokens_or_fallback();
84 let creation = self.cache_creation_tokens_or_zero();
85 read + creation
86 }
87
88 #[inline]
89 pub fn cache_savings_ratio(&self) -> Option<f64> {
90 if !self.has_cache_read_metric() {
91 return None;
92 }
93 let read = self.cache_read_tokens_or_fallback() as f64;
94 let prompt = self.prompt_tokens as f64;
95 if prompt > 0.0 {
96 Some(read / prompt)
97 } else {
98 None
99 }
100 }
101}
102
103#[cfg(test)]
104mod usage_tests {
105 use super::Usage;
106
107 #[test]
108 fn cache_helpers_fall_back_to_cached_prompt_tokens() {
109 let usage = Usage {
110 prompt_tokens: 1_000,
111 completion_tokens: 200,
112 total_tokens: 1_200,
113 cached_prompt_tokens: Some(600),
114 cache_creation_tokens: Some(150),
115 cache_read_tokens: None,
116 };
117
118 assert_eq!(usage.cache_read_tokens_or_fallback(), 600);
119 assert_eq!(usage.cache_creation_tokens_or_zero(), 150);
120 assert_eq!(usage.total_cache_tokens(), 750);
121 assert_eq!(usage.is_cache_hit(), Some(true));
122 assert_eq!(usage.is_cache_miss(), Some(false));
123 assert_eq!(usage.cache_savings_ratio(), Some(0.6));
124 assert_eq!(usage.cache_hit_rate(), Some(80.0));
125 }
126
127 #[test]
128 fn cache_helpers_preserve_unknown_without_metrics() {
129 let usage = Usage {
130 prompt_tokens: 1_000,
131 completion_tokens: 200,
132 total_tokens: 1_200,
133 cached_prompt_tokens: None,
134 cache_creation_tokens: None,
135 cache_read_tokens: None,
136 };
137
138 assert_eq!(usage.total_cache_tokens(), 0);
139 assert_eq!(usage.is_cache_hit(), None);
140 assert_eq!(usage.is_cache_miss(), None);
141 assert_eq!(usage.cache_savings_ratio(), None);
142 assert_eq!(usage.cache_hit_rate(), None);
143 }
144}
145
146#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
147pub enum FinishReason {
148 #[default]
149 Stop,
150 Length,
151 ToolCalls,
152 ContentFilter,
153 Pause,
154 Refusal,
155 Error(String),
156}
157
158#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
160pub struct ToolCall {
161 pub id: String,
163
164 #[serde(rename = "type")]
166 pub call_type: String,
167
168 #[serde(skip_serializing_if = "Option::is_none")]
170 pub function: Option<FunctionCall>,
171
172 #[serde(skip_serializing_if = "Option::is_none")]
174 pub text: Option<String>,
175
176 #[serde(skip_serializing_if = "Option::is_none")]
178 pub thought_signature: Option<String>,
179}
180
181#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
183pub struct FunctionCall {
184 #[serde(default, skip_serializing_if = "Option::is_none")]
186 pub namespace: Option<String>,
187
188 pub name: String,
190
191 pub arguments: String,
193}
194
195impl ToolCall {
196 pub fn function(id: String, name: String, arguments: String) -> Self {
198 Self::function_with_namespace(id, None, name, arguments)
199 }
200
201 pub fn function_with_namespace(
203 id: String,
204 namespace: Option<String>,
205 name: String,
206 arguments: String,
207 ) -> Self {
208 Self {
209 id,
210 call_type: "function".to_owned(),
211 function: Some(FunctionCall {
212 namespace,
213 name,
214 arguments,
215 }),
216 text: None,
217 thought_signature: None,
218 }
219 }
220
221 pub fn custom(id: String, name: String, text: String) -> Self {
223 Self {
224 id,
225 call_type: "custom".to_owned(),
226 function: Some(FunctionCall {
227 namespace: None,
228 name,
229 arguments: text.clone(),
230 }),
231 text: Some(text),
232 thought_signature: None,
233 }
234 }
235
236 pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
238 if let Some(ref func) = self.function {
239 parse_tool_arguments(&func.arguments)
240 } else {
241 serde_json::from_str("")
243 }
244 }
245
246 pub fn validate(&self) -> Result<(), String> {
248 if self.id.is_empty() {
249 return Err("Tool call ID cannot be empty".to_owned());
250 }
251
252 match self.call_type.as_str() {
253 "function" => {
254 if let Some(func) = &self.function {
255 if func.name.is_empty() {
256 return Err("Function name cannot be empty".to_owned());
257 }
258 if let Err(e) = self.parsed_arguments() {
260 return Err(format!("Invalid JSON in function arguments: {}", e));
261 }
262 } else {
263 return Err("Function tool call missing function details".to_owned());
264 }
265 }
266 "custom" => {
267 if let Some(func) = &self.function {
269 if func.name.is_empty() {
270 return Err("Custom tool name cannot be empty".to_owned());
271 }
272 } else {
273 return Err("Custom tool call missing function details".to_owned());
274 }
275 }
276 _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
277 }
278
279 Ok(())
280 }
281}
282
283fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
284 let trimmed = raw_arguments.trim();
285 match serde_json::from_str(trimmed) {
286 Ok(parsed) => Ok(parsed),
287 Err(primary_error) => {
288 if let Some(candidate) = extract_balanced_json(trimmed)
289 && let Ok(parsed) = serde_json::from_str(candidate)
290 {
291 return Ok(parsed);
292 }
293 Err(primary_error)
294 }
295 }
296}
297
298fn extract_balanced_json(input: &str) -> Option<&str> {
299 let start = input.find(['{', '['])?;
300 let opening = input.as_bytes().get(start).copied()?;
301 let closing = match opening {
302 b'{' => b'}',
303 b'[' => b']',
304 _ => return None,
305 };
306
307 let mut depth = 0usize;
308 let mut in_string = false;
309 let mut escaped = false;
310
311 for (offset, ch) in input[start..].char_indices() {
312 if in_string {
313 if escaped {
314 escaped = false;
315 continue;
316 }
317 if ch == '\\' {
318 escaped = true;
319 continue;
320 }
321 if ch == '"' {
322 in_string = false;
323 }
324 continue;
325 }
326
327 match ch {
328 '"' => in_string = true,
329 _ if ch as u32 == opening as u32 => depth += 1,
330 _ if ch as u32 == closing as u32 => {
331 depth = depth.saturating_sub(1);
332 if depth == 0 {
333 let end = start + offset + ch.len_utf8();
334 return input.get(start..end);
335 }
336 }
337 _ => {}
338 }
339 }
340
341 None
342}
343
344#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
346pub struct LLMResponse {
347 pub content: Option<String>,
349
350 pub tool_calls: Option<Vec<ToolCall>>,
352
353 pub model: String,
355
356 pub usage: Option<Usage>,
358
359 pub finish_reason: FinishReason,
361
362 pub reasoning: Option<String>,
364
365 pub reasoning_details: Option<Vec<String>>,
367
368 pub tool_references: Vec<String>,
370
371 pub request_id: Option<String>,
373
374 pub organization_id: Option<String>,
376}
377
378impl LLMResponse {
379 pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
381 Self {
382 content: Some(content.into()),
383 tool_calls: None,
384 model: model.into(),
385 usage: None,
386 finish_reason: FinishReason::Stop,
387 reasoning: None,
388 reasoning_details: None,
389 tool_references: Vec::new(),
390 request_id: None,
391 organization_id: None,
392 }
393 }
394
395 pub fn content_text(&self) -> &str {
397 self.content.as_deref().unwrap_or("")
398 }
399
400 pub fn content_string(&self) -> String {
402 self.content.clone().unwrap_or_default()
403 }
404}
405
406#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
407pub struct LLMErrorMetadata {
408 pub provider: Option<String>,
409 pub status: Option<u16>,
410 pub code: Option<String>,
411 pub request_id: Option<String>,
412 pub organization_id: Option<String>,
413 pub retry_after: Option<String>,
414 pub message: Option<String>,
415}
416
417impl LLMErrorMetadata {
418 pub fn new(
419 provider: impl Into<String>,
420 status: Option<u16>,
421 code: Option<String>,
422 request_id: Option<String>,
423 organization_id: Option<String>,
424 retry_after: Option<String>,
425 message: Option<String>,
426 ) -> Box<Self> {
427 Box::new(Self {
428 provider: Some(provider.into()),
429 status,
430 code,
431 request_id,
432 organization_id,
433 retry_after,
434 message,
435 })
436 }
437}
438
439#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
441#[serde(tag = "type", rename_all = "snake_case")]
442pub enum LLMError {
443 #[error("Authentication failed: {message}")]
444 Authentication {
445 message: String,
446 metadata: Option<Box<LLMErrorMetadata>>,
447 },
448 #[error("Rate limit exceeded")]
449 RateLimit {
450 metadata: Option<Box<LLMErrorMetadata>>,
451 },
452 #[error("Invalid request: {message}")]
453 InvalidRequest {
454 message: String,
455 metadata: Option<Box<LLMErrorMetadata>>,
456 },
457 #[error("Network error: {message}")]
458 Network {
459 message: String,
460 metadata: Option<Box<LLMErrorMetadata>>,
461 },
462 #[error("Provider error: {message}")]
463 Provider {
464 message: String,
465 metadata: Option<Box<LLMErrorMetadata>>,
466 },
467}
468
469#[cfg(test)]
470mod tests {
471 use super::ToolCall;
472 use serde_json::json;
473
474 #[test]
475 fn parsed_arguments_accepts_trailing_characters() {
476 let call = ToolCall::function(
477 "call_read".to_string(),
478 "read_file".to_string(),
479 r#"{"path":"src/main.rs"} trailing text"#.to_string(),
480 );
481
482 let parsed = call
483 .parsed_arguments()
484 .expect("arguments with trailing text should recover");
485 assert_eq!(parsed, json!({"path":"src/main.rs"}));
486 }
487
488 #[test]
489 fn parsed_arguments_accepts_code_fenced_json() {
490 let call = ToolCall::function(
491 "call_read".to_string(),
492 "read_file".to_string(),
493 "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
494 );
495
496 let parsed = call
497 .parsed_arguments()
498 .expect("code-fenced arguments should recover");
499 assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
500 }
501
502 #[test]
503 fn parsed_arguments_rejects_incomplete_json() {
504 let call = ToolCall::function(
505 "call_read".to_string(),
506 "read_file".to_string(),
507 r#"{"path":"src/main.rs""#.to_string(),
508 );
509
510 assert!(call.parsed_arguments().is_err());
511 }
512
513 #[test]
514 fn function_call_serializes_optional_namespace() {
515 let call = ToolCall::function_with_namespace(
516 "call_read".to_string(),
517 Some("workspace".to_string()),
518 "read_file".to_string(),
519 r#"{"path":"src/main.rs"}"#.to_string(),
520 );
521
522 let json = serde_json::to_value(&call).expect("tool call should serialize");
523 assert_eq!(json["function"]["namespace"], "workspace");
524 assert_eq!(json["function"]["name"], "read_file");
525 }
526}