1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7 Gemini,
8 OpenAI,
9 Anthropic,
10 DeepSeek,
11 OpenRouter,
12 Ollama,
13 ZAI,
14 Moonshot,
15 HuggingFace,
16 Minimax,
17}
18
19#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
20pub struct Usage {
21 pub prompt_tokens: u32,
22 pub completion_tokens: u32,
23 pub total_tokens: u32,
24 pub cached_prompt_tokens: Option<u32>,
25 pub cache_creation_tokens: Option<u32>,
26 pub cache_read_tokens: Option<u32>,
27}
28
29impl Usage {
30 #[inline]
31 fn has_cache_read_metric(&self) -> bool {
32 self.cache_read_tokens.is_some() || self.cached_prompt_tokens.is_some()
33 }
34
35 #[inline]
36 fn has_any_cache_metrics(&self) -> bool {
37 self.has_cache_read_metric() || self.cache_creation_tokens.is_some()
38 }
39
40 #[inline]
41 pub fn cache_read_tokens_or_fallback(&self) -> u32 {
42 self.cache_read_tokens
43 .or(self.cached_prompt_tokens)
44 .unwrap_or(0)
45 }
46
47 #[inline]
48 pub fn cache_creation_tokens_or_zero(&self) -> u32 {
49 self.cache_creation_tokens.unwrap_or(0)
50 }
51
52 #[inline]
53 pub fn cache_hit_rate(&self) -> Option<f64> {
54 if !self.has_any_cache_metrics() {
55 return None;
56 }
57 let read = self.cache_read_tokens_or_fallback() as f64;
58 let creation = self.cache_creation_tokens_or_zero() as f64;
59 let total = read + creation;
60 if total > 0.0 {
61 Some((read / total) * 100.0)
62 } else {
63 None
64 }
65 }
66
67 #[inline]
68 pub fn is_cache_hit(&self) -> Option<bool> {
69 self.has_any_cache_metrics()
70 .then(|| self.cache_read_tokens_or_fallback() > 0)
71 }
72
73 #[inline]
74 pub fn is_cache_miss(&self) -> Option<bool> {
75 self.has_any_cache_metrics().then(|| {
76 self.cache_creation_tokens_or_zero() > 0 && self.cache_read_tokens_or_fallback() == 0
77 })
78 }
79
80 #[inline]
81 pub fn total_cache_tokens(&self) -> u32 {
82 let read = self.cache_read_tokens_or_fallback();
83 let creation = self.cache_creation_tokens_or_zero();
84 read + creation
85 }
86
87 #[inline]
88 pub fn cache_savings_ratio(&self) -> Option<f64> {
89 if !self.has_cache_read_metric() {
90 return None;
91 }
92 let read = self.cache_read_tokens_or_fallback() as f64;
93 let prompt = self.prompt_tokens as f64;
94 if prompt > 0.0 {
95 Some(read / prompt)
96 } else {
97 None
98 }
99 }
100}
101
102#[cfg(test)]
103mod usage_tests {
104 use super::Usage;
105
106 #[test]
107 fn cache_helpers_fall_back_to_cached_prompt_tokens() {
108 let usage = Usage {
109 prompt_tokens: 1_000,
110 completion_tokens: 200,
111 total_tokens: 1_200,
112 cached_prompt_tokens: Some(600),
113 cache_creation_tokens: Some(150),
114 cache_read_tokens: None,
115 };
116
117 assert_eq!(usage.cache_read_tokens_or_fallback(), 600);
118 assert_eq!(usage.cache_creation_tokens_or_zero(), 150);
119 assert_eq!(usage.total_cache_tokens(), 750);
120 assert_eq!(usage.is_cache_hit(), Some(true));
121 assert_eq!(usage.is_cache_miss(), Some(false));
122 assert_eq!(usage.cache_savings_ratio(), Some(0.6));
123 assert_eq!(usage.cache_hit_rate(), Some(80.0));
124 }
125
126 #[test]
127 fn cache_helpers_preserve_unknown_without_metrics() {
128 let usage = Usage {
129 prompt_tokens: 1_000,
130 completion_tokens: 200,
131 total_tokens: 1_200,
132 cached_prompt_tokens: None,
133 cache_creation_tokens: None,
134 cache_read_tokens: None,
135 };
136
137 assert_eq!(usage.total_cache_tokens(), 0);
138 assert_eq!(usage.is_cache_hit(), None);
139 assert_eq!(usage.is_cache_miss(), None);
140 assert_eq!(usage.cache_savings_ratio(), None);
141 assert_eq!(usage.cache_hit_rate(), None);
142 }
143}
144
145#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
146pub enum FinishReason {
147 #[default]
148 Stop,
149 Length,
150 ToolCalls,
151 ContentFilter,
152 Pause,
153 Refusal,
154 Error(String),
155}
156
157#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
159pub struct ToolCall {
160 pub id: String,
162
163 #[serde(rename = "type")]
165 pub call_type: String,
166
167 #[serde(skip_serializing_if = "Option::is_none")]
169 pub function: Option<FunctionCall>,
170
171 #[serde(skip_serializing_if = "Option::is_none")]
173 pub text: Option<String>,
174
175 #[serde(skip_serializing_if = "Option::is_none")]
177 pub thought_signature: Option<String>,
178}
179
180#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
182pub struct FunctionCall {
183 #[serde(default, skip_serializing_if = "Option::is_none")]
185 pub namespace: Option<String>,
186
187 pub name: String,
189
190 pub arguments: String,
192}
193
194impl ToolCall {
195 pub fn function(id: String, name: String, arguments: String) -> Self {
197 Self::function_with_namespace(id, None, name, arguments)
198 }
199
200 pub fn function_with_namespace(
202 id: String,
203 namespace: Option<String>,
204 name: String,
205 arguments: String,
206 ) -> Self {
207 Self {
208 id,
209 call_type: "function".to_owned(),
210 function: Some(FunctionCall {
211 namespace,
212 name,
213 arguments,
214 }),
215 text: None,
216 thought_signature: None,
217 }
218 }
219
220 pub fn custom(id: String, name: String, text: String) -> Self {
222 Self {
223 id,
224 call_type: "custom".to_owned(),
225 function: Some(FunctionCall {
226 namespace: None,
227 name,
228 arguments: text.clone(),
229 }),
230 text: Some(text),
231 thought_signature: None,
232 }
233 }
234
235 pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
237 if let Some(ref func) = self.function {
238 parse_tool_arguments(&func.arguments)
239 } else {
240 serde_json::from_str("")
242 }
243 }
244
245 pub fn validate(&self) -> Result<(), String> {
247 if self.id.is_empty() {
248 return Err("Tool call ID cannot be empty".to_owned());
249 }
250
251 match self.call_type.as_str() {
252 "function" => {
253 if let Some(func) = &self.function {
254 if func.name.is_empty() {
255 return Err("Function name cannot be empty".to_owned());
256 }
257 if let Err(e) = self.parsed_arguments() {
259 return Err(format!("Invalid JSON in function arguments: {}", e));
260 }
261 } else {
262 return Err("Function tool call missing function details".to_owned());
263 }
264 }
265 "custom" => {
266 if let Some(func) = &self.function {
268 if func.name.is_empty() {
269 return Err("Custom tool name cannot be empty".to_owned());
270 }
271 } else {
272 return Err("Custom tool call missing function details".to_owned());
273 }
274 }
275 _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
276 }
277
278 Ok(())
279 }
280}
281
282fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
283 let trimmed = raw_arguments.trim();
284 match serde_json::from_str(trimmed) {
285 Ok(parsed) => Ok(parsed),
286 Err(primary_error) => {
287 if let Some(candidate) = extract_balanced_json(trimmed)
288 && let Ok(parsed) = serde_json::from_str(candidate)
289 {
290 return Ok(parsed);
291 }
292 if let Some(candidate) = repair_tag_polluted_json(trimmed)
293 && let Ok(parsed) = serde_json::from_str(&candidate)
294 {
295 return Ok(parsed);
296 }
297 Err(primary_error)
298 }
299 }
300}
301
302fn extract_balanced_json(input: &str) -> Option<&str> {
303 let start = input.find(['{', '['])?;
304 let opening = input.as_bytes().get(start).copied()?;
305 let closing = match opening {
306 b'{' => b'}',
307 b'[' => b']',
308 _ => return None,
309 };
310
311 let mut depth = 0usize;
312 let mut in_string = false;
313 let mut escaped = false;
314
315 for (offset, ch) in input[start..].char_indices() {
316 if in_string {
317 if escaped {
318 escaped = false;
319 continue;
320 }
321 if ch == '\\' {
322 escaped = true;
323 continue;
324 }
325 if ch == '"' {
326 in_string = false;
327 }
328 continue;
329 }
330
331 match ch {
332 '"' => in_string = true,
333 _ if ch as u32 == opening as u32 => depth += 1,
334 _ if ch as u32 == closing as u32 => {
335 depth = depth.saturating_sub(1);
336 if depth == 0 {
337 let end = start + offset + ch.len_utf8();
338 return input.get(start..end);
339 }
340 }
341 _ => {}
342 }
343 }
344
345 None
346}
347
348fn repair_tag_polluted_json(input: &str) -> Option<String> {
349 let start = input.find(['{', '['])?;
350 let candidate = input.get(start..)?;
351 let boundary = find_provider_markup_boundary(candidate)?;
352 if boundary == 0 {
353 return None;
354 }
355
356 close_incomplete_json_prefix(candidate[..boundary].trim_end())
357}
358
359fn find_provider_markup_boundary(input: &str) -> Option<usize> {
360 const PROVIDER_MARKERS: &[&str] = &[
361 "<</",
362 "</parameter>",
363 "</invoke>",
364 "</minimax:tool_call>",
365 "<minimax:tool_call>",
366 "<parameter name=\"",
367 "<invoke name=\"",
368 "<tool_call>",
369 "</tool_call>",
370 ];
371
372 input.char_indices().find_map(|(offset, _)| {
373 let rest = input.get(offset..)?;
374 PROVIDER_MARKERS
375 .iter()
376 .any(|marker| rest.starts_with(marker))
377 .then_some(offset)
378 })
379}
380
381fn close_incomplete_json_prefix(prefix: &str) -> Option<String> {
382 if prefix.is_empty() {
383 return None;
384 }
385
386 let mut repaired = String::with_capacity(prefix.len() + 8);
387 let mut expected_closers = Vec::new();
388 let mut in_string = false;
389 let mut escaped = false;
390
391 for ch in prefix.chars() {
392 repaired.push(ch);
393
394 if in_string {
395 if escaped {
396 escaped = false;
397 continue;
398 }
399
400 match ch {
401 '\\' => escaped = true,
402 '"' => in_string = false,
403 _ => {}
404 }
405 continue;
406 }
407
408 match ch {
409 '"' => in_string = true,
410 '{' => expected_closers.push('}'),
411 '[' => expected_closers.push(']'),
412 '}' | ']' => {
413 if expected_closers.pop() != Some(ch) {
414 return None;
415 }
416 }
417 _ => {}
418 }
419 }
420
421 if in_string {
422 repaired.push('"');
423 }
424 while let Some(closer) = expected_closers.pop() {
425 repaired.push(closer);
426 }
427
428 Some(repaired)
429}
430
431#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
433pub struct LLMResponse {
434 pub content: Option<String>,
436
437 pub tool_calls: Option<Vec<ToolCall>>,
439
440 pub model: String,
442
443 pub usage: Option<Usage>,
445
446 pub finish_reason: FinishReason,
448
449 pub reasoning: Option<String>,
451
452 pub reasoning_details: Option<Vec<String>>,
454
455 pub tool_references: Vec<String>,
457
458 pub request_id: Option<String>,
460
461 pub organization_id: Option<String>,
463}
464
465impl LLMResponse {
466 pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
468 Self {
469 content: Some(content.into()),
470 tool_calls: None,
471 model: model.into(),
472 usage: None,
473 finish_reason: FinishReason::Stop,
474 reasoning: None,
475 reasoning_details: None,
476 tool_references: Vec::new(),
477 request_id: None,
478 organization_id: None,
479 }
480 }
481
482 pub fn content_text(&self) -> &str {
484 self.content.as_deref().unwrap_or("")
485 }
486
487 pub fn content_string(&self) -> String {
489 self.content.clone().unwrap_or_default()
490 }
491}
492
493#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
494pub struct LLMErrorMetadata {
495 pub provider: Option<String>,
496 pub status: Option<u16>,
497 pub code: Option<String>,
498 pub request_id: Option<String>,
499 pub organization_id: Option<String>,
500 pub retry_after: Option<String>,
501 pub message: Option<String>,
502}
503
504impl LLMErrorMetadata {
505 pub fn new(
506 provider: impl Into<String>,
507 status: Option<u16>,
508 code: Option<String>,
509 request_id: Option<String>,
510 organization_id: Option<String>,
511 retry_after: Option<String>,
512 message: Option<String>,
513 ) -> Box<Self> {
514 Box::new(Self {
515 provider: Some(provider.into()),
516 status,
517 code,
518 request_id,
519 organization_id,
520 retry_after,
521 message,
522 })
523 }
524}
525
526#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
528#[serde(tag = "type", rename_all = "snake_case")]
529pub enum LLMError {
530 #[error("Authentication failed: {message}")]
531 Authentication {
532 message: String,
533 metadata: Option<Box<LLMErrorMetadata>>,
534 },
535 #[error("Rate limit exceeded")]
536 RateLimit {
537 metadata: Option<Box<LLMErrorMetadata>>,
538 },
539 #[error("Invalid request: {message}")]
540 InvalidRequest {
541 message: String,
542 metadata: Option<Box<LLMErrorMetadata>>,
543 },
544 #[error("Network error: {message}")]
545 Network {
546 message: String,
547 metadata: Option<Box<LLMErrorMetadata>>,
548 },
549 #[error("Provider error: {message}")]
550 Provider {
551 message: String,
552 metadata: Option<Box<LLMErrorMetadata>>,
553 },
554}
555
556#[cfg(test)]
557mod tests {
558 use super::ToolCall;
559 use serde_json::json;
560
561 #[test]
562 fn parsed_arguments_accepts_trailing_characters() {
563 let call = ToolCall::function(
564 "call_read".to_string(),
565 "read_file".to_string(),
566 r#"{"path":"src/main.rs"} trailing text"#.to_string(),
567 );
568
569 let parsed = call
570 .parsed_arguments()
571 .expect("arguments with trailing text should recover");
572 assert_eq!(parsed, json!({"path":"src/main.rs"}));
573 }
574
575 #[test]
576 fn parsed_arguments_accepts_code_fenced_json() {
577 let call = ToolCall::function(
578 "call_read".to_string(),
579 "read_file".to_string(),
580 "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
581 );
582
583 let parsed = call
584 .parsed_arguments()
585 .expect("code-fenced arguments should recover");
586 assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
587 }
588
589 #[test]
590 fn parsed_arguments_rejects_incomplete_json() {
591 let call = ToolCall::function(
592 "call_read".to_string(),
593 "read_file".to_string(),
594 r#"{"path":"src/main.rs""#.to_string(),
595 );
596
597 assert!(call.parsed_arguments().is_err());
598 }
599
600 #[test]
601 fn parsed_arguments_recovers_truncated_minimax_markup() {
602 let call = ToolCall::function(
603 "call_search".to_string(),
604 "unified_search".to_string(),
605 "{\"action\": \"grep\", \"pattern\": \"persistent_memory\", \"path\": \"vtcode-core/src</parameter>\n<</invoke>\n</minimax:tool_call>".to_string(),
606 );
607
608 let parsed = call
609 .parsed_arguments()
610 .expect("minimax markup spillover should recover");
611 assert_eq!(
612 parsed,
613 json!({
614 "action": "grep",
615 "pattern": "persistent_memory",
616 "path": "vtcode-core/src"
617 })
618 );
619 }
620
621 #[test]
622 fn function_call_serializes_optional_namespace() {
623 let call = ToolCall::function_with_namespace(
624 "call_read".to_string(),
625 Some("workspace".to_string()),
626 "read_file".to_string(),
627 r#"{"path":"src/main.rs"}"#.to_string(),
628 );
629
630 let json = serde_json::to_value(&call).expect("tool call should serialize");
631 assert_eq!(json["function"]["namespace"], "workspace");
632 assert_eq!(json["function"]["name"], "read_file");
633 }
634}