Skip to main content

vtcode_commons/
llm.rs

1//! Core LLM types shared across the project
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7    Gemini,
8    OpenAI,
9    Anthropic,
10    DeepSeek,
11    OpenRouter,
12    Ollama,
13    ZAI,
14    Moonshot,
15    HuggingFace,
16    Minimax,
17    LiteLLM,
18}
19
20#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
21pub struct Usage {
22    pub prompt_tokens: u32,
23    pub completion_tokens: u32,
24    pub total_tokens: u32,
25    pub cached_prompt_tokens: Option<u32>,
26    pub cache_creation_tokens: Option<u32>,
27    pub cache_read_tokens: Option<u32>,
28}
29
30impl Usage {
31    #[inline]
32    fn has_cache_read_metric(&self) -> bool {
33        self.cache_read_tokens.is_some() || self.cached_prompt_tokens.is_some()
34    }
35
36    #[inline]
37    fn has_any_cache_metrics(&self) -> bool {
38        self.has_cache_read_metric() || self.cache_creation_tokens.is_some()
39    }
40
41    #[inline]
42    pub fn cache_read_tokens_or_fallback(&self) -> u32 {
43        self.cache_read_tokens
44            .or(self.cached_prompt_tokens)
45            .unwrap_or(0)
46    }
47
48    #[inline]
49    pub fn cache_creation_tokens_or_zero(&self) -> u32 {
50        self.cache_creation_tokens.unwrap_or(0)
51    }
52
53    #[inline]
54    pub fn cache_hit_rate(&self) -> Option<f64> {
55        if !self.has_any_cache_metrics() {
56            return None;
57        }
58        let read = self.cache_read_tokens_or_fallback() as f64;
59        let creation = self.cache_creation_tokens_or_zero() as f64;
60        let total = read + creation;
61        if total > 0.0 {
62            Some((read / total) * 100.0)
63        } else {
64            None
65        }
66    }
67
68    #[inline]
69    pub fn is_cache_hit(&self) -> Option<bool> {
70        self.has_any_cache_metrics()
71            .then(|| self.cache_read_tokens_or_fallback() > 0)
72    }
73
74    #[inline]
75    pub fn is_cache_miss(&self) -> Option<bool> {
76        self.has_any_cache_metrics().then(|| {
77            self.cache_creation_tokens_or_zero() > 0 && self.cache_read_tokens_or_fallback() == 0
78        })
79    }
80
81    #[inline]
82    pub fn total_cache_tokens(&self) -> u32 {
83        let read = self.cache_read_tokens_or_fallback();
84        let creation = self.cache_creation_tokens_or_zero();
85        read + creation
86    }
87
88    #[inline]
89    pub fn cache_savings_ratio(&self) -> Option<f64> {
90        if !self.has_cache_read_metric() {
91            return None;
92        }
93        let read = self.cache_read_tokens_or_fallback() as f64;
94        let prompt = self.prompt_tokens as f64;
95        if prompt > 0.0 {
96            Some(read / prompt)
97        } else {
98            None
99        }
100    }
101}
102
103#[cfg(test)]
104mod usage_tests {
105    use super::Usage;
106
107    #[test]
108    fn cache_helpers_fall_back_to_cached_prompt_tokens() {
109        let usage = Usage {
110            prompt_tokens: 1_000,
111            completion_tokens: 200,
112            total_tokens: 1_200,
113            cached_prompt_tokens: Some(600),
114            cache_creation_tokens: Some(150),
115            cache_read_tokens: None,
116        };
117
118        assert_eq!(usage.cache_read_tokens_or_fallback(), 600);
119        assert_eq!(usage.cache_creation_tokens_or_zero(), 150);
120        assert_eq!(usage.total_cache_tokens(), 750);
121        assert_eq!(usage.is_cache_hit(), Some(true));
122        assert_eq!(usage.is_cache_miss(), Some(false));
123        assert_eq!(usage.cache_savings_ratio(), Some(0.6));
124        assert_eq!(usage.cache_hit_rate(), Some(80.0));
125    }
126
127    #[test]
128    fn cache_helpers_preserve_unknown_without_metrics() {
129        let usage = Usage {
130            prompt_tokens: 1_000,
131            completion_tokens: 200,
132            total_tokens: 1_200,
133            cached_prompt_tokens: None,
134            cache_creation_tokens: None,
135            cache_read_tokens: None,
136        };
137
138        assert_eq!(usage.total_cache_tokens(), 0);
139        assert_eq!(usage.is_cache_hit(), None);
140        assert_eq!(usage.is_cache_miss(), None);
141        assert_eq!(usage.cache_savings_ratio(), None);
142        assert_eq!(usage.cache_hit_rate(), None);
143    }
144}
145
146#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
147pub enum FinishReason {
148    #[default]
149    Stop,
150    Length,
151    ToolCalls,
152    ContentFilter,
153    Pause,
154    Refusal,
155    Error(String),
156}
157
158/// Universal tool call that matches OpenAI/Anthropic/Gemini specifications
159#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
160pub struct ToolCall {
161    /// Unique identifier for this tool call (e.g., "call_123")
162    pub id: String,
163
164    /// The type of tool call: "function", "custom" (GPT-5 freeform), or other
165    #[serde(rename = "type")]
166    pub call_type: String,
167
168    /// Function call details (for function-type tools)
169    #[serde(skip_serializing_if = "Option::is_none")]
170    pub function: Option<FunctionCall>,
171
172    /// Raw text payload (for custom freeform tools in GPT-5)
173    #[serde(skip_serializing_if = "Option::is_none")]
174    pub text: Option<String>,
175
176    /// Gemini-specific thought signature for maintaining reasoning context
177    #[serde(skip_serializing_if = "Option::is_none")]
178    pub thought_signature: Option<String>,
179}
180
181/// Function call within a tool call
182#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
183pub struct FunctionCall {
184    /// The name of the function to call
185    pub name: String,
186
187    /// The arguments to pass to the function, as a JSON string
188    pub arguments: String,
189}
190
191impl ToolCall {
192    /// Create a new function tool call
193    pub fn function(id: String, name: String, arguments: String) -> Self {
194        Self {
195            id,
196            call_type: "function".to_owned(),
197            function: Some(FunctionCall { name, arguments }),
198            text: None,
199            thought_signature: None,
200        }
201    }
202
203    /// Create a new custom tool call with raw text payload (GPT-5 freeform)
204    pub fn custom(id: String, name: String, text: String) -> Self {
205        Self {
206            id,
207            call_type: "custom".to_owned(),
208            function: Some(FunctionCall {
209                name,
210                arguments: text.clone(),
211            }),
212            text: Some(text),
213            thought_signature: None,
214        }
215    }
216
217    /// Parse the arguments as JSON Value (for function-type tools)
218    pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
219        if let Some(ref func) = self.function {
220            parse_tool_arguments(&func.arguments)
221        } else {
222            // Return an error by trying to parse invalid JSON
223            serde_json::from_str("")
224        }
225    }
226
227    /// Validate that this tool call is properly formed
228    pub fn validate(&self) -> Result<(), String> {
229        if self.id.is_empty() {
230            return Err("Tool call ID cannot be empty".to_owned());
231        }
232
233        match self.call_type.as_str() {
234            "function" => {
235                if let Some(func) = &self.function {
236                    if func.name.is_empty() {
237                        return Err("Function name cannot be empty".to_owned());
238                    }
239                    // Validate that arguments is valid JSON for function tools
240                    if let Err(e) = self.parsed_arguments() {
241                        return Err(format!("Invalid JSON in function arguments: {}", e));
242                    }
243                } else {
244                    return Err("Function tool call missing function details".to_owned());
245                }
246            }
247            "custom" => {
248                // For custom tools, we allow raw text payload without JSON validation
249                if let Some(func) = &self.function {
250                    if func.name.is_empty() {
251                        return Err("Custom tool name cannot be empty".to_owned());
252                    }
253                } else {
254                    return Err("Custom tool call missing function details".to_owned());
255                }
256            }
257            _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
258        }
259
260        Ok(())
261    }
262}
263
264fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
265    let trimmed = raw_arguments.trim();
266    match serde_json::from_str(trimmed) {
267        Ok(parsed) => Ok(parsed),
268        Err(primary_error) => {
269            if let Some(candidate) = extract_balanced_json(trimmed)
270                && let Ok(parsed) = serde_json::from_str(candidate)
271            {
272                return Ok(parsed);
273            }
274            Err(primary_error)
275        }
276    }
277}
278
279fn extract_balanced_json(input: &str) -> Option<&str> {
280    let start = input.find(['{', '['])?;
281    let opening = input.as_bytes().get(start).copied()?;
282    let closing = match opening {
283        b'{' => b'}',
284        b'[' => b']',
285        _ => return None,
286    };
287
288    let mut depth = 0usize;
289    let mut in_string = false;
290    let mut escaped = false;
291
292    for (offset, ch) in input[start..].char_indices() {
293        if in_string {
294            if escaped {
295                escaped = false;
296                continue;
297            }
298            if ch == '\\' {
299                escaped = true;
300                continue;
301            }
302            if ch == '"' {
303                in_string = false;
304            }
305            continue;
306        }
307
308        match ch {
309            '"' => in_string = true,
310            _ if ch as u32 == opening as u32 => depth += 1,
311            _ if ch as u32 == closing as u32 => {
312                depth = depth.saturating_sub(1);
313                if depth == 0 {
314                    let end = start + offset + ch.len_utf8();
315                    return input.get(start..end);
316                }
317            }
318            _ => {}
319        }
320    }
321
322    None
323}
324
325/// Universal LLM response structure
326#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
327pub struct LLMResponse {
328    /// The response content text
329    pub content: Option<String>,
330
331    /// Tool calls made by the model
332    pub tool_calls: Option<Vec<ToolCall>>,
333
334    /// The model that generated this response
335    pub model: String,
336
337    /// Token usage statistics
338    pub usage: Option<Usage>,
339
340    /// Why the response finished
341    pub finish_reason: FinishReason,
342
343    /// Reasoning content (for models that support it)
344    pub reasoning: Option<String>,
345
346    /// Detailed reasoning traces (for models that support it)
347    pub reasoning_details: Option<Vec<String>>,
348
349    /// Tool references for context
350    pub tool_references: Vec<String>,
351
352    /// Request ID from the provider
353    pub request_id: Option<String>,
354
355    /// Organization ID from the provider
356    pub organization_id: Option<String>,
357}
358
359impl LLMResponse {
360    /// Create a new LLM response with mandatory fields
361    pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
362        Self {
363            content: Some(content.into()),
364            tool_calls: None,
365            model: model.into(),
366            usage: None,
367            finish_reason: FinishReason::Stop,
368            reasoning: None,
369            reasoning_details: None,
370            tool_references: Vec::new(),
371            request_id: None,
372            organization_id: None,
373        }
374    }
375
376    /// Get content or empty string
377    pub fn content_text(&self) -> &str {
378        self.content.as_deref().unwrap_or("")
379    }
380
381    /// Get content as String (clone)
382    pub fn content_string(&self) -> String {
383        self.content.clone().unwrap_or_default()
384    }
385}
386
387#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
388pub struct LLMErrorMetadata {
389    pub provider: Option<String>,
390    pub status: Option<u16>,
391    pub code: Option<String>,
392    pub request_id: Option<String>,
393    pub organization_id: Option<String>,
394    pub retry_after: Option<String>,
395    pub message: Option<String>,
396}
397
398impl LLMErrorMetadata {
399    pub fn new(
400        provider: impl Into<String>,
401        status: Option<u16>,
402        code: Option<String>,
403        request_id: Option<String>,
404        organization_id: Option<String>,
405        retry_after: Option<String>,
406        message: Option<String>,
407    ) -> Box<Self> {
408        Box::new(Self {
409            provider: Some(provider.into()),
410            status,
411            code,
412            request_id,
413            organization_id,
414            retry_after,
415            message,
416        })
417    }
418}
419
420/// LLM error types with optional provider metadata
421#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
422#[serde(tag = "type", rename_all = "snake_case")]
423pub enum LLMError {
424    #[error("Authentication failed: {message}")]
425    Authentication {
426        message: String,
427        metadata: Option<Box<LLMErrorMetadata>>,
428    },
429    #[error("Rate limit exceeded")]
430    RateLimit {
431        metadata: Option<Box<LLMErrorMetadata>>,
432    },
433    #[error("Invalid request: {message}")]
434    InvalidRequest {
435        message: String,
436        metadata: Option<Box<LLMErrorMetadata>>,
437    },
438    #[error("Network error: {message}")]
439    Network {
440        message: String,
441        metadata: Option<Box<LLMErrorMetadata>>,
442    },
443    #[error("Provider error: {message}")]
444    Provider {
445        message: String,
446        metadata: Option<Box<LLMErrorMetadata>>,
447    },
448}
449
450#[cfg(test)]
451mod tests {
452    use super::ToolCall;
453    use serde_json::json;
454
455    #[test]
456    fn parsed_arguments_accepts_trailing_characters() {
457        let call = ToolCall::function(
458            "call_read".to_string(),
459            "read_file".to_string(),
460            r#"{"path":"src/main.rs"} trailing text"#.to_string(),
461        );
462
463        let parsed = call
464            .parsed_arguments()
465            .expect("arguments with trailing text should recover");
466        assert_eq!(parsed, json!({"path":"src/main.rs"}));
467    }
468
469    #[test]
470    fn parsed_arguments_accepts_code_fenced_json() {
471        let call = ToolCall::function(
472            "call_read".to_string(),
473            "read_file".to_string(),
474            "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
475        );
476
477        let parsed = call
478            .parsed_arguments()
479            .expect("code-fenced arguments should recover");
480        assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
481    }
482
483    #[test]
484    fn parsed_arguments_rejects_incomplete_json() {
485        let call = ToolCall::function(
486            "call_read".to_string(),
487            "read_file".to_string(),
488            r#"{"path":"src/main.rs""#.to_string(),
489        );
490
491        assert!(call.parsed_arguments().is_err());
492    }
493}