Skip to main content

vtcode_core/llm/
provider_builder.rs

1use crate::config::TimeoutsConfig;
2use crate::config::core::PromptCachingConfig;
3use crate::llm::provider::{LLMError, LLMProvider};
4use std::marker::PhantomData;
5
6/// Generic provider builder to eliminate duplicate provider creation patterns
7pub struct ProviderBuilder<T> {
8    api_key: Option<String>,
9    model: Option<String>,
10    base_url: Option<String>,
11    prompt_cache: Option<PromptCachingConfig>,
12    timeouts: Option<TimeoutsConfig>,
13    _phantom: PhantomData<T>,
14}
15
16impl<T> Default for ProviderBuilder<T> {
17    fn default() -> Self {
18        Self {
19            api_key: None,
20            model: None,
21            base_url: None,
22            prompt_cache: None,
23            timeouts: None,
24            _phantom: PhantomData,
25        }
26    }
27}
28
29impl<T> ProviderBuilder<T>
30where
31    T: ProviderConfig,
32{
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    pub fn api_key(mut self, api_key: String) -> Self {
38        self.api_key = Some(api_key);
39        self
40    }
41
42    pub fn model(mut self, model: String) -> Self {
43        self.model = Some(model);
44        self
45    }
46
47    pub fn base_url(mut self, base_url: String) -> Self {
48        self.base_url = Some(base_url);
49        self
50    }
51
52    pub fn prompt_cache(mut self, prompt_cache: PromptCachingConfig) -> Self {
53        self.prompt_cache = Some(prompt_cache);
54        self
55    }
56
57    pub fn timeouts(mut self, timeouts: TimeoutsConfig) -> Self {
58        self.timeouts = Some(timeouts);
59        self
60    }
61
62    pub fn try_build(self) -> Result<Box<dyn LLMProvider>, LLMError> {
63        crate::llm::provider_config::create_provider_unified(
64            T::PROVIDER_KEY,
65            self.api_key,
66            self.model,
67            self.base_url,
68            self.prompt_cache,
69            self.timeouts,
70        )
71    }
72
73    pub fn build(self) -> Box<dyn LLMProvider> {
74        match self.try_build() {
75            Ok(provider) => provider,
76            Err(error) => unreachable!(
77                "provider builder invariant violated for `{}`: {}",
78                T::PROVIDER_KEY,
79                error
80            ),
81        }
82    }
83}
84
85/// Trait for provider-specific configuration and creation
86pub trait ProviderConfig {
87    const PROVIDER_KEY: &'static str;
88    const DISPLAY_NAME: &'static str;
89    const DEFAULT_MODEL: &'static str;
90    const API_BASE_URL: &'static str;
91    const BASE_URL_ENV_VAR: Option<&'static str>;
92
93    fn create_provider(
94        api_key: String,
95        model: String,
96        base_url: String,
97        prompt_cache_enabled: bool,
98        prompt_cache_settings: Self::PromptCacheSettings,
99        timeouts: TimeoutsConfig,
100    ) -> Box<dyn LLMProvider>
101    where
102        Self::PromptCacheSettings: Send + Sync + 'static,
103    {
104        let _ = prompt_cache_settings;
105        let prompt_cache = prompt_cache_enabled.then(|| PromptCachingConfig {
106            enabled: true,
107            ..Default::default()
108        });
109
110        match crate::llm::provider_config::create_provider_unified(
111            Self::PROVIDER_KEY,
112            (!api_key.trim().is_empty()).then_some(api_key),
113            (!model.trim().is_empty()).then_some(model),
114            (!base_url.trim().is_empty()).then_some(base_url),
115            prompt_cache,
116            Some(timeouts),
117        ) {
118            Ok(provider) => provider,
119            Err(error) => unreachable!(
120                "provider config invariant violated for `{}`: {}",
121                Self::PROVIDER_KEY,
122                error
123            ),
124        }
125    }
126
127    type PromptCacheSettings: Clone + Default + Send + Sync + 'static;
128}
129
130/// HTTP client pool to avoid creating new clients for each provider
131mod http_client_pool {
132    use crate::config::TimeoutsConfig;
133    use hashbrown::HashMap;
134    use once_cell::sync::Lazy;
135    use reqwest::Client as HttpClient;
136    use std::sync::{Arc, RwLock};
137    use std::time::Duration;
138
139    type HttpClientPool = Arc<RwLock<HashMap<String, Arc<HttpClient>>>>;
140
141    static CLIENT_POOL: Lazy<HttpClientPool> = Lazy::new(|| {
142        let mut pool = HashMap::new();
143
144        // Default client
145        pool.insert("default".to_string(), Arc::new(HttpClient::new()));
146
147        // Timeout-configured clients
148        pool.insert(
149            "timeout_30s".to_string(),
150            Arc::new(
151                HttpClient::builder()
152                    .timeout(Duration::from_secs(30))
153                    .build()
154                    .unwrap_or_else(|error| {
155                        tracing::warn!(
156                            error = %error,
157                            "Failed to build 30s timeout HTTP client; falling back to default client"
158                        );
159                        HttpClient::new()
160                    }),
161            ),
162        );
163
164        pool.insert(
165            "timeout_120s".to_string(),
166            Arc::new(
167                HttpClient::builder()
168                    .timeout(Duration::from_secs(120))
169                    .build()
170                    .unwrap_or_else(|error| {
171                        tracing::warn!(
172                            error = %error,
173                            "Failed to build 120s timeout HTTP client; falling back to default client"
174                        );
175                        HttpClient::new()
176                    }),
177            ),
178        );
179
180        Arc::new(RwLock::new(pool))
181    });
182
183    pub fn get_http_client(key: &str) -> Arc<HttpClient> {
184        let pool_guard = CLIENT_POOL.read();
185        let pool = match pool_guard {
186            Ok(guard) => guard,
187            Err(poisoned) => {
188                tracing::warn!("HTTP client pool poisoned; continuing with recovered state");
189                poisoned.into_inner()
190            }
191        };
192
193        if let Some(client) = pool.get(key).cloned() {
194            return client;
195        }
196
197        if let Some(default_client) = pool.get("default").cloned() {
198            return default_client;
199        }
200
201        tracing::warn!("HTTP client pool missing default client; constructing transient client");
202        Arc::new(HttpClient::new())
203    }
204
205    pub fn get_http_client_for_timeouts(timeouts: &TimeoutsConfig) -> Arc<HttpClient> {
206        let key = if timeouts.default_ceiling_seconds >= 120 {
207            "timeout_120s"
208        } else if timeouts.default_ceiling_seconds >= 30 {
209            "timeout_30s"
210        } else {
211            "default"
212        };
213        get_http_client(key)
214    }
215}
216
217pub use http_client_pool::{get_http_client, get_http_client_for_timeouts};