Skip to main content

vtcode_core/llm/providers/
base.rs

1//! Base trait and common implementations for LLM providers
2//!
3//! This module provides a unified foundation for all LLM providers to eliminate
4//! code duplication across provider implementations.
5
6use crate::llm::provider::{LLMError, LLMRequest, LLMResponse, Message, ToolDefinition};
7use async_trait::async_trait;
8use hashbrown::HashMap;
9use reqwest::{Client as HttpClient, StatusCode};
10use serde_json::Value;
11use std::sync::{Arc, LazyLock, Mutex};
12use std::time::Duration;
13use tokio::sync::{OwnedSemaphorePermit, Semaphore};
14use tokio::time::{sleep, timeout};
15
16const DEFAULT_MAX_INFLIGHT_PER_MODEL: usize = 4;
17const RATE_LIMIT_ACQUIRE_TIMEOUT: Duration = Duration::from_secs(10);
18
19static MODEL_LIMITERS: LazyLock<Mutex<HashMap<String, Arc<Semaphore>>>> =
20    LazyLock::new(|| Mutex::new(HashMap::new()));
21
22/// Base configuration shared by all providers
23#[derive(Debug, Clone)]
24pub struct ProviderConfig {
25    pub api_key: String,
26    pub base_url: String,
27    pub model: String,
28    pub timeout: Duration,
29    pub max_retries: u32,
30}
31
32impl ProviderConfig {
33    /// Create provider config with sensible defaults
34    pub fn new(api_key: String, base_url: String, model: String) -> Self {
35        Self {
36            api_key,
37            base_url,
38            model,
39            timeout: Duration::from_secs(120),
40            max_retries: 3,
41        }
42    }
43
44    /// Build HTTP client with provider-specific configuration
45    pub fn build_http_client(&self) -> Result<HttpClient, LLMError> {
46        use crate::llm::http_client::HttpClientFactory;
47        Ok(HttpClientFactory::with_timeouts(
48            self.timeout,
49            Duration::from_secs(30),
50        ))
51    }
52}
53
54/// Common HTTP error handling for all providers
55pub fn handle_http_error(status: StatusCode, error_text: &str, _model: &str) -> LLMError {
56    match status {
57        StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => LLMError::Authentication {
58            message: format!("Authentication failed ({}): {}", status, error_text),
59            metadata: None,
60        },
61        StatusCode::TOO_MANY_REQUESTS => LLMError::RateLimit { metadata: None },
62        StatusCode::REQUEST_TIMEOUT => LLMError::Network {
63            message: format!("Request timeout ({}): {}", status, error_text),
64            metadata: None,
65        },
66        _ if status.is_server_error() => LLMError::Provider {
67            message: format!("Server error ({}): {}", status, error_text),
68            metadata: None,
69        },
70        _ => LLMError::Network {
71            message: format!("HTTP error ({}): {}", status, error_text),
72            metadata: None,
73        },
74    }
75}
76
77/// Check if error indicates model not found (common across providers)
78pub fn is_model_not_found(status: StatusCode, error_text: &str) -> bool {
79    status == StatusCode::NOT_FOUND
80        || error_text.contains("model_not_found")
81        || (error_text.to_ascii_lowercase().contains("does not exist")
82            && error_text.to_ascii_lowercase().contains("model"))
83}
84
85/// Common request building utilities
86pub mod request_builder {
87    use super::*;
88
89    /// Build standard headers for API requests
90    pub fn build_headers(
91        api_key: &str,
92        provider_headers: Option<Vec<(&str, &str)>>,
93    ) -> reqwest::header::HeaderMap {
94        use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
95
96        let mut headers = HeaderMap::new();
97        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
98
99        // Default authorization header (can be overridden by providers)
100        if let Ok(auth_value) = HeaderValue::from_str(&format!("Bearer {}", api_key)) {
101            headers.insert(AUTHORIZATION, auth_value);
102        }
103
104        // Add provider-specific headers
105        if let Some(custom_headers) = provider_headers {
106            for (key, value) in custom_headers {
107                if let (Ok(name), Ok(val)) = (
108                    HeaderName::from_bytes(key.as_bytes()),
109                    HeaderValue::from_str(value),
110                ) {
111                    headers.insert(name, val);
112                }
113            }
114        }
115
116        headers
117    }
118
119    /// Convert tools to OpenAI-compatible format (used by many providers)
120    pub fn serialize_tools_openai(tools: &[ToolDefinition]) -> Option<Vec<Value>> {
121        if tools.is_empty() {
122            return None;
123        }
124        Some(tools.iter().map(|tool| serde_json::json!(tool)).collect())
125    }
126
127    /// Build standard request body structure
128    pub fn build_request_body(
129        messages: &[Message],
130        model: &str,
131        max_tokens: Option<u32>,
132        temperature: Option<f32>,
133        tools: Option<Vec<Value>>,
134        stream: bool,
135        reasoning_effort: Option<String>,
136    ) -> Value {
137        let mut body = serde_json::json!({
138            "model": model,
139            "messages": messages.iter().map(|msg| serde_json::json!({
140                "role": msg.role.to_string().to_lowercase(),
141                "content": msg.content,
142            })).collect::<Vec<_>>(),
143        });
144
145        if let Some(max_tokens_val) = max_tokens {
146            body["max_tokens"] = serde_json::json!(max_tokens_val);
147        }
148
149        if let Some(temp) = temperature {
150            body["temperature"] = serde_json::json!(temp);
151        }
152
153        if let Some(val) = tools {
154            body["tools"] = serde_json::json!(val);
155        }
156
157        if let Some(effort) = reasoning_effort {
158            body["reasoning_effort"] = serde_json::json!(effort);
159        }
160
161        if stream {
162            body["stream"] = serde_json::json!(true);
163        }
164
165        body
166    }
167}
168
169/// Base provider trait with common functionality
170#[async_trait]
171pub trait BaseProvider: Send + Sync {
172    /// Get provider configuration
173    fn config(&self) -> &ProviderConfig;
174
175    /// Build HTTP request for the provider
176    fn build_request(&self, request: &LLMRequest) -> Result<reqwest::Request, LLMError>;
177
178    /// Parse response from the provider
179    fn parse_response(&self, response: Value) -> Result<LLMResponse, LLMError>;
180
181    /// Execute LLM request with common error handling and retry logic
182    async fn execute_request(&self, request: LLMRequest) -> Result<LLMResponse, LLMError> {
183        let _permit = acquire_model_permit(&self.config().model).await?;
184        let client = self.config().build_http_client()?;
185        let max_retries = self.config().max_retries;
186
187        let mut last_error = None;
188
189        for attempt in 0..=max_retries {
190            match self.build_request(&request) {
191                Ok(http_request) => {
192                    match client.execute(http_request).await {
193                        Ok(response) => {
194                            let status = response.status();
195
196                            match response.text().await {
197                                Ok(text) => {
198                                    // Try to parse as JSON first
199                                    match serde_json::from_str::<Value>(&text) {
200                                        Ok(json_value) => {
201                                            // Check for provider-specific error format
202                                            if let Some(error_obj) = json_value.get("error") {
203                                                let error_text = error_obj.to_string();
204                                                if attempt < max_retries
205                                                    && should_retry_status(status)
206                                                {
207                                                    sleep(backoff_duration(attempt)).await;
208                                                    last_error = Some(handle_http_error(
209                                                        status,
210                                                        &error_text,
211                                                        &self.config().model,
212                                                    ));
213                                                    continue;
214                                                }
215                                                return Err(handle_http_error(
216                                                    status,
217                                                    &error_text,
218                                                    &self.config().model,
219                                                ));
220                                            }
221
222                                            // Success - parse response
223                                            return self.parse_response(json_value);
224                                        }
225                                        Err(_) => {
226                                            // Not JSON - treat as error text
227                                            if attempt < max_retries && should_retry_status(status)
228                                            {
229                                                sleep(backoff_duration(attempt)).await;
230                                                last_error = Some(handle_http_error(
231                                                    status,
232                                                    &text,
233                                                    &self.config().model,
234                                                ));
235                                                continue;
236                                            }
237                                            return Err(handle_http_error(
238                                                status,
239                                                &text,
240                                                &self.config().model,
241                                            ));
242                                        }
243                                    }
244                                }
245                                Err(e) => {
246                                    let error = LLMError::Network {
247                                        message: format!("Failed to read response: {}", e),
248                                        metadata: None,
249                                    };
250                                    if attempt < max_retries {
251                                        last_error = Some(error);
252                                        continue;
253                                    }
254                                    return Err(error);
255                                }
256                            }
257                        }
258                        Err(e) => {
259                            let error = LLMError::Network {
260                                message: format!("Request failed: {}", e),
261                                metadata: None,
262                            };
263                            if attempt < max_retries {
264                                sleep(backoff_duration(attempt)).await;
265                                last_error = Some(error);
266                                continue;
267                            }
268                            return Err(error);
269                        }
270                    }
271                }
272                Err(e) => {
273                    if attempt < max_retries {
274                        last_error = Some(e);
275                        continue;
276                    }
277                    return Err(e);
278                }
279            }
280        }
281
282        // All retries exhausted
283        Err(last_error.unwrap_or_else(|| LLMError::Network {
284            message: "All retries exhausted".to_string(),
285            metadata: None,
286        }))
287    }
288}
289
290/// Determine if a status code should trigger a retry
291fn should_retry_status(status: StatusCode) -> bool {
292    matches!(
293        status,
294        StatusCode::REQUEST_TIMEOUT
295            | StatusCode::TOO_MANY_REQUESTS
296            | StatusCode::INTERNAL_SERVER_ERROR
297            | StatusCode::BAD_GATEWAY
298            | StatusCode::SERVICE_UNAVAILABLE
299            | StatusCode::GATEWAY_TIMEOUT
300    )
301}
302
303/// Exponential backoff with an upper bound to reduce provider hammering
304fn backoff_duration(attempt: u32) -> Duration {
305    let capped_attempt = attempt.min(5);
306    const BASE_MS: u64 = 200;
307    let backoff_ms = BASE_MS.saturating_mul(2_u64.saturating_pow(capped_attempt));
308    Duration::from_millis(backoff_ms.min(5_000))
309}
310
311fn limiter_for_model(model: &str) -> Arc<Semaphore> {
312    if let Ok(mut guard) = MODEL_LIMITERS.lock() {
313        guard
314            .entry(model.to_string())
315            .or_insert_with(|| Arc::new(Semaphore::new(DEFAULT_MAX_INFLIGHT_PER_MODEL)))
316            .clone()
317    } else {
318        Arc::new(Semaphore::new(DEFAULT_MAX_INFLIGHT_PER_MODEL))
319    }
320}
321
322async fn acquire_model_permit(model: &str) -> Result<OwnedSemaphorePermit, LLMError> {
323    let limiter = limiter_for_model(model);
324    match timeout(RATE_LIMIT_ACQUIRE_TIMEOUT, limiter.acquire_owned()).await {
325        Ok(Ok(permit)) => Ok(permit),
326        Ok(Err(_)) => Err(LLMError::RateLimit { metadata: None }),
327        Err(_) => Err(LLMError::RateLimit { metadata: None }),
328    }
329}