Skip to main content

st/proxy/
openrouter.rs

1//! 🌐 OpenRouter Provider Implementation
2//!
3//! "One API to rule them all!" - The Cheet 😺
4//!
5//! OpenRouter provides unified access to 100+ LLM models via OpenAI-compatible API
6//! https://openrouter.ai/docs
7
8use crate::proxy::{LlmMessage, LlmProvider, LlmRequest, LlmResponse, LlmRole, LlmUsage};
9use anyhow::{Context, Result};
10use async_trait::async_trait;
11use reqwest::Client;
12use serde::{Deserialize, Serialize};
13
14pub struct OpenRouterProvider {
15    client: Client,
16    api_key: String,
17    base_url: String,
18    /// Optional site URL for OpenRouter analytics
19    site_url: Option<String>,
20    /// Optional app name for OpenRouter analytics
21    app_name: Option<String>,
22}
23
24impl OpenRouterProvider {
25    pub fn new(api_key: String) -> Self {
26        Self {
27            client: Client::new(),
28            api_key,
29            base_url: "https://openrouter.ai/api/v1".to_string(),
30            site_url: Some("https://github.com/8b-is/smart-tree".to_string()),
31            app_name: Some("Smart Tree".to_string()),
32        }
33    }
34
35    /// Create with custom configuration
36    pub fn with_config(
37        api_key: String,
38        site_url: Option<String>,
39        app_name: Option<String>,
40    ) -> Self {
41        Self {
42            client: Client::new(),
43            api_key,
44            base_url: "https://openrouter.ai/api/v1".to_string(),
45            site_url,
46            app_name,
47        }
48    }
49}
50
51impl Default for OpenRouterProvider {
52    fn default() -> Self {
53        let api_key = std::env::var("OPENROUTER_API_KEY").unwrap_or_default();
54        Self::new(api_key)
55    }
56}
57
58#[async_trait]
59impl LlmProvider for OpenRouterProvider {
60    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
61        let url = format!("{}/chat/completions", self.base_url);
62
63        // Default to a good free/cheap model if none specified
64        let model = if request.model.is_empty() || request.model == "default" {
65            "anthropic/claude-3-haiku".to_string()
66        } else {
67            request.model.clone()
68        };
69
70        let openrouter_request = OpenRouterChatRequest {
71            model,
72            messages: request.messages.into_iter().map(Into::into).collect(),
73            temperature: request.temperature,
74            max_tokens: request.max_tokens,
75            stream: request.stream,
76        };
77
78        let mut req = self
79            .client
80            .post(&url)
81            .header("Authorization", format!("Bearer {}", self.api_key))
82            .header("Content-Type", "application/json");
83
84        // Add optional OpenRouter headers
85        if let Some(site) = &self.site_url {
86            req = req.header("HTTP-Referer", site);
87        }
88        if let Some(app) = &self.app_name {
89            req = req.header("X-Title", app);
90        }
91
92        let response = req
93            .json(&openrouter_request)
94            .send()
95            .await
96            .context("Failed to send request to OpenRouter")?;
97
98        if !response.status().is_success() {
99            let error_text = response.text().await?;
100            return Err(anyhow::anyhow!("OpenRouter API error: {}", error_text));
101        }
102
103        let openrouter_response: OpenRouterChatResponse = response.json().await?;
104
105        let content = openrouter_response
106            .choices
107            .first()
108            .map(|c| c.message.content.clone())
109            .unwrap_or_default();
110
111        Ok(LlmResponse {
112            content,
113            model: openrouter_response
114                .model
115                .unwrap_or_else(|| "unknown".to_string()),
116            usage: openrouter_response.usage.map(Into::into),
117        })
118    }
119
120    fn name(&self) -> &'static str {
121        "OpenRouter"
122    }
123}
124
125#[derive(Debug, Serialize)]
126struct OpenRouterChatRequest {
127    model: String,
128    messages: Vec<OpenRouterMessage>,
129    #[serde(skip_serializing_if = "Option::is_none")]
130    temperature: Option<f32>,
131    #[serde(skip_serializing_if = "Option::is_none")]
132    max_tokens: Option<usize>,
133    stream: bool,
134}
135
136#[derive(Debug, Serialize, Deserialize)]
137struct OpenRouterMessage {
138    role: String,
139    content: String,
140}
141
142impl From<LlmMessage> for OpenRouterMessage {
143    fn from(msg: LlmMessage) -> Self {
144        Self {
145            role: match msg.role {
146                LlmRole::System => "system".to_string(),
147                LlmRole::User => "user".to_string(),
148                LlmRole::Assistant => "assistant".to_string(),
149            },
150            content: msg.content,
151        }
152    }
153}
154
155#[derive(Debug, Deserialize)]
156struct OpenRouterChatResponse {
157    model: Option<String>,
158    choices: Vec<OpenRouterChoice>,
159    usage: Option<OpenRouterUsage>,
160}
161
162#[derive(Debug, Deserialize)]
163struct OpenRouterChoice {
164    message: OpenRouterMessage,
165}
166
167#[derive(Debug, Deserialize)]
168struct OpenRouterUsage {
169    prompt_tokens: usize,
170    completion_tokens: usize,
171    total_tokens: usize,
172}
173
174impl From<OpenRouterUsage> for LlmUsage {
175    fn from(usage: OpenRouterUsage) -> Self {
176        Self {
177            prompt_tokens: usage.prompt_tokens,
178            completion_tokens: usage.completion_tokens,
179            total_tokens: usage.total_tokens,
180        }
181    }
182}
183
184/// Popular OpenRouter models for quick access
185pub mod models {
186    pub const CLAUDE_3_OPUS: &str = "anthropic/claude-3-opus";
187    pub const CLAUDE_3_SONNET: &str = "anthropic/claude-3-sonnet";
188    pub const CLAUDE_3_HAIKU: &str = "anthropic/claude-3-haiku";
189    pub const GPT_4_TURBO: &str = "openai/gpt-4-turbo";
190    pub const GPT_4O: &str = "openai/gpt-4o";
191    pub const GPT_4O_MINI: &str = "openai/gpt-4o-mini";
192    pub const LLAMA_3_70B: &str = "meta-llama/llama-3-70b-instruct";
193    pub const MIXTRAL_8X7B: &str = "mistralai/mixtral-8x7b-instruct";
194    pub const GEMINI_PRO: &str = "google/gemini-pro";
195    pub const DEEPSEEK_CODER: &str = "deepseek/deepseek-coder";
196    pub const QWEN_72B: &str = "qwen/qwen-72b-chat";
197}