Skip to main content

vtcode_commons/
llm.rs

1//! Core LLM types shared across the project
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7    Gemini,
8    OpenAI,
9    Anthropic,
10    DeepSeek,
11    Mistral,
12    OpenRouter,
13    Ollama,
14    LlamaCpp,
15    ZAI,
16    Moonshot,
17    HuggingFace,
18    Minimax,
19    MiMo,
20    OpenCodeZen,
21    OpenCodeGo,
22    Qwen,
23    StepFun,
24    Evolink,
25    Poolside,
26}
27
28#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
29pub struct Usage {
30    pub prompt_tokens: u32,
31    pub completion_tokens: u32,
32    pub total_tokens: u32,
33    pub cached_prompt_tokens: Option<u32>,
34    pub cache_creation_tokens: Option<u32>,
35    pub cache_read_tokens: Option<u32>,
36}
37
38impl Usage {
39    #[inline]
40    fn has_cache_read_metric(&self) -> bool {
41        self.cache_read_tokens.is_some() || self.cached_prompt_tokens.is_some()
42    }
43
44    #[inline]
45    fn has_any_cache_metrics(&self) -> bool {
46        self.has_cache_read_metric() || self.cache_creation_tokens.is_some()
47    }
48
49    #[inline]
50    pub fn cache_read_tokens_or_fallback(&self) -> u32 {
51        self.cache_read_tokens
52            .or(self.cached_prompt_tokens)
53            .unwrap_or(0)
54    }
55
56    #[inline]
57    pub fn cache_creation_tokens_or_zero(&self) -> u32 {
58        self.cache_creation_tokens.unwrap_or(0)
59    }
60
61    #[inline]
62    pub fn cache_hit_rate(&self) -> Option<f64> {
63        if !self.has_any_cache_metrics() {
64            return None;
65        }
66        let read = self.cache_read_tokens_or_fallback() as f64;
67        let creation = self.cache_creation_tokens_or_zero() as f64;
68        let total = read + creation;
69        if total > 0.0 {
70            Some((read / total) * 100.0)
71        } else {
72            None
73        }
74    }
75
76    #[inline]
77    pub fn is_cache_hit(&self) -> Option<bool> {
78        self.has_any_cache_metrics()
79            .then(|| self.cache_read_tokens_or_fallback() > 0)
80    }
81
82    #[inline]
83    pub fn is_cache_miss(&self) -> Option<bool> {
84        self.has_any_cache_metrics().then(|| {
85            self.cache_creation_tokens_or_zero() > 0 && self.cache_read_tokens_or_fallback() == 0
86        })
87    }
88
89    #[inline]
90    pub fn total_cache_tokens(&self) -> u32 {
91        let read = self.cache_read_tokens_or_fallback();
92        let creation = self.cache_creation_tokens_or_zero();
93        read + creation
94    }
95
96    #[inline]
97    pub fn cache_savings_ratio(&self) -> Option<f64> {
98        if !self.has_cache_read_metric() {
99            return None;
100        }
101        let read = self.cache_read_tokens_or_fallback() as f64;
102        let prompt = self.prompt_tokens as f64;
103        if prompt > 0.0 {
104            Some(read / prompt)
105        } else {
106            None
107        }
108    }
109}
110
111/// Provider-agnostic balance information for account status display.
112#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
113pub struct BalanceInfo {
114    /// Human-readable balance string (e.g. "100.00¥", "$50.00").
115    pub display: String,
116    /// Whether the account has sufficient balance for API calls.
117    pub is_available: bool,
118}
119
120/// DeepSeek-specific balance info from GET /user/balance
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct DeepSeekBalanceResponse {
123    pub is_available: bool,
124    pub balance_infos: Vec<DeepSeekCurrencyBalance>,
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct DeepSeekCurrencyBalance {
129    pub currency: String,
130    pub total_balance: String,
131    #[serde(default)]
132    pub granted_balance: String,
133    #[serde(default)]
134    pub topped_up_balance: String,
135}
136
137impl From<DeepSeekBalanceResponse> for BalanceInfo {
138    fn from(resp: DeepSeekBalanceResponse) -> Self {
139        let display = resp
140            .balance_infos
141            .first()
142            .map(|b| {
143                let symbol = match b.currency.as_str() {
144                    "CNY" => "¥",
145                    "USD" => "$",
146                    _ => &b.currency,
147                };
148                format!("{}{}", b.total_balance, symbol)
149            })
150            .unwrap_or_else(|| "N/A".to_string());
151        BalanceInfo {
152            display,
153            is_available: resp.is_available,
154        }
155    }
156}
157
158#[cfg(test)]
159mod usage_tests {
160    use super::Usage;
161
162    #[test]
163    fn cache_helpers_fall_back_to_cached_prompt_tokens() {
164        let usage = Usage {
165            prompt_tokens: 1_000,
166            completion_tokens: 200,
167            total_tokens: 1_200,
168            cached_prompt_tokens: Some(600),
169            cache_creation_tokens: Some(150),
170            cache_read_tokens: None,
171        };
172
173        assert_eq!(usage.cache_read_tokens_or_fallback(), 600);
174        assert_eq!(usage.cache_creation_tokens_or_zero(), 150);
175        assert_eq!(usage.total_cache_tokens(), 750);
176        assert_eq!(usage.is_cache_hit(), Some(true));
177        assert_eq!(usage.is_cache_miss(), Some(false));
178        assert_eq!(usage.cache_savings_ratio(), Some(0.6));
179        assert_eq!(usage.cache_hit_rate(), Some(80.0));
180    }
181
182    #[test]
183    fn cache_helpers_preserve_unknown_without_metrics() {
184        let usage = Usage {
185            prompt_tokens: 1_000,
186            completion_tokens: 200,
187            total_tokens: 1_200,
188            cached_prompt_tokens: None,
189            cache_creation_tokens: None,
190            cache_read_tokens: None,
191        };
192
193        assert_eq!(usage.total_cache_tokens(), 0);
194        assert_eq!(usage.is_cache_hit(), None);
195        assert_eq!(usage.is_cache_miss(), None);
196        assert_eq!(usage.cache_savings_ratio(), None);
197        assert_eq!(usage.cache_hit_rate(), None);
198    }
199}
200
201#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
202pub enum FinishReason {
203    #[default]
204    Stop,
205    Length,
206    ToolCalls,
207    ContentFilter,
208    Pause,
209    Refusal,
210    Error(String),
211}
212
213/// Universal tool call that matches OpenAI/Anthropic/Gemini specifications
214#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
215pub struct ToolCall {
216    /// Unique identifier for this tool call (e.g., "call_123")
217    pub id: String,
218
219    /// The type of tool call: "function", "custom" (GPT-5 freeform), or other
220    #[serde(rename = "type")]
221    pub call_type: String,
222
223    /// Function call details (for function-type tools)
224    #[serde(skip_serializing_if = "Option::is_none")]
225    pub function: Option<FunctionCall>,
226
227    /// Raw text payload (for custom freeform tools in GPT-5)
228    #[serde(skip_serializing_if = "Option::is_none")]
229    pub text: Option<String>,
230
231    /// Gemini-specific thought signature for maintaining reasoning context
232    #[serde(skip_serializing_if = "Option::is_none")]
233    pub thought_signature: Option<String>,
234}
235
236/// Function call within a tool call
237#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
238pub struct FunctionCall {
239    /// Optional namespace for grouped or deferred tools.
240    #[serde(default, skip_serializing_if = "Option::is_none")]
241    pub namespace: Option<String>,
242
243    /// The name of the function to call
244    pub name: String,
245
246    /// The arguments to pass to the function, as a JSON string
247    pub arguments: String,
248}
249
250impl ToolCall {
251    /// Create a new function tool call
252    pub fn function(id: String, name: String, arguments: String) -> Self {
253        Self::function_with_namespace(id, None, name, arguments)
254    }
255
256    /// Create a new function tool call with an optional namespace.
257    pub fn function_with_namespace(
258        id: String,
259        namespace: Option<String>,
260        name: String,
261        arguments: String,
262    ) -> Self {
263        Self {
264            id,
265            call_type: "function".to_owned(),
266            function: Some(FunctionCall {
267                namespace,
268                name,
269                arguments,
270            }),
271            text: None,
272            thought_signature: None,
273        }
274    }
275
276    /// Create a new custom tool call with raw text payload (GPT-5 freeform)
277    pub fn custom(id: String, name: String, text: String) -> Self {
278        Self {
279            id,
280            call_type: "custom".to_owned(),
281            function: Some(FunctionCall {
282                namespace: None,
283                name,
284                arguments: text.clone(),
285            }),
286            text: Some(text),
287            thought_signature: None,
288        }
289    }
290
291    /// Returns true when this tool call uses GPT-5 custom/freeform semantics.
292    pub fn is_custom(&self) -> bool {
293        self.call_type == "custom"
294    }
295
296    /// Returns the tool name when the call includes function details.
297    pub fn tool_name(&self) -> Option<&str> {
298        self.function
299            .as_ref()
300            .map(|function| function.name.as_str())
301    }
302
303    /// Returns the raw payload text exactly as emitted by the model.
304    pub fn raw_input(&self) -> Option<&str> {
305        self.text.as_deref().or_else(|| {
306            self.function
307                .as_ref()
308                .map(|function| function.arguments.as_str())
309        })
310    }
311
312    /// Parse the arguments as JSON Value (for function-type tools)
313    pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
314        if let Some(ref func) = self.function {
315            parse_tool_arguments(&func.arguments)
316        } else {
317            // Return an error by trying to parse invalid JSON
318            serde_json::from_str("")
319        }
320    }
321
322    /// Returns the execution payload for this tool call.
323    ///
324    /// Function tools keep their JSON semantics. Custom tools execute with their
325    /// raw text payload wrapped as a JSON string value so freeform inputs can
326    /// flow through the existing tool pipeline.
327    pub fn execution_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
328        if self.is_custom() {
329            return Ok(serde_json::Value::String(
330                self.raw_input().unwrap_or_default().to_string(),
331            ));
332        }
333
334        self.parsed_arguments()
335    }
336
337    /// Validate that this tool call is properly formed
338    pub fn validate(&self) -> Result<(), String> {
339        if self.id.is_empty() {
340            return Err("Tool call ID cannot be empty".to_owned());
341        }
342
343        match self.call_type.as_str() {
344            "function" => {
345                if let Some(func) = &self.function {
346                    if func.name.is_empty() {
347                        return Err("Function name cannot be empty".to_owned());
348                    }
349                    // Validate that arguments is valid JSON for function tools
350                    if let Err(e) = self.parsed_arguments() {
351                        return Err(format!("Invalid JSON in function arguments: {}", e));
352                    }
353                } else {
354                    return Err("Function tool call missing function details".to_owned());
355                }
356            }
357            "custom" => {
358                // For custom tools, we allow raw text payload without JSON validation
359                if let Some(func) = &self.function {
360                    if func.name.is_empty() {
361                        return Err("Custom tool name cannot be empty".to_owned());
362                    }
363                } else {
364                    return Err("Custom tool call missing function details".to_owned());
365                }
366            }
367            _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
368        }
369
370        Ok(())
371    }
372}
373
374fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
375    let trimmed = raw_arguments.trim();
376    match serde_json::from_str(trimmed) {
377        Ok(parsed) => Ok(parsed),
378        Err(primary_error) => {
379            if let Some(candidate) = extract_balanced_json(trimmed)
380                && let Ok(parsed) = serde_json::from_str(candidate)
381            {
382                return Ok(parsed);
383            }
384            if let Some(candidate) = repair_tag_polluted_json(trimmed)
385                && let Ok(parsed) = serde_json::from_str(&candidate)
386            {
387                return Ok(parsed);
388            }
389            if let Some(repaired) = close_incomplete_json_prefix(trimmed)
390                && let Ok(parsed) = serde_json::from_str(&repaired)
391            {
392                return Ok(parsed);
393            }
394            Err(primary_error)
395        }
396    }
397}
398
399fn extract_balanced_json(input: &str) -> Option<&str> {
400    let start = input.find(['{', '['])?;
401    let opening = input.as_bytes().get(start).copied()?;
402    let closing = match opening {
403        b'{' => b'}',
404        b'[' => b']',
405        _ => return None,
406    };
407
408    let mut depth = 0usize;
409    let mut in_string = false;
410    let mut escaped = false;
411
412    for (offset, ch) in input[start..].char_indices() {
413        if in_string {
414            if escaped {
415                escaped = false;
416                continue;
417            }
418            if ch == '\\' {
419                escaped = true;
420                continue;
421            }
422            if ch == '"' {
423                in_string = false;
424            }
425            continue;
426        }
427
428        match ch {
429            '"' => in_string = true,
430            _ if ch as u32 == opening as u32 => depth += 1,
431            _ if ch as u32 == closing as u32 => {
432                depth = depth.saturating_sub(1);
433                if depth == 0 {
434                    let end = start + offset + ch.len_utf8();
435                    return input.get(start..end);
436                }
437            }
438            _ => {}
439        }
440    }
441
442    None
443}
444
445fn repair_tag_polluted_json(input: &str) -> Option<String> {
446    let start = input.find(['{', '['])?;
447    let candidate = input.get(start..)?;
448    let boundary = find_provider_markup_boundary(candidate)?;
449    if boundary == 0 {
450        return None;
451    }
452
453    close_incomplete_json_prefix(candidate[..boundary].trim_end())
454}
455
456fn find_provider_markup_boundary(input: &str) -> Option<usize> {
457    const PROVIDER_MARKERS: &[&str] = &[
458        "<</",
459        "</parameter>",
460        "</invoke>",
461        "</minimax:tool_call>",
462        "<minimax:tool_call>",
463        "<parameter name=\"",
464        "<invoke name=\"",
465        "<tool_call>",
466        "</tool_call>",
467    ];
468
469    input.char_indices().find_map(|(offset, _)| {
470        let rest = input.get(offset..)?;
471        PROVIDER_MARKERS
472            .iter()
473            .any(|marker| rest.starts_with(marker))
474            .then_some(offset)
475    })
476}
477
478fn close_incomplete_json_prefix(prefix: &str) -> Option<String> {
479    if prefix.is_empty() {
480        return None;
481    }
482
483    let mut repaired = String::with_capacity(prefix.len() + 8);
484    let mut expected_closers = Vec::new();
485    let mut in_string = false;
486    let mut escaped = false;
487
488    for ch in prefix.chars() {
489        repaired.push(ch);
490
491        if in_string {
492            if escaped {
493                escaped = false;
494                continue;
495            }
496
497            match ch {
498                '\\' => escaped = true,
499                '"' => in_string = false,
500                _ => {}
501            }
502            continue;
503        }
504
505        match ch {
506            '"' => in_string = true,
507            '{' => expected_closers.push('}'),
508            '[' => expected_closers.push(']'),
509            '}' | ']' => {
510                if expected_closers.pop() != Some(ch) {
511                    return None;
512                }
513            }
514            _ => {}
515        }
516    }
517
518    if in_string {
519        repaired.push('"');
520    }
521    for closer in expected_closers.drain(..) {
522        repaired.push(closer);
523    }
524
525    Some(repaired)
526}
527
528/// Universal LLM response structure
529#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
530pub struct LLMResponse {
531    /// The response content text
532    pub content: Option<String>,
533
534    /// Tool calls made by the model
535    pub tool_calls: Option<Vec<ToolCall>>,
536
537    /// The model that generated this response
538    pub model: String,
539
540    /// Token usage statistics
541    pub usage: Option<Usage>,
542
543    /// Why the response finished
544    pub finish_reason: FinishReason,
545
546    /// Reasoning content (for models that support it)
547    pub reasoning: Option<String>,
548
549    /// Detailed reasoning traces (for models that support it)
550    pub reasoning_details: Option<Vec<String>>,
551
552    /// Tool references for context
553    pub tool_references: Vec<String>,
554
555    /// Request ID from the provider
556    pub request_id: Option<String>,
557
558    /// Organization ID from the provider
559    pub organization_id: Option<String>,
560
561    /// Compaction summary content from Anthropic's server-side compaction.
562    /// Populated when `stop_reason` is `Pause` (from `"compaction"`).
563    /// The caller should pass this back in subsequent requests so the API
564    /// can drop prior messages before the compaction block.
565    pub compaction: Option<String>,
566}
567
568impl LLMResponse {
569    /// Create a new LLM response with mandatory fields
570    pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
571        Self {
572            content: Some(content.into()),
573            tool_calls: None,
574            model: model.into(),
575            usage: None,
576            finish_reason: FinishReason::Stop,
577            reasoning: None,
578            reasoning_details: None,
579            tool_references: Vec::new(),
580            request_id: None,
581            organization_id: None,
582            compaction: None,
583        }
584    }
585
586    /// Get content or empty string
587    pub fn content_text(&self) -> &str {
588        self.content.as_deref().unwrap_or("")
589    }
590
591    /// Get content as String (clone)
592    pub fn content_string(&self) -> String {
593        self.content.clone().unwrap_or_default()
594    }
595}
596
597#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
598pub struct LLMErrorMetadata {
599    pub provider: Option<String>,
600    pub status: Option<u16>,
601    pub code: Option<String>,
602    pub request_id: Option<String>,
603    pub organization_id: Option<String>,
604    pub retry_after: Option<String>,
605    pub message: Option<String>,
606}
607
608impl LLMErrorMetadata {
609    /// Boxed constructor because metadata is always stored inside `Option<Box<LLMErrorMetadata>>`
610    /// in the LLMError enum variants.
611    #[must_use]
612    pub fn new(
613        provider: impl Into<String>,
614        status: Option<u16>,
615        code: Option<String>,
616        request_id: Option<String>,
617        organization_id: Option<String>,
618        retry_after: Option<String>,
619        message: Option<String>,
620    ) -> Box<Self> {
621        Box::new(Self {
622            provider: Some(provider.into()),
623            status,
624            code,
625            request_id,
626            organization_id,
627            retry_after,
628            message,
629        })
630    }
631}
632
633/// LLM error types with optional provider metadata
634#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
635#[serde(tag = "type", rename_all = "snake_case")]
636pub enum LLMError {
637    #[error("Authentication failed: {message}")]
638    Authentication {
639        message: String,
640        metadata: Option<Box<LLMErrorMetadata>>,
641    },
642    #[error("Rate limit exceeded")]
643    RateLimit {
644        metadata: Option<Box<LLMErrorMetadata>>,
645    },
646    #[error("Invalid request: {message}")]
647    InvalidRequest {
648        message: String,
649        metadata: Option<Box<LLMErrorMetadata>>,
650    },
651    #[error("Network error: {message}")]
652    Network {
653        message: String,
654        metadata: Option<Box<LLMErrorMetadata>>,
655    },
656    #[error("Provider error: {message}")]
657    Provider {
658        message: String,
659        metadata: Option<Box<LLMErrorMetadata>>,
660    },
661}
662
663#[cfg(test)]
664mod tests {
665    use super::ToolCall;
666    use serde_json::json;
667
668    #[test]
669    fn parsed_arguments_accepts_trailing_characters() {
670        let call = ToolCall::function(
671            "call_read".to_string(),
672            "read_file".to_string(),
673            r#"{"path":"src/main.rs"} trailing text"#.to_string(),
674        );
675
676        let parsed = call
677            .parsed_arguments()
678            .expect("arguments with trailing text should recover");
679        assert_eq!(parsed, json!({"path":"src/main.rs"}));
680    }
681
682    #[test]
683    fn parsed_arguments_accepts_code_fenced_json() {
684        let call = ToolCall::function(
685            "call_read".to_string(),
686            "read_file".to_string(),
687            "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
688        );
689
690        let parsed = call
691            .parsed_arguments()
692            .expect("code-fenced arguments should recover");
693        assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
694    }
695
696    #[test]
697    fn parsed_arguments_recovers_truncated_json_missing_closing_brace() {
698        let call = ToolCall::function(
699            "call_search".to_string(),
700            "unified_search".to_string(),
701            r#"{"action": "structural", "pattern": "context", "path": ".", "lang": "rust""#.to_string(),
702        );
703
704        let parsed = call
705            .parsed_arguments()
706            .expect("truncated JSON missing closing brace should recover");
707        assert_eq!(
708            parsed,
709            json!({
710                "action": "structural",
711                "pattern": "context",
712                "path": ".",
713                "lang": "rust"
714            })
715        );
716    }
717
718    #[test]
719    fn parsed_arguments_rejects_incomplete_json() {
720        let call = ToolCall::function(
721            "call_read".to_string(),
722            "read_file".to_string(),
723            r#"{"path":"src/main.rs","limit""#.to_string(),
724        );
725
726        assert!(call.parsed_arguments().is_err());
727    }
728
729    #[test]
730    fn parsed_arguments_recovers_truncated_minimax_markup() {
731        let call = ToolCall::function(
732            "call_search".to_string(),
733            "unified_search".to_string(),
734            "{\"action\": \"grep\", \"pattern\": \"persistent_memory\", \"path\": \"vtcode-core/src</parameter>\n<</invoke>\n</minimax:tool_call>".to_string(),
735        );
736
737        let parsed = call
738            .parsed_arguments()
739            .expect("minimax markup spillover should recover");
740        assert_eq!(
741            parsed,
742            json!({
743                "action": "grep",
744                "pattern": "persistent_memory",
745                "path": "vtcode-core/src"
746            })
747        );
748    }
749
750    #[test]
751    fn function_call_serializes_optional_namespace() {
752        let call = ToolCall::function_with_namespace(
753            "call_read".to_string(),
754            Some("workspace".to_string()),
755            "read_file".to_string(),
756            r#"{"path":"src/main.rs"}"#.to_string(),
757        );
758
759        let json = serde_json::to_value(&call).expect("tool call should serialize");
760        assert_eq!(json["function"]["namespace"], "workspace");
761        assert_eq!(json["function"]["name"], "read_file");
762    }
763
764    #[test]
765    fn custom_tool_call_exposes_raw_execution_arguments() {
766        let patch = "*** Begin Patch\n*** End Patch\n".to_string();
767        let call = ToolCall::custom(
768            "call_patch".to_string(),
769            "apply_patch".to_string(),
770            patch.clone(),
771        );
772
773        assert!(call.is_custom());
774        assert_eq!(call.tool_name(), Some("apply_patch"));
775        assert_eq!(call.raw_input(), Some(patch.as_str()));
776        assert_eq!(
777            call.execution_arguments().expect("custom arguments"),
778            json!(patch)
779        );
780        assert!(
781            call.parsed_arguments().is_err(),
782            "custom tool payload should stay freeform rather than JSON"
783        );
784    }
785}