1use anyhow::{Context, Result};
7use async_trait::async_trait;
8use futures::StreamExt;
9use reqwest::Client as HttpClient;
10use serde_json::Value;
11use std::time::Duration;
12
13use crate::config::TimeoutsConfig;
14use crate::llm::provider::{LLMError, LLMRequest, LLMStreamEvent};
15
16pub const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
18pub const DEFAULT_STREAM_TIMEOUT: Duration = Duration::from_secs(300);
19
20#[derive(Debug, Clone)]
22pub struct BaseProviderConfig {
23 pub api_key: String,
24 pub base_url: String,
25 pub model: String,
26 pub http_client: HttpClient,
27 pub prompt_cache_enabled: bool,
28 pub request_timeout: Duration,
29 pub stream_timeout: Duration,
30}
31
32impl BaseProviderConfig {
33 pub fn from_options(
35 api_key: Option<String>,
36 model: Option<String>,
37 base_url: Option<String>,
38 default_model: &'static str,
39 default_url: &'static str,
40 env_var: &'static str,
41 timeouts: Option<TimeoutsConfig>,
42 ) -> Result<Self> {
43 let api_key_value = api_key.unwrap_or_default();
44 let model_value = model.unwrap_or_else(|| default_model.to_string());
45 let base_url_value = Self::resolve_base_url(base_url, default_url, env_var)?;
46
47 let timeout_config = timeouts.unwrap_or_default();
48 let http_timeout = timeout_config
49 .ceiling_duration(timeout_config.streaming_ceiling_seconds)
50 .unwrap_or(DEFAULT_REQUEST_TIMEOUT);
51 let http_client = HttpClient::builder()
52 .timeout(http_timeout)
53 .build()
54 .context("Failed to build HTTP client")?;
55
56 Ok(Self {
57 api_key: api_key_value,
58 base_url: base_url_value,
59 model: model_value,
60 http_client,
61 prompt_cache_enabled: false,
62 request_timeout: http_timeout,
63 stream_timeout: timeout_config
64 .ceiling_duration(timeout_config.streaming_ceiling_seconds)
65 .unwrap_or(DEFAULT_STREAM_TIMEOUT),
66 })
67 }
68
69 fn resolve_base_url(
71 base_url: Option<String>,
72 default_url: &'static str,
73 env_var: &'static str,
74 ) -> Result<String> {
75 if let Some(url) = base_url {
76 Ok(url.trim().to_string())
77 } else if let Ok(env_val) = std::env::var(env_var) {
78 Ok(env_val.trim().to_string())
79 } else {
80 Ok(default_url.to_string())
81 }
82 }
83
84 pub fn validate_api_key(&self) -> Result<()> {
86 if self.api_key.is_empty() {
87 anyhow::bail!("API key is required")
88 }
89 Ok(())
90 }
91}
92
93#[async_trait]
95pub trait OpenAICompatibleProvider: Send + Sync {
96 fn provider_name(&self) -> &'static str;
97 fn supports_prompt_caching(&self) -> bool;
98
99 fn parse_openai_request(&self, value: &Value, default_model: &str) -> Option<LLMRequest> {
101 crate::llm::utils::parse_chat_request_openai_format(value, default_model)
102 }
103
104 fn serialize_openai_messages(&self, request: &LLMRequest) -> Value {
106 use crate::llm::providers::common::serialize_messages_openai_format;
107 match serialize_messages_openai_format(request, self.provider_name()) {
108 Ok(messages) => serde_json::json!({ "messages": messages }),
109 Err(_) => serde_json::json!({ "messages": [] }),
110 }
111 }
112
113 fn parse_openai_response(
115 &self,
116 response: Value,
117 model: String,
118 include_cache: bool,
119 ) -> Result<crate::llm::provider::LLMResponse> {
120 crate::llm::utils::parse_response_openai_format(
121 response,
122 self.provider_name(),
123 model,
124 include_cache,
125 None,
126 )
127 }
128}
129
130pub struct ErrorHandler {
132 _provider_name: &'static str,
133}
134
135impl ErrorHandler {
136 pub fn new(provider_name: &'static str) -> Self {
137 Self {
138 _provider_name: provider_name,
139 }
140 }
141
142 pub fn handle_http_error(&self, status: reqwest::StatusCode, error_text: &str) -> LLMError {
144 use reqwest::StatusCode;
145
146 let error_message = match status {
147 StatusCode::UNAUTHORIZED => "Authentication failed: Invalid API key".to_string(),
148 StatusCode::TOO_MANY_REQUESTS => "Rate limit exceeded".to_string(),
149 StatusCode::BAD_REQUEST => format!("Bad request: {}", error_text.trim()),
150 s if s.as_u16() == 402 => "Insufficient balance".to_string(),
151 _ => format!("HTTP {}: {}", status, error_text.trim()),
152 };
153
154 let formatted_error =
155 crate::llm::error_display::format_llm_error(self._provider_name, &error_message);
156
157 if status == StatusCode::TOO_MANY_REQUESTS {
159 LLMError::RateLimit { metadata: None }
160 } else {
161 LLMError::Provider {
162 message: formatted_error,
163 metadata: None,
164 }
165 }
166 }
167
168 pub fn validate_request(&self, request: &LLMRequest) -> Result<()> {
170 if request.messages.is_empty() {
171 anyhow::bail!("Request must contain at least one message")
172 }
173
174 if request.model.is_empty() {
175 anyhow::bail!("Request must specify a model")
176 }
177
178 if !self.is_model_supported(&request.model) {
180 anyhow::bail!("Unsupported model: {}", request.model)
181 }
182
183 Ok(())
184 }
185
186 fn is_model_supported(&self, model: &str) -> bool {
188 !model.is_empty()
191 }
192}
193
194pub struct StreamProcessor {
196 provider_name: &'static str,
197 supports_reasoning: bool,
198}
199
200impl StreamProcessor {
201 pub fn new(provider_name: &'static str, supports_reasoning: bool) -> Self {
202 Self {
203 provider_name,
204 supports_reasoning,
205 }
206 }
207
208 pub fn process_stream_chunk(&self, chunk: &str) -> Vec<LLMStreamEvent> {
210 let mut events = Vec::new();
211
212 for line in chunk.lines() {
213 let line = line.trim();
214 if line.is_empty() {
215 continue;
216 }
217
218 if let Some(data) = line.strip_prefix("data: ") {
219 if data == "[DONE]" {
220 continue;
222 }
223
224 match serde_json::from_str::<Value>(data) {
225 Ok(json) => {
226 if let Some(event) = self.parse_stream_event(json) {
227 events.push(event);
228 }
229 }
230 Err(_) => {
231 continue;
233 }
234 }
235 }
236 }
237
238 events
239 }
240
241 fn parse_stream_event(&self, json: Value) -> Option<LLMStreamEvent> {
243 crate::llm::utils::parse_stream_event_openai_format(json, self.provider_name)
245 }
246
247 pub fn extract_reasoning(&self, content: &str) -> (Vec<String>, Option<String>) {
249 if !self.supports_reasoning {
250 return (Vec::new(), None);
251 }
252
253 crate::llm::utils::extract_reasoning_content(content)
255 }
256}
257
258pub struct AuthHandler {
260 auth_type: AuthType,
261 api_key: String,
262}
263
264#[derive(Debug, Clone, Copy)]
265pub enum AuthType {
266 BearerToken,
267 ApiKeyHeader(&'static str),
268 QueryParam(&'static str),
269}
270
271impl AuthHandler {
272 pub fn new(auth_type: AuthType, api_key: String) -> Self {
273 Self { auth_type, api_key }
274 }
275
276 pub fn apply_auth(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
278 match self.auth_type {
279 AuthType::BearerToken => builder.bearer_auth(&self.api_key),
280 AuthType::ApiKeyHeader(header_name) => builder.header(header_name, &self.api_key),
281 AuthType::QueryParam(param_name) => builder.query(&[(param_name, &self.api_key)]),
282 }
283 }
284}
285
286pub struct RequestProcessor {
288 provider_name: &'static str,
289}
290
291impl RequestProcessor {
292 pub fn new(provider_name: &'static str) -> Self {
293 Self { provider_name }
294 }
295
296 pub async fn build_request(
298 &self,
299 client: &HttpClient,
300 method: reqwest::Method,
301 url: String,
302 auth: Option<&AuthHandler>,
303 body: Option<Value>,
304 ) -> Result<reqwest::RequestBuilder> {
305 let mut builder = client.request(method, &url);
306
307 if let Some(auth_handler) = auth {
308 builder = auth_handler.apply_auth(builder);
309 }
310
311 builder = builder
312 .header("Content-Type", "application/json")
313 .header("User-Agent", "VT Code/1.0");
314
315 if let Some(body_value) = body {
316 builder = builder.json(&body_value);
317 }
318
319 Ok(builder)
320 }
321
322 pub async fn handle_response(&self, response: reqwest::Response) -> Result<Value> {
324 let status = response.status();
325
326 if !status.is_success() {
327 let error_text = response.text().await.unwrap_or_default();
328 let error_handler = ErrorHandler::new(self.provider_name);
329 return Err(error_handler.handle_http_error(status, &error_text).into());
330 }
331
332 let response_text = response
333 .text()
334 .await
335 .context("Failed to read response body")?;
336
337 serde_json::from_str(&response_text).context("Failed to parse JSON response")
338 }
339
340 pub async fn handle_stream_response(
342 &self,
343 response: reqwest::Response,
344 ) -> Result<impl futures::Stream<Item = Result<String>>> {
345 let status = response.status();
346
347 if !status.is_success() {
348 let error_text = response.text().await.unwrap_or_default();
349 let error_handler = ErrorHandler::new(self.provider_name);
350 return Err(error_handler.handle_http_error(status, &error_text).into());
351 }
352
353 Ok(response.bytes_stream().map(|result| {
354 result
355 .map(|bytes| String::from_utf8_lossy(&bytes).to_string())
356 .map_err(|e| anyhow::anyhow!("Stream error: {}", e))
357 }))
358 }
359}
360
361pub struct ModelResolver {
363 #[expect(dead_code)]
364 provider_name: &'static str,
365 default_model: &'static str,
366 supported_models: &'static [&'static str],
367}
368
369impl ModelResolver {
370 pub fn new(
371 provider_name: &'static str,
372 default_model: &'static str,
373 supported_models: &'static [&'static str],
374 ) -> Self {
375 Self {
376 provider_name,
377 default_model,
378 supported_models,
379 }
380 }
381
382 pub fn resolve_model(&self, model: Option<String>) -> String {
384 model.unwrap_or_else(|| self.default_model.to_string())
385 }
386
387 pub fn validate_model(&self, model: &str) -> Result<()> {
389 if self.supported_models.is_empty() {
390 if model.is_empty() {
392 anyhow::bail!("Model cannot be empty")
393 }
394 return Ok(());
395 }
396
397 if !self.supported_models.contains(&model) {
398 anyhow::bail!(
399 "Unsupported model: {}. Supported models: {:?}",
400 model,
401 self.supported_models
402 )
403 }
404
405 Ok(())
406 }
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412
413 #[test]
414 fn test_base_provider_config() {
415 let config = BaseProviderConfig::from_options(
416 Some("test_key".to_string()),
417 Some("test_model".to_string()),
418 None,
419 "default_model",
420 "https://api.example.com",
421 "TEST_API_KEY",
422 None,
423 )
424 .unwrap();
425
426 assert_eq!(config.api_key, "test_key");
427 assert_eq!(config.model, "test_model");
428 assert_eq!(config.base_url, "https://api.example.com");
429 }
430
431 #[test]
432 fn test_error_handler() {
433 let handler = ErrorHandler::new("test_provider");
434
435 let unauthorized =
436 handler.handle_http_error(reqwest::StatusCode::UNAUTHORIZED, "Invalid API key");
437 let rate_limited = handler.handle_http_error(reqwest::StatusCode::TOO_MANY_REQUESTS, "");
438
439 assert!(matches!(
440 unauthorized,
441 LLMError::Provider {
442 message: _,
443 metadata: _
444 }
445 ));
446 assert!(matches!(rate_limited, LLMError::RateLimit { metadata: _ }));
447 }
448
449 #[test]
450 fn test_model_resolver() {
451 let resolver = ModelResolver::new("test_provider", "default-model", &["model1", "model2"]);
452
453 assert_eq!(resolver.resolve_model(None), "default-model");
454 assert_eq!(resolver.resolve_model(Some("custom".to_string())), "custom");
455
456 resolver.validate_model("model1").unwrap();
457 assert!(resolver.validate_model("unsupported").is_err());
458 }
459}