1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7 Gemini,
8 OpenAI,
9 Anthropic,
10 DeepSeek,
11 OpenRouter,
12 Ollama,
13 ZAI,
14 Moonshot,
15 HuggingFace,
16 Minimax,
17 LiteLLM,
18}
19
20#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
21pub struct Usage {
22 pub prompt_tokens: u32,
23 pub completion_tokens: u32,
24 pub total_tokens: u32,
25 pub cached_prompt_tokens: Option<u32>,
26 pub cache_creation_tokens: Option<u32>,
27 pub cache_read_tokens: Option<u32>,
28}
29
30impl Usage {
31 #[inline]
32 fn has_cache_read_metric(&self) -> bool {
33 self.cache_read_tokens.is_some() || self.cached_prompt_tokens.is_some()
34 }
35
36 #[inline]
37 fn has_any_cache_metrics(&self) -> bool {
38 self.has_cache_read_metric() || self.cache_creation_tokens.is_some()
39 }
40
41 #[inline]
42 pub fn cache_read_tokens_or_fallback(&self) -> u32 {
43 self.cache_read_tokens
44 .or(self.cached_prompt_tokens)
45 .unwrap_or(0)
46 }
47
48 #[inline]
49 pub fn cache_creation_tokens_or_zero(&self) -> u32 {
50 self.cache_creation_tokens.unwrap_or(0)
51 }
52
53 #[inline]
54 pub fn cache_hit_rate(&self) -> Option<f64> {
55 if !self.has_any_cache_metrics() {
56 return None;
57 }
58 let read = self.cache_read_tokens_or_fallback() as f64;
59 let creation = self.cache_creation_tokens_or_zero() as f64;
60 let total = read + creation;
61 if total > 0.0 {
62 Some((read / total) * 100.0)
63 } else {
64 None
65 }
66 }
67
68 #[inline]
69 pub fn is_cache_hit(&self) -> Option<bool> {
70 self.has_any_cache_metrics()
71 .then(|| self.cache_read_tokens_or_fallback() > 0)
72 }
73
74 #[inline]
75 pub fn is_cache_miss(&self) -> Option<bool> {
76 self.has_any_cache_metrics().then(|| {
77 self.cache_creation_tokens_or_zero() > 0 && self.cache_read_tokens_or_fallback() == 0
78 })
79 }
80
81 #[inline]
82 pub fn total_cache_tokens(&self) -> u32 {
83 let read = self.cache_read_tokens_or_fallback();
84 let creation = self.cache_creation_tokens_or_zero();
85 read + creation
86 }
87
88 #[inline]
89 pub fn cache_savings_ratio(&self) -> Option<f64> {
90 if !self.has_cache_read_metric() {
91 return None;
92 }
93 let read = self.cache_read_tokens_or_fallback() as f64;
94 let prompt = self.prompt_tokens as f64;
95 if prompt > 0.0 {
96 Some(read / prompt)
97 } else {
98 None
99 }
100 }
101}
102
103#[cfg(test)]
104mod usage_tests {
105 use super::Usage;
106
107 #[test]
108 fn cache_helpers_fall_back_to_cached_prompt_tokens() {
109 let usage = Usage {
110 prompt_tokens: 1_000,
111 completion_tokens: 200,
112 total_tokens: 1_200,
113 cached_prompt_tokens: Some(600),
114 cache_creation_tokens: Some(150),
115 cache_read_tokens: None,
116 };
117
118 assert_eq!(usage.cache_read_tokens_or_fallback(), 600);
119 assert_eq!(usage.cache_creation_tokens_or_zero(), 150);
120 assert_eq!(usage.total_cache_tokens(), 750);
121 assert_eq!(usage.is_cache_hit(), Some(true));
122 assert_eq!(usage.is_cache_miss(), Some(false));
123 assert_eq!(usage.cache_savings_ratio(), Some(0.6));
124 assert_eq!(usage.cache_hit_rate(), Some(80.0));
125 }
126
127 #[test]
128 fn cache_helpers_preserve_unknown_without_metrics() {
129 let usage = Usage {
130 prompt_tokens: 1_000,
131 completion_tokens: 200,
132 total_tokens: 1_200,
133 cached_prompt_tokens: None,
134 cache_creation_tokens: None,
135 cache_read_tokens: None,
136 };
137
138 assert_eq!(usage.total_cache_tokens(), 0);
139 assert_eq!(usage.is_cache_hit(), None);
140 assert_eq!(usage.is_cache_miss(), None);
141 assert_eq!(usage.cache_savings_ratio(), None);
142 assert_eq!(usage.cache_hit_rate(), None);
143 }
144}
145
146#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
147pub enum FinishReason {
148 #[default]
149 Stop,
150 Length,
151 ToolCalls,
152 ContentFilter,
153 Pause,
154 Refusal,
155 Error(String),
156}
157
158#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
160pub struct ToolCall {
161 pub id: String,
163
164 #[serde(rename = "type")]
166 pub call_type: String,
167
168 #[serde(skip_serializing_if = "Option::is_none")]
170 pub function: Option<FunctionCall>,
171
172 #[serde(skip_serializing_if = "Option::is_none")]
174 pub text: Option<String>,
175
176 #[serde(skip_serializing_if = "Option::is_none")]
178 pub thought_signature: Option<String>,
179}
180
181#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
183pub struct FunctionCall {
184 pub name: String,
186
187 pub arguments: String,
189}
190
191impl ToolCall {
192 pub fn function(id: String, name: String, arguments: String) -> Self {
194 Self {
195 id,
196 call_type: "function".to_owned(),
197 function: Some(FunctionCall { name, arguments }),
198 text: None,
199 thought_signature: None,
200 }
201 }
202
203 pub fn custom(id: String, name: String, text: String) -> Self {
205 Self {
206 id,
207 call_type: "custom".to_owned(),
208 function: Some(FunctionCall {
209 name,
210 arguments: text.clone(),
211 }),
212 text: Some(text),
213 thought_signature: None,
214 }
215 }
216
217 pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
219 if let Some(ref func) = self.function {
220 parse_tool_arguments(&func.arguments)
221 } else {
222 serde_json::from_str("")
224 }
225 }
226
227 pub fn validate(&self) -> Result<(), String> {
229 if self.id.is_empty() {
230 return Err("Tool call ID cannot be empty".to_owned());
231 }
232
233 match self.call_type.as_str() {
234 "function" => {
235 if let Some(func) = &self.function {
236 if func.name.is_empty() {
237 return Err("Function name cannot be empty".to_owned());
238 }
239 if let Err(e) = self.parsed_arguments() {
241 return Err(format!("Invalid JSON in function arguments: {}", e));
242 }
243 } else {
244 return Err("Function tool call missing function details".to_owned());
245 }
246 }
247 "custom" => {
248 if let Some(func) = &self.function {
250 if func.name.is_empty() {
251 return Err("Custom tool name cannot be empty".to_owned());
252 }
253 } else {
254 return Err("Custom tool call missing function details".to_owned());
255 }
256 }
257 _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
258 }
259
260 Ok(())
261 }
262}
263
264fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
265 let trimmed = raw_arguments.trim();
266 match serde_json::from_str(trimmed) {
267 Ok(parsed) => Ok(parsed),
268 Err(primary_error) => {
269 if let Some(candidate) = extract_balanced_json(trimmed)
270 && let Ok(parsed) = serde_json::from_str(candidate)
271 {
272 return Ok(parsed);
273 }
274 Err(primary_error)
275 }
276 }
277}
278
279fn extract_balanced_json(input: &str) -> Option<&str> {
280 let start = input.find(['{', '['])?;
281 let opening = input.as_bytes().get(start).copied()?;
282 let closing = match opening {
283 b'{' => b'}',
284 b'[' => b']',
285 _ => return None,
286 };
287
288 let mut depth = 0usize;
289 let mut in_string = false;
290 let mut escaped = false;
291
292 for (offset, ch) in input[start..].char_indices() {
293 if in_string {
294 if escaped {
295 escaped = false;
296 continue;
297 }
298 if ch == '\\' {
299 escaped = true;
300 continue;
301 }
302 if ch == '"' {
303 in_string = false;
304 }
305 continue;
306 }
307
308 match ch {
309 '"' => in_string = true,
310 _ if ch as u32 == opening as u32 => depth += 1,
311 _ if ch as u32 == closing as u32 => {
312 depth = depth.saturating_sub(1);
313 if depth == 0 {
314 let end = start + offset + ch.len_utf8();
315 return input.get(start..end);
316 }
317 }
318 _ => {}
319 }
320 }
321
322 None
323}
324
325#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
327pub struct LLMResponse {
328 pub content: Option<String>,
330
331 pub tool_calls: Option<Vec<ToolCall>>,
333
334 pub model: String,
336
337 pub usage: Option<Usage>,
339
340 pub finish_reason: FinishReason,
342
343 pub reasoning: Option<String>,
345
346 pub reasoning_details: Option<Vec<String>>,
348
349 pub tool_references: Vec<String>,
351
352 pub request_id: Option<String>,
354
355 pub organization_id: Option<String>,
357}
358
359impl LLMResponse {
360 pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
362 Self {
363 content: Some(content.into()),
364 tool_calls: None,
365 model: model.into(),
366 usage: None,
367 finish_reason: FinishReason::Stop,
368 reasoning: None,
369 reasoning_details: None,
370 tool_references: Vec::new(),
371 request_id: None,
372 organization_id: None,
373 }
374 }
375
376 pub fn content_text(&self) -> &str {
378 self.content.as_deref().unwrap_or("")
379 }
380
381 pub fn content_string(&self) -> String {
383 self.content.clone().unwrap_or_default()
384 }
385}
386
387#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
388pub struct LLMErrorMetadata {
389 pub provider: Option<String>,
390 pub status: Option<u16>,
391 pub code: Option<String>,
392 pub request_id: Option<String>,
393 pub organization_id: Option<String>,
394 pub retry_after: Option<String>,
395 pub message: Option<String>,
396}
397
398impl LLMErrorMetadata {
399 pub fn new(
400 provider: impl Into<String>,
401 status: Option<u16>,
402 code: Option<String>,
403 request_id: Option<String>,
404 organization_id: Option<String>,
405 retry_after: Option<String>,
406 message: Option<String>,
407 ) -> Box<Self> {
408 Box::new(Self {
409 provider: Some(provider.into()),
410 status,
411 code,
412 request_id,
413 organization_id,
414 retry_after,
415 message,
416 })
417 }
418}
419
420#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
422#[serde(tag = "type", rename_all = "snake_case")]
423pub enum LLMError {
424 #[error("Authentication failed: {message}")]
425 Authentication {
426 message: String,
427 metadata: Option<Box<LLMErrorMetadata>>,
428 },
429 #[error("Rate limit exceeded")]
430 RateLimit {
431 metadata: Option<Box<LLMErrorMetadata>>,
432 },
433 #[error("Invalid request: {message}")]
434 InvalidRequest {
435 message: String,
436 metadata: Option<Box<LLMErrorMetadata>>,
437 },
438 #[error("Network error: {message}")]
439 Network {
440 message: String,
441 metadata: Option<Box<LLMErrorMetadata>>,
442 },
443 #[error("Provider error: {message}")]
444 Provider {
445 message: String,
446 metadata: Option<Box<LLMErrorMetadata>>,
447 },
448}
449
450#[cfg(test)]
451mod tests {
452 use super::ToolCall;
453 use serde_json::json;
454
455 #[test]
456 fn parsed_arguments_accepts_trailing_characters() {
457 let call = ToolCall::function(
458 "call_read".to_string(),
459 "read_file".to_string(),
460 r#"{"path":"src/main.rs"} trailing text"#.to_string(),
461 );
462
463 let parsed = call
464 .parsed_arguments()
465 .expect("arguments with trailing text should recover");
466 assert_eq!(parsed, json!({"path":"src/main.rs"}));
467 }
468
469 #[test]
470 fn parsed_arguments_accepts_code_fenced_json() {
471 let call = ToolCall::function(
472 "call_read".to_string(),
473 "read_file".to_string(),
474 "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
475 );
476
477 let parsed = call
478 .parsed_arguments()
479 .expect("code-fenced arguments should recover");
480 assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
481 }
482
483 #[test]
484 fn parsed_arguments_rejects_incomplete_json() {
485 let call = ToolCall::function(
486 "call_read".to_string(),
487 "read_file".to_string(),
488 r#"{"path":"src/main.rs""#.to_string(),
489 );
490
491 assert!(call.parsed_arguments().is_err());
492 }
493}