Skip to main content

vtcode_commons/
llm.rs

1//! Core LLM types shared across the project
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7    Gemini,
8    OpenAI,
9    Anthropic,
10    DeepSeek,
11    Mistral,
12    OpenRouter,
13    Ollama,
14    ZAI,
15    Moonshot,
16    HuggingFace,
17    Minimax,
18    OpenCodeZen,
19    OpenCodeGo,
20}
21
22#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
23pub struct Usage {
24    pub prompt_tokens: u32,
25    pub completion_tokens: u32,
26    pub total_tokens: u32,
27    pub cached_prompt_tokens: Option<u32>,
28    pub cache_creation_tokens: Option<u32>,
29    pub cache_read_tokens: Option<u32>,
30}
31
32impl Usage {
33    #[inline]
34    fn has_cache_read_metric(&self) -> bool {
35        self.cache_read_tokens.is_some() || self.cached_prompt_tokens.is_some()
36    }
37
38    #[inline]
39    fn has_any_cache_metrics(&self) -> bool {
40        self.has_cache_read_metric() || self.cache_creation_tokens.is_some()
41    }
42
43    #[inline]
44    pub fn cache_read_tokens_or_fallback(&self) -> u32 {
45        self.cache_read_tokens
46            .or(self.cached_prompt_tokens)
47            .unwrap_or(0)
48    }
49
50    #[inline]
51    pub fn cache_creation_tokens_or_zero(&self) -> u32 {
52        self.cache_creation_tokens.unwrap_or(0)
53    }
54
55    #[inline]
56    pub fn cache_hit_rate(&self) -> Option<f64> {
57        if !self.has_any_cache_metrics() {
58            return None;
59        }
60        let read = self.cache_read_tokens_or_fallback() as f64;
61        let creation = self.cache_creation_tokens_or_zero() as f64;
62        let total = read + creation;
63        if total > 0.0 {
64            Some((read / total) * 100.0)
65        } else {
66            None
67        }
68    }
69
70    #[inline]
71    pub fn is_cache_hit(&self) -> Option<bool> {
72        self.has_any_cache_metrics()
73            .then(|| self.cache_read_tokens_or_fallback() > 0)
74    }
75
76    #[inline]
77    pub fn is_cache_miss(&self) -> Option<bool> {
78        self.has_any_cache_metrics().then(|| {
79            self.cache_creation_tokens_or_zero() > 0 && self.cache_read_tokens_or_fallback() == 0
80        })
81    }
82
83    #[inline]
84    pub fn total_cache_tokens(&self) -> u32 {
85        let read = self.cache_read_tokens_or_fallback();
86        let creation = self.cache_creation_tokens_or_zero();
87        read + creation
88    }
89
90    #[inline]
91    pub fn cache_savings_ratio(&self) -> Option<f64> {
92        if !self.has_cache_read_metric() {
93            return None;
94        }
95        let read = self.cache_read_tokens_or_fallback() as f64;
96        let prompt = self.prompt_tokens as f64;
97        if prompt > 0.0 {
98            Some(read / prompt)
99        } else {
100            None
101        }
102    }
103}
104
105/// Provider-agnostic balance information for account status display.
106#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
107pub struct BalanceInfo {
108    /// Human-readable balance string (e.g. "100.00¥", "$50.00").
109    pub display: String,
110    /// Whether the account has sufficient balance for API calls.
111    pub is_available: bool,
112}
113
114/// DeepSeek-specific balance info from GET /user/balance
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct DeepSeekBalanceResponse {
117    pub is_available: bool,
118    pub balance_infos: Vec<DeepSeekCurrencyBalance>,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct DeepSeekCurrencyBalance {
123    pub currency: String,
124    pub total_balance: String,
125    #[serde(default)]
126    pub granted_balance: String,
127    #[serde(default)]
128    pub topped_up_balance: String,
129}
130
131impl From<DeepSeekBalanceResponse> for BalanceInfo {
132    fn from(resp: DeepSeekBalanceResponse) -> Self {
133        let display = resp
134            .balance_infos
135            .first()
136            .map(|b| {
137                let symbol = match b.currency.as_str() {
138                    "CNY" => "¥",
139                    "USD" => "$",
140                    _ => &b.currency,
141                };
142                format!("{}{}", b.total_balance, symbol)
143            })
144            .unwrap_or_else(|| "N/A".to_string());
145        BalanceInfo {
146            display,
147            is_available: resp.is_available,
148        }
149    }
150}
151
152#[cfg(test)]
153mod usage_tests {
154    use super::Usage;
155
156    #[test]
157    fn cache_helpers_fall_back_to_cached_prompt_tokens() {
158        let usage = Usage {
159            prompt_tokens: 1_000,
160            completion_tokens: 200,
161            total_tokens: 1_200,
162            cached_prompt_tokens: Some(600),
163            cache_creation_tokens: Some(150),
164            cache_read_tokens: None,
165        };
166
167        assert_eq!(usage.cache_read_tokens_or_fallback(), 600);
168        assert_eq!(usage.cache_creation_tokens_or_zero(), 150);
169        assert_eq!(usage.total_cache_tokens(), 750);
170        assert_eq!(usage.is_cache_hit(), Some(true));
171        assert_eq!(usage.is_cache_miss(), Some(false));
172        assert_eq!(usage.cache_savings_ratio(), Some(0.6));
173        assert_eq!(usage.cache_hit_rate(), Some(80.0));
174    }
175
176    #[test]
177    fn cache_helpers_preserve_unknown_without_metrics() {
178        let usage = Usage {
179            prompt_tokens: 1_000,
180            completion_tokens: 200,
181            total_tokens: 1_200,
182            cached_prompt_tokens: None,
183            cache_creation_tokens: None,
184            cache_read_tokens: None,
185        };
186
187        assert_eq!(usage.total_cache_tokens(), 0);
188        assert_eq!(usage.is_cache_hit(), None);
189        assert_eq!(usage.is_cache_miss(), None);
190        assert_eq!(usage.cache_savings_ratio(), None);
191        assert_eq!(usage.cache_hit_rate(), None);
192    }
193}
194
195#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
196pub enum FinishReason {
197    #[default]
198    Stop,
199    Length,
200    ToolCalls,
201    ContentFilter,
202    Pause,
203    Refusal,
204    Error(String),
205}
206
207/// Universal tool call that matches OpenAI/Anthropic/Gemini specifications
208#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
209pub struct ToolCall {
210    /// Unique identifier for this tool call (e.g., "call_123")
211    pub id: String,
212
213    /// The type of tool call: "function", "custom" (GPT-5 freeform), or other
214    #[serde(rename = "type")]
215    pub call_type: String,
216
217    /// Function call details (for function-type tools)
218    #[serde(skip_serializing_if = "Option::is_none")]
219    pub function: Option<FunctionCall>,
220
221    /// Raw text payload (for custom freeform tools in GPT-5)
222    #[serde(skip_serializing_if = "Option::is_none")]
223    pub text: Option<String>,
224
225    /// Gemini-specific thought signature for maintaining reasoning context
226    #[serde(skip_serializing_if = "Option::is_none")]
227    pub thought_signature: Option<String>,
228}
229
230/// Function call within a tool call
231#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
232pub struct FunctionCall {
233    /// Optional namespace for grouped or deferred tools.
234    #[serde(default, skip_serializing_if = "Option::is_none")]
235    pub namespace: Option<String>,
236
237    /// The name of the function to call
238    pub name: String,
239
240    /// The arguments to pass to the function, as a JSON string
241    pub arguments: String,
242}
243
244impl ToolCall {
245    /// Create a new function tool call
246    pub fn function(id: String, name: String, arguments: String) -> Self {
247        Self::function_with_namespace(id, None, name, arguments)
248    }
249
250    /// Create a new function tool call with an optional namespace.
251    pub fn function_with_namespace(
252        id: String,
253        namespace: Option<String>,
254        name: String,
255        arguments: String,
256    ) -> Self {
257        Self {
258            id,
259            call_type: "function".to_owned(),
260            function: Some(FunctionCall {
261                namespace,
262                name,
263                arguments,
264            }),
265            text: None,
266            thought_signature: None,
267        }
268    }
269
270    /// Create a new custom tool call with raw text payload (GPT-5 freeform)
271    pub fn custom(id: String, name: String, text: String) -> Self {
272        Self {
273            id,
274            call_type: "custom".to_owned(),
275            function: Some(FunctionCall {
276                namespace: None,
277                name,
278                arguments: text.clone(),
279            }),
280            text: Some(text),
281            thought_signature: None,
282        }
283    }
284
285    /// Returns true when this tool call uses GPT-5 custom/freeform semantics.
286    pub fn is_custom(&self) -> bool {
287        self.call_type == "custom"
288    }
289
290    /// Returns the tool name when the call includes function details.
291    pub fn tool_name(&self) -> Option<&str> {
292        self.function
293            .as_ref()
294            .map(|function| function.name.as_str())
295    }
296
297    /// Returns the raw payload text exactly as emitted by the model.
298    pub fn raw_input(&self) -> Option<&str> {
299        self.text.as_deref().or_else(|| {
300            self.function
301                .as_ref()
302                .map(|function| function.arguments.as_str())
303        })
304    }
305
306    /// Parse the arguments as JSON Value (for function-type tools)
307    pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
308        if let Some(ref func) = self.function {
309            parse_tool_arguments(&func.arguments)
310        } else {
311            // Return an error by trying to parse invalid JSON
312            serde_json::from_str("")
313        }
314    }
315
316    /// Returns the execution payload for this tool call.
317    ///
318    /// Function tools keep their JSON semantics. Custom tools execute with their
319    /// raw text payload wrapped as a JSON string value so freeform inputs can
320    /// flow through the existing tool pipeline.
321    pub fn execution_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
322        if self.is_custom() {
323            return Ok(serde_json::Value::String(
324                self.raw_input().unwrap_or_default().to_string(),
325            ));
326        }
327
328        self.parsed_arguments()
329    }
330
331    /// Validate that this tool call is properly formed
332    pub fn validate(&self) -> Result<(), String> {
333        if self.id.is_empty() {
334            return Err("Tool call ID cannot be empty".to_owned());
335        }
336
337        match self.call_type.as_str() {
338            "function" => {
339                if let Some(func) = &self.function {
340                    if func.name.is_empty() {
341                        return Err("Function name cannot be empty".to_owned());
342                    }
343                    // Validate that arguments is valid JSON for function tools
344                    if let Err(e) = self.parsed_arguments() {
345                        return Err(format!("Invalid JSON in function arguments: {}", e));
346                    }
347                } else {
348                    return Err("Function tool call missing function details".to_owned());
349                }
350            }
351            "custom" => {
352                // For custom tools, we allow raw text payload without JSON validation
353                if let Some(func) = &self.function {
354                    if func.name.is_empty() {
355                        return Err("Custom tool name cannot be empty".to_owned());
356                    }
357                } else {
358                    return Err("Custom tool call missing function details".to_owned());
359                }
360            }
361            _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
362        }
363
364        Ok(())
365    }
366}
367
368fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
369    let trimmed = raw_arguments.trim();
370    match serde_json::from_str(trimmed) {
371        Ok(parsed) => Ok(parsed),
372        Err(primary_error) => {
373            if let Some(candidate) = extract_balanced_json(trimmed)
374                && let Ok(parsed) = serde_json::from_str(candidate)
375            {
376                return Ok(parsed);
377            }
378            if let Some(candidate) = repair_tag_polluted_json(trimmed)
379                && let Ok(parsed) = serde_json::from_str(&candidate)
380            {
381                return Ok(parsed);
382            }
383            Err(primary_error)
384        }
385    }
386}
387
388fn extract_balanced_json(input: &str) -> Option<&str> {
389    let start = input.find(['{', '['])?;
390    let opening = input.as_bytes().get(start).copied()?;
391    let closing = match opening {
392        b'{' => b'}',
393        b'[' => b']',
394        _ => return None,
395    };
396
397    let mut depth = 0usize;
398    let mut in_string = false;
399    let mut escaped = false;
400
401    for (offset, ch) in input[start..].char_indices() {
402        if in_string {
403            if escaped {
404                escaped = false;
405                continue;
406            }
407            if ch == '\\' {
408                escaped = true;
409                continue;
410            }
411            if ch == '"' {
412                in_string = false;
413            }
414            continue;
415        }
416
417        match ch {
418            '"' => in_string = true,
419            _ if ch as u32 == opening as u32 => depth += 1,
420            _ if ch as u32 == closing as u32 => {
421                depth = depth.saturating_sub(1);
422                if depth == 0 {
423                    let end = start + offset + ch.len_utf8();
424                    return input.get(start..end);
425                }
426            }
427            _ => {}
428        }
429    }
430
431    None
432}
433
434fn repair_tag_polluted_json(input: &str) -> Option<String> {
435    let start = input.find(['{', '['])?;
436    let candidate = input.get(start..)?;
437    let boundary = find_provider_markup_boundary(candidate)?;
438    if boundary == 0 {
439        return None;
440    }
441
442    close_incomplete_json_prefix(candidate[..boundary].trim_end())
443}
444
445fn find_provider_markup_boundary(input: &str) -> Option<usize> {
446    const PROVIDER_MARKERS: &[&str] = &[
447        "<</",
448        "</parameter>",
449        "</invoke>",
450        "</minimax:tool_call>",
451        "<minimax:tool_call>",
452        "<parameter name=\"",
453        "<invoke name=\"",
454        "<tool_call>",
455        "</tool_call>",
456    ];
457
458    input.char_indices().find_map(|(offset, _)| {
459        let rest = input.get(offset..)?;
460        PROVIDER_MARKERS
461            .iter()
462            .any(|marker| rest.starts_with(marker))
463            .then_some(offset)
464    })
465}
466
467fn close_incomplete_json_prefix(prefix: &str) -> Option<String> {
468    if prefix.is_empty() {
469        return None;
470    }
471
472    let mut repaired = String::with_capacity(prefix.len() + 8);
473    let mut expected_closers = Vec::new();
474    let mut in_string = false;
475    let mut escaped = false;
476
477    for ch in prefix.chars() {
478        repaired.push(ch);
479
480        if in_string {
481            if escaped {
482                escaped = false;
483                continue;
484            }
485
486            match ch {
487                '\\' => escaped = true,
488                '"' => in_string = false,
489                _ => {}
490            }
491            continue;
492        }
493
494        match ch {
495            '"' => in_string = true,
496            '{' => expected_closers.push('}'),
497            '[' => expected_closers.push(']'),
498            '}' | ']' => {
499                if expected_closers.pop() != Some(ch) {
500                    return None;
501                }
502            }
503            _ => {}
504        }
505    }
506
507    if in_string {
508        repaired.push('"');
509    }
510    for closer in expected_closers.drain(..) {
511        repaired.push(closer);
512    }
513
514    Some(repaired)
515}
516
517/// Universal LLM response structure
518#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
519pub struct LLMResponse {
520    /// The response content text
521    pub content: Option<String>,
522
523    /// Tool calls made by the model
524    pub tool_calls: Option<Vec<ToolCall>>,
525
526    /// The model that generated this response
527    pub model: String,
528
529    /// Token usage statistics
530    pub usage: Option<Usage>,
531
532    /// Why the response finished
533    pub finish_reason: FinishReason,
534
535    /// Reasoning content (for models that support it)
536    pub reasoning: Option<String>,
537
538    /// Detailed reasoning traces (for models that support it)
539    pub reasoning_details: Option<Vec<String>>,
540
541    /// Tool references for context
542    pub tool_references: Vec<String>,
543
544    /// Request ID from the provider
545    pub request_id: Option<String>,
546
547    /// Organization ID from the provider
548    pub organization_id: Option<String>,
549}
550
551impl LLMResponse {
552    /// Create a new LLM response with mandatory fields
553    pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
554        Self {
555            content: Some(content.into()),
556            tool_calls: None,
557            model: model.into(),
558            usage: None,
559            finish_reason: FinishReason::Stop,
560            reasoning: None,
561            reasoning_details: None,
562            tool_references: Vec::new(),
563            request_id: None,
564            organization_id: None,
565        }
566    }
567
568    /// Get content or empty string
569    pub fn content_text(&self) -> &str {
570        self.content.as_deref().unwrap_or("")
571    }
572
573    /// Get content as String (clone)
574    pub fn content_string(&self) -> String {
575        self.content.clone().unwrap_or_default()
576    }
577}
578
579#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
580pub struct LLMErrorMetadata {
581    pub provider: Option<String>,
582    pub status: Option<u16>,
583    pub code: Option<String>,
584    pub request_id: Option<String>,
585    pub organization_id: Option<String>,
586    pub retry_after: Option<String>,
587    pub message: Option<String>,
588}
589
590impl LLMErrorMetadata {
591    /// Boxed constructor because metadata is always stored inside `Option<Box<LLMErrorMetadata>>`
592    /// in the LLMError enum variants.
593    #[must_use]
594    pub fn new(
595        provider: impl Into<String>,
596        status: Option<u16>,
597        code: Option<String>,
598        request_id: Option<String>,
599        organization_id: Option<String>,
600        retry_after: Option<String>,
601        message: Option<String>,
602    ) -> Box<Self> {
603        Box::new(Self {
604            provider: Some(provider.into()),
605            status,
606            code,
607            request_id,
608            organization_id,
609            retry_after,
610            message,
611        })
612    }
613}
614
615/// LLM error types with optional provider metadata
616#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
617#[serde(tag = "type", rename_all = "snake_case")]
618pub enum LLMError {
619    #[error("Authentication failed: {message}")]
620    Authentication {
621        message: String,
622        metadata: Option<Box<LLMErrorMetadata>>,
623    },
624    #[error("Rate limit exceeded")]
625    RateLimit {
626        metadata: Option<Box<LLMErrorMetadata>>,
627    },
628    #[error("Invalid request: {message}")]
629    InvalidRequest {
630        message: String,
631        metadata: Option<Box<LLMErrorMetadata>>,
632    },
633    #[error("Network error: {message}")]
634    Network {
635        message: String,
636        metadata: Option<Box<LLMErrorMetadata>>,
637    },
638    #[error("Provider error: {message}")]
639    Provider {
640        message: String,
641        metadata: Option<Box<LLMErrorMetadata>>,
642    },
643}
644
645#[cfg(test)]
646mod tests {
647    use super::ToolCall;
648    use serde_json::json;
649
650    #[test]
651    fn parsed_arguments_accepts_trailing_characters() {
652        let call = ToolCall::function(
653            "call_read".to_string(),
654            "read_file".to_string(),
655            r#"{"path":"src/main.rs"} trailing text"#.to_string(),
656        );
657
658        let parsed = call
659            .parsed_arguments()
660            .expect("arguments with trailing text should recover");
661        assert_eq!(parsed, json!({"path":"src/main.rs"}));
662    }
663
664    #[test]
665    fn parsed_arguments_accepts_code_fenced_json() {
666        let call = ToolCall::function(
667            "call_read".to_string(),
668            "read_file".to_string(),
669            "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
670        );
671
672        let parsed = call
673            .parsed_arguments()
674            .expect("code-fenced arguments should recover");
675        assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
676    }
677
678    #[test]
679    fn parsed_arguments_rejects_incomplete_json() {
680        let call = ToolCall::function(
681            "call_read".to_string(),
682            "read_file".to_string(),
683            r#"{"path":"src/main.rs""#.to_string(),
684        );
685
686        assert!(call.parsed_arguments().is_err());
687    }
688
689    #[test]
690    fn parsed_arguments_recovers_truncated_minimax_markup() {
691        let call = ToolCall::function(
692            "call_search".to_string(),
693            "unified_search".to_string(),
694            "{\"action\": \"grep\", \"pattern\": \"persistent_memory\", \"path\": \"vtcode-core/src</parameter>\n<</invoke>\n</minimax:tool_call>".to_string(),
695        );
696
697        let parsed = call
698            .parsed_arguments()
699            .expect("minimax markup spillover should recover");
700        assert_eq!(
701            parsed,
702            json!({
703                "action": "grep",
704                "pattern": "persistent_memory",
705                "path": "vtcode-core/src"
706            })
707        );
708    }
709
710    #[test]
711    fn function_call_serializes_optional_namespace() {
712        let call = ToolCall::function_with_namespace(
713            "call_read".to_string(),
714            Some("workspace".to_string()),
715            "read_file".to_string(),
716            r#"{"path":"src/main.rs"}"#.to_string(),
717        );
718
719        let json = serde_json::to_value(&call).expect("tool call should serialize");
720        assert_eq!(json["function"]["namespace"], "workspace");
721        assert_eq!(json["function"]["name"], "read_file");
722    }
723
724    #[test]
725    fn custom_tool_call_exposes_raw_execution_arguments() {
726        let patch = "*** Begin Patch\n*** End Patch\n".to_string();
727        let call = ToolCall::custom(
728            "call_patch".to_string(),
729            "apply_patch".to_string(),
730            patch.clone(),
731        );
732
733        assert!(call.is_custom());
734        assert_eq!(call.tool_name(), Some("apply_patch"));
735        assert_eq!(call.raw_input(), Some(patch.as_str()));
736        assert_eq!(
737            call.execution_arguments().expect("custom arguments"),
738            json!(patch)
739        );
740        assert!(
741            call.parsed_arguments().is_err(),
742            "custom tool payload should stay freeform rather than JSON"
743        );
744    }
745}