Skip to main content

vtcode_core/llm/
provider_base.rs

1//! Base traits and utilities for LLM providers
2//!
3//! This module provides shared functionality to eliminate duplicate code
4//! across the 15+ LLM provider implementations.
5
6use anyhow::{Context, Result};
7use async_trait::async_trait;
8use futures::StreamExt;
9use reqwest::Client as HttpClient;
10use serde_json::Value;
11use std::time::Duration;
12
13use crate::config::TimeoutsConfig;
14use crate::llm::provider::{LLMError, LLMRequest, LLMStreamEvent};
15
16/// Default timeout configurations
17pub const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
18pub const DEFAULT_STREAM_TIMEOUT: Duration = Duration::from_secs(300);
19
20/// Base configuration shared by all providers
21#[derive(Debug, Clone)]
22pub struct BaseProviderConfig {
23    pub api_key: String,
24    pub base_url: String,
25    pub model: String,
26    pub http_client: HttpClient,
27    pub prompt_cache_enabled: bool,
28    pub request_timeout: Duration,
29    pub stream_timeout: Duration,
30}
31
32impl BaseProviderConfig {
33    /// Create base configuration from common parameters
34    pub fn from_options(
35        api_key: Option<String>,
36        model: Option<String>,
37        base_url: Option<String>,
38        default_model: &'static str,
39        default_url: &'static str,
40        env_var: &'static str,
41        timeouts: Option<TimeoutsConfig>,
42    ) -> Result<Self> {
43        let api_key_value = api_key.unwrap_or_default();
44        let model_value = model.unwrap_or_else(|| default_model.to_string());
45        let base_url_value = Self::resolve_base_url(base_url, default_url, env_var)?;
46
47        let timeout_config = timeouts.unwrap_or_default();
48        let http_timeout = timeout_config
49            .ceiling_duration(timeout_config.streaming_ceiling_seconds)
50            .unwrap_or(DEFAULT_REQUEST_TIMEOUT);
51        let http_client = HttpClient::builder()
52            .timeout(http_timeout)
53            .build()
54            .context("Failed to build HTTP client")?;
55
56        Ok(Self {
57            api_key: api_key_value,
58            base_url: base_url_value,
59            model: model_value,
60            http_client,
61            prompt_cache_enabled: false,
62            request_timeout: http_timeout,
63            stream_timeout: timeout_config
64                .ceiling_duration(timeout_config.streaming_ceiling_seconds)
65                .unwrap_or(DEFAULT_STREAM_TIMEOUT),
66        })
67    }
68
69    /// Resolve base URL with environment variable fallback
70    fn resolve_base_url(
71        base_url: Option<String>,
72        default_url: &'static str,
73        env_var: &'static str,
74    ) -> Result<String> {
75        if let Some(url) = base_url {
76            Ok(url.trim().to_string())
77        } else if let Ok(env_val) = std::env::var(env_var) {
78            Ok(env_val.trim().to_string())
79        } else {
80            Ok(default_url.to_string())
81        }
82    }
83
84    /// Validate that required API key is present
85    pub fn validate_api_key(&self) -> Result<()> {
86        if self.api_key.is_empty() {
87            anyhow::bail!("API key is required")
88        }
89        Ok(())
90    }
91}
92
93/// Trait for providers that support standard OpenAI-compatible APIs
94#[async_trait]
95pub trait OpenAICompatibleProvider: Send + Sync {
96    fn provider_name(&self) -> &'static str;
97    fn supports_prompt_caching(&self) -> bool;
98
99    /// Parse request from OpenAI format
100    fn parse_openai_request(&self, value: &Value, default_model: &str) -> Option<LLMRequest> {
101        crate::llm::utils::parse_chat_request_openai_format(value, default_model)
102    }
103
104    /// Serialize messages to OpenAI format
105    fn serialize_openai_messages(&self, request: &LLMRequest) -> Value {
106        use crate::llm::providers::common::serialize_messages_openai_format;
107        match serialize_messages_openai_format(request, self.provider_name()) {
108            Ok(messages) => serde_json::json!({ "messages": messages }),
109            Err(_) => serde_json::json!({ "messages": [] }),
110        }
111    }
112
113    /// Parse response from OpenAI format
114    fn parse_openai_response(
115        &self,
116        response: Value,
117        model: String,
118        include_cache: bool,
119    ) -> Result<crate::llm::provider::LLMResponse> {
120        crate::llm::utils::parse_response_openai_format(
121            response,
122            self.provider_name(),
123            model,
124            include_cache,
125            None,
126        )
127    }
128}
129
130/// Shared error handling utilities
131pub struct ErrorHandler {
132    _provider_name: &'static str,
133}
134
135impl ErrorHandler {
136    pub fn new(provider_name: &'static str) -> Self {
137        Self {
138            _provider_name: provider_name,
139        }
140    }
141
142    /// Handle HTTP errors consistently across providers
143    pub fn handle_http_error(&self, status: reqwest::StatusCode, error_text: &str) -> LLMError {
144        use reqwest::StatusCode;
145
146        let error_message = match status {
147            StatusCode::UNAUTHORIZED => "Authentication failed: Invalid API key".to_string(),
148            StatusCode::TOO_MANY_REQUESTS => "Rate limit exceeded".to_string(),
149            StatusCode::BAD_REQUEST => format!("Bad request: {}", error_text.trim()),
150            s if s.as_u16() == 402 => "Insufficient balance".to_string(),
151            _ => format!("HTTP {}: {}", status, error_text.trim()),
152        };
153
154        let formatted_error =
155            crate::llm::error_display::format_llm_error(self._provider_name, &error_message);
156
157        // Handle different error types based on status code
158        if status == StatusCode::TOO_MANY_REQUESTS {
159            LLMError::RateLimit { metadata: None }
160        } else {
161            LLMError::Provider {
162                message: formatted_error,
163                metadata: None,
164            }
165        }
166    }
167
168    /// Handle request validation errors
169    pub fn validate_request(&self, request: &LLMRequest) -> Result<()> {
170        if request.messages.is_empty() {
171            anyhow::bail!("Request must contain at least one message")
172        }
173
174        if request.model.is_empty() {
175            anyhow::bail!("Request must specify a model")
176        }
177
178        // Check if model is supported (this would need to be customized per provider)
179        if !self.is_model_supported(&request.model) {
180            anyhow::bail!("Unsupported model: {}", request.model)
181        }
182
183        Ok(())
184    }
185
186    /// Check if model is supported (default implementation, override as needed)
187    fn is_model_supported(&self, model: &str) -> bool {
188        // Default implementation assumes all models are supported
189        // Individual providers should override this with their specific model lists
190        !model.is_empty()
191    }
192}
193
194/// Shared streaming utilities
195pub struct StreamProcessor {
196    provider_name: &'static str,
197    supports_reasoning: bool,
198}
199
200impl StreamProcessor {
201    pub fn new(provider_name: &'static str, supports_reasoning: bool) -> Self {
202        Self {
203            provider_name,
204            supports_reasoning,
205        }
206    }
207
208    /// Process SSE stream chunk consistently
209    pub fn process_stream_chunk(&self, chunk: &str) -> Vec<LLMStreamEvent> {
210        let mut events = Vec::new();
211
212        for line in chunk.lines() {
213            let line = line.trim();
214            if line.is_empty() {
215                continue;
216            }
217
218            if let Some(data) = line.strip_prefix("data: ") {
219                if data == "[DONE]" {
220                    // Stream completion indicated by DONE marker
221                    continue;
222                }
223
224                match serde_json::from_str::<Value>(data) {
225                    Ok(json) => {
226                        if let Some(event) = self.parse_stream_event(json) {
227                            events.push(event);
228                        }
229                    }
230                    Err(_) => {
231                        // Skip invalid JSON
232                        continue;
233                    }
234                }
235            }
236        }
237
238        events
239    }
240
241    /// Parse individual stream event (override for provider-specific logic)
242    fn parse_stream_event(&self, json: Value) -> Option<LLMStreamEvent> {
243        // Default implementation for OpenAI-compatible providers
244        crate::llm::utils::parse_stream_event_openai_format(json, self.provider_name)
245    }
246
247    /// Extract reasoning content if supported
248    pub fn extract_reasoning(&self, content: &str) -> (Vec<String>, Option<String>) {
249        if !self.supports_reasoning {
250            return (Vec::new(), None);
251        }
252
253        // Default implementation - providers can override
254        crate::llm::utils::extract_reasoning_content(content)
255    }
256}
257
258/// Unified authentication header handling
259pub struct AuthHandler {
260    auth_type: AuthType,
261    api_key: String,
262}
263
264#[derive(Debug, Clone, Copy)]
265pub enum AuthType {
266    BearerToken,
267    ApiKeyHeader(&'static str),
268    QueryParam(&'static str),
269}
270
271impl AuthHandler {
272    pub fn new(auth_type: AuthType, api_key: String) -> Self {
273        Self { auth_type, api_key }
274    }
275
276    /// Apply authentication to request builder
277    pub fn apply_auth(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
278        match self.auth_type {
279            AuthType::BearerToken => builder.bearer_auth(&self.api_key),
280            AuthType::ApiKeyHeader(header_name) => builder.header(header_name, &self.api_key),
281            AuthType::QueryParam(param_name) => builder.query(&[(param_name, &self.api_key)]),
282        }
283    }
284}
285
286/// Shared request/response processing utilities
287pub struct RequestProcessor {
288    provider_name: &'static str,
289}
290
291impl RequestProcessor {
292    pub fn new(provider_name: &'static str) -> Self {
293        Self { provider_name }
294    }
295
296    /// Build HTTP request with consistent error handling
297    pub async fn build_request(
298        &self,
299        client: &HttpClient,
300        method: reqwest::Method,
301        url: String,
302        auth: Option<&AuthHandler>,
303        body: Option<Value>,
304    ) -> Result<reqwest::RequestBuilder> {
305        let mut builder = client.request(method, &url);
306
307        if let Some(auth_handler) = auth {
308            builder = auth_handler.apply_auth(builder);
309        }
310
311        builder = builder
312            .header("Content-Type", "application/json")
313            .header("User-Agent", "VT Code/1.0");
314
315        if let Some(body_value) = body {
316            builder = builder.json(&body_value);
317        }
318
319        Ok(builder)
320    }
321
322    /// Handle response with consistent error processing
323    pub async fn handle_response(&self, response: reqwest::Response) -> Result<Value> {
324        let status = response.status();
325
326        if !status.is_success() {
327            let error_text = response.text().await.unwrap_or_default();
328            let error_handler = ErrorHandler::new(self.provider_name);
329            return Err(error_handler.handle_http_error(status, &error_text).into());
330        }
331
332        let response_text = response
333            .text()
334            .await
335            .context("Failed to read response body")?;
336
337        serde_json::from_str(&response_text).context("Failed to parse JSON response")
338    }
339
340    /// Handle streaming response
341    pub async fn handle_stream_response(
342        &self,
343        response: reqwest::Response,
344    ) -> Result<impl futures::Stream<Item = Result<String>>> {
345        let status = response.status();
346
347        if !status.is_success() {
348            let error_text = response.text().await.unwrap_or_default();
349            let error_handler = ErrorHandler::new(self.provider_name);
350            return Err(error_handler.handle_http_error(status, &error_text).into());
351        }
352
353        Ok(response.bytes_stream().map(|result| {
354            result
355                .map(|bytes| String::from_utf8_lossy(&bytes).to_string())
356                .map_err(|e| anyhow::anyhow!("Stream error: {}", e))
357        }))
358    }
359}
360
361/// Common model resolution utilities
362pub struct ModelResolver {
363    #[expect(dead_code)]
364    provider_name: &'static str,
365    default_model: &'static str,
366    supported_models: &'static [&'static str],
367}
368
369impl ModelResolver {
370    pub fn new(
371        provider_name: &'static str,
372        default_model: &'static str,
373        supported_models: &'static [&'static str],
374    ) -> Self {
375        Self {
376            provider_name,
377            default_model,
378            supported_models,
379        }
380    }
381
382    /// Resolve model with fallback to default
383    pub fn resolve_model(&self, model: Option<String>) -> String {
384        model.unwrap_or_else(|| self.default_model.to_string())
385    }
386
387    /// Validate model is supported
388    pub fn validate_model(&self, model: &str) -> Result<()> {
389        if self.supported_models.is_empty() {
390            // If no specific supported models listed, accept any non-empty model
391            if model.is_empty() {
392                anyhow::bail!("Model cannot be empty")
393            }
394            return Ok(());
395        }
396
397        if !self.supported_models.contains(&model) {
398            anyhow::bail!(
399                "Unsupported model: {}. Supported models: {:?}",
400                model,
401                self.supported_models
402            )
403        }
404
405        Ok(())
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412
413    #[test]
414    fn test_base_provider_config() {
415        let config = BaseProviderConfig::from_options(
416            Some("test_key".to_string()),
417            Some("test_model".to_string()),
418            None,
419            "default_model",
420            "https://api.example.com",
421            "TEST_API_KEY",
422            None,
423        )
424        .unwrap();
425
426        assert_eq!(config.api_key, "test_key");
427        assert_eq!(config.model, "test_model");
428        assert_eq!(config.base_url, "https://api.example.com");
429    }
430
431    #[test]
432    fn test_error_handler() {
433        let handler = ErrorHandler::new("test_provider");
434
435        let unauthorized =
436            handler.handle_http_error(reqwest::StatusCode::UNAUTHORIZED, "Invalid API key");
437        let rate_limited = handler.handle_http_error(reqwest::StatusCode::TOO_MANY_REQUESTS, "");
438
439        assert!(matches!(
440            unauthorized,
441            LLMError::Provider {
442                message: _,
443                metadata: _
444            }
445        ));
446        assert!(matches!(rate_limited, LLMError::RateLimit { metadata: _ }));
447    }
448
449    #[test]
450    fn test_model_resolver() {
451        let resolver = ModelResolver::new("test_provider", "default-model", &["model1", "model2"]);
452
453        assert_eq!(resolver.resolve_model(None), "default-model");
454        assert_eq!(resolver.resolve_model(Some("custom".to_string())), "custom");
455
456        resolver.validate_model("model1").unwrap();
457        assert!(resolver.validate_model("unsupported").is_err());
458    }
459}