1use crate::proxy::{LlmMessage, LlmProvider, LlmRequest, LlmResponse, LlmRole, LlmUsage};
9use anyhow::{Context, Result};
10use async_trait::async_trait;
11use reqwest::Client;
12use serde::{Deserialize, Serialize};
13
14pub struct OpenRouterProvider {
15 client: Client,
16 api_key: String,
17 base_url: String,
18 site_url: Option<String>,
20 app_name: Option<String>,
22}
23
24impl OpenRouterProvider {
25 pub fn new(api_key: String) -> Self {
26 Self {
27 client: Client::new(),
28 api_key,
29 base_url: "https://openrouter.ai/api/v1".to_string(),
30 site_url: Some("https://github.com/8b-is/smart-tree".to_string()),
31 app_name: Some("Smart Tree".to_string()),
32 }
33 }
34
35 pub fn with_config(
37 api_key: String,
38 site_url: Option<String>,
39 app_name: Option<String>,
40 ) -> Self {
41 Self {
42 client: Client::new(),
43 api_key,
44 base_url: "https://openrouter.ai/api/v1".to_string(),
45 site_url,
46 app_name,
47 }
48 }
49}
50
51impl Default for OpenRouterProvider {
52 fn default() -> Self {
53 let api_key = std::env::var("OPENROUTER_API_KEY").unwrap_or_default();
54 Self::new(api_key)
55 }
56}
57
58#[async_trait]
59impl LlmProvider for OpenRouterProvider {
60 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
61 let url = format!("{}/chat/completions", self.base_url);
62
63 let model = if request.model.is_empty() || request.model == "default" {
65 "anthropic/claude-3-haiku".to_string()
66 } else {
67 request.model.clone()
68 };
69
70 let openrouter_request = OpenRouterChatRequest {
71 model,
72 messages: request.messages.into_iter().map(Into::into).collect(),
73 temperature: request.temperature,
74 max_tokens: request.max_tokens,
75 stream: request.stream,
76 };
77
78 let mut req = self
79 .client
80 .post(&url)
81 .header("Authorization", format!("Bearer {}", self.api_key))
82 .header("Content-Type", "application/json");
83
84 if let Some(site) = &self.site_url {
86 req = req.header("HTTP-Referer", site);
87 }
88 if let Some(app) = &self.app_name {
89 req = req.header("X-Title", app);
90 }
91
92 let response = req
93 .json(&openrouter_request)
94 .send()
95 .await
96 .context("Failed to send request to OpenRouter")?;
97
98 if !response.status().is_success() {
99 let error_text = response.text().await?;
100 return Err(anyhow::anyhow!("OpenRouter API error: {}", error_text));
101 }
102
103 let openrouter_response: OpenRouterChatResponse = response.json().await?;
104
105 let content = openrouter_response
106 .choices
107 .first()
108 .map(|c| c.message.content.clone())
109 .unwrap_or_default();
110
111 Ok(LlmResponse {
112 content,
113 model: openrouter_response
114 .model
115 .unwrap_or_else(|| "unknown".to_string()),
116 usage: openrouter_response.usage.map(Into::into),
117 })
118 }
119
120 fn name(&self) -> &'static str {
121 "OpenRouter"
122 }
123}
124
125#[derive(Debug, Serialize)]
126struct OpenRouterChatRequest {
127 model: String,
128 messages: Vec<OpenRouterMessage>,
129 #[serde(skip_serializing_if = "Option::is_none")]
130 temperature: Option<f32>,
131 #[serde(skip_serializing_if = "Option::is_none")]
132 max_tokens: Option<usize>,
133 stream: bool,
134}
135
136#[derive(Debug, Serialize, Deserialize)]
137struct OpenRouterMessage {
138 role: String,
139 content: String,
140}
141
142impl From<LlmMessage> for OpenRouterMessage {
143 fn from(msg: LlmMessage) -> Self {
144 Self {
145 role: match msg.role {
146 LlmRole::System => "system".to_string(),
147 LlmRole::User => "user".to_string(),
148 LlmRole::Assistant => "assistant".to_string(),
149 },
150 content: msg.content,
151 }
152 }
153}
154
155#[derive(Debug, Deserialize)]
156struct OpenRouterChatResponse {
157 model: Option<String>,
158 choices: Vec<OpenRouterChoice>,
159 usage: Option<OpenRouterUsage>,
160}
161
162#[derive(Debug, Deserialize)]
163struct OpenRouterChoice {
164 message: OpenRouterMessage,
165}
166
167#[derive(Debug, Deserialize)]
168struct OpenRouterUsage {
169 prompt_tokens: usize,
170 completion_tokens: usize,
171 total_tokens: usize,
172}
173
174impl From<OpenRouterUsage> for LlmUsage {
175 fn from(usage: OpenRouterUsage) -> Self {
176 Self {
177 prompt_tokens: usage.prompt_tokens,
178 completion_tokens: usage.completion_tokens,
179 total_tokens: usage.total_tokens,
180 }
181 }
182}
183
184pub mod models {
186 pub const CLAUDE_3_OPUS: &str = "anthropic/claude-3-opus";
187 pub const CLAUDE_3_SONNET: &str = "anthropic/claude-3-sonnet";
188 pub const CLAUDE_3_HAIKU: &str = "anthropic/claude-3-haiku";
189 pub const GPT_4_TURBO: &str = "openai/gpt-4-turbo";
190 pub const GPT_4O: &str = "openai/gpt-4o";
191 pub const GPT_4O_MINI: &str = "openai/gpt-4o-mini";
192 pub const LLAMA_3_70B: &str = "meta-llama/llama-3-70b-instruct";
193 pub const MIXTRAL_8X7B: &str = "mistralai/mixtral-8x7b-instruct";
194 pub const GEMINI_PRO: &str = "google/gemini-pro";
195 pub const DEEPSEEK_CODER: &str = "deepseek/deepseek-coder";
196 pub const QWEN_72B: &str = "qwen/qwen-72b-chat";
197}