Skip to main content

vtcode_core/llm/providers/openrouter/provider/
mod.rs

1#![allow(clippy::collapsible_if)]
2
3use tracing::warn;
4
5use crate::config::TimeoutsConfig;
6use crate::config::constants::{env_vars, models, urls};
7use crate::config::core::{
8    AnthropicConfig, ModelConfig, OpenRouterPromptCacheSettings, PromptCachingConfig,
9};
10use crate::config::models::ModelId;
11use crate::llm::error_display;
12use crate::llm::provider::{LLMError, LLMRequest, Message, MessageRole, ToolChoice};
13use crate::llm::providers::common::{
14    extract_prompt_cache_settings, override_base_url, resolve_model,
15};
16use reqwest::{Client as HttpClient, Response, StatusCode};
17use serde_json::Value;
18use std::borrow::Cow;
19use std::str::FromStr;
20
21const OPENROUTER_REFERER: &str = "https://github.com/vinhnx/vtcode";
22const OPENROUTER_TITLE: &str = "VT Code";
23const OPENROUTER_CATEGORIES: &str = "agents,coding";
24
25mod client_impl;
26mod parsing;
27mod provider_impl;
28#[cfg(test)]
29mod tests;
30
31pub struct OpenRouterProvider {
32    api_key: String,
33    http_client: HttpClient,
34    base_url: String,
35    model: String,
36    prompt_cache_enabled: bool,
37    prompt_cache_settings: OpenRouterPromptCacheSettings,
38    model_behavior: Option<ModelConfig>,
39}
40
41impl OpenRouterProvider {
42    pub fn new(api_key: String) -> Self {
43        Self::with_model_internal(
44            api_key,
45            models::openrouter::DEFAULT_MODEL.to_string(),
46            None,
47            None,
48            TimeoutsConfig::default(),
49            None,
50        )
51    }
52
53    pub fn with_model(api_key: String, model: String) -> Self {
54        Self::with_model_internal(api_key, model, None, None, TimeoutsConfig::default(), None)
55    }
56
57    pub fn new_with_client(
58        api_key: String,
59        model: String,
60        http_client: reqwest::Client,
61        base_url: String,
62        _timeouts: TimeoutsConfig,
63    ) -> Self {
64        Self {
65            api_key,
66            http_client,
67            base_url,
68            model,
69            prompt_cache_enabled: false,
70            prompt_cache_settings: OpenRouterPromptCacheSettings::default(),
71            model_behavior: None,
72        }
73    }
74
75    pub fn from_config(
76        api_key: Option<String>,
77        model: Option<String>,
78        base_url: Option<String>,
79        prompt_cache: Option<PromptCachingConfig>,
80        timeouts: Option<TimeoutsConfig>,
81        _anthropic: Option<AnthropicConfig>,
82        model_behavior: Option<ModelConfig>,
83    ) -> Self {
84        let api_key_value = api_key.unwrap_or_default();
85        let model_value = resolve_model(model, models::openrouter::DEFAULT_MODEL);
86
87        Self::with_model_internal(
88            api_key_value,
89            model_value,
90            prompt_cache,
91            base_url,
92            timeouts.unwrap_or_default(),
93            model_behavior,
94        )
95    }
96
97    fn with_model_internal(
98        api_key: String,
99        model: String,
100        prompt_cache: Option<PromptCachingConfig>,
101        base_url: Option<String>,
102        timeouts: TimeoutsConfig,
103        model_behavior: Option<ModelConfig>,
104    ) -> Self {
105        use crate::llm::http_client::HttpClientFactory;
106        let (prompt_cache_enabled, prompt_cache_settings) = extract_prompt_cache_settings(
107            prompt_cache,
108            |p| &p.openrouter,
109            |cfg, settings| cfg.enabled && settings.enabled,
110        );
111
112        Self {
113            api_key,
114            http_client: HttpClientFactory::for_llm(&timeouts),
115            base_url: override_base_url(
116                urls::OPENROUTER_API_BASE,
117                base_url,
118                Some(env_vars::OPENROUTER_BASE_URL),
119            ),
120            model,
121            prompt_cache_enabled,
122            prompt_cache_settings,
123            model_behavior,
124        }
125    }
126
127    pub(super) fn resolve_model<'a>(&'a self, request: &'a LLMRequest) -> &'a str {
128        if request.model.trim().is_empty() {
129            self.model.as_str()
130        } else {
131            request.model.as_str()
132        }
133    }
134
135    fn request_includes_tools(request: &LLMRequest) -> bool {
136        request
137            .tools
138            .as_ref()
139            .map(|tools| !tools.is_empty())
140            .unwrap_or(false)
141    }
142
143    fn enforce_tool_capabilities<'a>(&'a self, request: &'a LLMRequest) -> Cow<'a, LLMRequest> {
144        let resolved_model = self.resolve_model(request);
145        let tools_requested = Self::request_includes_tools(request);
146        let tool_restricted = if let Ok(model_id) = ModelId::from_str(resolved_model) {
147            !model_id.supports_tool_calls()
148        } else {
149            models::openrouter::TOOL_UNAVAILABLE_MODELS.contains(&resolved_model)
150        };
151
152        if tools_requested && tool_restricted {
153            Cow::Owned(Self::tool_free_request(request))
154        } else {
155            Cow::Borrowed(request)
156        }
157    }
158
159    fn tool_free_request(original: &LLMRequest) -> LLMRequest {
160        let mut sanitized = original.clone();
161        sanitized.tools = None;
162        sanitized.tool_choice = Some(ToolChoice::None);
163        sanitized.parallel_tool_calls = None;
164        sanitized.parallel_tool_config = None;
165
166        let mut normalized_messages: Vec<Message> = Vec::with_capacity(original.messages.len());
167
168        for message in &original.messages {
169            match message.role {
170                MessageRole::Assistant => {
171                    let mut cleaned = message.clone();
172                    cleaned.tool_calls = None;
173                    cleaned.tool_call_id = None;
174
175                    let content_text = cleaned.content.as_text();
176                    let has_content = !content_text.trim().is_empty();
177                    if has_content || cleaned.reasoning.is_some() {
178                        normalized_messages.push(cleaned);
179                    }
180                }
181                MessageRole::Tool => {
182                    let content_text = message.content.as_text();
183                    if content_text.trim().is_empty() {
184                        continue;
185                    }
186
187                    let mut converted = Message::user(content_text.into_owned());
188                    converted.reasoning = message.reasoning.clone();
189                    normalized_messages.push(converted);
190                }
191                _ => {
192                    normalized_messages.push(message.clone());
193                }
194            }
195        }
196
197        sanitized.messages = normalized_messages;
198        sanitized
199    }
200
201    fn request_includes_images(request: &LLMRequest) -> bool {
202        request.messages.iter().any(|msg| msg.content.has_images())
203    }
204
205    fn image_free_request(original: &LLMRequest) -> LLMRequest {
206        let mut sanitized = original.clone();
207        for message in &mut sanitized.messages {
208            if let Some(text_only) = message.content.without_images() {
209                message.content = text_only;
210            }
211        }
212        sanitized
213    }
214
215    /// Retry a request with a fallback payload. Returns `Ok(Some(response))` on
216    /// success, `Err` on rate-limit, and `Ok(None)` when the fallback also fails
217    /// (caller assembles the combined error).
218    async fn retry_with_fallback(
219        &self,
220        original_status: StatusCode,
221        original_error: &str,
222        fallback_request: &LLMRequest,
223        stream_override: Option<bool>,
224        label: &str,
225    ) -> Result<Option<Response>, LLMError> {
226        let (mut fallback_payload, fallback_url) = self.build_provider_payload(fallback_request)?;
227        if let Some(stream_flag) = stream_override {
228            fallback_payload["stream"] = Value::Bool(stream_flag);
229        }
230
231        let fallback_response = self
232            .dispatch_request(&fallback_url, &fallback_payload)
233            .await?;
234        if fallback_response.status().is_success() {
235            return Ok(Some(fallback_response));
236        }
237
238        let fallback_status = fallback_response.status();
239        let fallback_text = fallback_response.text().await.unwrap_or_default();
240
241        if fallback_status.as_u16() == 429 || fallback_text.contains("quota") {
242            return Err(LLMError::RateLimit { metadata: None });
243        }
244
245        let combined_error = format!(
246            "HTTP {}: {} | {} fallback failed with HTTP {}: {}",
247            original_status, original_error, label, fallback_status, fallback_text
248        );
249        let formatted_error = error_display::format_llm_error("OpenRouter", &combined_error);
250        Err(LLMError::Provider {
251            message: formatted_error,
252            metadata: None,
253        })
254    }
255
256    /// Attempt a feature-specific fallback when the provider rejects a request.
257    ///
258    /// Returns `Ok(Some(response))` if the fallback succeeds, `Ok(None)` if the
259    /// condition doesn't match (caller should try the next fallback), and `Err`
260    /// for rate-limit or combined error failures.
261    async fn try_feature_fallback(
262        &self,
263        request: &LLMRequest,
264        status: StatusCode,
265        error_text: &str,
266        stream_override: Option<bool>,
267        has_feature: fn(&LLMRequest) -> bool,
268        error_match: &str,
269        warn_message: &str,
270        strip_feature: fn(&LLMRequest) -> LLMRequest,
271        label: &str,
272    ) -> Result<Option<Response>, LLMError> {
273        if has_feature(request)
274            && status == StatusCode::NOT_FOUND
275            && error_text.contains(error_match)
276        {
277            warn!("{}", warn_message);
278            let fallback_request = strip_feature(request);
279            return self
280                .retry_with_fallback(
281                    status,
282                    error_text,
283                    &fallback_request,
284                    stream_override,
285                    label,
286                )
287                .await;
288        }
289        Ok(None)
290    }
291
292    fn build_provider_payload(&self, request: &LLMRequest) -> Result<(Value, String), LLMError> {
293        Ok((
294            self.convert_to_openrouter_format(request)?,
295            format!("{}/chat/completions", self.base_url),
296        ))
297    }
298
299    async fn dispatch_request(&self, url: &str, payload: &Value) -> Result<Response, LLMError> {
300        self.http_client
301            .post(url)
302            .bearer_auth(&self.api_key)
303            .header("HTTP-Referer", OPENROUTER_REFERER)
304            .header("X-OpenRouter-Title", OPENROUTER_TITLE)
305            .header("X-OpenRouter-Categories", OPENROUTER_CATEGORIES)
306            .json(payload)
307            .send()
308            .await
309            .map_err(|e| {
310                let formatted_error =
311                    error_display::format_llm_error("OpenRouter", &format!("Network error: {}", e));
312                LLMError::Network {
313                    message: formatted_error,
314                    metadata: None,
315                }
316            })
317    }
318
319    async fn send_with_fallback(
320        &self,
321        request: &LLMRequest,
322        stream_override: Option<bool>,
323    ) -> Result<Response, LLMError> {
324        let adjusted_request = self.enforce_tool_capabilities(request);
325        let request_ref = adjusted_request.as_ref();
326
327        let (mut payload, url) = self.build_provider_payload(request_ref)?;
328        if let Some(stream_flag) = stream_override {
329            payload["stream"] = Value::Bool(stream_flag);
330        }
331
332        let response = self.dispatch_request(&url, &payload).await?;
333        if response.status().is_success() {
334            return Ok(response);
335        }
336
337        let status = response.status();
338        let error_text = response.text().await.unwrap_or_default();
339
340        if status.as_u16() == 429 || error_text.contains("quota") {
341            return Err(LLMError::RateLimit { metadata: None });
342        }
343
344        if let Some(resp) = self
345            .try_feature_fallback(
346                request_ref,
347                status,
348                &error_text,
349                stream_override,
350                Self::request_includes_tools,
351                "No endpoints found that support tool use",
352                "OpenRouter endpoint does not support tool use; retrying without tools",
353                Self::tool_free_request,
354                "Tool",
355            )
356            .await?
357        {
358            return Ok(resp);
359        }
360
361        if let Some(resp) = self
362            .try_feature_fallback(
363                request_ref,
364                status,
365                &error_text,
366                stream_override,
367                Self::request_includes_images,
368                "No endpoints found that support image input",
369                "OpenRouter endpoint does not support image input; retrying without images",
370                Self::image_free_request,
371                "Image",
372            )
373            .await?
374        {
375            return Ok(resp);
376        }
377
378        // Use unified error parsing for consistent error categorization
379        use crate::llm::providers::error_handling::parse_api_error;
380        Err(parse_api_error("OpenRouter", status, &error_text))
381    }
382}