Skip to main content

vtcode_commons/
llm.rs

1//! Core LLM types shared across the project
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7    Gemini,
8    OpenAI,
9    Anthropic,
10    DeepSeek,
11    OpenRouter,
12    Ollama,
13    ZAI,
14    Moonshot,
15    HuggingFace,
16    Minimax,
17    LiteLLM,
18}
19
20#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
21pub struct Usage {
22    pub prompt_tokens: u32,
23    pub completion_tokens: u32,
24    pub total_tokens: u32,
25    pub cached_prompt_tokens: Option<u32>,
26    pub cache_creation_tokens: Option<u32>,
27    pub cache_read_tokens: Option<u32>,
28}
29
30impl Usage {
31    #[inline]
32    fn has_cache_read_metric(&self) -> bool {
33        self.cache_read_tokens.is_some() || self.cached_prompt_tokens.is_some()
34    }
35
36    #[inline]
37    fn has_any_cache_metrics(&self) -> bool {
38        self.has_cache_read_metric() || self.cache_creation_tokens.is_some()
39    }
40
41    #[inline]
42    pub fn cache_read_tokens_or_fallback(&self) -> u32 {
43        self.cache_read_tokens
44            .or(self.cached_prompt_tokens)
45            .unwrap_or(0)
46    }
47
48    #[inline]
49    pub fn cache_creation_tokens_or_zero(&self) -> u32 {
50        self.cache_creation_tokens.unwrap_or(0)
51    }
52
53    #[inline]
54    pub fn cache_hit_rate(&self) -> Option<f64> {
55        if !self.has_any_cache_metrics() {
56            return None;
57        }
58        let read = self.cache_read_tokens_or_fallback() as f64;
59        let creation = self.cache_creation_tokens_or_zero() as f64;
60        let total = read + creation;
61        if total > 0.0 {
62            Some((read / total) * 100.0)
63        } else {
64            None
65        }
66    }
67
68    #[inline]
69    pub fn is_cache_hit(&self) -> Option<bool> {
70        self.has_any_cache_metrics()
71            .then(|| self.cache_read_tokens_or_fallback() > 0)
72    }
73
74    #[inline]
75    pub fn is_cache_miss(&self) -> Option<bool> {
76        self.has_any_cache_metrics().then(|| {
77            self.cache_creation_tokens_or_zero() > 0 && self.cache_read_tokens_or_fallback() == 0
78        })
79    }
80
81    #[inline]
82    pub fn total_cache_tokens(&self) -> u32 {
83        let read = self.cache_read_tokens_or_fallback();
84        let creation = self.cache_creation_tokens_or_zero();
85        read + creation
86    }
87
88    #[inline]
89    pub fn cache_savings_ratio(&self) -> Option<f64> {
90        if !self.has_cache_read_metric() {
91            return None;
92        }
93        let read = self.cache_read_tokens_or_fallback() as f64;
94        let prompt = self.prompt_tokens as f64;
95        if prompt > 0.0 {
96            Some(read / prompt)
97        } else {
98            None
99        }
100    }
101}
102
103#[cfg(test)]
104mod usage_tests {
105    use super::Usage;
106
107    #[test]
108    fn cache_helpers_fall_back_to_cached_prompt_tokens() {
109        let usage = Usage {
110            prompt_tokens: 1_000,
111            completion_tokens: 200,
112            total_tokens: 1_200,
113            cached_prompt_tokens: Some(600),
114            cache_creation_tokens: Some(150),
115            cache_read_tokens: None,
116        };
117
118        assert_eq!(usage.cache_read_tokens_or_fallback(), 600);
119        assert_eq!(usage.cache_creation_tokens_or_zero(), 150);
120        assert_eq!(usage.total_cache_tokens(), 750);
121        assert_eq!(usage.is_cache_hit(), Some(true));
122        assert_eq!(usage.is_cache_miss(), Some(false));
123        assert_eq!(usage.cache_savings_ratio(), Some(0.6));
124        assert_eq!(usage.cache_hit_rate(), Some(80.0));
125    }
126
127    #[test]
128    fn cache_helpers_preserve_unknown_without_metrics() {
129        let usage = Usage {
130            prompt_tokens: 1_000,
131            completion_tokens: 200,
132            total_tokens: 1_200,
133            cached_prompt_tokens: None,
134            cache_creation_tokens: None,
135            cache_read_tokens: None,
136        };
137
138        assert_eq!(usage.total_cache_tokens(), 0);
139        assert_eq!(usage.is_cache_hit(), None);
140        assert_eq!(usage.is_cache_miss(), None);
141        assert_eq!(usage.cache_savings_ratio(), None);
142        assert_eq!(usage.cache_hit_rate(), None);
143    }
144}
145
146#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
147pub enum FinishReason {
148    #[default]
149    Stop,
150    Length,
151    ToolCalls,
152    ContentFilter,
153    Pause,
154    Refusal,
155    Error(String),
156}
157
158/// Universal tool call that matches OpenAI/Anthropic/Gemini specifications
159#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
160pub struct ToolCall {
161    /// Unique identifier for this tool call (e.g., "call_123")
162    pub id: String,
163
164    /// The type of tool call: "function", "custom" (GPT-5 freeform), or other
165    #[serde(rename = "type")]
166    pub call_type: String,
167
168    /// Function call details (for function-type tools)
169    #[serde(skip_serializing_if = "Option::is_none")]
170    pub function: Option<FunctionCall>,
171
172    /// Raw text payload (for custom freeform tools in GPT-5)
173    #[serde(skip_serializing_if = "Option::is_none")]
174    pub text: Option<String>,
175
176    /// Gemini-specific thought signature for maintaining reasoning context
177    #[serde(skip_serializing_if = "Option::is_none")]
178    pub thought_signature: Option<String>,
179}
180
181/// Function call within a tool call
182#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
183pub struct FunctionCall {
184    /// Optional namespace for grouped or deferred tools.
185    #[serde(default, skip_serializing_if = "Option::is_none")]
186    pub namespace: Option<String>,
187
188    /// The name of the function to call
189    pub name: String,
190
191    /// The arguments to pass to the function, as a JSON string
192    pub arguments: String,
193}
194
195impl ToolCall {
196    /// Create a new function tool call
197    pub fn function(id: String, name: String, arguments: String) -> Self {
198        Self::function_with_namespace(id, None, name, arguments)
199    }
200
201    /// Create a new function tool call with an optional namespace.
202    pub fn function_with_namespace(
203        id: String,
204        namespace: Option<String>,
205        name: String,
206        arguments: String,
207    ) -> Self {
208        Self {
209            id,
210            call_type: "function".to_owned(),
211            function: Some(FunctionCall {
212                namespace,
213                name,
214                arguments,
215            }),
216            text: None,
217            thought_signature: None,
218        }
219    }
220
221    /// Create a new custom tool call with raw text payload (GPT-5 freeform)
222    pub fn custom(id: String, name: String, text: String) -> Self {
223        Self {
224            id,
225            call_type: "custom".to_owned(),
226            function: Some(FunctionCall {
227                namespace: None,
228                name,
229                arguments: text.clone(),
230            }),
231            text: Some(text),
232            thought_signature: None,
233        }
234    }
235
236    /// Parse the arguments as JSON Value (for function-type tools)
237    pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
238        if let Some(ref func) = self.function {
239            parse_tool_arguments(&func.arguments)
240        } else {
241            // Return an error by trying to parse invalid JSON
242            serde_json::from_str("")
243        }
244    }
245
246    /// Validate that this tool call is properly formed
247    pub fn validate(&self) -> Result<(), String> {
248        if self.id.is_empty() {
249            return Err("Tool call ID cannot be empty".to_owned());
250        }
251
252        match self.call_type.as_str() {
253            "function" => {
254                if let Some(func) = &self.function {
255                    if func.name.is_empty() {
256                        return Err("Function name cannot be empty".to_owned());
257                    }
258                    // Validate that arguments is valid JSON for function tools
259                    if let Err(e) = self.parsed_arguments() {
260                        return Err(format!("Invalid JSON in function arguments: {}", e));
261                    }
262                } else {
263                    return Err("Function tool call missing function details".to_owned());
264                }
265            }
266            "custom" => {
267                // For custom tools, we allow raw text payload without JSON validation
268                if let Some(func) = &self.function {
269                    if func.name.is_empty() {
270                        return Err("Custom tool name cannot be empty".to_owned());
271                    }
272                } else {
273                    return Err("Custom tool call missing function details".to_owned());
274                }
275            }
276            _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
277        }
278
279        Ok(())
280    }
281}
282
283fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
284    let trimmed = raw_arguments.trim();
285    match serde_json::from_str(trimmed) {
286        Ok(parsed) => Ok(parsed),
287        Err(primary_error) => {
288            if let Some(candidate) = extract_balanced_json(trimmed)
289                && let Ok(parsed) = serde_json::from_str(candidate)
290            {
291                return Ok(parsed);
292            }
293            Err(primary_error)
294        }
295    }
296}
297
298fn extract_balanced_json(input: &str) -> Option<&str> {
299    let start = input.find(['{', '['])?;
300    let opening = input.as_bytes().get(start).copied()?;
301    let closing = match opening {
302        b'{' => b'}',
303        b'[' => b']',
304        _ => return None,
305    };
306
307    let mut depth = 0usize;
308    let mut in_string = false;
309    let mut escaped = false;
310
311    for (offset, ch) in input[start..].char_indices() {
312        if in_string {
313            if escaped {
314                escaped = false;
315                continue;
316            }
317            if ch == '\\' {
318                escaped = true;
319                continue;
320            }
321            if ch == '"' {
322                in_string = false;
323            }
324            continue;
325        }
326
327        match ch {
328            '"' => in_string = true,
329            _ if ch as u32 == opening as u32 => depth += 1,
330            _ if ch as u32 == closing as u32 => {
331                depth = depth.saturating_sub(1);
332                if depth == 0 {
333                    let end = start + offset + ch.len_utf8();
334                    return input.get(start..end);
335                }
336            }
337            _ => {}
338        }
339    }
340
341    None
342}
343
344/// Universal LLM response structure
345#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
346pub struct LLMResponse {
347    /// The response content text
348    pub content: Option<String>,
349
350    /// Tool calls made by the model
351    pub tool_calls: Option<Vec<ToolCall>>,
352
353    /// The model that generated this response
354    pub model: String,
355
356    /// Token usage statistics
357    pub usage: Option<Usage>,
358
359    /// Why the response finished
360    pub finish_reason: FinishReason,
361
362    /// Reasoning content (for models that support it)
363    pub reasoning: Option<String>,
364
365    /// Detailed reasoning traces (for models that support it)
366    pub reasoning_details: Option<Vec<String>>,
367
368    /// Tool references for context
369    pub tool_references: Vec<String>,
370
371    /// Request ID from the provider
372    pub request_id: Option<String>,
373
374    /// Organization ID from the provider
375    pub organization_id: Option<String>,
376}
377
378impl LLMResponse {
379    /// Create a new LLM response with mandatory fields
380    pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
381        Self {
382            content: Some(content.into()),
383            tool_calls: None,
384            model: model.into(),
385            usage: None,
386            finish_reason: FinishReason::Stop,
387            reasoning: None,
388            reasoning_details: None,
389            tool_references: Vec::new(),
390            request_id: None,
391            organization_id: None,
392        }
393    }
394
395    /// Get content or empty string
396    pub fn content_text(&self) -> &str {
397        self.content.as_deref().unwrap_or("")
398    }
399
400    /// Get content as String (clone)
401    pub fn content_string(&self) -> String {
402        self.content.clone().unwrap_or_default()
403    }
404}
405
406#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
407pub struct LLMErrorMetadata {
408    pub provider: Option<String>,
409    pub status: Option<u16>,
410    pub code: Option<String>,
411    pub request_id: Option<String>,
412    pub organization_id: Option<String>,
413    pub retry_after: Option<String>,
414    pub message: Option<String>,
415}
416
417impl LLMErrorMetadata {
418    pub fn new(
419        provider: impl Into<String>,
420        status: Option<u16>,
421        code: Option<String>,
422        request_id: Option<String>,
423        organization_id: Option<String>,
424        retry_after: Option<String>,
425        message: Option<String>,
426    ) -> Box<Self> {
427        Box::new(Self {
428            provider: Some(provider.into()),
429            status,
430            code,
431            request_id,
432            organization_id,
433            retry_after,
434            message,
435        })
436    }
437}
438
439/// LLM error types with optional provider metadata
440#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
441#[serde(tag = "type", rename_all = "snake_case")]
442pub enum LLMError {
443    #[error("Authentication failed: {message}")]
444    Authentication {
445        message: String,
446        metadata: Option<Box<LLMErrorMetadata>>,
447    },
448    #[error("Rate limit exceeded")]
449    RateLimit {
450        metadata: Option<Box<LLMErrorMetadata>>,
451    },
452    #[error("Invalid request: {message}")]
453    InvalidRequest {
454        message: String,
455        metadata: Option<Box<LLMErrorMetadata>>,
456    },
457    #[error("Network error: {message}")]
458    Network {
459        message: String,
460        metadata: Option<Box<LLMErrorMetadata>>,
461    },
462    #[error("Provider error: {message}")]
463    Provider {
464        message: String,
465        metadata: Option<Box<LLMErrorMetadata>>,
466    },
467}
468
469#[cfg(test)]
470mod tests {
471    use super::ToolCall;
472    use serde_json::json;
473
474    #[test]
475    fn parsed_arguments_accepts_trailing_characters() {
476        let call = ToolCall::function(
477            "call_read".to_string(),
478            "read_file".to_string(),
479            r#"{"path":"src/main.rs"} trailing text"#.to_string(),
480        );
481
482        let parsed = call
483            .parsed_arguments()
484            .expect("arguments with trailing text should recover");
485        assert_eq!(parsed, json!({"path":"src/main.rs"}));
486    }
487
488    #[test]
489    fn parsed_arguments_accepts_code_fenced_json() {
490        let call = ToolCall::function(
491            "call_read".to_string(),
492            "read_file".to_string(),
493            "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
494        );
495
496        let parsed = call
497            .parsed_arguments()
498            .expect("code-fenced arguments should recover");
499        assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
500    }
501
502    #[test]
503    fn parsed_arguments_rejects_incomplete_json() {
504        let call = ToolCall::function(
505            "call_read".to_string(),
506            "read_file".to_string(),
507            r#"{"path":"src/main.rs""#.to_string(),
508        );
509
510        assert!(call.parsed_arguments().is_err());
511    }
512
513    #[test]
514    fn function_call_serializes_optional_namespace() {
515        let call = ToolCall::function_with_namespace(
516            "call_read".to_string(),
517            Some("workspace".to_string()),
518            "read_file".to_string(),
519            r#"{"path":"src/main.rs"}"#.to_string(),
520        );
521
522        let json = serde_json::to_value(&call).expect("tool call should serialize");
523        assert_eq!(json["function"]["namespace"], "workspace");
524        assert_eq!(json["function"]["name"], "read_file");
525    }
526}