Skip to main content

vtcode_commons/
llm.rs

1//! Core LLM types shared across the project
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7    Gemini,
8    OpenAI,
9    Anthropic,
10    DeepSeek,
11    Mistral,
12    OpenRouter,
13    Ollama,
14    ZAI,
15    Moonshot,
16    HuggingFace,
17    Minimax,
18    MiMo,
19    OpenCodeZen,
20    OpenCodeGo,
21    Qwen,
22    StepFun,
23    Poolside,
24}
25
26#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
27pub struct Usage {
28    pub prompt_tokens: u32,
29    pub completion_tokens: u32,
30    pub total_tokens: u32,
31    pub cached_prompt_tokens: Option<u32>,
32    pub cache_creation_tokens: Option<u32>,
33    pub cache_read_tokens: Option<u32>,
34}
35
36impl Usage {
37    #[inline]
38    fn has_cache_read_metric(&self) -> bool {
39        self.cache_read_tokens.is_some() || self.cached_prompt_tokens.is_some()
40    }
41
42    #[inline]
43    fn has_any_cache_metrics(&self) -> bool {
44        self.has_cache_read_metric() || self.cache_creation_tokens.is_some()
45    }
46
47    #[inline]
48    pub fn cache_read_tokens_or_fallback(&self) -> u32 {
49        self.cache_read_tokens
50            .or(self.cached_prompt_tokens)
51            .unwrap_or(0)
52    }
53
54    #[inline]
55    pub fn cache_creation_tokens_or_zero(&self) -> u32 {
56        self.cache_creation_tokens.unwrap_or(0)
57    }
58
59    #[inline]
60    pub fn cache_hit_rate(&self) -> Option<f64> {
61        if !self.has_any_cache_metrics() {
62            return None;
63        }
64        let read = self.cache_read_tokens_or_fallback() as f64;
65        let creation = self.cache_creation_tokens_or_zero() as f64;
66        let total = read + creation;
67        if total > 0.0 {
68            Some((read / total) * 100.0)
69        } else {
70            None
71        }
72    }
73
74    #[inline]
75    pub fn is_cache_hit(&self) -> Option<bool> {
76        self.has_any_cache_metrics()
77            .then(|| self.cache_read_tokens_or_fallback() > 0)
78    }
79
80    #[inline]
81    pub fn is_cache_miss(&self) -> Option<bool> {
82        self.has_any_cache_metrics().then(|| {
83            self.cache_creation_tokens_or_zero() > 0 && self.cache_read_tokens_or_fallback() == 0
84        })
85    }
86
87    #[inline]
88    pub fn total_cache_tokens(&self) -> u32 {
89        let read = self.cache_read_tokens_or_fallback();
90        let creation = self.cache_creation_tokens_or_zero();
91        read + creation
92    }
93
94    #[inline]
95    pub fn cache_savings_ratio(&self) -> Option<f64> {
96        if !self.has_cache_read_metric() {
97            return None;
98        }
99        let read = self.cache_read_tokens_or_fallback() as f64;
100        let prompt = self.prompt_tokens as f64;
101        if prompt > 0.0 {
102            Some(read / prompt)
103        } else {
104            None
105        }
106    }
107}
108
109/// Provider-agnostic balance information for account status display.
110#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
111pub struct BalanceInfo {
112    /// Human-readable balance string (e.g. "100.00¥", "$50.00").
113    pub display: String,
114    /// Whether the account has sufficient balance for API calls.
115    pub is_available: bool,
116}
117
118/// DeepSeek-specific balance info from GET /user/balance
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct DeepSeekBalanceResponse {
121    pub is_available: bool,
122    pub balance_infos: Vec<DeepSeekCurrencyBalance>,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct DeepSeekCurrencyBalance {
127    pub currency: String,
128    pub total_balance: String,
129    #[serde(default)]
130    pub granted_balance: String,
131    #[serde(default)]
132    pub topped_up_balance: String,
133}
134
135impl From<DeepSeekBalanceResponse> for BalanceInfo {
136    fn from(resp: DeepSeekBalanceResponse) -> Self {
137        let display = resp
138            .balance_infos
139            .first()
140            .map(|b| {
141                let symbol = match b.currency.as_str() {
142                    "CNY" => "¥",
143                    "USD" => "$",
144                    _ => &b.currency,
145                };
146                format!("{}{}", b.total_balance, symbol)
147            })
148            .unwrap_or_else(|| "N/A".to_string());
149        BalanceInfo {
150            display,
151            is_available: resp.is_available,
152        }
153    }
154}
155
156#[cfg(test)]
157mod usage_tests {
158    use super::Usage;
159
160    #[test]
161    fn cache_helpers_fall_back_to_cached_prompt_tokens() {
162        let usage = Usage {
163            prompt_tokens: 1_000,
164            completion_tokens: 200,
165            total_tokens: 1_200,
166            cached_prompt_tokens: Some(600),
167            cache_creation_tokens: Some(150),
168            cache_read_tokens: None,
169        };
170
171        assert_eq!(usage.cache_read_tokens_or_fallback(), 600);
172        assert_eq!(usage.cache_creation_tokens_or_zero(), 150);
173        assert_eq!(usage.total_cache_tokens(), 750);
174        assert_eq!(usage.is_cache_hit(), Some(true));
175        assert_eq!(usage.is_cache_miss(), Some(false));
176        assert_eq!(usage.cache_savings_ratio(), Some(0.6));
177        assert_eq!(usage.cache_hit_rate(), Some(80.0));
178    }
179
180    #[test]
181    fn cache_helpers_preserve_unknown_without_metrics() {
182        let usage = Usage {
183            prompt_tokens: 1_000,
184            completion_tokens: 200,
185            total_tokens: 1_200,
186            cached_prompt_tokens: None,
187            cache_creation_tokens: None,
188            cache_read_tokens: None,
189        };
190
191        assert_eq!(usage.total_cache_tokens(), 0);
192        assert_eq!(usage.is_cache_hit(), None);
193        assert_eq!(usage.is_cache_miss(), None);
194        assert_eq!(usage.cache_savings_ratio(), None);
195        assert_eq!(usage.cache_hit_rate(), None);
196    }
197}
198
199#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
200pub enum FinishReason {
201    #[default]
202    Stop,
203    Length,
204    ToolCalls,
205    ContentFilter,
206    Pause,
207    Refusal,
208    Error(String),
209}
210
211/// Universal tool call that matches OpenAI/Anthropic/Gemini specifications
212#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
213pub struct ToolCall {
214    /// Unique identifier for this tool call (e.g., "call_123")
215    pub id: String,
216
217    /// The type of tool call: "function", "custom" (GPT-5 freeform), or other
218    #[serde(rename = "type")]
219    pub call_type: String,
220
221    /// Function call details (for function-type tools)
222    #[serde(skip_serializing_if = "Option::is_none")]
223    pub function: Option<FunctionCall>,
224
225    /// Raw text payload (for custom freeform tools in GPT-5)
226    #[serde(skip_serializing_if = "Option::is_none")]
227    pub text: Option<String>,
228
229    /// Gemini-specific thought signature for maintaining reasoning context
230    #[serde(skip_serializing_if = "Option::is_none")]
231    pub thought_signature: Option<String>,
232}
233
234/// Function call within a tool call
235#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
236pub struct FunctionCall {
237    /// Optional namespace for grouped or deferred tools.
238    #[serde(default, skip_serializing_if = "Option::is_none")]
239    pub namespace: Option<String>,
240
241    /// The name of the function to call
242    pub name: String,
243
244    /// The arguments to pass to the function, as a JSON string
245    pub arguments: String,
246}
247
248impl ToolCall {
249    /// Create a new function tool call
250    pub fn function(id: String, name: String, arguments: String) -> Self {
251        Self::function_with_namespace(id, None, name, arguments)
252    }
253
254    /// Create a new function tool call with an optional namespace.
255    pub fn function_with_namespace(
256        id: String,
257        namespace: Option<String>,
258        name: String,
259        arguments: String,
260    ) -> Self {
261        Self {
262            id,
263            call_type: "function".to_owned(),
264            function: Some(FunctionCall {
265                namespace,
266                name,
267                arguments,
268            }),
269            text: None,
270            thought_signature: None,
271        }
272    }
273
274    /// Create a new custom tool call with raw text payload (GPT-5 freeform)
275    pub fn custom(id: String, name: String, text: String) -> Self {
276        Self {
277            id,
278            call_type: "custom".to_owned(),
279            function: Some(FunctionCall {
280                namespace: None,
281                name,
282                arguments: text.clone(),
283            }),
284            text: Some(text),
285            thought_signature: None,
286        }
287    }
288
289    /// Returns true when this tool call uses GPT-5 custom/freeform semantics.
290    pub fn is_custom(&self) -> bool {
291        self.call_type == "custom"
292    }
293
294    /// Returns the tool name when the call includes function details.
295    pub fn tool_name(&self) -> Option<&str> {
296        self.function
297            .as_ref()
298            .map(|function| function.name.as_str())
299    }
300
301    /// Returns the raw payload text exactly as emitted by the model.
302    pub fn raw_input(&self) -> Option<&str> {
303        self.text.as_deref().or_else(|| {
304            self.function
305                .as_ref()
306                .map(|function| function.arguments.as_str())
307        })
308    }
309
310    /// Parse the arguments as JSON Value (for function-type tools)
311    pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
312        if let Some(ref func) = self.function {
313            parse_tool_arguments(&func.arguments)
314        } else {
315            // Return an error by trying to parse invalid JSON
316            serde_json::from_str("")
317        }
318    }
319
320    /// Returns the execution payload for this tool call.
321    ///
322    /// Function tools keep their JSON semantics. Custom tools execute with their
323    /// raw text payload wrapped as a JSON string value so freeform inputs can
324    /// flow through the existing tool pipeline.
325    pub fn execution_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
326        if self.is_custom() {
327            return Ok(serde_json::Value::String(
328                self.raw_input().unwrap_or_default().to_string(),
329            ));
330        }
331
332        self.parsed_arguments()
333    }
334
335    /// Validate that this tool call is properly formed
336    pub fn validate(&self) -> Result<(), String> {
337        if self.id.is_empty() {
338            return Err("Tool call ID cannot be empty".to_owned());
339        }
340
341        match self.call_type.as_str() {
342            "function" => {
343                if let Some(func) = &self.function {
344                    if func.name.is_empty() {
345                        return Err("Function name cannot be empty".to_owned());
346                    }
347                    // Validate that arguments is valid JSON for function tools
348                    if let Err(e) = self.parsed_arguments() {
349                        return Err(format!("Invalid JSON in function arguments: {}", e));
350                    }
351                } else {
352                    return Err("Function tool call missing function details".to_owned());
353                }
354            }
355            "custom" => {
356                // For custom tools, we allow raw text payload without JSON validation
357                if let Some(func) = &self.function {
358                    if func.name.is_empty() {
359                        return Err("Custom tool name cannot be empty".to_owned());
360                    }
361                } else {
362                    return Err("Custom tool call missing function details".to_owned());
363                }
364            }
365            _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
366        }
367
368        Ok(())
369    }
370}
371
372fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
373    let trimmed = raw_arguments.trim();
374    match serde_json::from_str(trimmed) {
375        Ok(parsed) => Ok(parsed),
376        Err(primary_error) => {
377            if let Some(candidate) = extract_balanced_json(trimmed)
378                && let Ok(parsed) = serde_json::from_str(candidate)
379            {
380                return Ok(parsed);
381            }
382            if let Some(candidate) = repair_tag_polluted_json(trimmed)
383                && let Ok(parsed) = serde_json::from_str(&candidate)
384            {
385                return Ok(parsed);
386            }
387            Err(primary_error)
388        }
389    }
390}
391
392fn extract_balanced_json(input: &str) -> Option<&str> {
393    let start = input.find(['{', '['])?;
394    let opening = input.as_bytes().get(start).copied()?;
395    let closing = match opening {
396        b'{' => b'}',
397        b'[' => b']',
398        _ => return None,
399    };
400
401    let mut depth = 0usize;
402    let mut in_string = false;
403    let mut escaped = false;
404
405    for (offset, ch) in input[start..].char_indices() {
406        if in_string {
407            if escaped {
408                escaped = false;
409                continue;
410            }
411            if ch == '\\' {
412                escaped = true;
413                continue;
414            }
415            if ch == '"' {
416                in_string = false;
417            }
418            continue;
419        }
420
421        match ch {
422            '"' => in_string = true,
423            _ if ch as u32 == opening as u32 => depth += 1,
424            _ if ch as u32 == closing as u32 => {
425                depth = depth.saturating_sub(1);
426                if depth == 0 {
427                    let end = start + offset + ch.len_utf8();
428                    return input.get(start..end);
429                }
430            }
431            _ => {}
432        }
433    }
434
435    None
436}
437
438fn repair_tag_polluted_json(input: &str) -> Option<String> {
439    let start = input.find(['{', '['])?;
440    let candidate = input.get(start..)?;
441    let boundary = find_provider_markup_boundary(candidate)?;
442    if boundary == 0 {
443        return None;
444    }
445
446    close_incomplete_json_prefix(candidate[..boundary].trim_end())
447}
448
449fn find_provider_markup_boundary(input: &str) -> Option<usize> {
450    const PROVIDER_MARKERS: &[&str] = &[
451        "<</",
452        "</parameter>",
453        "</invoke>",
454        "</minimax:tool_call>",
455        "<minimax:tool_call>",
456        "<parameter name=\"",
457        "<invoke name=\"",
458        "<tool_call>",
459        "</tool_call>",
460    ];
461
462    input.char_indices().find_map(|(offset, _)| {
463        let rest = input.get(offset..)?;
464        PROVIDER_MARKERS
465            .iter()
466            .any(|marker| rest.starts_with(marker))
467            .then_some(offset)
468    })
469}
470
471fn close_incomplete_json_prefix(prefix: &str) -> Option<String> {
472    if prefix.is_empty() {
473        return None;
474    }
475
476    let mut repaired = String::with_capacity(prefix.len() + 8);
477    let mut expected_closers = Vec::new();
478    let mut in_string = false;
479    let mut escaped = false;
480
481    for ch in prefix.chars() {
482        repaired.push(ch);
483
484        if in_string {
485            if escaped {
486                escaped = false;
487                continue;
488            }
489
490            match ch {
491                '\\' => escaped = true,
492                '"' => in_string = false,
493                _ => {}
494            }
495            continue;
496        }
497
498        match ch {
499            '"' => in_string = true,
500            '{' => expected_closers.push('}'),
501            '[' => expected_closers.push(']'),
502            '}' | ']' => {
503                if expected_closers.pop() != Some(ch) {
504                    return None;
505                }
506            }
507            _ => {}
508        }
509    }
510
511    if in_string {
512        repaired.push('"');
513    }
514    for closer in expected_closers.drain(..) {
515        repaired.push(closer);
516    }
517
518    Some(repaired)
519}
520
521/// Universal LLM response structure
522#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
523pub struct LLMResponse {
524    /// The response content text
525    pub content: Option<String>,
526
527    /// Tool calls made by the model
528    pub tool_calls: Option<Vec<ToolCall>>,
529
530    /// The model that generated this response
531    pub model: String,
532
533    /// Token usage statistics
534    pub usage: Option<Usage>,
535
536    /// Why the response finished
537    pub finish_reason: FinishReason,
538
539    /// Reasoning content (for models that support it)
540    pub reasoning: Option<String>,
541
542    /// Detailed reasoning traces (for models that support it)
543    pub reasoning_details: Option<Vec<String>>,
544
545    /// Tool references for context
546    pub tool_references: Vec<String>,
547
548    /// Request ID from the provider
549    pub request_id: Option<String>,
550
551    /// Organization ID from the provider
552    pub organization_id: Option<String>,
553
554    /// Compaction summary content from Anthropic's server-side compaction.
555    /// Populated when `stop_reason` is `Pause` (from `"compaction"`).
556    /// The caller should pass this back in subsequent requests so the API
557    /// can drop prior messages before the compaction block.
558    pub compaction: Option<String>,
559}
560
561impl LLMResponse {
562    /// Create a new LLM response with mandatory fields
563    pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
564        Self {
565            content: Some(content.into()),
566            tool_calls: None,
567            model: model.into(),
568            usage: None,
569            finish_reason: FinishReason::Stop,
570            reasoning: None,
571            reasoning_details: None,
572            tool_references: Vec::new(),
573            request_id: None,
574            organization_id: None,
575            compaction: None,
576        }
577    }
578
579    /// Get content or empty string
580    pub fn content_text(&self) -> &str {
581        self.content.as_deref().unwrap_or("")
582    }
583
584    /// Get content as String (clone)
585    pub fn content_string(&self) -> String {
586        self.content.clone().unwrap_or_default()
587    }
588}
589
590#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
591pub struct LLMErrorMetadata {
592    pub provider: Option<String>,
593    pub status: Option<u16>,
594    pub code: Option<String>,
595    pub request_id: Option<String>,
596    pub organization_id: Option<String>,
597    pub retry_after: Option<String>,
598    pub message: Option<String>,
599}
600
601impl LLMErrorMetadata {
602    /// Boxed constructor because metadata is always stored inside `Option<Box<LLMErrorMetadata>>`
603    /// in the LLMError enum variants.
604    #[must_use]
605    pub fn new(
606        provider: impl Into<String>,
607        status: Option<u16>,
608        code: Option<String>,
609        request_id: Option<String>,
610        organization_id: Option<String>,
611        retry_after: Option<String>,
612        message: Option<String>,
613    ) -> Box<Self> {
614        Box::new(Self {
615            provider: Some(provider.into()),
616            status,
617            code,
618            request_id,
619            organization_id,
620            retry_after,
621            message,
622        })
623    }
624}
625
626/// LLM error types with optional provider metadata
627#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
628#[serde(tag = "type", rename_all = "snake_case")]
629pub enum LLMError {
630    #[error("Authentication failed: {message}")]
631    Authentication {
632        message: String,
633        metadata: Option<Box<LLMErrorMetadata>>,
634    },
635    #[error("Rate limit exceeded")]
636    RateLimit {
637        metadata: Option<Box<LLMErrorMetadata>>,
638    },
639    #[error("Invalid request: {message}")]
640    InvalidRequest {
641        message: String,
642        metadata: Option<Box<LLMErrorMetadata>>,
643    },
644    #[error("Network error: {message}")]
645    Network {
646        message: String,
647        metadata: Option<Box<LLMErrorMetadata>>,
648    },
649    #[error("Provider error: {message}")]
650    Provider {
651        message: String,
652        metadata: Option<Box<LLMErrorMetadata>>,
653    },
654}
655
656#[cfg(test)]
657mod tests {
658    use super::ToolCall;
659    use serde_json::json;
660
661    #[test]
662    fn parsed_arguments_accepts_trailing_characters() {
663        let call = ToolCall::function(
664            "call_read".to_string(),
665            "read_file".to_string(),
666            r#"{"path":"src/main.rs"} trailing text"#.to_string(),
667        );
668
669        let parsed = call
670            .parsed_arguments()
671            .expect("arguments with trailing text should recover");
672        assert_eq!(parsed, json!({"path":"src/main.rs"}));
673    }
674
675    #[test]
676    fn parsed_arguments_accepts_code_fenced_json() {
677        let call = ToolCall::function(
678            "call_read".to_string(),
679            "read_file".to_string(),
680            "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
681        );
682
683        let parsed = call
684            .parsed_arguments()
685            .expect("code-fenced arguments should recover");
686        assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
687    }
688
689    #[test]
690    fn parsed_arguments_rejects_incomplete_json() {
691        let call = ToolCall::function(
692            "call_read".to_string(),
693            "read_file".to_string(),
694            r#"{"path":"src/main.rs""#.to_string(),
695        );
696
697        assert!(call.parsed_arguments().is_err());
698    }
699
700    #[test]
701    fn parsed_arguments_recovers_truncated_minimax_markup() {
702        let call = ToolCall::function(
703            "call_search".to_string(),
704            "unified_search".to_string(),
705            "{\"action\": \"grep\", \"pattern\": \"persistent_memory\", \"path\": \"vtcode-core/src</parameter>\n<</invoke>\n</minimax:tool_call>".to_string(),
706        );
707
708        let parsed = call
709            .parsed_arguments()
710            .expect("minimax markup spillover should recover");
711        assert_eq!(
712            parsed,
713            json!({
714                "action": "grep",
715                "pattern": "persistent_memory",
716                "path": "vtcode-core/src"
717            })
718        );
719    }
720
721    #[test]
722    fn function_call_serializes_optional_namespace() {
723        let call = ToolCall::function_with_namespace(
724            "call_read".to_string(),
725            Some("workspace".to_string()),
726            "read_file".to_string(),
727            r#"{"path":"src/main.rs"}"#.to_string(),
728        );
729
730        let json = serde_json::to_value(&call).expect("tool call should serialize");
731        assert_eq!(json["function"]["namespace"], "workspace");
732        assert_eq!(json["function"]["name"], "read_file");
733    }
734
735    #[test]
736    fn custom_tool_call_exposes_raw_execution_arguments() {
737        let patch = "*** Begin Patch\n*** End Patch\n".to_string();
738        let call = ToolCall::custom(
739            "call_patch".to_string(),
740            "apply_patch".to_string(),
741            patch.clone(),
742        );
743
744        assert!(call.is_custom());
745        assert_eq!(call.tool_name(), Some("apply_patch"));
746        assert_eq!(call.raw_input(), Some(patch.as_str()));
747        assert_eq!(
748            call.execution_arguments().expect("custom arguments"),
749            json!(patch)
750        );
751        assert!(
752            call.parsed_arguments().is_err(),
753            "custom tool payload should stay freeform rather than JSON"
754        );
755    }
756}