Skip to main content

vtcode_commons/
llm.rs

1//! Core LLM types shared across the project
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7    Gemini,
8    OpenAI,
9    Anthropic,
10    DeepSeek,
11    Mistral,
12    OpenRouter,
13    Ollama,
14    LlamaCpp,
15    ZAI,
16    Moonshot,
17    HuggingFace,
18    Minimax,
19    MiMo,
20    OpenCodeZen,
21    OpenCodeGo,
22    Qwen,
23    StepFun,
24    Evolink,
25    Poolside,
26}
27
28#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
29pub struct Usage {
30    pub prompt_tokens: u32,
31    pub completion_tokens: u32,
32    pub total_tokens: u32,
33    pub cached_prompt_tokens: Option<u32>,
34    pub cache_creation_tokens: Option<u32>,
35    pub cache_read_tokens: Option<u32>,
36    /// Per-iteration token usage for Anthropic server-side fallback and compaction.
37    /// Each entry represents one sampling pass (message, fallback_message, or compaction).
38    #[serde(default, skip_serializing_if = "Option::is_none")]
39    pub iterations: Option<Vec<serde_json::Value>>,
40}
41
42impl Usage {
43    #[inline]
44    fn has_cache_read_metric(&self) -> bool {
45        self.cache_read_tokens.is_some() || self.cached_prompt_tokens.is_some()
46    }
47
48    #[inline]
49    fn has_any_cache_metrics(&self) -> bool {
50        self.has_cache_read_metric() || self.cache_creation_tokens.is_some()
51    }
52
53    #[inline]
54    pub fn cache_read_tokens_or_fallback(&self) -> u32 {
55        self.cache_read_tokens
56            .or(self.cached_prompt_tokens)
57            .unwrap_or(0)
58    }
59
60    #[inline]
61    pub fn cache_creation_tokens_or_zero(&self) -> u32 {
62        self.cache_creation_tokens.unwrap_or(0)
63    }
64
65    #[inline]
66    pub fn cache_hit_rate(&self) -> Option<f64> {
67        if !self.has_any_cache_metrics() {
68            return None;
69        }
70        let read = self.cache_read_tokens_or_fallback() as f64;
71        let creation = self.cache_creation_tokens_or_zero() as f64;
72        let total = read + creation;
73        if total > 0.0 {
74            Some((read / total) * 100.0)
75        } else {
76            None
77        }
78    }
79
80    #[inline]
81    pub fn is_cache_hit(&self) -> Option<bool> {
82        self.has_any_cache_metrics()
83            .then(|| self.cache_read_tokens_or_fallback() > 0)
84    }
85
86    #[inline]
87    pub fn is_cache_miss(&self) -> Option<bool> {
88        self.has_any_cache_metrics().then(|| {
89            self.cache_creation_tokens_or_zero() > 0 && self.cache_read_tokens_or_fallback() == 0
90        })
91    }
92
93    #[inline]
94    pub fn total_cache_tokens(&self) -> u32 {
95        let read = self.cache_read_tokens_or_fallback();
96        let creation = self.cache_creation_tokens_or_zero();
97        read + creation
98    }
99
100    #[inline]
101    pub fn cache_savings_ratio(&self) -> Option<f64> {
102        if !self.has_cache_read_metric() {
103            return None;
104        }
105        let read = self.cache_read_tokens_or_fallback() as f64;
106        let prompt = self.prompt_tokens as f64;
107        if prompt > 0.0 {
108            Some(read / prompt)
109        } else {
110            None
111        }
112    }
113}
114
115/// Provider-agnostic balance information for account status display.
116#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
117pub struct BalanceInfo {
118    /// Human-readable balance string (e.g. "100.00¥", "$50.00").
119    pub display: String,
120    /// Whether the account has sufficient balance for API calls.
121    pub is_available: bool,
122}
123
124/// DeepSeek-specific balance info from GET /user/balance
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct DeepSeekBalanceResponse {
127    pub is_available: bool,
128    pub balance_infos: Vec<DeepSeekCurrencyBalance>,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct DeepSeekCurrencyBalance {
133    pub currency: String,
134    pub total_balance: String,
135    #[serde(default)]
136    pub granted_balance: String,
137    #[serde(default)]
138    pub topped_up_balance: String,
139}
140
141impl From<DeepSeekBalanceResponse> for BalanceInfo {
142    fn from(resp: DeepSeekBalanceResponse) -> Self {
143        let display = resp
144            .balance_infos
145            .first()
146            .map(|b| {
147                let symbol = match b.currency.as_str() {
148                    "CNY" => "¥",
149                    "USD" => "$",
150                    _ => &b.currency,
151                };
152                format!("{}{}", b.total_balance, symbol)
153            })
154            .unwrap_or_else(|| "N/A".to_string());
155        BalanceInfo {
156            display,
157            is_available: resp.is_available,
158        }
159    }
160}
161
162#[cfg(test)]
163mod usage_tests {
164    use super::Usage;
165
166    #[test]
167    fn cache_helpers_fall_back_to_cached_prompt_tokens() {
168        let usage = Usage {
169            prompt_tokens: 1_000,
170            completion_tokens: 200,
171            total_tokens: 1_200,
172            cached_prompt_tokens: Some(600),
173            cache_creation_tokens: Some(150),
174            cache_read_tokens: None,
175            iterations: None,
176        };
177
178        assert_eq!(usage.cache_read_tokens_or_fallback(), 600);
179        assert_eq!(usage.cache_creation_tokens_or_zero(), 150);
180        assert_eq!(usage.total_cache_tokens(), 750);
181        assert_eq!(usage.is_cache_hit(), Some(true));
182        assert_eq!(usage.is_cache_miss(), Some(false));
183        assert_eq!(usage.cache_savings_ratio(), Some(0.6));
184        assert_eq!(usage.cache_hit_rate(), Some(80.0));
185    }
186
187    #[test]
188    fn cache_helpers_preserve_unknown_without_metrics() {
189        let usage = Usage {
190            prompt_tokens: 1_000,
191            completion_tokens: 200,
192            total_tokens: 1_200,
193            cached_prompt_tokens: None,
194            cache_creation_tokens: None,
195            cache_read_tokens: None,
196            iterations: None,
197        };
198
199        assert_eq!(usage.total_cache_tokens(), 0);
200        assert_eq!(usage.is_cache_hit(), None);
201        assert_eq!(usage.is_cache_miss(), None);
202        assert_eq!(usage.cache_savings_ratio(), None);
203        assert_eq!(usage.cache_hit_rate(), None);
204    }
205}
206
207#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
208pub enum FinishReason {
209    #[default]
210    Stop,
211    Length,
212    ToolCalls,
213    ContentFilter,
214    Pause,
215    Refusal,
216    Error(String),
217}
218
219/// Universal tool call that matches OpenAI/Anthropic/Gemini specifications
220#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
221pub struct ToolCall {
222    /// Unique identifier for this tool call (e.g., "call_123")
223    pub id: String,
224
225    /// The type of tool call: "function", "custom" (GPT-5 freeform), or other
226    #[serde(rename = "type")]
227    pub call_type: String,
228
229    /// Function call details (for function-type tools)
230    #[serde(skip_serializing_if = "Option::is_none")]
231    pub function: Option<FunctionCall>,
232
233    /// Raw text payload (for custom freeform tools in GPT-5)
234    #[serde(skip_serializing_if = "Option::is_none")]
235    pub text: Option<String>,
236
237    /// Gemini-specific thought signature for maintaining reasoning context
238    #[serde(skip_serializing_if = "Option::is_none")]
239    pub thought_signature: Option<String>,
240}
241
242/// Function call within a tool call
243#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
244pub struct FunctionCall {
245    /// Optional namespace for grouped or deferred tools.
246    #[serde(default, skip_serializing_if = "Option::is_none")]
247    pub namespace: Option<String>,
248
249    /// The name of the function to call
250    pub name: String,
251
252    /// The arguments to pass to the function, as a JSON string
253    pub arguments: String,
254}
255
256impl ToolCall {
257    /// Create a new function tool call
258    pub fn function(id: String, name: String, arguments: String) -> Self {
259        Self::function_with_namespace(id, None, name, arguments)
260    }
261
262    /// Create a new function tool call with an optional namespace.
263    pub fn function_with_namespace(
264        id: String,
265        namespace: Option<String>,
266        name: String,
267        arguments: String,
268    ) -> Self {
269        Self {
270            id,
271            call_type: "function".to_owned(),
272            function: Some(FunctionCall {
273                namespace,
274                name,
275                arguments,
276            }),
277            text: None,
278            thought_signature: None,
279        }
280    }
281
282    /// Create a new custom tool call with raw text payload (GPT-5 freeform)
283    pub fn custom(id: String, name: String, text: String) -> Self {
284        Self {
285            id,
286            call_type: "custom".to_owned(),
287            function: Some(FunctionCall {
288                namespace: None,
289                name,
290                arguments: text.clone(),
291            }),
292            text: Some(text),
293            thought_signature: None,
294        }
295    }
296
297    /// Returns true when this tool call uses GPT-5 custom/freeform semantics.
298    pub fn is_custom(&self) -> bool {
299        self.call_type == "custom"
300    }
301
302    /// Returns the tool name when the call includes function details.
303    pub fn tool_name(&self) -> Option<&str> {
304        self.function
305            .as_ref()
306            .map(|function| function.name.as_str())
307    }
308
309    /// Returns the raw payload text exactly as emitted by the model.
310    pub fn raw_input(&self) -> Option<&str> {
311        self.text.as_deref().or_else(|| {
312            self.function
313                .as_ref()
314                .map(|function| function.arguments.as_str())
315        })
316    }
317
318    /// Parse the arguments as JSON Value (for function-type tools)
319    pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
320        if let Some(ref func) = self.function {
321            parse_tool_arguments(&func.arguments)
322        } else {
323            // Return an error by trying to parse invalid JSON
324            serde_json::from_str("")
325        }
326    }
327
328    /// Returns the execution payload for this tool call.
329    ///
330    /// Function tools keep their JSON semantics. Custom tools execute with their
331    /// raw text payload wrapped as a JSON string value so freeform inputs can
332    /// flow through the existing tool pipeline.
333    pub fn execution_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
334        if self.is_custom() {
335            return Ok(serde_json::Value::String(
336                self.raw_input().unwrap_or_default().to_string(),
337            ));
338        }
339
340        self.parsed_arguments()
341    }
342
343    /// Validate that this tool call is properly formed
344    pub fn validate(&self) -> Result<(), String> {
345        if self.id.is_empty() {
346            return Err("Tool call ID cannot be empty".to_owned());
347        }
348
349        match self.call_type.as_str() {
350            "function" => {
351                if let Some(func) = &self.function {
352                    if func.name.is_empty() {
353                        return Err("Function name cannot be empty".to_owned());
354                    }
355                    // Validate that arguments is valid JSON for function tools
356                    if let Err(e) = self.parsed_arguments() {
357                        return Err(format!("Invalid JSON in function arguments: {e}"));
358                    }
359                } else {
360                    return Err("Function tool call missing function details".to_owned());
361                }
362            }
363            "custom" => {
364                // For custom tools, we allow raw text payload without JSON validation
365                if let Some(func) = &self.function {
366                    if func.name.is_empty() {
367                        return Err("Custom tool name cannot be empty".to_owned());
368                    }
369                } else {
370                    return Err("Custom tool call missing function details".to_owned());
371                }
372            }
373            _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
374        }
375
376        Ok(())
377    }
378}
379
380fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
381    let trimmed = raw_arguments.trim();
382    match serde_json::from_str(trimmed) {
383        Ok(parsed) => Ok(parsed),
384        Err(primary_error) => {
385            if let Some(candidate) = extract_balanced_json(trimmed)
386                && let Ok(parsed) = serde_json::from_str(candidate)
387            {
388                return Ok(parsed);
389            }
390            if let Some(candidate) = repair_tag_polluted_json(trimmed)
391                && let Ok(parsed) = serde_json::from_str(&candidate)
392            {
393                return Ok(parsed);
394            }
395            if let Some(repaired) = close_incomplete_json_prefix(trimmed)
396                && let Ok(parsed) = serde_json::from_str(&repaired)
397            {
398                return Ok(parsed);
399            }
400            Err(primary_error)
401        }
402    }
403}
404
405fn extract_balanced_json(input: &str) -> Option<&str> {
406    let start = input.find(['{', '['])?;
407    let opening = input.as_bytes().get(start).copied()?;
408    let closing = match opening {
409        b'{' => b'}',
410        b'[' => b']',
411        _ => return None,
412    };
413
414    let mut depth = 0usize;
415    let mut in_string = false;
416    let mut escaped = false;
417
418    for (offset, ch) in input[start..].char_indices() {
419        if in_string {
420            if escaped {
421                escaped = false;
422                continue;
423            }
424            if ch == '\\' {
425                escaped = true;
426                continue;
427            }
428            if ch == '"' {
429                in_string = false;
430            }
431            continue;
432        }
433
434        match ch {
435            '"' => in_string = true,
436            _ if ch as u32 == opening as u32 => depth += 1,
437            _ if ch as u32 == closing as u32 => {
438                depth = depth.saturating_sub(1);
439                if depth == 0 {
440                    let end = start + offset + ch.len_utf8();
441                    return input.get(start..end);
442                }
443            }
444            _ => {}
445        }
446    }
447
448    None
449}
450
451fn repair_tag_polluted_json(input: &str) -> Option<String> {
452    let start = input.find(['{', '['])?;
453    let candidate = input.get(start..)?;
454    let boundary = find_provider_markup_boundary(candidate)?;
455    if boundary == 0 {
456        return None;
457    }
458
459    close_incomplete_json_prefix(candidate[..boundary].trim_end())
460}
461
462fn find_provider_markup_boundary(input: &str) -> Option<usize> {
463    const PROVIDER_MARKERS: &[&str] = &[
464        "<</",
465        "</parameter>",
466        "</invoke>",
467        "</minimax:tool_call>",
468        "<minimax:tool_call>",
469        "<parameter name=\"",
470        "<invoke name=\"",
471        "<tool_call>",
472        "</tool_call>",
473    ];
474
475    input.char_indices().find_map(|(offset, _)| {
476        let rest = input.get(offset..)?;
477        PROVIDER_MARKERS
478            .iter()
479            .any(|marker| rest.starts_with(marker))
480            .then_some(offset)
481    })
482}
483
484fn close_incomplete_json_prefix(prefix: &str) -> Option<String> {
485    if prefix.is_empty() {
486        return None;
487    }
488
489    let mut repaired = String::with_capacity(prefix.len() + 8);
490    let mut expected_closers = Vec::new();
491    let mut in_string = false;
492    let mut escaped = false;
493
494    for ch in prefix.chars() {
495        repaired.push(ch);
496
497        if in_string {
498            if escaped {
499                escaped = false;
500                continue;
501            }
502
503            match ch {
504                '\\' => escaped = true,
505                '"' => in_string = false,
506                _ => {}
507            }
508            continue;
509        }
510
511        match ch {
512            '"' => in_string = true,
513            '{' => expected_closers.push('}'),
514            '[' => expected_closers.push(']'),
515            '}' | ']' if expected_closers.pop() != Some(ch) => return None,
516            '}' | ']' => {}
517            _ => {}
518        }
519    }
520
521    if in_string {
522        repaired.push('"');
523    }
524    for closer in expected_closers.drain(..) {
525        repaired.push(closer);
526    }
527
528    Some(repaired)
529}
530
531/// Universal LLM response structure
532#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
533pub struct LLMResponse {
534    /// The response content text
535    pub content: Option<String>,
536
537    /// Tool calls made by the model
538    pub tool_calls: Option<Vec<ToolCall>>,
539
540    /// The model that generated this response
541    pub model: String,
542
543    /// Token usage statistics
544    pub usage: Option<Usage>,
545
546    /// Why the response finished
547    pub finish_reason: FinishReason,
548
549    /// Reasoning content (for models that support it)
550    pub reasoning: Option<String>,
551
552    /// Detailed reasoning traces (for models that support it)
553    pub reasoning_details: Option<Vec<String>>,
554
555    /// Tool references for context
556    pub tool_references: Vec<String>,
557
558    /// Request ID from the provider
559    pub request_id: Option<String>,
560
561    /// Organization ID from the provider
562    pub organization_id: Option<String>,
563
564    /// Compaction summary content from Anthropic's server-side compaction.
565    /// Populated when `stop_reason` is `Pause` (from `"compaction"`).
566    /// The caller should pass this back in subsequent requests so the API
567    /// can drop prior messages before the compaction block.
568    pub compaction: Option<String>,
569}
570
571impl LLMResponse {
572    /// Create a new LLM response with mandatory fields
573    pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
574        Self {
575            content: Some(content.into()),
576            tool_calls: None,
577            model: model.into(),
578            usage: None,
579            finish_reason: FinishReason::Stop,
580            reasoning: None,
581            reasoning_details: None,
582            tool_references: Vec::new(),
583            request_id: None,
584            organization_id: None,
585            compaction: None,
586        }
587    }
588
589    /// Get content or empty string
590    pub fn content_text(&self) -> &str {
591        self.content.as_deref().unwrap_or("")
592    }
593
594    /// Get content as String (clone)
595    pub fn content_string(&self) -> String {
596        self.content.clone().unwrap_or_default()
597    }
598}
599
600#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
601pub struct LLMErrorMetadata {
602    pub provider: Option<String>,
603    pub status: Option<u16>,
604    pub code: Option<String>,
605    pub request_id: Option<String>,
606    pub organization_id: Option<String>,
607    pub retry_after: Option<String>,
608    pub message: Option<String>,
609}
610
611impl LLMErrorMetadata {
612    /// Boxed constructor because metadata is always stored inside `Option<Box<LLMErrorMetadata>>`
613    /// in the LLMError enum variants.
614    #[must_use]
615    pub fn new(
616        provider: impl Into<String>,
617        status: Option<u16>,
618        code: Option<String>,
619        request_id: Option<String>,
620        organization_id: Option<String>,
621        retry_after: Option<String>,
622        message: Option<String>,
623    ) -> Box<Self> {
624        Box::new(Self {
625            provider: Some(provider.into()),
626            status,
627            code,
628            request_id,
629            organization_id,
630            retry_after,
631            message,
632        })
633    }
634}
635
636/// LLM error types with optional provider metadata
637#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
638#[serde(tag = "type", rename_all = "snake_case")]
639pub enum LLMError {
640    #[error("Authentication failed: {message}")]
641    Authentication {
642        message: String,
643        metadata: Option<Box<LLMErrorMetadata>>,
644    },
645    #[error("Rate limit exceeded")]
646    RateLimit {
647        metadata: Option<Box<LLMErrorMetadata>>,
648    },
649    #[error("Invalid request: {message}")]
650    InvalidRequest {
651        message: String,
652        metadata: Option<Box<LLMErrorMetadata>>,
653    },
654    #[error("Network error: {message}")]
655    Network {
656        message: String,
657        metadata: Option<Box<LLMErrorMetadata>>,
658    },
659    #[error("Provider error: {message}")]
660    Provider {
661        message: String,
662        metadata: Option<Box<LLMErrorMetadata>>,
663    },
664}
665
666#[cfg(test)]
667mod tests {
668    use super::ToolCall;
669    use serde_json::json;
670
671    #[test]
672    fn parsed_arguments_accepts_trailing_characters() {
673        let call = ToolCall::function(
674            "call_read".to_string(),
675            "read_file".to_string(),
676            r#"{"path":"src/main.rs"} trailing text"#.to_string(),
677        );
678
679        let parsed = call
680            .parsed_arguments()
681            .expect("arguments with trailing text should recover");
682        assert_eq!(parsed, json!({"path":"src/main.rs"}));
683    }
684
685    #[test]
686    fn parsed_arguments_accepts_code_fenced_json() {
687        let call = ToolCall::function(
688            "call_read".to_string(),
689            "read_file".to_string(),
690            "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
691        );
692
693        let parsed = call
694            .parsed_arguments()
695            .expect("code-fenced arguments should recover");
696        assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
697    }
698
699    #[test]
700    fn parsed_arguments_recovers_truncated_json_missing_closing_brace() {
701        let call = ToolCall::function(
702            "call_search".to_string(),
703            "unified_search".to_string(),
704            r#"{"action": "structural", "pattern": "context", "path": ".", "lang": "rust""#
705                .to_string(),
706        );
707
708        let parsed = call
709            .parsed_arguments()
710            .expect("truncated JSON missing closing brace should recover");
711        assert_eq!(
712            parsed,
713            json!({
714                "action": "structural",
715                "pattern": "context",
716                "path": ".",
717                "lang": "rust"
718            })
719        );
720    }
721
722    #[test]
723    fn parsed_arguments_rejects_incomplete_json() {
724        let call = ToolCall::function(
725            "call_read".to_string(),
726            "read_file".to_string(),
727            r#"{"path":"src/main.rs","limit""#.to_string(),
728        );
729
730        assert!(call.parsed_arguments().is_err());
731    }
732
733    #[test]
734    fn parsed_arguments_recovers_truncated_minimax_markup() {
735        let call = ToolCall::function(
736            "call_search".to_string(),
737            "unified_search".to_string(),
738            "{\"action\": \"grep\", \"pattern\": \"persistent_memory\", \"path\": \"vtcode-core/src</parameter>\n<</invoke>\n</minimax:tool_call>".to_string(),
739        );
740
741        let parsed = call
742            .parsed_arguments()
743            .expect("minimax markup spillover should recover");
744        assert_eq!(
745            parsed,
746            json!({
747                "action": "grep",
748                "pattern": "persistent_memory",
749                "path": "vtcode-core/src"
750            })
751        );
752    }
753
754    #[test]
755    fn function_call_serializes_optional_namespace() {
756        let call = ToolCall::function_with_namespace(
757            "call_read".to_string(),
758            Some("workspace".to_string()),
759            "read_file".to_string(),
760            r#"{"path":"src/main.rs"}"#.to_string(),
761        );
762
763        let json = serde_json::to_value(&call).expect("tool call should serialize");
764        assert_eq!(json["function"]["namespace"], "workspace");
765        assert_eq!(json["function"]["name"], "read_file");
766    }
767
768    #[test]
769    fn custom_tool_call_exposes_raw_execution_arguments() {
770        let patch = "*** Begin Patch\n*** End Patch\n".to_string();
771        let call = ToolCall::custom(
772            "call_patch".to_string(),
773            "apply_patch".to_string(),
774            patch.clone(),
775        );
776
777        assert!(call.is_custom());
778        assert_eq!(call.tool_name(), Some("apply_patch"));
779        assert_eq!(call.raw_input(), Some(patch.as_str()));
780        assert_eq!(
781            call.execution_arguments().expect("custom arguments"),
782            json!(patch)
783        );
784        assert!(
785            call.parsed_arguments().is_err(),
786            "custom tool payload should stay freeform rather than JSON"
787        );
788    }
789}