vtcode_core/llm/
provider_builder.rs1use crate::config::TimeoutsConfig;
2use crate::config::core::PromptCachingConfig;
3use crate::llm::provider::{LLMError, LLMProvider};
4use std::marker::PhantomData;
5
6pub struct ProviderBuilder<T> {
8 api_key: Option<String>,
9 model: Option<String>,
10 base_url: Option<String>,
11 prompt_cache: Option<PromptCachingConfig>,
12 timeouts: Option<TimeoutsConfig>,
13 _phantom: PhantomData<T>,
14}
15
16impl<T> Default for ProviderBuilder<T> {
17 fn default() -> Self {
18 Self {
19 api_key: None,
20 model: None,
21 base_url: None,
22 prompt_cache: None,
23 timeouts: None,
24 _phantom: PhantomData,
25 }
26 }
27}
28
29impl<T> ProviderBuilder<T>
30where
31 T: ProviderConfig,
32{
33 pub fn new() -> Self {
34 Self::default()
35 }
36
37 pub fn api_key(mut self, api_key: String) -> Self {
38 self.api_key = Some(api_key);
39 self
40 }
41
42 pub fn model(mut self, model: String) -> Self {
43 self.model = Some(model);
44 self
45 }
46
47 pub fn base_url(mut self, base_url: String) -> Self {
48 self.base_url = Some(base_url);
49 self
50 }
51
52 pub fn prompt_cache(mut self, prompt_cache: PromptCachingConfig) -> Self {
53 self.prompt_cache = Some(prompt_cache);
54 self
55 }
56
57 pub fn timeouts(mut self, timeouts: TimeoutsConfig) -> Self {
58 self.timeouts = Some(timeouts);
59 self
60 }
61
62 pub fn try_build(self) -> Result<Box<dyn LLMProvider>, LLMError> {
63 crate::llm::provider_config::create_provider_unified(
64 T::PROVIDER_KEY,
65 self.api_key,
66 self.model,
67 self.base_url,
68 self.prompt_cache,
69 self.timeouts,
70 )
71 }
72
73 pub fn build(self) -> Box<dyn LLMProvider> {
74 match self.try_build() {
75 Ok(provider) => provider,
76 Err(error) => unreachable!(
77 "provider builder invariant violated for `{}`: {}",
78 T::PROVIDER_KEY,
79 error
80 ),
81 }
82 }
83}
84
85pub trait ProviderConfig {
87 const PROVIDER_KEY: &'static str;
88 const DISPLAY_NAME: &'static str;
89 const DEFAULT_MODEL: &'static str;
90 const API_BASE_URL: &'static str;
91 const BASE_URL_ENV_VAR: Option<&'static str>;
92
93 fn create_provider(
94 api_key: String,
95 model: String,
96 base_url: String,
97 prompt_cache_enabled: bool,
98 prompt_cache_settings: Self::PromptCacheSettings,
99 timeouts: TimeoutsConfig,
100 ) -> Box<dyn LLMProvider>
101 where
102 Self::PromptCacheSettings: Send + Sync + 'static,
103 {
104 let _ = prompt_cache_settings;
105 let prompt_cache = prompt_cache_enabled.then(|| PromptCachingConfig {
106 enabled: true,
107 ..Default::default()
108 });
109
110 match crate::llm::provider_config::create_provider_unified(
111 Self::PROVIDER_KEY,
112 (!api_key.trim().is_empty()).then_some(api_key),
113 (!model.trim().is_empty()).then_some(model),
114 (!base_url.trim().is_empty()).then_some(base_url),
115 prompt_cache,
116 Some(timeouts),
117 ) {
118 Ok(provider) => provider,
119 Err(error) => unreachable!(
120 "provider config invariant violated for `{}`: {}",
121 Self::PROVIDER_KEY,
122 error
123 ),
124 }
125 }
126
127 type PromptCacheSettings: Clone + Default + Send + Sync + 'static;
128}
129
130mod http_client_pool {
132 use crate::config::TimeoutsConfig;
133 use hashbrown::HashMap;
134 use once_cell::sync::Lazy;
135 use reqwest::Client as HttpClient;
136 use std::sync::{Arc, RwLock};
137 use std::time::Duration;
138
139 type HttpClientPool = Arc<RwLock<HashMap<String, Arc<HttpClient>>>>;
140
141 static CLIENT_POOL: Lazy<HttpClientPool> = Lazy::new(|| {
142 let mut pool = HashMap::new();
143
144 pool.insert("default".to_string(), Arc::new(HttpClient::new()));
146
147 pool.insert(
149 "timeout_30s".to_string(),
150 Arc::new(
151 HttpClient::builder()
152 .timeout(Duration::from_secs(30))
153 .build()
154 .unwrap_or_else(|error| {
155 tracing::warn!(
156 error = %error,
157 "Failed to build 30s timeout HTTP client; falling back to default client"
158 );
159 HttpClient::new()
160 }),
161 ),
162 );
163
164 pool.insert(
165 "timeout_120s".to_string(),
166 Arc::new(
167 HttpClient::builder()
168 .timeout(Duration::from_secs(120))
169 .build()
170 .unwrap_or_else(|error| {
171 tracing::warn!(
172 error = %error,
173 "Failed to build 120s timeout HTTP client; falling back to default client"
174 );
175 HttpClient::new()
176 }),
177 ),
178 );
179
180 Arc::new(RwLock::new(pool))
181 });
182
183 pub fn get_http_client(key: &str) -> Arc<HttpClient> {
184 let pool_guard = CLIENT_POOL.read();
185 let pool = match pool_guard {
186 Ok(guard) => guard,
187 Err(poisoned) => {
188 tracing::warn!("HTTP client pool poisoned; continuing with recovered state");
189 poisoned.into_inner()
190 }
191 };
192
193 if let Some(client) = pool.get(key).cloned() {
194 return client;
195 }
196
197 if let Some(default_client) = pool.get("default").cloned() {
198 return default_client;
199 }
200
201 tracing::warn!("HTTP client pool missing default client; constructing transient client");
202 Arc::new(HttpClient::new())
203 }
204
205 pub fn get_http_client_for_timeouts(timeouts: &TimeoutsConfig) -> Arc<HttpClient> {
206 let key = if timeouts.default_ceiling_seconds >= 120 {
207 "timeout_120s"
208 } else if timeouts.default_ceiling_seconds >= 30 {
209 "timeout_30s"
210 } else {
211 "default"
212 };
213 get_http_client(key)
214 }
215}
216
217pub use http_client_pool::{get_http_client, get_http_client_for_timeouts};