1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7 Gemini,
8 OpenAI,
9 Anthropic,
10 DeepSeek,
11 OpenRouter,
12 Ollama,
13 ZAI,
14 Moonshot,
15 HuggingFace,
16 Minimax,
17}
18
19#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
20pub struct Usage {
21 pub prompt_tokens: u32,
22 pub completion_tokens: u32,
23 pub total_tokens: u32,
24 pub cached_prompt_tokens: Option<u32>,
25 pub cache_creation_tokens: Option<u32>,
26 pub cache_read_tokens: Option<u32>,
27}
28
29impl Usage {
30 #[inline]
31 fn has_cache_read_metric(&self) -> bool {
32 self.cache_read_tokens.is_some() || self.cached_prompt_tokens.is_some()
33 }
34
35 #[inline]
36 fn has_any_cache_metrics(&self) -> bool {
37 self.has_cache_read_metric() || self.cache_creation_tokens.is_some()
38 }
39
40 #[inline]
41 pub fn cache_read_tokens_or_fallback(&self) -> u32 {
42 self.cache_read_tokens
43 .or(self.cached_prompt_tokens)
44 .unwrap_or(0)
45 }
46
47 #[inline]
48 pub fn cache_creation_tokens_or_zero(&self) -> u32 {
49 self.cache_creation_tokens.unwrap_or(0)
50 }
51
52 #[inline]
53 pub fn cache_hit_rate(&self) -> Option<f64> {
54 if !self.has_any_cache_metrics() {
55 return None;
56 }
57 let read = self.cache_read_tokens_or_fallback() as f64;
58 let creation = self.cache_creation_tokens_or_zero() as f64;
59 let total = read + creation;
60 if total > 0.0 {
61 Some((read / total) * 100.0)
62 } else {
63 None
64 }
65 }
66
67 #[inline]
68 pub fn is_cache_hit(&self) -> Option<bool> {
69 self.has_any_cache_metrics()
70 .then(|| self.cache_read_tokens_or_fallback() > 0)
71 }
72
73 #[inline]
74 pub fn is_cache_miss(&self) -> Option<bool> {
75 self.has_any_cache_metrics().then(|| {
76 self.cache_creation_tokens_or_zero() > 0 && self.cache_read_tokens_or_fallback() == 0
77 })
78 }
79
80 #[inline]
81 pub fn total_cache_tokens(&self) -> u32 {
82 let read = self.cache_read_tokens_or_fallback();
83 let creation = self.cache_creation_tokens_or_zero();
84 read + creation
85 }
86
87 #[inline]
88 pub fn cache_savings_ratio(&self) -> Option<f64> {
89 if !self.has_cache_read_metric() {
90 return None;
91 }
92 let read = self.cache_read_tokens_or_fallback() as f64;
93 let prompt = self.prompt_tokens as f64;
94 if prompt > 0.0 {
95 Some(read / prompt)
96 } else {
97 None
98 }
99 }
100}
101
102#[cfg(test)]
103mod usage_tests {
104 use super::Usage;
105
106 #[test]
107 fn cache_helpers_fall_back_to_cached_prompt_tokens() {
108 let usage = Usage {
109 prompt_tokens: 1_000,
110 completion_tokens: 200,
111 total_tokens: 1_200,
112 cached_prompt_tokens: Some(600),
113 cache_creation_tokens: Some(150),
114 cache_read_tokens: None,
115 };
116
117 assert_eq!(usage.cache_read_tokens_or_fallback(), 600);
118 assert_eq!(usage.cache_creation_tokens_or_zero(), 150);
119 assert_eq!(usage.total_cache_tokens(), 750);
120 assert_eq!(usage.is_cache_hit(), Some(true));
121 assert_eq!(usage.is_cache_miss(), Some(false));
122 assert_eq!(usage.cache_savings_ratio(), Some(0.6));
123 assert_eq!(usage.cache_hit_rate(), Some(80.0));
124 }
125
126 #[test]
127 fn cache_helpers_preserve_unknown_without_metrics() {
128 let usage = Usage {
129 prompt_tokens: 1_000,
130 completion_tokens: 200,
131 total_tokens: 1_200,
132 cached_prompt_tokens: None,
133 cache_creation_tokens: None,
134 cache_read_tokens: None,
135 };
136
137 assert_eq!(usage.total_cache_tokens(), 0);
138 assert_eq!(usage.is_cache_hit(), None);
139 assert_eq!(usage.is_cache_miss(), None);
140 assert_eq!(usage.cache_savings_ratio(), None);
141 assert_eq!(usage.cache_hit_rate(), None);
142 }
143}
144
145#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
146pub enum FinishReason {
147 #[default]
148 Stop,
149 Length,
150 ToolCalls,
151 ContentFilter,
152 Pause,
153 Refusal,
154 Error(String),
155}
156
157#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
159pub struct ToolCall {
160 pub id: String,
162
163 #[serde(rename = "type")]
165 pub call_type: String,
166
167 #[serde(skip_serializing_if = "Option::is_none")]
169 pub function: Option<FunctionCall>,
170
171 #[serde(skip_serializing_if = "Option::is_none")]
173 pub text: Option<String>,
174
175 #[serde(skip_serializing_if = "Option::is_none")]
177 pub thought_signature: Option<String>,
178}
179
180#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
182pub struct FunctionCall {
183 #[serde(default, skip_serializing_if = "Option::is_none")]
185 pub namespace: Option<String>,
186
187 pub name: String,
189
190 pub arguments: String,
192}
193
194impl ToolCall {
195 pub fn function(id: String, name: String, arguments: String) -> Self {
197 Self::function_with_namespace(id, None, name, arguments)
198 }
199
200 pub fn function_with_namespace(
202 id: String,
203 namespace: Option<String>,
204 name: String,
205 arguments: String,
206 ) -> Self {
207 Self {
208 id,
209 call_type: "function".to_owned(),
210 function: Some(FunctionCall {
211 namespace,
212 name,
213 arguments,
214 }),
215 text: None,
216 thought_signature: None,
217 }
218 }
219
220 pub fn custom(id: String, name: String, text: String) -> Self {
222 Self {
223 id,
224 call_type: "custom".to_owned(),
225 function: Some(FunctionCall {
226 namespace: None,
227 name,
228 arguments: text.clone(),
229 }),
230 text: Some(text),
231 thought_signature: None,
232 }
233 }
234
235 pub fn is_custom(&self) -> bool {
237 self.call_type == "custom"
238 }
239
240 pub fn tool_name(&self) -> Option<&str> {
242 self.function
243 .as_ref()
244 .map(|function| function.name.as_str())
245 }
246
247 pub fn raw_input(&self) -> Option<&str> {
249 self.text.as_deref().or_else(|| {
250 self.function
251 .as_ref()
252 .map(|function| function.arguments.as_str())
253 })
254 }
255
256 pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
258 if let Some(ref func) = self.function {
259 parse_tool_arguments(&func.arguments)
260 } else {
261 serde_json::from_str("")
263 }
264 }
265
266 pub fn execution_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
272 if self.is_custom() {
273 return Ok(serde_json::Value::String(
274 self.raw_input().unwrap_or_default().to_string(),
275 ));
276 }
277
278 self.parsed_arguments()
279 }
280
281 pub fn validate(&self) -> Result<(), String> {
283 if self.id.is_empty() {
284 return Err("Tool call ID cannot be empty".to_owned());
285 }
286
287 match self.call_type.as_str() {
288 "function" => {
289 if let Some(func) = &self.function {
290 if func.name.is_empty() {
291 return Err("Function name cannot be empty".to_owned());
292 }
293 if let Err(e) = self.parsed_arguments() {
295 return Err(format!("Invalid JSON in function arguments: {}", e));
296 }
297 } else {
298 return Err("Function tool call missing function details".to_owned());
299 }
300 }
301 "custom" => {
302 if let Some(func) = &self.function {
304 if func.name.is_empty() {
305 return Err("Custom tool name cannot be empty".to_owned());
306 }
307 } else {
308 return Err("Custom tool call missing function details".to_owned());
309 }
310 }
311 _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
312 }
313
314 Ok(())
315 }
316}
317
318fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
319 let trimmed = raw_arguments.trim();
320 match serde_json::from_str(trimmed) {
321 Ok(parsed) => Ok(parsed),
322 Err(primary_error) => {
323 if let Some(candidate) = extract_balanced_json(trimmed)
324 && let Ok(parsed) = serde_json::from_str(candidate)
325 {
326 return Ok(parsed);
327 }
328 if let Some(candidate) = repair_tag_polluted_json(trimmed)
329 && let Ok(parsed) = serde_json::from_str(&candidate)
330 {
331 return Ok(parsed);
332 }
333 Err(primary_error)
334 }
335 }
336}
337
338fn extract_balanced_json(input: &str) -> Option<&str> {
339 let start = input.find(['{', '['])?;
340 let opening = input.as_bytes().get(start).copied()?;
341 let closing = match opening {
342 b'{' => b'}',
343 b'[' => b']',
344 _ => return None,
345 };
346
347 let mut depth = 0usize;
348 let mut in_string = false;
349 let mut escaped = false;
350
351 for (offset, ch) in input[start..].char_indices() {
352 if in_string {
353 if escaped {
354 escaped = false;
355 continue;
356 }
357 if ch == '\\' {
358 escaped = true;
359 continue;
360 }
361 if ch == '"' {
362 in_string = false;
363 }
364 continue;
365 }
366
367 match ch {
368 '"' => in_string = true,
369 _ if ch as u32 == opening as u32 => depth += 1,
370 _ if ch as u32 == closing as u32 => {
371 depth = depth.saturating_sub(1);
372 if depth == 0 {
373 let end = start + offset + ch.len_utf8();
374 return input.get(start..end);
375 }
376 }
377 _ => {}
378 }
379 }
380
381 None
382}
383
384fn repair_tag_polluted_json(input: &str) -> Option<String> {
385 let start = input.find(['{', '['])?;
386 let candidate = input.get(start..)?;
387 let boundary = find_provider_markup_boundary(candidate)?;
388 if boundary == 0 {
389 return None;
390 }
391
392 close_incomplete_json_prefix(candidate[..boundary].trim_end())
393}
394
395fn find_provider_markup_boundary(input: &str) -> Option<usize> {
396 const PROVIDER_MARKERS: &[&str] = &[
397 "<</",
398 "</parameter>",
399 "</invoke>",
400 "</minimax:tool_call>",
401 "<minimax:tool_call>",
402 "<parameter name=\"",
403 "<invoke name=\"",
404 "<tool_call>",
405 "</tool_call>",
406 ];
407
408 input.char_indices().find_map(|(offset, _)| {
409 let rest = input.get(offset..)?;
410 PROVIDER_MARKERS
411 .iter()
412 .any(|marker| rest.starts_with(marker))
413 .then_some(offset)
414 })
415}
416
417fn close_incomplete_json_prefix(prefix: &str) -> Option<String> {
418 if prefix.is_empty() {
419 return None;
420 }
421
422 let mut repaired = String::with_capacity(prefix.len() + 8);
423 let mut expected_closers = Vec::new();
424 let mut in_string = false;
425 let mut escaped = false;
426
427 for ch in prefix.chars() {
428 repaired.push(ch);
429
430 if in_string {
431 if escaped {
432 escaped = false;
433 continue;
434 }
435
436 match ch {
437 '\\' => escaped = true,
438 '"' => in_string = false,
439 _ => {}
440 }
441 continue;
442 }
443
444 match ch {
445 '"' => in_string = true,
446 '{' => expected_closers.push('}'),
447 '[' => expected_closers.push(']'),
448 '}' | ']' => {
449 if expected_closers.pop() != Some(ch) {
450 return None;
451 }
452 }
453 _ => {}
454 }
455 }
456
457 if in_string {
458 repaired.push('"');
459 }
460 while let Some(closer) = expected_closers.pop() {
461 repaired.push(closer);
462 }
463
464 Some(repaired)
465}
466
467#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
469pub struct LLMResponse {
470 pub content: Option<String>,
472
473 pub tool_calls: Option<Vec<ToolCall>>,
475
476 pub model: String,
478
479 pub usage: Option<Usage>,
481
482 pub finish_reason: FinishReason,
484
485 pub reasoning: Option<String>,
487
488 pub reasoning_details: Option<Vec<String>>,
490
491 pub tool_references: Vec<String>,
493
494 pub request_id: Option<String>,
496
497 pub organization_id: Option<String>,
499}
500
501impl LLMResponse {
502 pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
504 Self {
505 content: Some(content.into()),
506 tool_calls: None,
507 model: model.into(),
508 usage: None,
509 finish_reason: FinishReason::Stop,
510 reasoning: None,
511 reasoning_details: None,
512 tool_references: Vec::new(),
513 request_id: None,
514 organization_id: None,
515 }
516 }
517
518 pub fn content_text(&self) -> &str {
520 self.content.as_deref().unwrap_or("")
521 }
522
523 pub fn content_string(&self) -> String {
525 self.content.clone().unwrap_or_default()
526 }
527}
528
529#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
530pub struct LLMErrorMetadata {
531 pub provider: Option<String>,
532 pub status: Option<u16>,
533 pub code: Option<String>,
534 pub request_id: Option<String>,
535 pub organization_id: Option<String>,
536 pub retry_after: Option<String>,
537 pub message: Option<String>,
538}
539
540impl LLMErrorMetadata {
541 pub fn new(
542 provider: impl Into<String>,
543 status: Option<u16>,
544 code: Option<String>,
545 request_id: Option<String>,
546 organization_id: Option<String>,
547 retry_after: Option<String>,
548 message: Option<String>,
549 ) -> Box<Self> {
550 Box::new(Self {
551 provider: Some(provider.into()),
552 status,
553 code,
554 request_id,
555 organization_id,
556 retry_after,
557 message,
558 })
559 }
560}
561
562#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
564#[serde(tag = "type", rename_all = "snake_case")]
565pub enum LLMError {
566 #[error("Authentication failed: {message}")]
567 Authentication {
568 message: String,
569 metadata: Option<Box<LLMErrorMetadata>>,
570 },
571 #[error("Rate limit exceeded")]
572 RateLimit {
573 metadata: Option<Box<LLMErrorMetadata>>,
574 },
575 #[error("Invalid request: {message}")]
576 InvalidRequest {
577 message: String,
578 metadata: Option<Box<LLMErrorMetadata>>,
579 },
580 #[error("Network error: {message}")]
581 Network {
582 message: String,
583 metadata: Option<Box<LLMErrorMetadata>>,
584 },
585 #[error("Provider error: {message}")]
586 Provider {
587 message: String,
588 metadata: Option<Box<LLMErrorMetadata>>,
589 },
590}
591
592#[cfg(test)]
593mod tests {
594 use super::ToolCall;
595 use serde_json::json;
596
597 #[test]
598 fn parsed_arguments_accepts_trailing_characters() {
599 let call = ToolCall::function(
600 "call_read".to_string(),
601 "read_file".to_string(),
602 r#"{"path":"src/main.rs"} trailing text"#.to_string(),
603 );
604
605 let parsed = call
606 .parsed_arguments()
607 .expect("arguments with trailing text should recover");
608 assert_eq!(parsed, json!({"path":"src/main.rs"}));
609 }
610
611 #[test]
612 fn parsed_arguments_accepts_code_fenced_json() {
613 let call = ToolCall::function(
614 "call_read".to_string(),
615 "read_file".to_string(),
616 "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
617 );
618
619 let parsed = call
620 .parsed_arguments()
621 .expect("code-fenced arguments should recover");
622 assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
623 }
624
625 #[test]
626 fn parsed_arguments_rejects_incomplete_json() {
627 let call = ToolCall::function(
628 "call_read".to_string(),
629 "read_file".to_string(),
630 r#"{"path":"src/main.rs""#.to_string(),
631 );
632
633 assert!(call.parsed_arguments().is_err());
634 }
635
636 #[test]
637 fn parsed_arguments_recovers_truncated_minimax_markup() {
638 let call = ToolCall::function(
639 "call_search".to_string(),
640 "unified_search".to_string(),
641 "{\"action\": \"grep\", \"pattern\": \"persistent_memory\", \"path\": \"vtcode-core/src</parameter>\n<</invoke>\n</minimax:tool_call>".to_string(),
642 );
643
644 let parsed = call
645 .parsed_arguments()
646 .expect("minimax markup spillover should recover");
647 assert_eq!(
648 parsed,
649 json!({
650 "action": "grep",
651 "pattern": "persistent_memory",
652 "path": "vtcode-core/src"
653 })
654 );
655 }
656
657 #[test]
658 fn function_call_serializes_optional_namespace() {
659 let call = ToolCall::function_with_namespace(
660 "call_read".to_string(),
661 Some("workspace".to_string()),
662 "read_file".to_string(),
663 r#"{"path":"src/main.rs"}"#.to_string(),
664 );
665
666 let json = serde_json::to_value(&call).expect("tool call should serialize");
667 assert_eq!(json["function"]["namespace"], "workspace");
668 assert_eq!(json["function"]["name"], "read_file");
669 }
670
671 #[test]
672 fn custom_tool_call_exposes_raw_execution_arguments() {
673 let patch = "*** Begin Patch\n*** End Patch\n".to_string();
674 let call = ToolCall::custom(
675 "call_patch".to_string(),
676 "apply_patch".to_string(),
677 patch.clone(),
678 );
679
680 assert!(call.is_custom());
681 assert_eq!(call.tool_name(), Some("apply_patch"));
682 assert_eq!(call.raw_input(), Some(patch.as_str()));
683 assert_eq!(
684 call.execution_arguments().expect("custom arguments"),
685 json!(patch)
686 );
687 assert!(
688 call.parsed_arguments().is_err(),
689 "custom tool payload should stay freeform rather than JSON"
690 );
691 }
692}