Skip to main content

vtcode_core/llm/providers/
poolside.rs

1use async_trait::async_trait;
2use reqwest::Client as HttpClient;
3use serde_json::{Map, Value};
4
5use crate::config::TimeoutsConfig;
6use crate::config::constants::{env_vars, models, urls};
7use crate::config::core::{AnthropicConfig, ModelConfig, PromptCachingConfig};
8use crate::llm::error_display;
9use crate::llm::provider::{LLMError, LLMProvider, LLMRequest, LLMResponse, LLMStream};
10
11use super::{
12    common::{
13        ensure_model, extract_prompt_cache_settings_default, impl_llm_client, override_base_url,
14        parse_json_response, parse_response_openai_format, resolve_model,
15        serialize_messages_openai_format, serialize_tools_openai_format,
16        spawn_openai_compatible_stream, validate_supported_models,
17    },
18    error_handling::handle_openai_http_error,
19};
20
21const PROVIDER_NAME: &str = "Poolside";
22const PROVIDER_KEY: &str = "poolside";
23
24pub struct PoolsideProvider {
25    api_key: String,
26    http_client: HttpClient,
27    base_url: String,
28    model: String,
29    prompt_cache_enabled: bool,
30    #[allow(dead_code)]
31    model_behavior: Option<ModelConfig>,
32}
33
34impl PoolsideProvider {
35    pub fn new(api_key: String) -> Self {
36        Self::with_model_internal(
37            api_key,
38            models::poolside::DEFAULT_MODEL.to_string(),
39            None,
40            None,
41            TimeoutsConfig::default(),
42            None,
43        )
44    }
45
46    pub fn with_model(api_key: String, model: String) -> Self {
47        Self::with_model_internal(api_key, model, None, None, TimeoutsConfig::default(), None)
48    }
49
50    pub fn new_with_client(
51        api_key: String,
52        model: String,
53        http_client: reqwest::Client,
54        base_url: String,
55        _timeouts: TimeoutsConfig,
56    ) -> Self {
57        Self {
58            api_key,
59            http_client,
60            base_url,
61            model,
62            prompt_cache_enabled: false,
63            model_behavior: None,
64        }
65    }
66
67    pub fn from_config(
68        api_key: Option<String>,
69        model: Option<String>,
70        base_url: Option<String>,
71        prompt_cache: Option<PromptCachingConfig>,
72        timeouts: Option<TimeoutsConfig>,
73        _anthropic: Option<AnthropicConfig>,
74        model_behavior: Option<ModelConfig>,
75    ) -> Self {
76        let api_key_value = api_key
77            .filter(|k| !k.trim().is_empty())
78            .or_else(|| {
79                std::env::var("POOLSIDE_API_KEY")
80                    .ok()
81                    .filter(|k| !k.trim().is_empty())
82            })
83            .unwrap_or_default();
84
85        Self::with_model_internal(
86            api_key_value,
87            resolve_model(model, models::poolside::DEFAULT_MODEL),
88            prompt_cache,
89            base_url,
90            timeouts.unwrap_or_default(),
91            model_behavior,
92        )
93    }
94
95    fn with_model_internal(
96        api_key: String,
97        model: String,
98        prompt_cache: Option<PromptCachingConfig>,
99        base_url: Option<String>,
100        timeouts: TimeoutsConfig,
101        model_behavior: Option<ModelConfig>,
102    ) -> Self {
103        use crate::llm::http_client::HttpClientFactory;
104
105        let (prompt_cache_enabled, _) =
106            extract_prompt_cache_settings_default(prompt_cache, PROVIDER_KEY);
107
108        Self {
109            api_key,
110            http_client: HttpClientFactory::for_llm(&timeouts),
111            base_url: override_base_url(
112                urls::POOLSIDE_API_BASE,
113                base_url,
114                Some(env_vars::POOLSIDE_BASE_URL),
115            ),
116            model,
117            prompt_cache_enabled,
118            model_behavior,
119        }
120    }
121
122    fn convert_to_poolside_format(&self, request: &LLMRequest) -> Result<Value, LLMError> {
123        let mut payload = Map::with_capacity(12);
124
125        payload.insert("model".to_owned(), Value::String(request.model.clone()));
126
127        let mut messages = self.serialize_messages(request)?;
128
129        if let Some(system_prompt) = &request.system_prompt {
130            let trimmed = system_prompt.trim();
131            if !trimmed.is_empty() {
132                messages.insert(0, serde_json::json!({"role": "system", "content": trimmed}));
133            }
134        }
135
136        payload.insert("messages".to_owned(), Value::Array(messages));
137
138        if let Some(max_tokens) = request.max_tokens {
139            payload.insert(
140                "max_tokens".to_owned(),
141                Value::Number(serde_json::Number::from(max_tokens as u64)),
142            );
143        }
144
145        if let Some(temperature) = request.temperature {
146            payload.insert(
147                "temperature".to_owned(),
148                Value::Number(serde_json::Number::from_f64(temperature as f64).ok_or(
149                    LLMError::InvalidRequest {
150                        message: "invalid temperature value".to_string(),
151                        metadata: None,
152                    },
153                )?),
154            );
155        }
156
157        if let Some(top_p) = request.top_p {
158            payload.insert(
159                "top_p".to_owned(),
160                Value::Number(serde_json::Number::from_f64(top_p as f64).ok_or(
161                    LLMError::InvalidRequest {
162                        message: "invalid top_p value".to_string(),
163                        metadata: None,
164                    },
165                )?),
166            );
167        }
168
169        if request.stream {
170            payload.insert("stream".to_string(), Value::Bool(true));
171            payload.insert(
172                "stream_options".to_string(),
173                serde_json::json!({"include_usage": true}),
174            );
175        }
176
177        if let Some(tools) = &request.tools
178            && let Some(serialized_tools) = serialize_tools_openai_format(tools)
179        {
180            payload.insert("tools".to_string(), Value::Array(serialized_tools));
181        }
182
183        if let Some(choice) = &request.tool_choice {
184            payload.insert(
185                "tool_choice".to_string(),
186                choice.to_provider_format(PROVIDER_KEY),
187            );
188        }
189
190        if let Some(meta) = &request.metadata
191            && let Some(user_id) = meta.get("user_id").and_then(|v| v.as_str())
192        {
193            payload.insert("user_id".to_owned(), Value::String(user_id.to_owned()));
194        }
195
196        Ok(Value::Object(payload))
197    }
198
199    async fn send_request(&self, payload: &Value) -> Result<reqwest::Response, LLMError> {
200        let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
201
202        self.http_client
203            .post(&url)
204            .header("Authorization", format!("Bearer {}", self.api_key))
205            .json(payload)
206            .send()
207            .await
208            .map_err(|e| LLMError::Network {
209                message: error_display::format_llm_error(
210                    PROVIDER_NAME,
211                    &format!("network error: {}", e),
212                ),
213                metadata: None,
214            })
215    }
216
217    fn serialize_messages(&self, request: &LLMRequest) -> Result<Vec<Value>, LLMError> {
218        serialize_messages_openai_format(request, PROVIDER_KEY)
219    }
220
221    fn parse_response(&self, response_json: Value, model: String) -> Result<LLMResponse, LLMError> {
222        let reasoning_extractor = |_message: &Value, _choice: &Value| -> Option<String> { None };
223
224        parse_response_openai_format(
225            response_json,
226            PROVIDER_NAME,
227            model,
228            self.prompt_cache_enabled,
229            Some(reasoning_extractor),
230        )
231    }
232}
233
234#[async_trait]
235impl LLMProvider for PoolsideProvider {
236    fn name(&self) -> &str {
237        PROVIDER_KEY
238    }
239
240    fn supports_streaming(&self) -> bool {
241        true
242    }
243
244    fn supports_tools(&self, _model: &str) -> bool {
245        true
246    }
247
248    fn supports_structured_output(&self, _model: &str) -> bool {
249        true
250    }
251
252    fn supports_vision(&self, _model: &str) -> bool {
253        false
254    }
255
256    fn supports_reasoning(&self, _model: &str) -> bool {
257        true
258    }
259
260    fn supports_reasoning_effort(&self, _model: &str) -> bool {
261        false
262    }
263
264    fn effective_context_size(&self, _model: &str) -> usize {
265        131_072
266    }
267
268    async fn generate(&self, mut request: LLMRequest) -> Result<LLMResponse, LLMError> {
269        let model = ensure_model(&mut request, &self.model);
270
271        let payload = self.convert_to_poolside_format(&request)?;
272        let response = self.send_request(&payload).await?;
273        let response =
274            handle_openai_http_error(response, PROVIDER_NAME, "POOLSIDE_API_KEY").await?;
275
276        let response_json = parse_json_response(response, PROVIDER_NAME).await?;
277        self.parse_response(response_json, model)
278    }
279
280    async fn stream(&self, mut request: LLMRequest) -> Result<LLMStream, LLMError> {
281        ensure_model(&mut request, &self.model);
282        self.validate_request(&request)?;
283        request.stream = true;
284        let model = request.model.clone();
285
286        let payload = self.convert_to_poolside_format(&request)?;
287        let response = self.send_request(&payload).await?;
288        let response =
289            handle_openai_http_error(response, PROVIDER_NAME, "POOLSIDE_API_KEY").await?;
290
291        Ok(spawn_openai_compatible_stream(
292            response,
293            PROVIDER_NAME,
294            model,
295            None,
296            super::shared::OpenAiDeltaOrder::ContentFirst,
297        ))
298    }
299
300    fn supported_models(&self) -> Vec<String> {
301        models::poolside::SUPPORTED_MODELS
302            .iter()
303            .map(|model| model.to_string())
304            .collect()
305    }
306
307    fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
308        validate_supported_models(
309            request,
310            PROVIDER_NAME,
311            PROVIDER_KEY,
312            models::poolside::SUPPORTED_MODELS,
313        )
314    }
315}
316
317impl_llm_client!(PoolsideProvider);