Skip to main content

vtcode_commons/
llm.rs

1//! Core LLM types shared across the project
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7    Gemini,
8    OpenAI,
9    Anthropic,
10    DeepSeek,
11    Mistral,
12    OpenRouter,
13    Ollama,
14    ZAI,
15    Moonshot,
16    HuggingFace,
17    Minimax,
18    MiMo,
19    OpenCodeZen,
20    OpenCodeGo,
21    Qwen,
22    Poolside,
23}
24
25#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
26pub struct Usage {
27    pub prompt_tokens: u32,
28    pub completion_tokens: u32,
29    pub total_tokens: u32,
30    pub cached_prompt_tokens: Option<u32>,
31    pub cache_creation_tokens: Option<u32>,
32    pub cache_read_tokens: Option<u32>,
33}
34
35impl Usage {
36    #[inline]
37    fn has_cache_read_metric(&self) -> bool {
38        self.cache_read_tokens.is_some() || self.cached_prompt_tokens.is_some()
39    }
40
41    #[inline]
42    fn has_any_cache_metrics(&self) -> bool {
43        self.has_cache_read_metric() || self.cache_creation_tokens.is_some()
44    }
45
46    #[inline]
47    pub fn cache_read_tokens_or_fallback(&self) -> u32 {
48        self.cache_read_tokens
49            .or(self.cached_prompt_tokens)
50            .unwrap_or(0)
51    }
52
53    #[inline]
54    pub fn cache_creation_tokens_or_zero(&self) -> u32 {
55        self.cache_creation_tokens.unwrap_or(0)
56    }
57
58    #[inline]
59    pub fn cache_hit_rate(&self) -> Option<f64> {
60        if !self.has_any_cache_metrics() {
61            return None;
62        }
63        let read = self.cache_read_tokens_or_fallback() as f64;
64        let creation = self.cache_creation_tokens_or_zero() as f64;
65        let total = read + creation;
66        if total > 0.0 {
67            Some((read / total) * 100.0)
68        } else {
69            None
70        }
71    }
72
73    #[inline]
74    pub fn is_cache_hit(&self) -> Option<bool> {
75        self.has_any_cache_metrics()
76            .then(|| self.cache_read_tokens_or_fallback() > 0)
77    }
78
79    #[inline]
80    pub fn is_cache_miss(&self) -> Option<bool> {
81        self.has_any_cache_metrics().then(|| {
82            self.cache_creation_tokens_or_zero() > 0 && self.cache_read_tokens_or_fallback() == 0
83        })
84    }
85
86    #[inline]
87    pub fn total_cache_tokens(&self) -> u32 {
88        let read = self.cache_read_tokens_or_fallback();
89        let creation = self.cache_creation_tokens_or_zero();
90        read + creation
91    }
92
93    #[inline]
94    pub fn cache_savings_ratio(&self) -> Option<f64> {
95        if !self.has_cache_read_metric() {
96            return None;
97        }
98        let read = self.cache_read_tokens_or_fallback() as f64;
99        let prompt = self.prompt_tokens as f64;
100        if prompt > 0.0 {
101            Some(read / prompt)
102        } else {
103            None
104        }
105    }
106}
107
108/// Provider-agnostic balance information for account status display.
109#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
110pub struct BalanceInfo {
111    /// Human-readable balance string (e.g. "100.00¥", "$50.00").
112    pub display: String,
113    /// Whether the account has sufficient balance for API calls.
114    pub is_available: bool,
115}
116
117/// DeepSeek-specific balance info from GET /user/balance
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct DeepSeekBalanceResponse {
120    pub is_available: bool,
121    pub balance_infos: Vec<DeepSeekCurrencyBalance>,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct DeepSeekCurrencyBalance {
126    pub currency: String,
127    pub total_balance: String,
128    #[serde(default)]
129    pub granted_balance: String,
130    #[serde(default)]
131    pub topped_up_balance: String,
132}
133
134impl From<DeepSeekBalanceResponse> for BalanceInfo {
135    fn from(resp: DeepSeekBalanceResponse) -> Self {
136        let display = resp
137            .balance_infos
138            .first()
139            .map(|b| {
140                let symbol = match b.currency.as_str() {
141                    "CNY" => "¥",
142                    "USD" => "$",
143                    _ => &b.currency,
144                };
145                format!("{}{}", b.total_balance, symbol)
146            })
147            .unwrap_or_else(|| "N/A".to_string());
148        BalanceInfo {
149            display,
150            is_available: resp.is_available,
151        }
152    }
153}
154
155#[cfg(test)]
156mod usage_tests {
157    use super::Usage;
158
159    #[test]
160    fn cache_helpers_fall_back_to_cached_prompt_tokens() {
161        let usage = Usage {
162            prompt_tokens: 1_000,
163            completion_tokens: 200,
164            total_tokens: 1_200,
165            cached_prompt_tokens: Some(600),
166            cache_creation_tokens: Some(150),
167            cache_read_tokens: None,
168        };
169
170        assert_eq!(usage.cache_read_tokens_or_fallback(), 600);
171        assert_eq!(usage.cache_creation_tokens_or_zero(), 150);
172        assert_eq!(usage.total_cache_tokens(), 750);
173        assert_eq!(usage.is_cache_hit(), Some(true));
174        assert_eq!(usage.is_cache_miss(), Some(false));
175        assert_eq!(usage.cache_savings_ratio(), Some(0.6));
176        assert_eq!(usage.cache_hit_rate(), Some(80.0));
177    }
178
179    #[test]
180    fn cache_helpers_preserve_unknown_without_metrics() {
181        let usage = Usage {
182            prompt_tokens: 1_000,
183            completion_tokens: 200,
184            total_tokens: 1_200,
185            cached_prompt_tokens: None,
186            cache_creation_tokens: None,
187            cache_read_tokens: None,
188        };
189
190        assert_eq!(usage.total_cache_tokens(), 0);
191        assert_eq!(usage.is_cache_hit(), None);
192        assert_eq!(usage.is_cache_miss(), None);
193        assert_eq!(usage.cache_savings_ratio(), None);
194        assert_eq!(usage.cache_hit_rate(), None);
195    }
196}
197
198#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
199pub enum FinishReason {
200    #[default]
201    Stop,
202    Length,
203    ToolCalls,
204    ContentFilter,
205    Pause,
206    Refusal,
207    Error(String),
208}
209
210/// Universal tool call that matches OpenAI/Anthropic/Gemini specifications
211#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
212pub struct ToolCall {
213    /// Unique identifier for this tool call (e.g., "call_123")
214    pub id: String,
215
216    /// The type of tool call: "function", "custom" (GPT-5 freeform), or other
217    #[serde(rename = "type")]
218    pub call_type: String,
219
220    /// Function call details (for function-type tools)
221    #[serde(skip_serializing_if = "Option::is_none")]
222    pub function: Option<FunctionCall>,
223
224    /// Raw text payload (for custom freeform tools in GPT-5)
225    #[serde(skip_serializing_if = "Option::is_none")]
226    pub text: Option<String>,
227
228    /// Gemini-specific thought signature for maintaining reasoning context
229    #[serde(skip_serializing_if = "Option::is_none")]
230    pub thought_signature: Option<String>,
231}
232
233/// Function call within a tool call
234#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
235pub struct FunctionCall {
236    /// Optional namespace for grouped or deferred tools.
237    #[serde(default, skip_serializing_if = "Option::is_none")]
238    pub namespace: Option<String>,
239
240    /// The name of the function to call
241    pub name: String,
242
243    /// The arguments to pass to the function, as a JSON string
244    pub arguments: String,
245}
246
247impl ToolCall {
248    /// Create a new function tool call
249    pub fn function(id: String, name: String, arguments: String) -> Self {
250        Self::function_with_namespace(id, None, name, arguments)
251    }
252
253    /// Create a new function tool call with an optional namespace.
254    pub fn function_with_namespace(
255        id: String,
256        namespace: Option<String>,
257        name: String,
258        arguments: String,
259    ) -> Self {
260        Self {
261            id,
262            call_type: "function".to_owned(),
263            function: Some(FunctionCall {
264                namespace,
265                name,
266                arguments,
267            }),
268            text: None,
269            thought_signature: None,
270        }
271    }
272
273    /// Create a new custom tool call with raw text payload (GPT-5 freeform)
274    pub fn custom(id: String, name: String, text: String) -> Self {
275        Self {
276            id,
277            call_type: "custom".to_owned(),
278            function: Some(FunctionCall {
279                namespace: None,
280                name,
281                arguments: text.clone(),
282            }),
283            text: Some(text),
284            thought_signature: None,
285        }
286    }
287
288    /// Returns true when this tool call uses GPT-5 custom/freeform semantics.
289    pub fn is_custom(&self) -> bool {
290        self.call_type == "custom"
291    }
292
293    /// Returns the tool name when the call includes function details.
294    pub fn tool_name(&self) -> Option<&str> {
295        self.function
296            .as_ref()
297            .map(|function| function.name.as_str())
298    }
299
300    /// Returns the raw payload text exactly as emitted by the model.
301    pub fn raw_input(&self) -> Option<&str> {
302        self.text.as_deref().or_else(|| {
303            self.function
304                .as_ref()
305                .map(|function| function.arguments.as_str())
306        })
307    }
308
309    /// Parse the arguments as JSON Value (for function-type tools)
310    pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
311        if let Some(ref func) = self.function {
312            parse_tool_arguments(&func.arguments)
313        } else {
314            // Return an error by trying to parse invalid JSON
315            serde_json::from_str("")
316        }
317    }
318
319    /// Returns the execution payload for this tool call.
320    ///
321    /// Function tools keep their JSON semantics. Custom tools execute with their
322    /// raw text payload wrapped as a JSON string value so freeform inputs can
323    /// flow through the existing tool pipeline.
324    pub fn execution_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
325        if self.is_custom() {
326            return Ok(serde_json::Value::String(
327                self.raw_input().unwrap_or_default().to_string(),
328            ));
329        }
330
331        self.parsed_arguments()
332    }
333
334    /// Validate that this tool call is properly formed
335    pub fn validate(&self) -> Result<(), String> {
336        if self.id.is_empty() {
337            return Err("Tool call ID cannot be empty".to_owned());
338        }
339
340        match self.call_type.as_str() {
341            "function" => {
342                if let Some(func) = &self.function {
343                    if func.name.is_empty() {
344                        return Err("Function name cannot be empty".to_owned());
345                    }
346                    // Validate that arguments is valid JSON for function tools
347                    if let Err(e) = self.parsed_arguments() {
348                        return Err(format!("Invalid JSON in function arguments: {}", e));
349                    }
350                } else {
351                    return Err("Function tool call missing function details".to_owned());
352                }
353            }
354            "custom" => {
355                // For custom tools, we allow raw text payload without JSON validation
356                if let Some(func) = &self.function {
357                    if func.name.is_empty() {
358                        return Err("Custom tool name cannot be empty".to_owned());
359                    }
360                } else {
361                    return Err("Custom tool call missing function details".to_owned());
362                }
363            }
364            _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
365        }
366
367        Ok(())
368    }
369}
370
371fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
372    let trimmed = raw_arguments.trim();
373    match serde_json::from_str(trimmed) {
374        Ok(parsed) => Ok(parsed),
375        Err(primary_error) => {
376            if let Some(candidate) = extract_balanced_json(trimmed)
377                && let Ok(parsed) = serde_json::from_str(candidate)
378            {
379                return Ok(parsed);
380            }
381            if let Some(candidate) = repair_tag_polluted_json(trimmed)
382                && let Ok(parsed) = serde_json::from_str(&candidate)
383            {
384                return Ok(parsed);
385            }
386            Err(primary_error)
387        }
388    }
389}
390
391fn extract_balanced_json(input: &str) -> Option<&str> {
392    let start = input.find(['{', '['])?;
393    let opening = input.as_bytes().get(start).copied()?;
394    let closing = match opening {
395        b'{' => b'}',
396        b'[' => b']',
397        _ => return None,
398    };
399
400    let mut depth = 0usize;
401    let mut in_string = false;
402    let mut escaped = false;
403
404    for (offset, ch) in input[start..].char_indices() {
405        if in_string {
406            if escaped {
407                escaped = false;
408                continue;
409            }
410            if ch == '\\' {
411                escaped = true;
412                continue;
413            }
414            if ch == '"' {
415                in_string = false;
416            }
417            continue;
418        }
419
420        match ch {
421            '"' => in_string = true,
422            _ if ch as u32 == opening as u32 => depth += 1,
423            _ if ch as u32 == closing as u32 => {
424                depth = depth.saturating_sub(1);
425                if depth == 0 {
426                    let end = start + offset + ch.len_utf8();
427                    return input.get(start..end);
428                }
429            }
430            _ => {}
431        }
432    }
433
434    None
435}
436
437fn repair_tag_polluted_json(input: &str) -> Option<String> {
438    let start = input.find(['{', '['])?;
439    let candidate = input.get(start..)?;
440    let boundary = find_provider_markup_boundary(candidate)?;
441    if boundary == 0 {
442        return None;
443    }
444
445    close_incomplete_json_prefix(candidate[..boundary].trim_end())
446}
447
448fn find_provider_markup_boundary(input: &str) -> Option<usize> {
449    const PROVIDER_MARKERS: &[&str] = &[
450        "<</",
451        "</parameter>",
452        "</invoke>",
453        "</minimax:tool_call>",
454        "<minimax:tool_call>",
455        "<parameter name=\"",
456        "<invoke name=\"",
457        "<tool_call>",
458        "</tool_call>",
459    ];
460
461    input.char_indices().find_map(|(offset, _)| {
462        let rest = input.get(offset..)?;
463        PROVIDER_MARKERS
464            .iter()
465            .any(|marker| rest.starts_with(marker))
466            .then_some(offset)
467    })
468}
469
470fn close_incomplete_json_prefix(prefix: &str) -> Option<String> {
471    if prefix.is_empty() {
472        return None;
473    }
474
475    let mut repaired = String::with_capacity(prefix.len() + 8);
476    let mut expected_closers = Vec::new();
477    let mut in_string = false;
478    let mut escaped = false;
479
480    for ch in prefix.chars() {
481        repaired.push(ch);
482
483        if in_string {
484            if escaped {
485                escaped = false;
486                continue;
487            }
488
489            match ch {
490                '\\' => escaped = true,
491                '"' => in_string = false,
492                _ => {}
493            }
494            continue;
495        }
496
497        match ch {
498            '"' => in_string = true,
499            '{' => expected_closers.push('}'),
500            '[' => expected_closers.push(']'),
501            '}' | ']' => {
502                if expected_closers.pop() != Some(ch) {
503                    return None;
504                }
505            }
506            _ => {}
507        }
508    }
509
510    if in_string {
511        repaired.push('"');
512    }
513    for closer in expected_closers.drain(..) {
514        repaired.push(closer);
515    }
516
517    Some(repaired)
518}
519
520/// Universal LLM response structure
521#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
522pub struct LLMResponse {
523    /// The response content text
524    pub content: Option<String>,
525
526    /// Tool calls made by the model
527    pub tool_calls: Option<Vec<ToolCall>>,
528
529    /// The model that generated this response
530    pub model: String,
531
532    /// Token usage statistics
533    pub usage: Option<Usage>,
534
535    /// Why the response finished
536    pub finish_reason: FinishReason,
537
538    /// Reasoning content (for models that support it)
539    pub reasoning: Option<String>,
540
541    /// Detailed reasoning traces (for models that support it)
542    pub reasoning_details: Option<Vec<String>>,
543
544    /// Tool references for context
545    pub tool_references: Vec<String>,
546
547    /// Request ID from the provider
548    pub request_id: Option<String>,
549
550    /// Organization ID from the provider
551    pub organization_id: Option<String>,
552
553    /// Compaction summary content from Anthropic's server-side compaction.
554    /// Populated when `stop_reason` is `Pause` (from `"compaction"`).
555    /// The caller should pass this back in subsequent requests so the API
556    /// can drop prior messages before the compaction block.
557    pub compaction: Option<String>,
558}
559
560impl LLMResponse {
561    /// Create a new LLM response with mandatory fields
562    pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
563        Self {
564            content: Some(content.into()),
565            tool_calls: None,
566            model: model.into(),
567            usage: None,
568            finish_reason: FinishReason::Stop,
569            reasoning: None,
570            reasoning_details: None,
571            tool_references: Vec::new(),
572            request_id: None,
573            organization_id: None,
574            compaction: None,
575        }
576    }
577
578    /// Get content or empty string
579    pub fn content_text(&self) -> &str {
580        self.content.as_deref().unwrap_or("")
581    }
582
583    /// Get content as String (clone)
584    pub fn content_string(&self) -> String {
585        self.content.clone().unwrap_or_default()
586    }
587}
588
589#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
590pub struct LLMErrorMetadata {
591    pub provider: Option<String>,
592    pub status: Option<u16>,
593    pub code: Option<String>,
594    pub request_id: Option<String>,
595    pub organization_id: Option<String>,
596    pub retry_after: Option<String>,
597    pub message: Option<String>,
598}
599
600impl LLMErrorMetadata {
601    /// Boxed constructor because metadata is always stored inside `Option<Box<LLMErrorMetadata>>`
602    /// in the LLMError enum variants.
603    #[must_use]
604    pub fn new(
605        provider: impl Into<String>,
606        status: Option<u16>,
607        code: Option<String>,
608        request_id: Option<String>,
609        organization_id: Option<String>,
610        retry_after: Option<String>,
611        message: Option<String>,
612    ) -> Box<Self> {
613        Box::new(Self {
614            provider: Some(provider.into()),
615            status,
616            code,
617            request_id,
618            organization_id,
619            retry_after,
620            message,
621        })
622    }
623}
624
625/// LLM error types with optional provider metadata
626#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
627#[serde(tag = "type", rename_all = "snake_case")]
628pub enum LLMError {
629    #[error("Authentication failed: {message}")]
630    Authentication {
631        message: String,
632        metadata: Option<Box<LLMErrorMetadata>>,
633    },
634    #[error("Rate limit exceeded")]
635    RateLimit {
636        metadata: Option<Box<LLMErrorMetadata>>,
637    },
638    #[error("Invalid request: {message}")]
639    InvalidRequest {
640        message: String,
641        metadata: Option<Box<LLMErrorMetadata>>,
642    },
643    #[error("Network error: {message}")]
644    Network {
645        message: String,
646        metadata: Option<Box<LLMErrorMetadata>>,
647    },
648    #[error("Provider error: {message}")]
649    Provider {
650        message: String,
651        metadata: Option<Box<LLMErrorMetadata>>,
652    },
653}
654
655#[cfg(test)]
656mod tests {
657    use super::ToolCall;
658    use serde_json::json;
659
660    #[test]
661    fn parsed_arguments_accepts_trailing_characters() {
662        let call = ToolCall::function(
663            "call_read".to_string(),
664            "read_file".to_string(),
665            r#"{"path":"src/main.rs"} trailing text"#.to_string(),
666        );
667
668        let parsed = call
669            .parsed_arguments()
670            .expect("arguments with trailing text should recover");
671        assert_eq!(parsed, json!({"path":"src/main.rs"}));
672    }
673
674    #[test]
675    fn parsed_arguments_accepts_code_fenced_json() {
676        let call = ToolCall::function(
677            "call_read".to_string(),
678            "read_file".to_string(),
679            "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
680        );
681
682        let parsed = call
683            .parsed_arguments()
684            .expect("code-fenced arguments should recover");
685        assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
686    }
687
688    #[test]
689    fn parsed_arguments_rejects_incomplete_json() {
690        let call = ToolCall::function(
691            "call_read".to_string(),
692            "read_file".to_string(),
693            r#"{"path":"src/main.rs""#.to_string(),
694        );
695
696        assert!(call.parsed_arguments().is_err());
697    }
698
699    #[test]
700    fn parsed_arguments_recovers_truncated_minimax_markup() {
701        let call = ToolCall::function(
702            "call_search".to_string(),
703            "unified_search".to_string(),
704            "{\"action\": \"grep\", \"pattern\": \"persistent_memory\", \"path\": \"vtcode-core/src</parameter>\n<</invoke>\n</minimax:tool_call>".to_string(),
705        );
706
707        let parsed = call
708            .parsed_arguments()
709            .expect("minimax markup spillover should recover");
710        assert_eq!(
711            parsed,
712            json!({
713                "action": "grep",
714                "pattern": "persistent_memory",
715                "path": "vtcode-core/src"
716            })
717        );
718    }
719
720    #[test]
721    fn function_call_serializes_optional_namespace() {
722        let call = ToolCall::function_with_namespace(
723            "call_read".to_string(),
724            Some("workspace".to_string()),
725            "read_file".to_string(),
726            r#"{"path":"src/main.rs"}"#.to_string(),
727        );
728
729        let json = serde_json::to_value(&call).expect("tool call should serialize");
730        assert_eq!(json["function"]["namespace"], "workspace");
731        assert_eq!(json["function"]["name"], "read_file");
732    }
733
734    #[test]
735    fn custom_tool_call_exposes_raw_execution_arguments() {
736        let patch = "*** Begin Patch\n*** End Patch\n".to_string();
737        let call = ToolCall::custom(
738            "call_patch".to_string(),
739            "apply_patch".to_string(),
740            patch.clone(),
741        );
742
743        assert!(call.is_custom());
744        assert_eq!(call.tool_name(), Some("apply_patch"));
745        assert_eq!(call.raw_input(), Some(patch.as_str()));
746        assert_eq!(
747            call.execution_arguments().expect("custom arguments"),
748            json!(patch)
749        );
750        assert!(
751            call.parsed_arguments().is_err(),
752            "custom tool payload should stay freeform rather than JSON"
753        );
754    }
755}