Skip to main content

vtcode_commons/
llm.rs

1//! Core LLM types shared across the project
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7    Gemini,
8    OpenAI,
9    Anthropic,
10    DeepSeek,
11    Mistral,
12    OpenRouter,
13    Ollama,
14    LlamaCpp,
15    ZAI,
16    Moonshot,
17    HuggingFace,
18    Minimax,
19    MiMo,
20    OpenCodeZen,
21    OpenCodeGo,
22    Qwen,
23    StepFun,
24    Evolink,
25    Poolside,
26}
27
28#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
29pub struct Usage {
30    pub prompt_tokens: u32,
31    pub completion_tokens: u32,
32    pub total_tokens: u32,
33    pub cached_prompt_tokens: Option<u32>,
34    pub cache_creation_tokens: Option<u32>,
35    pub cache_read_tokens: Option<u32>,
36    /// Per-iteration token usage for Anthropic server-side fallback and compaction.
37    /// Each entry represents one sampling pass (message, fallback_message, or compaction).
38    #[serde(default, skip_serializing_if = "Option::is_none")]
39    pub iterations: Option<Vec<serde_json::Value>>,
40}
41
42impl Usage {
43    #[inline]
44    fn has_cache_read_metric(&self) -> bool {
45        self.cache_read_tokens.is_some() || self.cached_prompt_tokens.is_some()
46    }
47
48    #[inline]
49    fn has_any_cache_metrics(&self) -> bool {
50        self.has_cache_read_metric() || self.cache_creation_tokens.is_some()
51    }
52
53    #[inline]
54    pub fn cache_read_tokens_or_fallback(&self) -> u32 {
55        self.cache_read_tokens
56            .or(self.cached_prompt_tokens)
57            .unwrap_or(0)
58    }
59
60    #[inline]
61    pub fn cache_creation_tokens_or_zero(&self) -> u32 {
62        self.cache_creation_tokens.unwrap_or(0)
63    }
64
65    #[inline]
66    pub fn cache_hit_rate(&self) -> Option<f64> {
67        if !self.has_any_cache_metrics() {
68            return None;
69        }
70        let read = self.cache_read_tokens_or_fallback() as f64;
71        let creation = self.cache_creation_tokens_or_zero() as f64;
72        let total = read + creation;
73        if total > 0.0 {
74            Some((read / total) * 100.0)
75        } else {
76            None
77        }
78    }
79
80    #[inline]
81    pub fn is_cache_hit(&self) -> Option<bool> {
82        self.has_any_cache_metrics()
83            .then(|| self.cache_read_tokens_or_fallback() > 0)
84    }
85
86    #[inline]
87    pub fn is_cache_miss(&self) -> Option<bool> {
88        self.has_any_cache_metrics().then(|| {
89            self.cache_creation_tokens_or_zero() > 0 && self.cache_read_tokens_or_fallback() == 0
90        })
91    }
92
93    #[inline]
94    pub fn total_cache_tokens(&self) -> u32 {
95        let read = self.cache_read_tokens_or_fallback();
96        let creation = self.cache_creation_tokens_or_zero();
97        read + creation
98    }
99
100    #[inline]
101    pub fn cache_savings_ratio(&self) -> Option<f64> {
102        if !self.has_cache_read_metric() {
103            return None;
104        }
105        let read = self.cache_read_tokens_or_fallback() as f64;
106        let prompt = self.prompt_tokens as f64;
107        if prompt > 0.0 {
108            Some(read / prompt)
109        } else {
110            None
111        }
112    }
113}
114
115/// Provider-agnostic balance information for account status display.
116#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
117pub struct BalanceInfo {
118    /// Human-readable balance string (e.g. "100.00¥", "$50.00").
119    pub display: String,
120    /// Whether the account has sufficient balance for API calls.
121    pub is_available: bool,
122}
123
124/// DeepSeek-specific balance info from GET /user/balance
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct DeepSeekBalanceResponse {
127    pub is_available: bool,
128    pub balance_infos: Vec<DeepSeekCurrencyBalance>,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct DeepSeekCurrencyBalance {
133    pub currency: String,
134    pub total_balance: String,
135    #[serde(default)]
136    pub granted_balance: String,
137    #[serde(default)]
138    pub topped_up_balance: String,
139}
140
141impl From<DeepSeekBalanceResponse> for BalanceInfo {
142    fn from(resp: DeepSeekBalanceResponse) -> Self {
143        let display = resp
144            .balance_infos
145            .first()
146            .map(|b| {
147                let symbol = match b.currency.as_str() {
148                    "CNY" => "¥",
149                    "USD" => "$",
150                    _ => &b.currency,
151                };
152                format!("{}{}", b.total_balance, symbol)
153            })
154            .unwrap_or_else(|| "N/A".to_string());
155        BalanceInfo {
156            display,
157            is_available: resp.is_available,
158        }
159    }
160}
161
162#[cfg(test)]
163mod usage_tests {
164    use super::Usage;
165
166    #[test]
167    fn cache_helpers_fall_back_to_cached_prompt_tokens() {
168        let usage = Usage {
169            prompt_tokens: 1_000,
170            completion_tokens: 200,
171            total_tokens: 1_200,
172            cached_prompt_tokens: Some(600),
173            cache_creation_tokens: Some(150),
174            cache_read_tokens: None,
175        };
176
177        assert_eq!(usage.cache_read_tokens_or_fallback(), 600);
178        assert_eq!(usage.cache_creation_tokens_or_zero(), 150);
179        assert_eq!(usage.total_cache_tokens(), 750);
180        assert_eq!(usage.is_cache_hit(), Some(true));
181        assert_eq!(usage.is_cache_miss(), Some(false));
182        assert_eq!(usage.cache_savings_ratio(), Some(0.6));
183        assert_eq!(usage.cache_hit_rate(), Some(80.0));
184    }
185
186    #[test]
187    fn cache_helpers_preserve_unknown_without_metrics() {
188        let usage = Usage {
189            prompt_tokens: 1_000,
190            completion_tokens: 200,
191            total_tokens: 1_200,
192            cached_prompt_tokens: None,
193            cache_creation_tokens: None,
194            cache_read_tokens: None,
195        };
196
197        assert_eq!(usage.total_cache_tokens(), 0);
198        assert_eq!(usage.is_cache_hit(), None);
199        assert_eq!(usage.is_cache_miss(), None);
200        assert_eq!(usage.cache_savings_ratio(), None);
201        assert_eq!(usage.cache_hit_rate(), None);
202    }
203}
204
205#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
206pub enum FinishReason {
207    #[default]
208    Stop,
209    Length,
210    ToolCalls,
211    ContentFilter,
212    Pause,
213    Refusal,
214    Error(String),
215}
216
217/// Universal tool call that matches OpenAI/Anthropic/Gemini specifications
218#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
219pub struct ToolCall {
220    /// Unique identifier for this tool call (e.g., "call_123")
221    pub id: String,
222
223    /// The type of tool call: "function", "custom" (GPT-5 freeform), or other
224    #[serde(rename = "type")]
225    pub call_type: String,
226
227    /// Function call details (for function-type tools)
228    #[serde(skip_serializing_if = "Option::is_none")]
229    pub function: Option<FunctionCall>,
230
231    /// Raw text payload (for custom freeform tools in GPT-5)
232    #[serde(skip_serializing_if = "Option::is_none")]
233    pub text: Option<String>,
234
235    /// Gemini-specific thought signature for maintaining reasoning context
236    #[serde(skip_serializing_if = "Option::is_none")]
237    pub thought_signature: Option<String>,
238}
239
240/// Function call within a tool call
241#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
242pub struct FunctionCall {
243    /// Optional namespace for grouped or deferred tools.
244    #[serde(default, skip_serializing_if = "Option::is_none")]
245    pub namespace: Option<String>,
246
247    /// The name of the function to call
248    pub name: String,
249
250    /// The arguments to pass to the function, as a JSON string
251    pub arguments: String,
252}
253
254impl ToolCall {
255    /// Create a new function tool call
256    pub fn function(id: String, name: String, arguments: String) -> Self {
257        Self::function_with_namespace(id, None, name, arguments)
258    }
259
260    /// Create a new function tool call with an optional namespace.
261    pub fn function_with_namespace(
262        id: String,
263        namespace: Option<String>,
264        name: String,
265        arguments: String,
266    ) -> Self {
267        Self {
268            id,
269            call_type: "function".to_owned(),
270            function: Some(FunctionCall {
271                namespace,
272                name,
273                arguments,
274            }),
275            text: None,
276            thought_signature: None,
277        }
278    }
279
280    /// Create a new custom tool call with raw text payload (GPT-5 freeform)
281    pub fn custom(id: String, name: String, text: String) -> Self {
282        Self {
283            id,
284            call_type: "custom".to_owned(),
285            function: Some(FunctionCall {
286                namespace: None,
287                name,
288                arguments: text.clone(),
289            }),
290            text: Some(text),
291            thought_signature: None,
292        }
293    }
294
295    /// Returns true when this tool call uses GPT-5 custom/freeform semantics.
296    pub fn is_custom(&self) -> bool {
297        self.call_type == "custom"
298    }
299
300    /// Returns the tool name when the call includes function details.
301    pub fn tool_name(&self) -> Option<&str> {
302        self.function
303            .as_ref()
304            .map(|function| function.name.as_str())
305    }
306
307    /// Returns the raw payload text exactly as emitted by the model.
308    pub fn raw_input(&self) -> Option<&str> {
309        self.text.as_deref().or_else(|| {
310            self.function
311                .as_ref()
312                .map(|function| function.arguments.as_str())
313        })
314    }
315
316    /// Parse the arguments as JSON Value (for function-type tools)
317    pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
318        if let Some(ref func) = self.function {
319            parse_tool_arguments(&func.arguments)
320        } else {
321            // Return an error by trying to parse invalid JSON
322            serde_json::from_str("")
323        }
324    }
325
326    /// Returns the execution payload for this tool call.
327    ///
328    /// Function tools keep their JSON semantics. Custom tools execute with their
329    /// raw text payload wrapped as a JSON string value so freeform inputs can
330    /// flow through the existing tool pipeline.
331    pub fn execution_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
332        if self.is_custom() {
333            return Ok(serde_json::Value::String(
334                self.raw_input().unwrap_or_default().to_string(),
335            ));
336        }
337
338        self.parsed_arguments()
339    }
340
341    /// Validate that this tool call is properly formed
342    pub fn validate(&self) -> Result<(), String> {
343        if self.id.is_empty() {
344            return Err("Tool call ID cannot be empty".to_owned());
345        }
346
347        match self.call_type.as_str() {
348            "function" => {
349                if let Some(func) = &self.function {
350                    if func.name.is_empty() {
351                        return Err("Function name cannot be empty".to_owned());
352                    }
353                    // Validate that arguments is valid JSON for function tools
354                    if let Err(e) = self.parsed_arguments() {
355                        return Err(format!("Invalid JSON in function arguments: {}", e));
356                    }
357                } else {
358                    return Err("Function tool call missing function details".to_owned());
359                }
360            }
361            "custom" => {
362                // For custom tools, we allow raw text payload without JSON validation
363                if let Some(func) = &self.function {
364                    if func.name.is_empty() {
365                        return Err("Custom tool name cannot be empty".to_owned());
366                    }
367                } else {
368                    return Err("Custom tool call missing function details".to_owned());
369                }
370            }
371            _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
372        }
373
374        Ok(())
375    }
376}
377
378fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
379    let trimmed = raw_arguments.trim();
380    match serde_json::from_str(trimmed) {
381        Ok(parsed) => Ok(parsed),
382        Err(primary_error) => {
383            if let Some(candidate) = extract_balanced_json(trimmed)
384                && let Ok(parsed) = serde_json::from_str(candidate)
385            {
386                return Ok(parsed);
387            }
388            if let Some(candidate) = repair_tag_polluted_json(trimmed)
389                && let Ok(parsed) = serde_json::from_str(&candidate)
390            {
391                return Ok(parsed);
392            }
393            if let Some(repaired) = close_incomplete_json_prefix(trimmed)
394                && let Ok(parsed) = serde_json::from_str(&repaired)
395            {
396                return Ok(parsed);
397            }
398            Err(primary_error)
399        }
400    }
401}
402
403fn extract_balanced_json(input: &str) -> Option<&str> {
404    let start = input.find(['{', '['])?;
405    let opening = input.as_bytes().get(start).copied()?;
406    let closing = match opening {
407        b'{' => b'}',
408        b'[' => b']',
409        _ => return None,
410    };
411
412    let mut depth = 0usize;
413    let mut in_string = false;
414    let mut escaped = false;
415
416    for (offset, ch) in input[start..].char_indices() {
417        if in_string {
418            if escaped {
419                escaped = false;
420                continue;
421            }
422            if ch == '\\' {
423                escaped = true;
424                continue;
425            }
426            if ch == '"' {
427                in_string = false;
428            }
429            continue;
430        }
431
432        match ch {
433            '"' => in_string = true,
434            _ if ch as u32 == opening as u32 => depth += 1,
435            _ if ch as u32 == closing as u32 => {
436                depth = depth.saturating_sub(1);
437                if depth == 0 {
438                    let end = start + offset + ch.len_utf8();
439                    return input.get(start..end);
440                }
441            }
442            _ => {}
443        }
444    }
445
446    None
447}
448
449fn repair_tag_polluted_json(input: &str) -> Option<String> {
450    let start = input.find(['{', '['])?;
451    let candidate = input.get(start..)?;
452    let boundary = find_provider_markup_boundary(candidate)?;
453    if boundary == 0 {
454        return None;
455    }
456
457    close_incomplete_json_prefix(candidate[..boundary].trim_end())
458}
459
460fn find_provider_markup_boundary(input: &str) -> Option<usize> {
461    const PROVIDER_MARKERS: &[&str] = &[
462        "<</",
463        "</parameter>",
464        "</invoke>",
465        "</minimax:tool_call>",
466        "<minimax:tool_call>",
467        "<parameter name=\"",
468        "<invoke name=\"",
469        "<tool_call>",
470        "</tool_call>",
471    ];
472
473    input.char_indices().find_map(|(offset, _)| {
474        let rest = input.get(offset..)?;
475        PROVIDER_MARKERS
476            .iter()
477            .any(|marker| rest.starts_with(marker))
478            .then_some(offset)
479    })
480}
481
482fn close_incomplete_json_prefix(prefix: &str) -> Option<String> {
483    if prefix.is_empty() {
484        return None;
485    }
486
487    let mut repaired = String::with_capacity(prefix.len() + 8);
488    let mut expected_closers = Vec::new();
489    let mut in_string = false;
490    let mut escaped = false;
491
492    for ch in prefix.chars() {
493        repaired.push(ch);
494
495        if in_string {
496            if escaped {
497                escaped = false;
498                continue;
499            }
500
501            match ch {
502                '\\' => escaped = true,
503                '"' => in_string = false,
504                _ => {}
505            }
506            continue;
507        }
508
509        match ch {
510            '"' => in_string = true,
511            '{' => expected_closers.push('}'),
512            '[' => expected_closers.push(']'),
513            '}' | ']' => {
514                if expected_closers.pop() != Some(ch) {
515                    return None;
516                }
517            }
518            _ => {}
519        }
520    }
521
522    if in_string {
523        repaired.push('"');
524    }
525    for closer in expected_closers.drain(..) {
526        repaired.push(closer);
527    }
528
529    Some(repaired)
530}
531
532/// Universal LLM response structure
533#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
534pub struct LLMResponse {
535    /// The response content text
536    pub content: Option<String>,
537
538    /// Tool calls made by the model
539    pub tool_calls: Option<Vec<ToolCall>>,
540
541    /// The model that generated this response
542    pub model: String,
543
544    /// Token usage statistics
545    pub usage: Option<Usage>,
546
547    /// Why the response finished
548    pub finish_reason: FinishReason,
549
550    /// Reasoning content (for models that support it)
551    pub reasoning: Option<String>,
552
553    /// Detailed reasoning traces (for models that support it)
554    pub reasoning_details: Option<Vec<String>>,
555
556    /// Tool references for context
557    pub tool_references: Vec<String>,
558
559    /// Request ID from the provider
560    pub request_id: Option<String>,
561
562    /// Organization ID from the provider
563    pub organization_id: Option<String>,
564
565    /// Compaction summary content from Anthropic's server-side compaction.
566    /// Populated when `stop_reason` is `Pause` (from `"compaction"`).
567    /// The caller should pass this back in subsequent requests so the API
568    /// can drop prior messages before the compaction block.
569    pub compaction: Option<String>,
570}
571
572impl LLMResponse {
573    /// Create a new LLM response with mandatory fields
574    pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
575        Self {
576            content: Some(content.into()),
577            tool_calls: None,
578            model: model.into(),
579            usage: None,
580            finish_reason: FinishReason::Stop,
581            reasoning: None,
582            reasoning_details: None,
583            tool_references: Vec::new(),
584            request_id: None,
585            organization_id: None,
586            compaction: None,
587        }
588    }
589
590    /// Get content or empty string
591    pub fn content_text(&self) -> &str {
592        self.content.as_deref().unwrap_or("")
593    }
594
595    /// Get content as String (clone)
596    pub fn content_string(&self) -> String {
597        self.content.clone().unwrap_or_default()
598    }
599}
600
601#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
602pub struct LLMErrorMetadata {
603    pub provider: Option<String>,
604    pub status: Option<u16>,
605    pub code: Option<String>,
606    pub request_id: Option<String>,
607    pub organization_id: Option<String>,
608    pub retry_after: Option<String>,
609    pub message: Option<String>,
610}
611
612impl LLMErrorMetadata {
613    /// Boxed constructor because metadata is always stored inside `Option<Box<LLMErrorMetadata>>`
614    /// in the LLMError enum variants.
615    #[must_use]
616    pub fn new(
617        provider: impl Into<String>,
618        status: Option<u16>,
619        code: Option<String>,
620        request_id: Option<String>,
621        organization_id: Option<String>,
622        retry_after: Option<String>,
623        message: Option<String>,
624    ) -> Box<Self> {
625        Box::new(Self {
626            provider: Some(provider.into()),
627            status,
628            code,
629            request_id,
630            organization_id,
631            retry_after,
632            message,
633        })
634    }
635}
636
637/// LLM error types with optional provider metadata
638#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
639#[serde(tag = "type", rename_all = "snake_case")]
640pub enum LLMError {
641    #[error("Authentication failed: {message}")]
642    Authentication {
643        message: String,
644        metadata: Option<Box<LLMErrorMetadata>>,
645    },
646    #[error("Rate limit exceeded")]
647    RateLimit {
648        metadata: Option<Box<LLMErrorMetadata>>,
649    },
650    #[error("Invalid request: {message}")]
651    InvalidRequest {
652        message: String,
653        metadata: Option<Box<LLMErrorMetadata>>,
654    },
655    #[error("Network error: {message}")]
656    Network {
657        message: String,
658        metadata: Option<Box<LLMErrorMetadata>>,
659    },
660    #[error("Provider error: {message}")]
661    Provider {
662        message: String,
663        metadata: Option<Box<LLMErrorMetadata>>,
664    },
665}
666
667#[cfg(test)]
668mod tests {
669    use super::ToolCall;
670    use serde_json::json;
671
672    #[test]
673    fn parsed_arguments_accepts_trailing_characters() {
674        let call = ToolCall::function(
675            "call_read".to_string(),
676            "read_file".to_string(),
677            r#"{"path":"src/main.rs"} trailing text"#.to_string(),
678        );
679
680        let parsed = call
681            .parsed_arguments()
682            .expect("arguments with trailing text should recover");
683        assert_eq!(parsed, json!({"path":"src/main.rs"}));
684    }
685
686    #[test]
687    fn parsed_arguments_accepts_code_fenced_json() {
688        let call = ToolCall::function(
689            "call_read".to_string(),
690            "read_file".to_string(),
691            "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
692        );
693
694        let parsed = call
695            .parsed_arguments()
696            .expect("code-fenced arguments should recover");
697        assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
698    }
699
700    #[test]
701    fn parsed_arguments_recovers_truncated_json_missing_closing_brace() {
702        let call = ToolCall::function(
703            "call_search".to_string(),
704            "unified_search".to_string(),
705            r#"{"action": "structural", "pattern": "context", "path": ".", "lang": "rust""#
706                .to_string(),
707        );
708
709        let parsed = call
710            .parsed_arguments()
711            .expect("truncated JSON missing closing brace should recover");
712        assert_eq!(
713            parsed,
714            json!({
715                "action": "structural",
716                "pattern": "context",
717                "path": ".",
718                "lang": "rust"
719            })
720        );
721    }
722
723    #[test]
724    fn parsed_arguments_rejects_incomplete_json() {
725        let call = ToolCall::function(
726            "call_read".to_string(),
727            "read_file".to_string(),
728            r#"{"path":"src/main.rs","limit""#.to_string(),
729        );
730
731        assert!(call.parsed_arguments().is_err());
732    }
733
734    #[test]
735    fn parsed_arguments_recovers_truncated_minimax_markup() {
736        let call = ToolCall::function(
737            "call_search".to_string(),
738            "unified_search".to_string(),
739            "{\"action\": \"grep\", \"pattern\": \"persistent_memory\", \"path\": \"vtcode-core/src</parameter>\n<</invoke>\n</minimax:tool_call>".to_string(),
740        );
741
742        let parsed = call
743            .parsed_arguments()
744            .expect("minimax markup spillover should recover");
745        assert_eq!(
746            parsed,
747            json!({
748                "action": "grep",
749                "pattern": "persistent_memory",
750                "path": "vtcode-core/src"
751            })
752        );
753    }
754
755    #[test]
756    fn function_call_serializes_optional_namespace() {
757        let call = ToolCall::function_with_namespace(
758            "call_read".to_string(),
759            Some("workspace".to_string()),
760            "read_file".to_string(),
761            r#"{"path":"src/main.rs"}"#.to_string(),
762        );
763
764        let json = serde_json::to_value(&call).expect("tool call should serialize");
765        assert_eq!(json["function"]["namespace"], "workspace");
766        assert_eq!(json["function"]["name"], "read_file");
767    }
768
769    #[test]
770    fn custom_tool_call_exposes_raw_execution_arguments() {
771        let patch = "*** Begin Patch\n*** End Patch\n".to_string();
772        let call = ToolCall::custom(
773            "call_patch".to_string(),
774            "apply_patch".to_string(),
775            patch.clone(),
776        );
777
778        assert!(call.is_custom());
779        assert_eq!(call.tool_name(), Some("apply_patch"));
780        assert_eq!(call.raw_input(), Some(patch.as_str()));
781        assert_eq!(
782            call.execution_arguments().expect("custom arguments"),
783            json!(patch)
784        );
785        assert!(
786            call.parsed_arguments().is_err(),
787            "custom tool payload should stay freeform rather than JSON"
788        );
789    }
790}