Skip to main content

vtcode_commons/
llm.rs

1//! Core LLM types shared across the project
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7    Gemini,
8    OpenAI,
9    Anthropic,
10    DeepSeek,
11    OpenRouter,
12    Ollama,
13    ZAI,
14    Moonshot,
15    HuggingFace,
16    Minimax,
17}
18
19#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
20pub struct Usage {
21    pub prompt_tokens: u32,
22    pub completion_tokens: u32,
23    pub total_tokens: u32,
24    pub cached_prompt_tokens: Option<u32>,
25    pub cache_creation_tokens: Option<u32>,
26    pub cache_read_tokens: Option<u32>,
27}
28
29impl Usage {
30    #[inline]
31    fn has_cache_read_metric(&self) -> bool {
32        self.cache_read_tokens.is_some() || self.cached_prompt_tokens.is_some()
33    }
34
35    #[inline]
36    fn has_any_cache_metrics(&self) -> bool {
37        self.has_cache_read_metric() || self.cache_creation_tokens.is_some()
38    }
39
40    #[inline]
41    pub fn cache_read_tokens_or_fallback(&self) -> u32 {
42        self.cache_read_tokens
43            .or(self.cached_prompt_tokens)
44            .unwrap_or(0)
45    }
46
47    #[inline]
48    pub fn cache_creation_tokens_or_zero(&self) -> u32 {
49        self.cache_creation_tokens.unwrap_or(0)
50    }
51
52    #[inline]
53    pub fn cache_hit_rate(&self) -> Option<f64> {
54        if !self.has_any_cache_metrics() {
55            return None;
56        }
57        let read = self.cache_read_tokens_or_fallback() as f64;
58        let creation = self.cache_creation_tokens_or_zero() as f64;
59        let total = read + creation;
60        if total > 0.0 {
61            Some((read / total) * 100.0)
62        } else {
63            None
64        }
65    }
66
67    #[inline]
68    pub fn is_cache_hit(&self) -> Option<bool> {
69        self.has_any_cache_metrics()
70            .then(|| self.cache_read_tokens_or_fallback() > 0)
71    }
72
73    #[inline]
74    pub fn is_cache_miss(&self) -> Option<bool> {
75        self.has_any_cache_metrics().then(|| {
76            self.cache_creation_tokens_or_zero() > 0 && self.cache_read_tokens_or_fallback() == 0
77        })
78    }
79
80    #[inline]
81    pub fn total_cache_tokens(&self) -> u32 {
82        let read = self.cache_read_tokens_or_fallback();
83        let creation = self.cache_creation_tokens_or_zero();
84        read + creation
85    }
86
87    #[inline]
88    pub fn cache_savings_ratio(&self) -> Option<f64> {
89        if !self.has_cache_read_metric() {
90            return None;
91        }
92        let read = self.cache_read_tokens_or_fallback() as f64;
93        let prompt = self.prompt_tokens as f64;
94        if prompt > 0.0 {
95            Some(read / prompt)
96        } else {
97            None
98        }
99    }
100}
101
102#[cfg(test)]
103mod usage_tests {
104    use super::Usage;
105
106    #[test]
107    fn cache_helpers_fall_back_to_cached_prompt_tokens() {
108        let usage = Usage {
109            prompt_tokens: 1_000,
110            completion_tokens: 200,
111            total_tokens: 1_200,
112            cached_prompt_tokens: Some(600),
113            cache_creation_tokens: Some(150),
114            cache_read_tokens: None,
115        };
116
117        assert_eq!(usage.cache_read_tokens_or_fallback(), 600);
118        assert_eq!(usage.cache_creation_tokens_or_zero(), 150);
119        assert_eq!(usage.total_cache_tokens(), 750);
120        assert_eq!(usage.is_cache_hit(), Some(true));
121        assert_eq!(usage.is_cache_miss(), Some(false));
122        assert_eq!(usage.cache_savings_ratio(), Some(0.6));
123        assert_eq!(usage.cache_hit_rate(), Some(80.0));
124    }
125
126    #[test]
127    fn cache_helpers_preserve_unknown_without_metrics() {
128        let usage = Usage {
129            prompt_tokens: 1_000,
130            completion_tokens: 200,
131            total_tokens: 1_200,
132            cached_prompt_tokens: None,
133            cache_creation_tokens: None,
134            cache_read_tokens: None,
135        };
136
137        assert_eq!(usage.total_cache_tokens(), 0);
138        assert_eq!(usage.is_cache_hit(), None);
139        assert_eq!(usage.is_cache_miss(), None);
140        assert_eq!(usage.cache_savings_ratio(), None);
141        assert_eq!(usage.cache_hit_rate(), None);
142    }
143}
144
145#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
146pub enum FinishReason {
147    #[default]
148    Stop,
149    Length,
150    ToolCalls,
151    ContentFilter,
152    Pause,
153    Refusal,
154    Error(String),
155}
156
157/// Universal tool call that matches OpenAI/Anthropic/Gemini specifications
158#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
159pub struct ToolCall {
160    /// Unique identifier for this tool call (e.g., "call_123")
161    pub id: String,
162
163    /// The type of tool call: "function", "custom" (GPT-5 freeform), or other
164    #[serde(rename = "type")]
165    pub call_type: String,
166
167    /// Function call details (for function-type tools)
168    #[serde(skip_serializing_if = "Option::is_none")]
169    pub function: Option<FunctionCall>,
170
171    /// Raw text payload (for custom freeform tools in GPT-5)
172    #[serde(skip_serializing_if = "Option::is_none")]
173    pub text: Option<String>,
174
175    /// Gemini-specific thought signature for maintaining reasoning context
176    #[serde(skip_serializing_if = "Option::is_none")]
177    pub thought_signature: Option<String>,
178}
179
180/// Function call within a tool call
181#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
182pub struct FunctionCall {
183    /// Optional namespace for grouped or deferred tools.
184    #[serde(default, skip_serializing_if = "Option::is_none")]
185    pub namespace: Option<String>,
186
187    /// The name of the function to call
188    pub name: String,
189
190    /// The arguments to pass to the function, as a JSON string
191    pub arguments: String,
192}
193
194impl ToolCall {
195    /// Create a new function tool call
196    pub fn function(id: String, name: String, arguments: String) -> Self {
197        Self::function_with_namespace(id, None, name, arguments)
198    }
199
200    /// Create a new function tool call with an optional namespace.
201    pub fn function_with_namespace(
202        id: String,
203        namespace: Option<String>,
204        name: String,
205        arguments: String,
206    ) -> Self {
207        Self {
208            id,
209            call_type: "function".to_owned(),
210            function: Some(FunctionCall {
211                namespace,
212                name,
213                arguments,
214            }),
215            text: None,
216            thought_signature: None,
217        }
218    }
219
220    /// Create a new custom tool call with raw text payload (GPT-5 freeform)
221    pub fn custom(id: String, name: String, text: String) -> Self {
222        Self {
223            id,
224            call_type: "custom".to_owned(),
225            function: Some(FunctionCall {
226                namespace: None,
227                name,
228                arguments: text.clone(),
229            }),
230            text: Some(text),
231            thought_signature: None,
232        }
233    }
234
235    /// Parse the arguments as JSON Value (for function-type tools)
236    pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
237        if let Some(ref func) = self.function {
238            parse_tool_arguments(&func.arguments)
239        } else {
240            // Return an error by trying to parse invalid JSON
241            serde_json::from_str("")
242        }
243    }
244
245    /// Validate that this tool call is properly formed
246    pub fn validate(&self) -> Result<(), String> {
247        if self.id.is_empty() {
248            return Err("Tool call ID cannot be empty".to_owned());
249        }
250
251        match self.call_type.as_str() {
252            "function" => {
253                if let Some(func) = &self.function {
254                    if func.name.is_empty() {
255                        return Err("Function name cannot be empty".to_owned());
256                    }
257                    // Validate that arguments is valid JSON for function tools
258                    if let Err(e) = self.parsed_arguments() {
259                        return Err(format!("Invalid JSON in function arguments: {}", e));
260                    }
261                } else {
262                    return Err("Function tool call missing function details".to_owned());
263                }
264            }
265            "custom" => {
266                // For custom tools, we allow raw text payload without JSON validation
267                if let Some(func) = &self.function {
268                    if func.name.is_empty() {
269                        return Err("Custom tool name cannot be empty".to_owned());
270                    }
271                } else {
272                    return Err("Custom tool call missing function details".to_owned());
273                }
274            }
275            _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
276        }
277
278        Ok(())
279    }
280}
281
282fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
283    let trimmed = raw_arguments.trim();
284    match serde_json::from_str(trimmed) {
285        Ok(parsed) => Ok(parsed),
286        Err(primary_error) => {
287            if let Some(candidate) = extract_balanced_json(trimmed)
288                && let Ok(parsed) = serde_json::from_str(candidate)
289            {
290                return Ok(parsed);
291            }
292            Err(primary_error)
293        }
294    }
295}
296
297fn extract_balanced_json(input: &str) -> Option<&str> {
298    let start = input.find(['{', '['])?;
299    let opening = input.as_bytes().get(start).copied()?;
300    let closing = match opening {
301        b'{' => b'}',
302        b'[' => b']',
303        _ => return None,
304    };
305
306    let mut depth = 0usize;
307    let mut in_string = false;
308    let mut escaped = false;
309
310    for (offset, ch) in input[start..].char_indices() {
311        if in_string {
312            if escaped {
313                escaped = false;
314                continue;
315            }
316            if ch == '\\' {
317                escaped = true;
318                continue;
319            }
320            if ch == '"' {
321                in_string = false;
322            }
323            continue;
324        }
325
326        match ch {
327            '"' => in_string = true,
328            _ if ch as u32 == opening as u32 => depth += 1,
329            _ if ch as u32 == closing as u32 => {
330                depth = depth.saturating_sub(1);
331                if depth == 0 {
332                    let end = start + offset + ch.len_utf8();
333                    return input.get(start..end);
334                }
335            }
336            _ => {}
337        }
338    }
339
340    None
341}
342
343/// Universal LLM response structure
344#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
345pub struct LLMResponse {
346    /// The response content text
347    pub content: Option<String>,
348
349    /// Tool calls made by the model
350    pub tool_calls: Option<Vec<ToolCall>>,
351
352    /// The model that generated this response
353    pub model: String,
354
355    /// Token usage statistics
356    pub usage: Option<Usage>,
357
358    /// Why the response finished
359    pub finish_reason: FinishReason,
360
361    /// Reasoning content (for models that support it)
362    pub reasoning: Option<String>,
363
364    /// Detailed reasoning traces (for models that support it)
365    pub reasoning_details: Option<Vec<String>>,
366
367    /// Tool references for context
368    pub tool_references: Vec<String>,
369
370    /// Request ID from the provider
371    pub request_id: Option<String>,
372
373    /// Organization ID from the provider
374    pub organization_id: Option<String>,
375}
376
377impl LLMResponse {
378    /// Create a new LLM response with mandatory fields
379    pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
380        Self {
381            content: Some(content.into()),
382            tool_calls: None,
383            model: model.into(),
384            usage: None,
385            finish_reason: FinishReason::Stop,
386            reasoning: None,
387            reasoning_details: None,
388            tool_references: Vec::new(),
389            request_id: None,
390            organization_id: None,
391        }
392    }
393
394    /// Get content or empty string
395    pub fn content_text(&self) -> &str {
396        self.content.as_deref().unwrap_or("")
397    }
398
399    /// Get content as String (clone)
400    pub fn content_string(&self) -> String {
401        self.content.clone().unwrap_or_default()
402    }
403}
404
405#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
406pub struct LLMErrorMetadata {
407    pub provider: Option<String>,
408    pub status: Option<u16>,
409    pub code: Option<String>,
410    pub request_id: Option<String>,
411    pub organization_id: Option<String>,
412    pub retry_after: Option<String>,
413    pub message: Option<String>,
414}
415
416impl LLMErrorMetadata {
417    pub fn new(
418        provider: impl Into<String>,
419        status: Option<u16>,
420        code: Option<String>,
421        request_id: Option<String>,
422        organization_id: Option<String>,
423        retry_after: Option<String>,
424        message: Option<String>,
425    ) -> Box<Self> {
426        Box::new(Self {
427            provider: Some(provider.into()),
428            status,
429            code,
430            request_id,
431            organization_id,
432            retry_after,
433            message,
434        })
435    }
436}
437
438/// LLM error types with optional provider metadata
439#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
440#[serde(tag = "type", rename_all = "snake_case")]
441pub enum LLMError {
442    #[error("Authentication failed: {message}")]
443    Authentication {
444        message: String,
445        metadata: Option<Box<LLMErrorMetadata>>,
446    },
447    #[error("Rate limit exceeded")]
448    RateLimit {
449        metadata: Option<Box<LLMErrorMetadata>>,
450    },
451    #[error("Invalid request: {message}")]
452    InvalidRequest {
453        message: String,
454        metadata: Option<Box<LLMErrorMetadata>>,
455    },
456    #[error("Network error: {message}")]
457    Network {
458        message: String,
459        metadata: Option<Box<LLMErrorMetadata>>,
460    },
461    #[error("Provider error: {message}")]
462    Provider {
463        message: String,
464        metadata: Option<Box<LLMErrorMetadata>>,
465    },
466}
467
468#[cfg(test)]
469mod tests {
470    use super::ToolCall;
471    use serde_json::json;
472
473    #[test]
474    fn parsed_arguments_accepts_trailing_characters() {
475        let call = ToolCall::function(
476            "call_read".to_string(),
477            "read_file".to_string(),
478            r#"{"path":"src/main.rs"} trailing text"#.to_string(),
479        );
480
481        let parsed = call
482            .parsed_arguments()
483            .expect("arguments with trailing text should recover");
484        assert_eq!(parsed, json!({"path":"src/main.rs"}));
485    }
486
487    #[test]
488    fn parsed_arguments_accepts_code_fenced_json() {
489        let call = ToolCall::function(
490            "call_read".to_string(),
491            "read_file".to_string(),
492            "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
493        );
494
495        let parsed = call
496            .parsed_arguments()
497            .expect("code-fenced arguments should recover");
498        assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
499    }
500
501    #[test]
502    fn parsed_arguments_rejects_incomplete_json() {
503        let call = ToolCall::function(
504            "call_read".to_string(),
505            "read_file".to_string(),
506            r#"{"path":"src/main.rs""#.to_string(),
507        );
508
509        assert!(call.parsed_arguments().is_err());
510    }
511
512    #[test]
513    fn function_call_serializes_optional_namespace() {
514        let call = ToolCall::function_with_namespace(
515            "call_read".to_string(),
516            Some("workspace".to_string()),
517            "read_file".to_string(),
518            r#"{"path":"src/main.rs"}"#.to_string(),
519        );
520
521        let json = serde_json::to_value(&call).expect("tool call should serialize");
522        assert_eq!(json["function"]["namespace"], "workspace");
523        assert_eq!(json["function"]["name"], "read_file");
524    }
525}