Skip to main content

vtcode_core/llm/providers/openrouter/provider/
mod.rs

1#![allow(clippy::collapsible_if)]
2
3use crate::config::TimeoutsConfig;
4use crate::config::constants::{env_vars, models, urls};
5use crate::config::core::{
6    AnthropicConfig, ModelConfig, OpenRouterPromptCacheSettings, PromptCachingConfig,
7};
8use crate::config::models::ModelId;
9use crate::llm::error_display;
10use crate::llm::provider::{LLMError, LLMRequest, Message, MessageRole, ToolChoice};
11use crate::llm::providers::common::{
12    extract_prompt_cache_settings, override_base_url, resolve_model,
13};
14use reqwest::{Client as HttpClient, Response, StatusCode};
15use serde_json::Value;
16use std::borrow::Cow;
17use std::str::FromStr;
18
19const OPENROUTER_REFERER: &str = "https://github.com/vinhnx/vtcode";
20const OPENROUTER_TITLE: &str = "VT Code";
21const OPENROUTER_CATEGORIES: &str = "agents,coding";
22
23mod client_impl;
24mod parsing;
25mod provider_impl;
26#[cfg(test)]
27mod tests;
28
29pub struct OpenRouterProvider {
30    api_key: String,
31    http_client: HttpClient,
32    base_url: String,
33    model: String,
34    prompt_cache_enabled: bool,
35    prompt_cache_settings: OpenRouterPromptCacheSettings,
36    model_behavior: Option<ModelConfig>,
37}
38
39impl OpenRouterProvider {
40    pub fn new(api_key: String) -> Self {
41        Self::with_model_internal(
42            api_key,
43            models::openrouter::DEFAULT_MODEL.to_string(),
44            None,
45            None,
46            TimeoutsConfig::default(),
47            None,
48        )
49    }
50
51    pub fn with_model(api_key: String, model: String) -> Self {
52        Self::with_model_internal(api_key, model, None, None, TimeoutsConfig::default(), None)
53    }
54
55    pub fn new_with_client(
56        api_key: String,
57        model: String,
58        http_client: reqwest::Client,
59        base_url: String,
60        _timeouts: TimeoutsConfig,
61    ) -> Self {
62        Self {
63            api_key,
64            http_client,
65            base_url,
66            model,
67            prompt_cache_enabled: false,
68            prompt_cache_settings: OpenRouterPromptCacheSettings::default(),
69            model_behavior: None,
70        }
71    }
72
73    pub fn from_config(
74        api_key: Option<String>,
75        model: Option<String>,
76        base_url: Option<String>,
77        prompt_cache: Option<PromptCachingConfig>,
78        timeouts: Option<TimeoutsConfig>,
79        _anthropic: Option<AnthropicConfig>,
80        model_behavior: Option<ModelConfig>,
81    ) -> Self {
82        let api_key_value = api_key.unwrap_or_default();
83        let model_value = resolve_model(model, models::openrouter::DEFAULT_MODEL);
84
85        Self::with_model_internal(
86            api_key_value,
87            model_value,
88            prompt_cache,
89            base_url,
90            timeouts.unwrap_or_default(),
91            model_behavior,
92        )
93    }
94
95    fn with_model_internal(
96        api_key: String,
97        model: String,
98        prompt_cache: Option<PromptCachingConfig>,
99        base_url: Option<String>,
100        timeouts: TimeoutsConfig,
101        model_behavior: Option<ModelConfig>,
102    ) -> Self {
103        use crate::llm::http_client::HttpClientFactory;
104        let (prompt_cache_enabled, prompt_cache_settings) = extract_prompt_cache_settings(
105            prompt_cache,
106            |p| &p.openrouter,
107            |cfg, settings| cfg.enabled && settings.enabled,
108        );
109
110        Self {
111            api_key,
112            http_client: HttpClientFactory::for_llm(&timeouts),
113            base_url: override_base_url(
114                urls::OPENROUTER_API_BASE,
115                base_url,
116                Some(env_vars::OPENROUTER_BASE_URL),
117            ),
118            model,
119            prompt_cache_enabled,
120            prompt_cache_settings,
121            model_behavior,
122        }
123    }
124
125    pub(super) fn resolve_model<'a>(&'a self, request: &'a LLMRequest) -> &'a str {
126        if request.model.trim().is_empty() {
127            self.model.as_str()
128        } else {
129            request.model.as_str()
130        }
131    }
132
133    fn request_includes_tools(request: &LLMRequest) -> bool {
134        request
135            .tools
136            .as_ref()
137            .map(|tools| !tools.is_empty())
138            .unwrap_or(false)
139    }
140
141    fn enforce_tool_capabilities<'a>(&'a self, request: &'a LLMRequest) -> Cow<'a, LLMRequest> {
142        let resolved_model = self.resolve_model(request);
143        let tools_requested = Self::request_includes_tools(request);
144        let tool_restricted = if let Ok(model_id) = ModelId::from_str(resolved_model) {
145            !model_id.supports_tool_calls()
146        } else {
147            models::openrouter::TOOL_UNAVAILABLE_MODELS.contains(&resolved_model)
148        };
149
150        if tools_requested && tool_restricted {
151            Cow::Owned(Self::tool_free_request(request))
152        } else {
153            Cow::Borrowed(request)
154        }
155    }
156
157    fn tool_free_request(original: &LLMRequest) -> LLMRequest {
158        let mut sanitized = original.clone();
159        sanitized.tools = None;
160        sanitized.tool_choice = Some(ToolChoice::None);
161        sanitized.parallel_tool_calls = None;
162        sanitized.parallel_tool_config = None;
163
164        let mut normalized_messages: Vec<Message> = Vec::with_capacity(original.messages.len());
165
166        for message in &original.messages {
167            match message.role {
168                MessageRole::Assistant => {
169                    let mut cleaned = message.clone();
170                    cleaned.tool_calls = None;
171                    cleaned.tool_call_id = None;
172
173                    let content_text = cleaned.content.as_text();
174                    let has_content = !content_text.trim().is_empty();
175                    if has_content || cleaned.reasoning.is_some() {
176                        normalized_messages.push(cleaned);
177                    }
178                }
179                MessageRole::Tool => {
180                    let content_text = message.content.as_text();
181                    if content_text.trim().is_empty() {
182                        continue;
183                    }
184
185                    let mut converted = Message::user(content_text.into_owned());
186                    converted.reasoning = message.reasoning.clone();
187                    normalized_messages.push(converted);
188                }
189                _ => {
190                    normalized_messages.push(message.clone());
191                }
192            }
193        }
194
195        sanitized.messages = normalized_messages;
196        sanitized
197    }
198
199    fn build_provider_payload(&self, request: &LLMRequest) -> Result<(Value, String), LLMError> {
200        Ok((
201            self.convert_to_openrouter_format(request)?,
202            format!("{}/chat/completions", self.base_url),
203        ))
204    }
205
206    async fn dispatch_request(&self, url: &str, payload: &Value) -> Result<Response, LLMError> {
207        self.http_client
208            .post(url)
209            .bearer_auth(&self.api_key)
210            .header("HTTP-Referer", OPENROUTER_REFERER)
211            .header("X-OpenRouter-Title", OPENROUTER_TITLE)
212            .header("X-OpenRouter-Categories", OPENROUTER_CATEGORIES)
213            .json(payload)
214            .send()
215            .await
216            .map_err(|e| {
217                let formatted_error =
218                    error_display::format_llm_error("OpenRouter", &format!("Network error: {}", e));
219                LLMError::Network {
220                    message: formatted_error,
221                    metadata: None,
222                }
223            })
224    }
225
226    async fn send_with_tool_fallback(
227        &self,
228        request: &LLMRequest,
229        stream_override: Option<bool>,
230    ) -> Result<Response, LLMError> {
231        let adjusted_request = self.enforce_tool_capabilities(request);
232        let request_ref = adjusted_request.as_ref();
233        let request_with_tools = Self::request_includes_tools(request_ref);
234
235        let (mut payload, url) = self.build_provider_payload(request_ref)?;
236        if let Some(stream_flag) = stream_override {
237            payload["stream"] = Value::Bool(stream_flag);
238        }
239
240        let response = self.dispatch_request(&url, &payload).await?;
241        if response.status().is_success() {
242            return Ok(response);
243        }
244
245        let status = response.status();
246        let error_text = response.text().await.unwrap_or_default();
247
248        if status.as_u16() == 429 || error_text.contains("quota") {
249            return Err(LLMError::RateLimit { metadata: None });
250        }
251
252        if request_with_tools
253            && status == StatusCode::NOT_FOUND
254            && error_text.contains("No endpoints found that support tool use")
255        {
256            let fallback_request = Self::tool_free_request(request_ref);
257            let (mut fallback_payload, fallback_url) =
258                self.build_provider_payload(&fallback_request)?;
259            if let Some(stream_flag) = stream_override {
260                fallback_payload["stream"] = Value::Bool(stream_flag);
261            }
262
263            let fallback_response = self
264                .dispatch_request(&fallback_url, &fallback_payload)
265                .await?;
266            if fallback_response.status().is_success() {
267                return Ok(fallback_response);
268            }
269
270            let fallback_status = fallback_response.status();
271            let fallback_text = fallback_response.text().await.unwrap_or_default();
272
273            if fallback_status.as_u16() == 429 || fallback_text.contains("quota") {
274                return Err(LLMError::RateLimit { metadata: None });
275            }
276
277            let combined_error = format!(
278                "HTTP {}: {} | Tool fallback failed with HTTP {}: {}",
279                status, error_text, fallback_status, fallback_text
280            );
281            let formatted_error = error_display::format_llm_error("OpenRouter", &combined_error);
282            return Err(LLMError::Provider {
283                message: formatted_error,
284                metadata: None,
285            });
286        }
287
288        // Use unified error parsing for consistent error categorization
289        use crate::llm::providers::error_handling::parse_api_error;
290        Err(parse_api_error("OpenRouter", status, &error_text))
291    }
292}