Skip to main content

vtcode_commons/
llm.rs

1//! Core LLM types shared across the project
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7    Gemini,
8    OpenAI,
9    Anthropic,
10    DeepSeek,
11    OpenRouter,
12    Ollama,
13    ZAI,
14    Moonshot,
15    HuggingFace,
16    Minimax,
17}
18
19#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
20pub struct Usage {
21    pub prompt_tokens: u32,
22    pub completion_tokens: u32,
23    pub total_tokens: u32,
24    pub cached_prompt_tokens: Option<u32>,
25    pub cache_creation_tokens: Option<u32>,
26    pub cache_read_tokens: Option<u32>,
27}
28
29impl Usage {
30    #[inline]
31    pub fn cache_hit_rate(&self) -> Option<f64> {
32        let read = self.cache_read_tokens? as f64;
33        let creation = self.cache_creation_tokens? as f64;
34        let total = read + creation;
35        if total > 0.0 {
36            Some((read / total) * 100.0)
37        } else {
38            None
39        }
40    }
41
42    #[inline]
43    pub fn is_cache_hit(&self) -> Option<bool> {
44        Some(self.cache_read_tokens? > 0)
45    }
46
47    #[inline]
48    pub fn is_cache_miss(&self) -> Option<bool> {
49        Some(self.cache_creation_tokens? > 0 && self.cache_read_tokens? == 0)
50    }
51
52    #[inline]
53    pub fn total_cache_tokens(&self) -> u32 {
54        let read = self.cache_read_tokens.unwrap_or(0);
55        let creation = self.cache_creation_tokens.unwrap_or(0);
56        read + creation
57    }
58
59    #[inline]
60    pub fn cache_savings_ratio(&self) -> Option<f64> {
61        let read = self.cache_read_tokens? as f64;
62        let prompt = self.prompt_tokens as f64;
63        if prompt > 0.0 {
64            Some(read / prompt)
65        } else {
66            None
67        }
68    }
69}
70
71#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
72pub enum FinishReason {
73    #[default]
74    Stop,
75    Length,
76    ToolCalls,
77    ContentFilter,
78    Pause,
79    Refusal,
80    Error(String),
81}
82
83/// Universal tool call that matches OpenAI/Anthropic/Gemini specifications
84#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
85pub struct ToolCall {
86    /// Unique identifier for this tool call (e.g., "call_123")
87    pub id: String,
88
89    /// The type of tool call: "function", "custom" (GPT-5 freeform), or other
90    #[serde(rename = "type")]
91    pub call_type: String,
92
93    /// Function call details (for function-type tools)
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub function: Option<FunctionCall>,
96
97    /// Raw text payload (for custom freeform tools in GPT-5)
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub text: Option<String>,
100
101    /// Gemini-specific thought signature for maintaining reasoning context
102    #[serde(skip_serializing_if = "Option::is_none")]
103    pub thought_signature: Option<String>,
104}
105
106/// Function call within a tool call
107#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
108pub struct FunctionCall {
109    /// The name of the function to call
110    pub name: String,
111
112    /// The arguments to pass to the function, as a JSON string
113    pub arguments: String,
114}
115
116impl ToolCall {
117    /// Create a new function tool call
118    pub fn function(id: String, name: String, arguments: String) -> Self {
119        Self {
120            id,
121            call_type: "function".to_owned(),
122            function: Some(FunctionCall { name, arguments }),
123            text: None,
124            thought_signature: None,
125        }
126    }
127
128    /// Create a new custom tool call with raw text payload (GPT-5 freeform)
129    pub fn custom(id: String, name: String, text: String) -> Self {
130        Self {
131            id,
132            call_type: "custom".to_owned(),
133            function: Some(FunctionCall {
134                name,
135                arguments: text.clone(),
136            }),
137            text: Some(text),
138            thought_signature: None,
139        }
140    }
141
142    /// Parse the arguments as JSON Value (for function-type tools)
143    pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
144        if let Some(ref func) = self.function {
145            parse_tool_arguments(&func.arguments)
146        } else {
147            // Return an error by trying to parse invalid JSON
148            serde_json::from_str("")
149        }
150    }
151
152    /// Validate that this tool call is properly formed
153    pub fn validate(&self) -> Result<(), String> {
154        if self.id.is_empty() {
155            return Err("Tool call ID cannot be empty".to_owned());
156        }
157
158        match self.call_type.as_str() {
159            "function" => {
160                if let Some(func) = &self.function {
161                    if func.name.is_empty() {
162                        return Err("Function name cannot be empty".to_owned());
163                    }
164                    // Validate that arguments is valid JSON for function tools
165                    if let Err(e) = self.parsed_arguments() {
166                        return Err(format!("Invalid JSON in function arguments: {}", e));
167                    }
168                } else {
169                    return Err("Function tool call missing function details".to_owned());
170                }
171            }
172            "custom" => {
173                // For custom tools, we allow raw text payload without JSON validation
174                if let Some(func) = &self.function {
175                    if func.name.is_empty() {
176                        return Err("Custom tool name cannot be empty".to_owned());
177                    }
178                } else {
179                    return Err("Custom tool call missing function details".to_owned());
180                }
181            }
182            _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
183        }
184
185        Ok(())
186    }
187}
188
189fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
190    let trimmed = raw_arguments.trim();
191    match serde_json::from_str(trimmed) {
192        Ok(parsed) => Ok(parsed),
193        Err(primary_error) => {
194            if let Some(candidate) = extract_balanced_json(trimmed)
195                && let Ok(parsed) = serde_json::from_str(candidate)
196            {
197                return Ok(parsed);
198            }
199            Err(primary_error)
200        }
201    }
202}
203
204fn extract_balanced_json(input: &str) -> Option<&str> {
205    let start = input.find(['{', '['])?;
206    let opening = input.as_bytes().get(start).copied()?;
207    let closing = match opening {
208        b'{' => b'}',
209        b'[' => b']',
210        _ => return None,
211    };
212
213    let mut depth = 0usize;
214    let mut in_string = false;
215    let mut escaped = false;
216
217    for (offset, ch) in input[start..].char_indices() {
218        if in_string {
219            if escaped {
220                escaped = false;
221                continue;
222            }
223            if ch == '\\' {
224                escaped = true;
225                continue;
226            }
227            if ch == '"' {
228                in_string = false;
229            }
230            continue;
231        }
232
233        match ch {
234            '"' => in_string = true,
235            _ if ch as u32 == opening as u32 => depth += 1,
236            _ if ch as u32 == closing as u32 => {
237                depth = depth.saturating_sub(1);
238                if depth == 0 {
239                    let end = start + offset + ch.len_utf8();
240                    return input.get(start..end);
241                }
242            }
243            _ => {}
244        }
245    }
246
247    None
248}
249
250/// Universal LLM response structure
251#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
252pub struct LLMResponse {
253    /// The response content text
254    pub content: Option<String>,
255
256    /// Tool calls made by the model
257    pub tool_calls: Option<Vec<ToolCall>>,
258
259    /// The model that generated this response
260    pub model: String,
261
262    /// Token usage statistics
263    pub usage: Option<Usage>,
264
265    /// Why the response finished
266    pub finish_reason: FinishReason,
267
268    /// Reasoning content (for models that support it)
269    pub reasoning: Option<String>,
270
271    /// Detailed reasoning traces (for models that support it)
272    pub reasoning_details: Option<Vec<String>>,
273
274    /// Tool references for context
275    pub tool_references: Vec<String>,
276
277    /// Request ID from the provider
278    pub request_id: Option<String>,
279
280    /// Organization ID from the provider
281    pub organization_id: Option<String>,
282}
283
284impl LLMResponse {
285    /// Create a new LLM response with mandatory fields
286    pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
287        Self {
288            content: Some(content.into()),
289            tool_calls: None,
290            model: model.into(),
291            usage: None,
292            finish_reason: FinishReason::Stop,
293            reasoning: None,
294            reasoning_details: None,
295            tool_references: Vec::new(),
296            request_id: None,
297            organization_id: None,
298        }
299    }
300
301    /// Get content or empty string
302    pub fn content_text(&self) -> &str {
303        self.content.as_deref().unwrap_or("")
304    }
305
306    /// Get content as String (clone)
307    pub fn content_string(&self) -> String {
308        self.content.clone().unwrap_or_default()
309    }
310}
311
312#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
313pub struct LLMErrorMetadata {
314    pub provider: Option<String>,
315    pub status: Option<u16>,
316    pub code: Option<String>,
317    pub request_id: Option<String>,
318    pub organization_id: Option<String>,
319    pub retry_after: Option<String>,
320    pub message: Option<String>,
321}
322
323impl LLMErrorMetadata {
324    pub fn new(
325        provider: impl Into<String>,
326        status: Option<u16>,
327        code: Option<String>,
328        request_id: Option<String>,
329        organization_id: Option<String>,
330        retry_after: Option<String>,
331        message: Option<String>,
332    ) -> Box<Self> {
333        Box::new(Self {
334            provider: Some(provider.into()),
335            status,
336            code,
337            request_id,
338            organization_id,
339            retry_after,
340            message,
341        })
342    }
343}
344
345/// LLM error types with optional provider metadata
346#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
347#[serde(tag = "type", rename_all = "snake_case")]
348pub enum LLMError {
349    #[error("Authentication failed: {message}")]
350    Authentication {
351        message: String,
352        metadata: Option<Box<LLMErrorMetadata>>,
353    },
354    #[error("Rate limit exceeded")]
355    RateLimit {
356        metadata: Option<Box<LLMErrorMetadata>>,
357    },
358    #[error("Invalid request: {message}")]
359    InvalidRequest {
360        message: String,
361        metadata: Option<Box<LLMErrorMetadata>>,
362    },
363    #[error("Network error: {message}")]
364    Network {
365        message: String,
366        metadata: Option<Box<LLMErrorMetadata>>,
367    },
368    #[error("Provider error: {message}")]
369    Provider {
370        message: String,
371        metadata: Option<Box<LLMErrorMetadata>>,
372    },
373}
374
375#[cfg(test)]
376mod tests {
377    use super::ToolCall;
378    use serde_json::json;
379
380    #[test]
381    fn parsed_arguments_accepts_trailing_characters() {
382        let call = ToolCall::function(
383            "call_read".to_string(),
384            "read_file".to_string(),
385            r#"{"path":"src/main.rs"} trailing text"#.to_string(),
386        );
387
388        let parsed = call
389            .parsed_arguments()
390            .expect("arguments with trailing text should recover");
391        assert_eq!(parsed, json!({"path":"src/main.rs"}));
392    }
393
394    #[test]
395    fn parsed_arguments_accepts_code_fenced_json() {
396        let call = ToolCall::function(
397            "call_read".to_string(),
398            "read_file".to_string(),
399            "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
400        );
401
402        let parsed = call
403            .parsed_arguments()
404            .expect("code-fenced arguments should recover");
405        assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
406    }
407
408    #[test]
409    fn parsed_arguments_rejects_incomplete_json() {
410        let call = ToolCall::function(
411            "call_read".to_string(),
412            "read_file".to_string(),
413            r#"{"path":"src/main.rs""#.to_string(),
414        );
415
416        assert!(call.parsed_arguments().is_err());
417    }
418}