Skip to main content

vtcode_commons/
llm.rs

1//! Core LLM types shared across the project
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7    Gemini,
8    OpenAI,
9    Anthropic,
10    DeepSeek,
11    Mistral,
12    OpenRouter,
13    Ollama,
14    LlamaCpp,
15    ZAI,
16    Moonshot,
17    HuggingFace,
18    Minimax,
19    MiMo,
20    OpenCodeZen,
21    OpenCodeGo,
22    Qwen,
23    StepFun,
24    Poolside,
25}
26
27#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
28pub struct Usage {
29    pub prompt_tokens: u32,
30    pub completion_tokens: u32,
31    pub total_tokens: u32,
32    pub cached_prompt_tokens: Option<u32>,
33    pub cache_creation_tokens: Option<u32>,
34    pub cache_read_tokens: Option<u32>,
35}
36
37impl Usage {
38    #[inline]
39    fn has_cache_read_metric(&self) -> bool {
40        self.cache_read_tokens.is_some() || self.cached_prompt_tokens.is_some()
41    }
42
43    #[inline]
44    fn has_any_cache_metrics(&self) -> bool {
45        self.has_cache_read_metric() || self.cache_creation_tokens.is_some()
46    }
47
48    #[inline]
49    pub fn cache_read_tokens_or_fallback(&self) -> u32 {
50        self.cache_read_tokens
51            .or(self.cached_prompt_tokens)
52            .unwrap_or(0)
53    }
54
55    #[inline]
56    pub fn cache_creation_tokens_or_zero(&self) -> u32 {
57        self.cache_creation_tokens.unwrap_or(0)
58    }
59
60    #[inline]
61    pub fn cache_hit_rate(&self) -> Option<f64> {
62        if !self.has_any_cache_metrics() {
63            return None;
64        }
65        let read = self.cache_read_tokens_or_fallback() as f64;
66        let creation = self.cache_creation_tokens_or_zero() as f64;
67        let total = read + creation;
68        if total > 0.0 {
69            Some((read / total) * 100.0)
70        } else {
71            None
72        }
73    }
74
75    #[inline]
76    pub fn is_cache_hit(&self) -> Option<bool> {
77        self.has_any_cache_metrics()
78            .then(|| self.cache_read_tokens_or_fallback() > 0)
79    }
80
81    #[inline]
82    pub fn is_cache_miss(&self) -> Option<bool> {
83        self.has_any_cache_metrics().then(|| {
84            self.cache_creation_tokens_or_zero() > 0 && self.cache_read_tokens_or_fallback() == 0
85        })
86    }
87
88    #[inline]
89    pub fn total_cache_tokens(&self) -> u32 {
90        let read = self.cache_read_tokens_or_fallback();
91        let creation = self.cache_creation_tokens_or_zero();
92        read + creation
93    }
94
95    #[inline]
96    pub fn cache_savings_ratio(&self) -> Option<f64> {
97        if !self.has_cache_read_metric() {
98            return None;
99        }
100        let read = self.cache_read_tokens_or_fallback() as f64;
101        let prompt = self.prompt_tokens as f64;
102        if prompt > 0.0 {
103            Some(read / prompt)
104        } else {
105            None
106        }
107    }
108}
109
110/// Provider-agnostic balance information for account status display.
111#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
112pub struct BalanceInfo {
113    /// Human-readable balance string (e.g. "100.00¥", "$50.00").
114    pub display: String,
115    /// Whether the account has sufficient balance for API calls.
116    pub is_available: bool,
117}
118
119/// DeepSeek-specific balance info from GET /user/balance
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct DeepSeekBalanceResponse {
122    pub is_available: bool,
123    pub balance_infos: Vec<DeepSeekCurrencyBalance>,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct DeepSeekCurrencyBalance {
128    pub currency: String,
129    pub total_balance: String,
130    #[serde(default)]
131    pub granted_balance: String,
132    #[serde(default)]
133    pub topped_up_balance: String,
134}
135
136impl From<DeepSeekBalanceResponse> for BalanceInfo {
137    fn from(resp: DeepSeekBalanceResponse) -> Self {
138        let display = resp
139            .balance_infos
140            .first()
141            .map(|b| {
142                let symbol = match b.currency.as_str() {
143                    "CNY" => "¥",
144                    "USD" => "$",
145                    _ => &b.currency,
146                };
147                format!("{}{}", b.total_balance, symbol)
148            })
149            .unwrap_or_else(|| "N/A".to_string());
150        BalanceInfo {
151            display,
152            is_available: resp.is_available,
153        }
154    }
155}
156
157#[cfg(test)]
158mod usage_tests {
159    use super::Usage;
160
161    #[test]
162    fn cache_helpers_fall_back_to_cached_prompt_tokens() {
163        let usage = Usage {
164            prompt_tokens: 1_000,
165            completion_tokens: 200,
166            total_tokens: 1_200,
167            cached_prompt_tokens: Some(600),
168            cache_creation_tokens: Some(150),
169            cache_read_tokens: None,
170        };
171
172        assert_eq!(usage.cache_read_tokens_or_fallback(), 600);
173        assert_eq!(usage.cache_creation_tokens_or_zero(), 150);
174        assert_eq!(usage.total_cache_tokens(), 750);
175        assert_eq!(usage.is_cache_hit(), Some(true));
176        assert_eq!(usage.is_cache_miss(), Some(false));
177        assert_eq!(usage.cache_savings_ratio(), Some(0.6));
178        assert_eq!(usage.cache_hit_rate(), Some(80.0));
179    }
180
181    #[test]
182    fn cache_helpers_preserve_unknown_without_metrics() {
183        let usage = Usage {
184            prompt_tokens: 1_000,
185            completion_tokens: 200,
186            total_tokens: 1_200,
187            cached_prompt_tokens: None,
188            cache_creation_tokens: None,
189            cache_read_tokens: None,
190        };
191
192        assert_eq!(usage.total_cache_tokens(), 0);
193        assert_eq!(usage.is_cache_hit(), None);
194        assert_eq!(usage.is_cache_miss(), None);
195        assert_eq!(usage.cache_savings_ratio(), None);
196        assert_eq!(usage.cache_hit_rate(), None);
197    }
198}
199
200#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
201pub enum FinishReason {
202    #[default]
203    Stop,
204    Length,
205    ToolCalls,
206    ContentFilter,
207    Pause,
208    Refusal,
209    Error(String),
210}
211
212/// Universal tool call that matches OpenAI/Anthropic/Gemini specifications
213#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
214pub struct ToolCall {
215    /// Unique identifier for this tool call (e.g., "call_123")
216    pub id: String,
217
218    /// The type of tool call: "function", "custom" (GPT-5 freeform), or other
219    #[serde(rename = "type")]
220    pub call_type: String,
221
222    /// Function call details (for function-type tools)
223    #[serde(skip_serializing_if = "Option::is_none")]
224    pub function: Option<FunctionCall>,
225
226    /// Raw text payload (for custom freeform tools in GPT-5)
227    #[serde(skip_serializing_if = "Option::is_none")]
228    pub text: Option<String>,
229
230    /// Gemini-specific thought signature for maintaining reasoning context
231    #[serde(skip_serializing_if = "Option::is_none")]
232    pub thought_signature: Option<String>,
233}
234
235/// Function call within a tool call
236#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
237pub struct FunctionCall {
238    /// Optional namespace for grouped or deferred tools.
239    #[serde(default, skip_serializing_if = "Option::is_none")]
240    pub namespace: Option<String>,
241
242    /// The name of the function to call
243    pub name: String,
244
245    /// The arguments to pass to the function, as a JSON string
246    pub arguments: String,
247}
248
249impl ToolCall {
250    /// Create a new function tool call
251    pub fn function(id: String, name: String, arguments: String) -> Self {
252        Self::function_with_namespace(id, None, name, arguments)
253    }
254
255    /// Create a new function tool call with an optional namespace.
256    pub fn function_with_namespace(
257        id: String,
258        namespace: Option<String>,
259        name: String,
260        arguments: String,
261    ) -> Self {
262        Self {
263            id,
264            call_type: "function".to_owned(),
265            function: Some(FunctionCall {
266                namespace,
267                name,
268                arguments,
269            }),
270            text: None,
271            thought_signature: None,
272        }
273    }
274
275    /// Create a new custom tool call with raw text payload (GPT-5 freeform)
276    pub fn custom(id: String, name: String, text: String) -> Self {
277        Self {
278            id,
279            call_type: "custom".to_owned(),
280            function: Some(FunctionCall {
281                namespace: None,
282                name,
283                arguments: text.clone(),
284            }),
285            text: Some(text),
286            thought_signature: None,
287        }
288    }
289
290    /// Returns true when this tool call uses GPT-5 custom/freeform semantics.
291    pub fn is_custom(&self) -> bool {
292        self.call_type == "custom"
293    }
294
295    /// Returns the tool name when the call includes function details.
296    pub fn tool_name(&self) -> Option<&str> {
297        self.function
298            .as_ref()
299            .map(|function| function.name.as_str())
300    }
301
302    /// Returns the raw payload text exactly as emitted by the model.
303    pub fn raw_input(&self) -> Option<&str> {
304        self.text.as_deref().or_else(|| {
305            self.function
306                .as_ref()
307                .map(|function| function.arguments.as_str())
308        })
309    }
310
311    /// Parse the arguments as JSON Value (for function-type tools)
312    pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
313        if let Some(ref func) = self.function {
314            parse_tool_arguments(&func.arguments)
315        } else {
316            // Return an error by trying to parse invalid JSON
317            serde_json::from_str("")
318        }
319    }
320
321    /// Returns the execution payload for this tool call.
322    ///
323    /// Function tools keep their JSON semantics. Custom tools execute with their
324    /// raw text payload wrapped as a JSON string value so freeform inputs can
325    /// flow through the existing tool pipeline.
326    pub fn execution_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
327        if self.is_custom() {
328            return Ok(serde_json::Value::String(
329                self.raw_input().unwrap_or_default().to_string(),
330            ));
331        }
332
333        self.parsed_arguments()
334    }
335
336    /// Validate that this tool call is properly formed
337    pub fn validate(&self) -> Result<(), String> {
338        if self.id.is_empty() {
339            return Err("Tool call ID cannot be empty".to_owned());
340        }
341
342        match self.call_type.as_str() {
343            "function" => {
344                if let Some(func) = &self.function {
345                    if func.name.is_empty() {
346                        return Err("Function name cannot be empty".to_owned());
347                    }
348                    // Validate that arguments is valid JSON for function tools
349                    if let Err(e) = self.parsed_arguments() {
350                        return Err(format!("Invalid JSON in function arguments: {}", e));
351                    }
352                } else {
353                    return Err("Function tool call missing function details".to_owned());
354                }
355            }
356            "custom" => {
357                // For custom tools, we allow raw text payload without JSON validation
358                if let Some(func) = &self.function {
359                    if func.name.is_empty() {
360                        return Err("Custom tool name cannot be empty".to_owned());
361                    }
362                } else {
363                    return Err("Custom tool call missing function details".to_owned());
364                }
365            }
366            _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
367        }
368
369        Ok(())
370    }
371}
372
373fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
374    let trimmed = raw_arguments.trim();
375    match serde_json::from_str(trimmed) {
376        Ok(parsed) => Ok(parsed),
377        Err(primary_error) => {
378            if let Some(candidate) = extract_balanced_json(trimmed)
379                && let Ok(parsed) = serde_json::from_str(candidate)
380            {
381                return Ok(parsed);
382            }
383            if let Some(candidate) = repair_tag_polluted_json(trimmed)
384                && let Ok(parsed) = serde_json::from_str(&candidate)
385            {
386                return Ok(parsed);
387            }
388            Err(primary_error)
389        }
390    }
391}
392
393fn extract_balanced_json(input: &str) -> Option<&str> {
394    let start = input.find(['{', '['])?;
395    let opening = input.as_bytes().get(start).copied()?;
396    let closing = match opening {
397        b'{' => b'}',
398        b'[' => b']',
399        _ => return None,
400    };
401
402    let mut depth = 0usize;
403    let mut in_string = false;
404    let mut escaped = false;
405
406    for (offset, ch) in input[start..].char_indices() {
407        if in_string {
408            if escaped {
409                escaped = false;
410                continue;
411            }
412            if ch == '\\' {
413                escaped = true;
414                continue;
415            }
416            if ch == '"' {
417                in_string = false;
418            }
419            continue;
420        }
421
422        match ch {
423            '"' => in_string = true,
424            _ if ch as u32 == opening as u32 => depth += 1,
425            _ if ch as u32 == closing as u32 => {
426                depth = depth.saturating_sub(1);
427                if depth == 0 {
428                    let end = start + offset + ch.len_utf8();
429                    return input.get(start..end);
430                }
431            }
432            _ => {}
433        }
434    }
435
436    None
437}
438
439fn repair_tag_polluted_json(input: &str) -> Option<String> {
440    let start = input.find(['{', '['])?;
441    let candidate = input.get(start..)?;
442    let boundary = find_provider_markup_boundary(candidate)?;
443    if boundary == 0 {
444        return None;
445    }
446
447    close_incomplete_json_prefix(candidate[..boundary].trim_end())
448}
449
450fn find_provider_markup_boundary(input: &str) -> Option<usize> {
451    const PROVIDER_MARKERS: &[&str] = &[
452        "<</",
453        "</parameter>",
454        "</invoke>",
455        "</minimax:tool_call>",
456        "<minimax:tool_call>",
457        "<parameter name=\"",
458        "<invoke name=\"",
459        "<tool_call>",
460        "</tool_call>",
461    ];
462
463    input.char_indices().find_map(|(offset, _)| {
464        let rest = input.get(offset..)?;
465        PROVIDER_MARKERS
466            .iter()
467            .any(|marker| rest.starts_with(marker))
468            .then_some(offset)
469    })
470}
471
472fn close_incomplete_json_prefix(prefix: &str) -> Option<String> {
473    if prefix.is_empty() {
474        return None;
475    }
476
477    let mut repaired = String::with_capacity(prefix.len() + 8);
478    let mut expected_closers = Vec::new();
479    let mut in_string = false;
480    let mut escaped = false;
481
482    for ch in prefix.chars() {
483        repaired.push(ch);
484
485        if in_string {
486            if escaped {
487                escaped = false;
488                continue;
489            }
490
491            match ch {
492                '\\' => escaped = true,
493                '"' => in_string = false,
494                _ => {}
495            }
496            continue;
497        }
498
499        match ch {
500            '"' => in_string = true,
501            '{' => expected_closers.push('}'),
502            '[' => expected_closers.push(']'),
503            '}' | ']' => {
504                if expected_closers.pop() != Some(ch) {
505                    return None;
506                }
507            }
508            _ => {}
509        }
510    }
511
512    if in_string {
513        repaired.push('"');
514    }
515    for closer in expected_closers.drain(..) {
516        repaired.push(closer);
517    }
518
519    Some(repaired)
520}
521
522/// Universal LLM response structure
523#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
524pub struct LLMResponse {
525    /// The response content text
526    pub content: Option<String>,
527
528    /// Tool calls made by the model
529    pub tool_calls: Option<Vec<ToolCall>>,
530
531    /// The model that generated this response
532    pub model: String,
533
534    /// Token usage statistics
535    pub usage: Option<Usage>,
536
537    /// Why the response finished
538    pub finish_reason: FinishReason,
539
540    /// Reasoning content (for models that support it)
541    pub reasoning: Option<String>,
542
543    /// Detailed reasoning traces (for models that support it)
544    pub reasoning_details: Option<Vec<String>>,
545
546    /// Tool references for context
547    pub tool_references: Vec<String>,
548
549    /// Request ID from the provider
550    pub request_id: Option<String>,
551
552    /// Organization ID from the provider
553    pub organization_id: Option<String>,
554
555    /// Compaction summary content from Anthropic's server-side compaction.
556    /// Populated when `stop_reason` is `Pause` (from `"compaction"`).
557    /// The caller should pass this back in subsequent requests so the API
558    /// can drop prior messages before the compaction block.
559    pub compaction: Option<String>,
560}
561
562impl LLMResponse {
563    /// Create a new LLM response with mandatory fields
564    pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
565        Self {
566            content: Some(content.into()),
567            tool_calls: None,
568            model: model.into(),
569            usage: None,
570            finish_reason: FinishReason::Stop,
571            reasoning: None,
572            reasoning_details: None,
573            tool_references: Vec::new(),
574            request_id: None,
575            organization_id: None,
576            compaction: None,
577        }
578    }
579
580    /// Get content or empty string
581    pub fn content_text(&self) -> &str {
582        self.content.as_deref().unwrap_or("")
583    }
584
585    /// Get content as String (clone)
586    pub fn content_string(&self) -> String {
587        self.content.clone().unwrap_or_default()
588    }
589}
590
591#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
592pub struct LLMErrorMetadata {
593    pub provider: Option<String>,
594    pub status: Option<u16>,
595    pub code: Option<String>,
596    pub request_id: Option<String>,
597    pub organization_id: Option<String>,
598    pub retry_after: Option<String>,
599    pub message: Option<String>,
600}
601
602impl LLMErrorMetadata {
603    /// Boxed constructor because metadata is always stored inside `Option<Box<LLMErrorMetadata>>`
604    /// in the LLMError enum variants.
605    #[must_use]
606    pub fn new(
607        provider: impl Into<String>,
608        status: Option<u16>,
609        code: Option<String>,
610        request_id: Option<String>,
611        organization_id: Option<String>,
612        retry_after: Option<String>,
613        message: Option<String>,
614    ) -> Box<Self> {
615        Box::new(Self {
616            provider: Some(provider.into()),
617            status,
618            code,
619            request_id,
620            organization_id,
621            retry_after,
622            message,
623        })
624    }
625}
626
627/// LLM error types with optional provider metadata
628#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
629#[serde(tag = "type", rename_all = "snake_case")]
630pub enum LLMError {
631    #[error("Authentication failed: {message}")]
632    Authentication {
633        message: String,
634        metadata: Option<Box<LLMErrorMetadata>>,
635    },
636    #[error("Rate limit exceeded")]
637    RateLimit {
638        metadata: Option<Box<LLMErrorMetadata>>,
639    },
640    #[error("Invalid request: {message}")]
641    InvalidRequest {
642        message: String,
643        metadata: Option<Box<LLMErrorMetadata>>,
644    },
645    #[error("Network error: {message}")]
646    Network {
647        message: String,
648        metadata: Option<Box<LLMErrorMetadata>>,
649    },
650    #[error("Provider error: {message}")]
651    Provider {
652        message: String,
653        metadata: Option<Box<LLMErrorMetadata>>,
654    },
655}
656
657#[cfg(test)]
658mod tests {
659    use super::ToolCall;
660    use serde_json::json;
661
662    #[test]
663    fn parsed_arguments_accepts_trailing_characters() {
664        let call = ToolCall::function(
665            "call_read".to_string(),
666            "read_file".to_string(),
667            r#"{"path":"src/main.rs"} trailing text"#.to_string(),
668        );
669
670        let parsed = call
671            .parsed_arguments()
672            .expect("arguments with trailing text should recover");
673        assert_eq!(parsed, json!({"path":"src/main.rs"}));
674    }
675
676    #[test]
677    fn parsed_arguments_accepts_code_fenced_json() {
678        let call = ToolCall::function(
679            "call_read".to_string(),
680            "read_file".to_string(),
681            "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
682        );
683
684        let parsed = call
685            .parsed_arguments()
686            .expect("code-fenced arguments should recover");
687        assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
688    }
689
690    #[test]
691    fn parsed_arguments_rejects_incomplete_json() {
692        let call = ToolCall::function(
693            "call_read".to_string(),
694            "read_file".to_string(),
695            r#"{"path":"src/main.rs""#.to_string(),
696        );
697
698        assert!(call.parsed_arguments().is_err());
699    }
700
701    #[test]
702    fn parsed_arguments_recovers_truncated_minimax_markup() {
703        let call = ToolCall::function(
704            "call_search".to_string(),
705            "unified_search".to_string(),
706            "{\"action\": \"grep\", \"pattern\": \"persistent_memory\", \"path\": \"vtcode-core/src</parameter>\n<</invoke>\n</minimax:tool_call>".to_string(),
707        );
708
709        let parsed = call
710            .parsed_arguments()
711            .expect("minimax markup spillover should recover");
712        assert_eq!(
713            parsed,
714            json!({
715                "action": "grep",
716                "pattern": "persistent_memory",
717                "path": "vtcode-core/src"
718            })
719        );
720    }
721
722    #[test]
723    fn function_call_serializes_optional_namespace() {
724        let call = ToolCall::function_with_namespace(
725            "call_read".to_string(),
726            Some("workspace".to_string()),
727            "read_file".to_string(),
728            r#"{"path":"src/main.rs"}"#.to_string(),
729        );
730
731        let json = serde_json::to_value(&call).expect("tool call should serialize");
732        assert_eq!(json["function"]["namespace"], "workspace");
733        assert_eq!(json["function"]["name"], "read_file");
734    }
735
736    #[test]
737    fn custom_tool_call_exposes_raw_execution_arguments() {
738        let patch = "*** Begin Patch\n*** End Patch\n".to_string();
739        let call = ToolCall::custom(
740            "call_patch".to_string(),
741            "apply_patch".to_string(),
742            patch.clone(),
743        );
744
745        assert!(call.is_custom());
746        assert_eq!(call.tool_name(), Some("apply_patch"));
747        assert_eq!(call.raw_input(), Some(patch.as_str()));
748        assert_eq!(
749            call.execution_arguments().expect("custom arguments"),
750            json!(patch)
751        );
752        assert!(
753            call.parsed_arguments().is_err(),
754            "custom tool payload should stay freeform rather than JSON"
755        );
756    }
757}