Skip to main content

vtcode_commons/
llm.rs

1//! Core LLM types shared across the project
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7    Gemini,
8    OpenAI,
9    Anthropic,
10    DeepSeek,
11    OpenRouter,
12    Ollama,
13    ZAI,
14    Moonshot,
15    HuggingFace,
16    Minimax,
17}
18
19#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
20pub struct Usage {
21    pub prompt_tokens: u32,
22    pub completion_tokens: u32,
23    pub total_tokens: u32,
24    pub cached_prompt_tokens: Option<u32>,
25    pub cache_creation_tokens: Option<u32>,
26    pub cache_read_tokens: Option<u32>,
27}
28
29impl Usage {
30    #[inline]
31    fn has_cache_read_metric(&self) -> bool {
32        self.cache_read_tokens.is_some() || self.cached_prompt_tokens.is_some()
33    }
34
35    #[inline]
36    fn has_any_cache_metrics(&self) -> bool {
37        self.has_cache_read_metric() || self.cache_creation_tokens.is_some()
38    }
39
40    #[inline]
41    pub fn cache_read_tokens_or_fallback(&self) -> u32 {
42        self.cache_read_tokens
43            .or(self.cached_prompt_tokens)
44            .unwrap_or(0)
45    }
46
47    #[inline]
48    pub fn cache_creation_tokens_or_zero(&self) -> u32 {
49        self.cache_creation_tokens.unwrap_or(0)
50    }
51
52    #[inline]
53    pub fn cache_hit_rate(&self) -> Option<f64> {
54        if !self.has_any_cache_metrics() {
55            return None;
56        }
57        let read = self.cache_read_tokens_or_fallback() as f64;
58        let creation = self.cache_creation_tokens_or_zero() as f64;
59        let total = read + creation;
60        if total > 0.0 {
61            Some((read / total) * 100.0)
62        } else {
63            None
64        }
65    }
66
67    #[inline]
68    pub fn is_cache_hit(&self) -> Option<bool> {
69        self.has_any_cache_metrics()
70            .then(|| self.cache_read_tokens_or_fallback() > 0)
71    }
72
73    #[inline]
74    pub fn is_cache_miss(&self) -> Option<bool> {
75        self.has_any_cache_metrics().then(|| {
76            self.cache_creation_tokens_or_zero() > 0 && self.cache_read_tokens_or_fallback() == 0
77        })
78    }
79
80    #[inline]
81    pub fn total_cache_tokens(&self) -> u32 {
82        let read = self.cache_read_tokens_or_fallback();
83        let creation = self.cache_creation_tokens_or_zero();
84        read + creation
85    }
86
87    #[inline]
88    pub fn cache_savings_ratio(&self) -> Option<f64> {
89        if !self.has_cache_read_metric() {
90            return None;
91        }
92        let read = self.cache_read_tokens_or_fallback() as f64;
93        let prompt = self.prompt_tokens as f64;
94        if prompt > 0.0 {
95            Some(read / prompt)
96        } else {
97            None
98        }
99    }
100}
101
102#[cfg(test)]
103mod usage_tests {
104    use super::Usage;
105
106    #[test]
107    fn cache_helpers_fall_back_to_cached_prompt_tokens() {
108        let usage = Usage {
109            prompt_tokens: 1_000,
110            completion_tokens: 200,
111            total_tokens: 1_200,
112            cached_prompt_tokens: Some(600),
113            cache_creation_tokens: Some(150),
114            cache_read_tokens: None,
115        };
116
117        assert_eq!(usage.cache_read_tokens_or_fallback(), 600);
118        assert_eq!(usage.cache_creation_tokens_or_zero(), 150);
119        assert_eq!(usage.total_cache_tokens(), 750);
120        assert_eq!(usage.is_cache_hit(), Some(true));
121        assert_eq!(usage.is_cache_miss(), Some(false));
122        assert_eq!(usage.cache_savings_ratio(), Some(0.6));
123        assert_eq!(usage.cache_hit_rate(), Some(80.0));
124    }
125
126    #[test]
127    fn cache_helpers_preserve_unknown_without_metrics() {
128        let usage = Usage {
129            prompt_tokens: 1_000,
130            completion_tokens: 200,
131            total_tokens: 1_200,
132            cached_prompt_tokens: None,
133            cache_creation_tokens: None,
134            cache_read_tokens: None,
135        };
136
137        assert_eq!(usage.total_cache_tokens(), 0);
138        assert_eq!(usage.is_cache_hit(), None);
139        assert_eq!(usage.is_cache_miss(), None);
140        assert_eq!(usage.cache_savings_ratio(), None);
141        assert_eq!(usage.cache_hit_rate(), None);
142    }
143}
144
145#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
146pub enum FinishReason {
147    #[default]
148    Stop,
149    Length,
150    ToolCalls,
151    ContentFilter,
152    Pause,
153    Refusal,
154    Error(String),
155}
156
157/// Universal tool call that matches OpenAI/Anthropic/Gemini specifications
158#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
159pub struct ToolCall {
160    /// Unique identifier for this tool call (e.g., "call_123")
161    pub id: String,
162
163    /// The type of tool call: "function", "custom" (GPT-5 freeform), or other
164    #[serde(rename = "type")]
165    pub call_type: String,
166
167    /// Function call details (for function-type tools)
168    #[serde(skip_serializing_if = "Option::is_none")]
169    pub function: Option<FunctionCall>,
170
171    /// Raw text payload (for custom freeform tools in GPT-5)
172    #[serde(skip_serializing_if = "Option::is_none")]
173    pub text: Option<String>,
174
175    /// Gemini-specific thought signature for maintaining reasoning context
176    #[serde(skip_serializing_if = "Option::is_none")]
177    pub thought_signature: Option<String>,
178}
179
180/// Function call within a tool call
181#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
182pub struct FunctionCall {
183    /// Optional namespace for grouped or deferred tools.
184    #[serde(default, skip_serializing_if = "Option::is_none")]
185    pub namespace: Option<String>,
186
187    /// The name of the function to call
188    pub name: String,
189
190    /// The arguments to pass to the function, as a JSON string
191    pub arguments: String,
192}
193
194impl ToolCall {
195    /// Create a new function tool call
196    pub fn function(id: String, name: String, arguments: String) -> Self {
197        Self::function_with_namespace(id, None, name, arguments)
198    }
199
200    /// Create a new function tool call with an optional namespace.
201    pub fn function_with_namespace(
202        id: String,
203        namespace: Option<String>,
204        name: String,
205        arguments: String,
206    ) -> Self {
207        Self {
208            id,
209            call_type: "function".to_owned(),
210            function: Some(FunctionCall {
211                namespace,
212                name,
213                arguments,
214            }),
215            text: None,
216            thought_signature: None,
217        }
218    }
219
220    /// Create a new custom tool call with raw text payload (GPT-5 freeform)
221    pub fn custom(id: String, name: String, text: String) -> Self {
222        Self {
223            id,
224            call_type: "custom".to_owned(),
225            function: Some(FunctionCall {
226                namespace: None,
227                name,
228                arguments: text.clone(),
229            }),
230            text: Some(text),
231            thought_signature: None,
232        }
233    }
234
235    /// Returns true when this tool call uses GPT-5 custom/freeform semantics.
236    pub fn is_custom(&self) -> bool {
237        self.call_type == "custom"
238    }
239
240    /// Returns the tool name when the call includes function details.
241    pub fn tool_name(&self) -> Option<&str> {
242        self.function
243            .as_ref()
244            .map(|function| function.name.as_str())
245    }
246
247    /// Returns the raw payload text exactly as emitted by the model.
248    pub fn raw_input(&self) -> Option<&str> {
249        self.text.as_deref().or_else(|| {
250            self.function
251                .as_ref()
252                .map(|function| function.arguments.as_str())
253        })
254    }
255
256    /// Parse the arguments as JSON Value (for function-type tools)
257    pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
258        if let Some(ref func) = self.function {
259            parse_tool_arguments(&func.arguments)
260        } else {
261            // Return an error by trying to parse invalid JSON
262            serde_json::from_str("")
263        }
264    }
265
266    /// Returns the execution payload for this tool call.
267    ///
268    /// Function tools keep their JSON semantics. Custom tools execute with their
269    /// raw text payload wrapped as a JSON string value so freeform inputs can
270    /// flow through the existing tool pipeline.
271    pub fn execution_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
272        if self.is_custom() {
273            return Ok(serde_json::Value::String(
274                self.raw_input().unwrap_or_default().to_string(),
275            ));
276        }
277
278        self.parsed_arguments()
279    }
280
281    /// Validate that this tool call is properly formed
282    pub fn validate(&self) -> Result<(), String> {
283        if self.id.is_empty() {
284            return Err("Tool call ID cannot be empty".to_owned());
285        }
286
287        match self.call_type.as_str() {
288            "function" => {
289                if let Some(func) = &self.function {
290                    if func.name.is_empty() {
291                        return Err("Function name cannot be empty".to_owned());
292                    }
293                    // Validate that arguments is valid JSON for function tools
294                    if let Err(e) = self.parsed_arguments() {
295                        return Err(format!("Invalid JSON in function arguments: {}", e));
296                    }
297                } else {
298                    return Err("Function tool call missing function details".to_owned());
299                }
300            }
301            "custom" => {
302                // For custom tools, we allow raw text payload without JSON validation
303                if let Some(func) = &self.function {
304                    if func.name.is_empty() {
305                        return Err("Custom tool name cannot be empty".to_owned());
306                    }
307                } else {
308                    return Err("Custom tool call missing function details".to_owned());
309                }
310            }
311            _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
312        }
313
314        Ok(())
315    }
316}
317
318fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
319    let trimmed = raw_arguments.trim();
320    match serde_json::from_str(trimmed) {
321        Ok(parsed) => Ok(parsed),
322        Err(primary_error) => {
323            if let Some(candidate) = extract_balanced_json(trimmed)
324                && let Ok(parsed) = serde_json::from_str(candidate)
325            {
326                return Ok(parsed);
327            }
328            if let Some(candidate) = repair_tag_polluted_json(trimmed)
329                && let Ok(parsed) = serde_json::from_str(&candidate)
330            {
331                return Ok(parsed);
332            }
333            Err(primary_error)
334        }
335    }
336}
337
338fn extract_balanced_json(input: &str) -> Option<&str> {
339    let start = input.find(['{', '['])?;
340    let opening = input.as_bytes().get(start).copied()?;
341    let closing = match opening {
342        b'{' => b'}',
343        b'[' => b']',
344        _ => return None,
345    };
346
347    let mut depth = 0usize;
348    let mut in_string = false;
349    let mut escaped = false;
350
351    for (offset, ch) in input[start..].char_indices() {
352        if in_string {
353            if escaped {
354                escaped = false;
355                continue;
356            }
357            if ch == '\\' {
358                escaped = true;
359                continue;
360            }
361            if ch == '"' {
362                in_string = false;
363            }
364            continue;
365        }
366
367        match ch {
368            '"' => in_string = true,
369            _ if ch as u32 == opening as u32 => depth += 1,
370            _ if ch as u32 == closing as u32 => {
371                depth = depth.saturating_sub(1);
372                if depth == 0 {
373                    let end = start + offset + ch.len_utf8();
374                    return input.get(start..end);
375                }
376            }
377            _ => {}
378        }
379    }
380
381    None
382}
383
384fn repair_tag_polluted_json(input: &str) -> Option<String> {
385    let start = input.find(['{', '['])?;
386    let candidate = input.get(start..)?;
387    let boundary = find_provider_markup_boundary(candidate)?;
388    if boundary == 0 {
389        return None;
390    }
391
392    close_incomplete_json_prefix(candidate[..boundary].trim_end())
393}
394
395fn find_provider_markup_boundary(input: &str) -> Option<usize> {
396    const PROVIDER_MARKERS: &[&str] = &[
397        "<</",
398        "</parameter>",
399        "</invoke>",
400        "</minimax:tool_call>",
401        "<minimax:tool_call>",
402        "<parameter name=\"",
403        "<invoke name=\"",
404        "<tool_call>",
405        "</tool_call>",
406    ];
407
408    input.char_indices().find_map(|(offset, _)| {
409        let rest = input.get(offset..)?;
410        PROVIDER_MARKERS
411            .iter()
412            .any(|marker| rest.starts_with(marker))
413            .then_some(offset)
414    })
415}
416
417fn close_incomplete_json_prefix(prefix: &str) -> Option<String> {
418    if prefix.is_empty() {
419        return None;
420    }
421
422    let mut repaired = String::with_capacity(prefix.len() + 8);
423    let mut expected_closers = Vec::new();
424    let mut in_string = false;
425    let mut escaped = false;
426
427    for ch in prefix.chars() {
428        repaired.push(ch);
429
430        if in_string {
431            if escaped {
432                escaped = false;
433                continue;
434            }
435
436            match ch {
437                '\\' => escaped = true,
438                '"' => in_string = false,
439                _ => {}
440            }
441            continue;
442        }
443
444        match ch {
445            '"' => in_string = true,
446            '{' => expected_closers.push('}'),
447            '[' => expected_closers.push(']'),
448            '}' | ']' => {
449                if expected_closers.pop() != Some(ch) {
450                    return None;
451                }
452            }
453            _ => {}
454        }
455    }
456
457    if in_string {
458        repaired.push('"');
459    }
460    while let Some(closer) = expected_closers.pop() {
461        repaired.push(closer);
462    }
463
464    Some(repaired)
465}
466
467/// Universal LLM response structure
468#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
469pub struct LLMResponse {
470    /// The response content text
471    pub content: Option<String>,
472
473    /// Tool calls made by the model
474    pub tool_calls: Option<Vec<ToolCall>>,
475
476    /// The model that generated this response
477    pub model: String,
478
479    /// Token usage statistics
480    pub usage: Option<Usage>,
481
482    /// Why the response finished
483    pub finish_reason: FinishReason,
484
485    /// Reasoning content (for models that support it)
486    pub reasoning: Option<String>,
487
488    /// Detailed reasoning traces (for models that support it)
489    pub reasoning_details: Option<Vec<String>>,
490
491    /// Tool references for context
492    pub tool_references: Vec<String>,
493
494    /// Request ID from the provider
495    pub request_id: Option<String>,
496
497    /// Organization ID from the provider
498    pub organization_id: Option<String>,
499}
500
501impl LLMResponse {
502    /// Create a new LLM response with mandatory fields
503    pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
504        Self {
505            content: Some(content.into()),
506            tool_calls: None,
507            model: model.into(),
508            usage: None,
509            finish_reason: FinishReason::Stop,
510            reasoning: None,
511            reasoning_details: None,
512            tool_references: Vec::new(),
513            request_id: None,
514            organization_id: None,
515        }
516    }
517
518    /// Get content or empty string
519    pub fn content_text(&self) -> &str {
520        self.content.as_deref().unwrap_or("")
521    }
522
523    /// Get content as String (clone)
524    pub fn content_string(&self) -> String {
525        self.content.clone().unwrap_or_default()
526    }
527}
528
529#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
530pub struct LLMErrorMetadata {
531    pub provider: Option<String>,
532    pub status: Option<u16>,
533    pub code: Option<String>,
534    pub request_id: Option<String>,
535    pub organization_id: Option<String>,
536    pub retry_after: Option<String>,
537    pub message: Option<String>,
538}
539
540impl LLMErrorMetadata {
541    pub fn new(
542        provider: impl Into<String>,
543        status: Option<u16>,
544        code: Option<String>,
545        request_id: Option<String>,
546        organization_id: Option<String>,
547        retry_after: Option<String>,
548        message: Option<String>,
549    ) -> Box<Self> {
550        Box::new(Self {
551            provider: Some(provider.into()),
552            status,
553            code,
554            request_id,
555            organization_id,
556            retry_after,
557            message,
558        })
559    }
560}
561
562/// LLM error types with optional provider metadata
563#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
564#[serde(tag = "type", rename_all = "snake_case")]
565pub enum LLMError {
566    #[error("Authentication failed: {message}")]
567    Authentication {
568        message: String,
569        metadata: Option<Box<LLMErrorMetadata>>,
570    },
571    #[error("Rate limit exceeded")]
572    RateLimit {
573        metadata: Option<Box<LLMErrorMetadata>>,
574    },
575    #[error("Invalid request: {message}")]
576    InvalidRequest {
577        message: String,
578        metadata: Option<Box<LLMErrorMetadata>>,
579    },
580    #[error("Network error: {message}")]
581    Network {
582        message: String,
583        metadata: Option<Box<LLMErrorMetadata>>,
584    },
585    #[error("Provider error: {message}")]
586    Provider {
587        message: String,
588        metadata: Option<Box<LLMErrorMetadata>>,
589    },
590}
591
592#[cfg(test)]
593mod tests {
594    use super::ToolCall;
595    use serde_json::json;
596
597    #[test]
598    fn parsed_arguments_accepts_trailing_characters() {
599        let call = ToolCall::function(
600            "call_read".to_string(),
601            "read_file".to_string(),
602            r#"{"path":"src/main.rs"} trailing text"#.to_string(),
603        );
604
605        let parsed = call
606            .parsed_arguments()
607            .expect("arguments with trailing text should recover");
608        assert_eq!(parsed, json!({"path":"src/main.rs"}));
609    }
610
611    #[test]
612    fn parsed_arguments_accepts_code_fenced_json() {
613        let call = ToolCall::function(
614            "call_read".to_string(),
615            "read_file".to_string(),
616            "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
617        );
618
619        let parsed = call
620            .parsed_arguments()
621            .expect("code-fenced arguments should recover");
622        assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
623    }
624
625    #[test]
626    fn parsed_arguments_rejects_incomplete_json() {
627        let call = ToolCall::function(
628            "call_read".to_string(),
629            "read_file".to_string(),
630            r#"{"path":"src/main.rs""#.to_string(),
631        );
632
633        assert!(call.parsed_arguments().is_err());
634    }
635
636    #[test]
637    fn parsed_arguments_recovers_truncated_minimax_markup() {
638        let call = ToolCall::function(
639            "call_search".to_string(),
640            "unified_search".to_string(),
641            "{\"action\": \"grep\", \"pattern\": \"persistent_memory\", \"path\": \"vtcode-core/src</parameter>\n<</invoke>\n</minimax:tool_call>".to_string(),
642        );
643
644        let parsed = call
645            .parsed_arguments()
646            .expect("minimax markup spillover should recover");
647        assert_eq!(
648            parsed,
649            json!({
650                "action": "grep",
651                "pattern": "persistent_memory",
652                "path": "vtcode-core/src"
653            })
654        );
655    }
656
657    #[test]
658    fn function_call_serializes_optional_namespace() {
659        let call = ToolCall::function_with_namespace(
660            "call_read".to_string(),
661            Some("workspace".to_string()),
662            "read_file".to_string(),
663            r#"{"path":"src/main.rs"}"#.to_string(),
664        );
665
666        let json = serde_json::to_value(&call).expect("tool call should serialize");
667        assert_eq!(json["function"]["namespace"], "workspace");
668        assert_eq!(json["function"]["name"], "read_file");
669    }
670
671    #[test]
672    fn custom_tool_call_exposes_raw_execution_arguments() {
673        let patch = "*** Begin Patch\n*** End Patch\n".to_string();
674        let call = ToolCall::custom(
675            "call_patch".to_string(),
676            "apply_patch".to_string(),
677            patch.clone(),
678        );
679
680        assert!(call.is_custom());
681        assert_eq!(call.tool_name(), Some("apply_patch"));
682        assert_eq!(call.raw_input(), Some(patch.as_str()));
683        assert_eq!(
684            call.execution_arguments().expect("custom arguments"),
685            json!(patch)
686        );
687        assert!(
688            call.parsed_arguments().is_err(),
689            "custom tool payload should stay freeform rather than JSON"
690        );
691    }
692}