vtcode_core/llm/providers/
base.rs1use crate::llm::provider::{LLMError, LLMRequest, LLMResponse, Message, ToolDefinition};
7use async_trait::async_trait;
8use hashbrown::HashMap;
9use reqwest::{Client as HttpClient, StatusCode};
10use serde_json::Value;
11use std::sync::{Arc, LazyLock, Mutex};
12use std::time::Duration;
13use tokio::sync::{OwnedSemaphorePermit, Semaphore};
14use tokio::time::{sleep, timeout};
15
16const DEFAULT_MAX_INFLIGHT_PER_MODEL: usize = 4;
17const RATE_LIMIT_ACQUIRE_TIMEOUT: Duration = Duration::from_secs(10);
18
19static MODEL_LIMITERS: LazyLock<Mutex<HashMap<String, Arc<Semaphore>>>> =
20 LazyLock::new(|| Mutex::new(HashMap::new()));
21
22#[derive(Debug, Clone)]
24pub struct ProviderConfig {
25 pub api_key: String,
26 pub base_url: String,
27 pub model: String,
28 pub timeout: Duration,
29 pub max_retries: u32,
30}
31
32impl ProviderConfig {
33 pub fn new(api_key: String, base_url: String, model: String) -> Self {
35 Self {
36 api_key,
37 base_url,
38 model,
39 timeout: Duration::from_secs(120),
40 max_retries: 3,
41 }
42 }
43
44 pub fn build_http_client(&self) -> Result<HttpClient, LLMError> {
46 use crate::llm::http_client::HttpClientFactory;
47 Ok(HttpClientFactory::with_timeouts(
48 self.timeout,
49 Duration::from_secs(30),
50 ))
51 }
52}
53
54pub fn handle_http_error(status: StatusCode, error_text: &str, _model: &str) -> LLMError {
56 match status {
57 StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => LLMError::Authentication {
58 message: format!("Authentication failed ({}): {}", status, error_text),
59 metadata: None,
60 },
61 StatusCode::TOO_MANY_REQUESTS => LLMError::RateLimit { metadata: None },
62 StatusCode::REQUEST_TIMEOUT => LLMError::Network {
63 message: format!("Request timeout ({}): {}", status, error_text),
64 metadata: None,
65 },
66 _ if status.is_server_error() => LLMError::Provider {
67 message: format!("Server error ({}): {}", status, error_text),
68 metadata: None,
69 },
70 _ => LLMError::Network {
71 message: format!("HTTP error ({}): {}", status, error_text),
72 metadata: None,
73 },
74 }
75}
76
77pub fn is_model_not_found(status: StatusCode, error_text: &str) -> bool {
79 status == StatusCode::NOT_FOUND
80 || error_text.contains("model_not_found")
81 || (error_text.to_ascii_lowercase().contains("does not exist")
82 && error_text.to_ascii_lowercase().contains("model"))
83}
84
85pub mod request_builder {
87 use super::*;
88
89 pub fn build_headers(
91 api_key: &str,
92 provider_headers: Option<Vec<(&str, &str)>>,
93 ) -> reqwest::header::HeaderMap {
94 use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
95
96 let mut headers = HeaderMap::new();
97 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
98
99 if let Ok(auth_value) = HeaderValue::from_str(&format!("Bearer {}", api_key)) {
101 headers.insert(AUTHORIZATION, auth_value);
102 }
103
104 if let Some(custom_headers) = provider_headers {
106 for (key, value) in custom_headers {
107 if let (Ok(name), Ok(val)) = (
108 HeaderName::from_bytes(key.as_bytes()),
109 HeaderValue::from_str(value),
110 ) {
111 headers.insert(name, val);
112 }
113 }
114 }
115
116 headers
117 }
118
119 pub fn serialize_tools_openai(tools: &[ToolDefinition]) -> Option<Vec<Value>> {
121 if tools.is_empty() {
122 return None;
123 }
124 Some(tools.iter().map(|tool| serde_json::json!(tool)).collect())
125 }
126
127 pub fn build_request_body(
129 messages: &[Message],
130 model: &str,
131 max_tokens: Option<u32>,
132 temperature: Option<f32>,
133 tools: Option<Vec<Value>>,
134 stream: bool,
135 reasoning_effort: Option<String>,
136 ) -> Value {
137 let mut body = serde_json::json!({
138 "model": model,
139 "messages": messages.iter().map(|msg| serde_json::json!({
140 "role": msg.role.to_string().to_lowercase(),
141 "content": msg.content,
142 })).collect::<Vec<_>>(),
143 });
144
145 if let Some(max_tokens_val) = max_tokens {
146 body["max_tokens"] = serde_json::json!(max_tokens_val);
147 }
148
149 if let Some(temp) = temperature {
150 body["temperature"] = serde_json::json!(temp);
151 }
152
153 if let Some(val) = tools {
154 body["tools"] = serde_json::json!(val);
155 }
156
157 if let Some(effort) = reasoning_effort {
158 body["reasoning_effort"] = serde_json::json!(effort);
159 }
160
161 if stream {
162 body["stream"] = serde_json::json!(true);
163 }
164
165 body
166 }
167}
168
169#[async_trait]
171pub trait BaseProvider: Send + Sync {
172 fn config(&self) -> &ProviderConfig;
174
175 fn build_request(&self, request: &LLMRequest) -> Result<reqwest::Request, LLMError>;
177
178 fn parse_response(&self, response: Value) -> Result<LLMResponse, LLMError>;
180
181 async fn execute_request(&self, request: LLMRequest) -> Result<LLMResponse, LLMError> {
183 let _permit = acquire_model_permit(&self.config().model).await?;
184 let client = self.config().build_http_client()?;
185 let max_retries = self.config().max_retries;
186
187 let mut last_error = None;
188
189 for attempt in 0..=max_retries {
190 match self.build_request(&request) {
191 Ok(http_request) => {
192 match client.execute(http_request).await {
193 Ok(response) => {
194 let status = response.status();
195
196 match response.text().await {
197 Ok(text) => {
198 match serde_json::from_str::<Value>(&text) {
200 Ok(json_value) => {
201 if let Some(error_obj) = json_value.get("error") {
203 let error_text = error_obj.to_string();
204 if attempt < max_retries
205 && should_retry_status(status)
206 {
207 sleep(backoff_duration(attempt)).await;
208 last_error = Some(handle_http_error(
209 status,
210 &error_text,
211 &self.config().model,
212 ));
213 continue;
214 }
215 return Err(handle_http_error(
216 status,
217 &error_text,
218 &self.config().model,
219 ));
220 }
221
222 return self.parse_response(json_value);
224 }
225 Err(_) => {
226 if attempt < max_retries && should_retry_status(status)
228 {
229 sleep(backoff_duration(attempt)).await;
230 last_error = Some(handle_http_error(
231 status,
232 &text,
233 &self.config().model,
234 ));
235 continue;
236 }
237 return Err(handle_http_error(
238 status,
239 &text,
240 &self.config().model,
241 ));
242 }
243 }
244 }
245 Err(e) => {
246 let error = LLMError::Network {
247 message: format!("Failed to read response: {}", e),
248 metadata: None,
249 };
250 if attempt < max_retries {
251 last_error = Some(error);
252 continue;
253 }
254 return Err(error);
255 }
256 }
257 }
258 Err(e) => {
259 let error = LLMError::Network {
260 message: format!("Request failed: {}", e),
261 metadata: None,
262 };
263 if attempt < max_retries {
264 sleep(backoff_duration(attempt)).await;
265 last_error = Some(error);
266 continue;
267 }
268 return Err(error);
269 }
270 }
271 }
272 Err(e) => {
273 if attempt < max_retries {
274 last_error = Some(e);
275 continue;
276 }
277 return Err(e);
278 }
279 }
280 }
281
282 Err(last_error.unwrap_or_else(|| LLMError::Network {
284 message: "All retries exhausted".to_string(),
285 metadata: None,
286 }))
287 }
288}
289
290fn should_retry_status(status: StatusCode) -> bool {
292 matches!(
293 status,
294 StatusCode::REQUEST_TIMEOUT
295 | StatusCode::TOO_MANY_REQUESTS
296 | StatusCode::INTERNAL_SERVER_ERROR
297 | StatusCode::BAD_GATEWAY
298 | StatusCode::SERVICE_UNAVAILABLE
299 | StatusCode::GATEWAY_TIMEOUT
300 )
301}
302
303fn backoff_duration(attempt: u32) -> Duration {
305 let capped_attempt = attempt.min(5);
306 const BASE_MS: u64 = 200;
307 let backoff_ms = BASE_MS.saturating_mul(2_u64.saturating_pow(capped_attempt));
308 Duration::from_millis(backoff_ms.min(5_000))
309}
310
311fn limiter_for_model(model: &str) -> Arc<Semaphore> {
312 if let Ok(mut guard) = MODEL_LIMITERS.lock() {
313 guard
314 .entry(model.to_string())
315 .or_insert_with(|| Arc::new(Semaphore::new(DEFAULT_MAX_INFLIGHT_PER_MODEL)))
316 .clone()
317 } else {
318 Arc::new(Semaphore::new(DEFAULT_MAX_INFLIGHT_PER_MODEL))
319 }
320}
321
322async fn acquire_model_permit(model: &str) -> Result<OwnedSemaphorePermit, LLMError> {
323 let limiter = limiter_for_model(model);
324 match timeout(RATE_LIMIT_ACQUIRE_TIMEOUT, limiter.acquire_owned()).await {
325 Ok(Ok(permit)) => Ok(permit),
326 Ok(Err(_)) => Err(LLMError::RateLimit { metadata: None }),
327 Err(_) => Err(LLMError::RateLimit { metadata: None }),
328 }
329}