ultrafast_models_sdk/providers/
openrouter.rs1use crate::error::ProviderError;
2use crate::models::{
3 AudioRequest, AudioResponse, ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
4 ImageRequest, ImageResponse, SpeechRequest, SpeechResponse, StreamChunk,
5};
6use crate::providers::{HealthStatus, Provider, ProviderConfig, ProviderHealth, StreamResult};
7use async_stream::stream;
8use std::collections::HashMap;
9use std::time::Instant;
10
11use super::http_client::{map_error_response, AuthStrategy, HttpProviderClient};
12
13pub struct OpenRouterProvider {
15 client: HttpProviderClient,
16 config: ProviderConfig,
17}
18
19impl OpenRouterProvider {
20 pub fn new(config: ProviderConfig) -> Result<Self, ProviderError> {
21 let client = HttpProviderClient::new(
23 config.timeout,
24 config.base_url.clone(),
25 "https://openrouter.ai/api/v1",
26 &config.headers,
27 AuthStrategy::Bearer {
28 token: config.api_key.clone(),
29 },
30 )?;
31
32 Ok(Self { client, config })
33 }
34
35 fn map_model(&self, model: &str) -> String {
36 self.config
37 .model_mapping
38 .get(model)
39 .cloned()
40 .unwrap_or_else(|| model.to_string())
41 }
42}
43
44#[async_trait::async_trait]
45impl Provider for OpenRouterProvider {
46 fn name(&self) -> &str {
47 "openrouter"
48 }
49
50 fn supports_streaming(&self) -> bool {
51 true
52 }
53
54 fn supports_function_calling(&self) -> bool {
55 true
56 }
57
58 fn supported_models(&self) -> Vec<String> {
59 vec![
61 "openrouter/gpt-4o".to_string(),
62 "openrouter/gpt-4-turbo".to_string(),
63 ]
64 }
65
66 async fn chat_completion(
67 &self,
68 mut request: ChatRequest,
69 ) -> Result<ChatResponse, ProviderError> {
70 request.model = self.map_model(&request.model);
71 let chat_response: ChatResponse =
72 self.client.post_json("/chat/completions", &request).await?;
73 Ok(chat_response)
74 }
75
76 async fn stream_chat_completion(
77 &self,
78 mut request: ChatRequest,
79 ) -> Result<StreamResult, ProviderError> {
80 request.model = self.map_model(&request.model);
81 request.stream = Some(true);
82
83 let response = self
84 .client
85 .post_json_raw("/chat/completions", &request)
86 .await?;
87 if !response.status().is_success() {
88 return Err(map_error_response(response).await);
89 }
90
91 let stream = Box::pin(stream! {
92 let mut bytes_stream = response.bytes_stream();
93 let mut buffer = String::new();
94
95 while let Some(chunk_result) = futures::StreamExt::next(&mut bytes_stream).await {
96 match chunk_result {
97 Ok(chunk) => {
98 let chunk_str = String::from_utf8_lossy(&chunk);
99 buffer.push_str(&chunk_str);
100
101 while let Some(line_end) = buffer.find('\n') {
102 let line = buffer[..line_end].trim().to_string();
103 buffer = buffer[line_end + 1..].to_string();
104
105 if let Some(json_str) = line.strip_prefix("data: ") {
106 if json_str == "[DONE]" {
107 return;
108 }
109
110 match serde_json::from_str::<StreamChunk>(json_str) {
111 Ok(stream_chunk) => yield Ok(stream_chunk),
112 Err(e) => yield Err(ProviderError::Serialization(e)),
113 }
114 }
115 }
116 }
117 Err(e) => yield Err(ProviderError::Http(e)),
118 }
119 }
120 });
121
122 Ok(stream)
123 }
124
125 async fn embedding(
126 &self,
127 mut request: EmbeddingRequest,
128 ) -> Result<EmbeddingResponse, ProviderError> {
129 request.model = self.map_model(&request.model);
130 let resp = self.client.post_json_raw("/embeddings", &request).await?;
132 let status = resp.status();
133 if !status.is_success() {
134 return Err(map_error_response(resp).await);
135 }
136 let text = resp.text().await?;
137 match serde_json::from_str::<EmbeddingResponse>(&text) {
138 Ok(er) => Ok(er),
139 Err(_) => Err(ProviderError::Api {
140 code: status.as_u16(),
141 message: text,
142 }),
143 }
144 }
145
146 async fn image_generation(
147 &self,
148 mut request: ImageRequest,
149 ) -> Result<ImageResponse, ProviderError> {
150 if let Some(ref model) = request.model {
151 request.model = Some(self.map_model(model));
152 }
153
154 let resp = self
156 .client
157 .post_json_raw("/images/generations", &request)
158 .await?;
159 if resp.status().as_u16() == 405 {
160 return Err(ProviderError::Configuration {
161 message: "Image generation not supported by OpenRouter for selected model"
162 .to_string(),
163 });
164 }
165 let status = resp.status();
166 if !status.is_success() {
167 return Err(map_error_response(resp).await);
168 }
169 let text = resp.text().await?;
170 match serde_json::from_str::<ImageResponse>(&text) {
171 Ok(er) => Ok(er),
172 Err(_) => Err(ProviderError::Api {
173 code: status.as_u16(),
174 message: text,
175 }),
176 }
177 }
178
179 async fn audio_transcription(
180 &self,
181 mut request: AudioRequest,
182 ) -> Result<AudioResponse, ProviderError> {
183 request.model = self.map_model(&request.model);
184
185 let form = reqwest::multipart::Form::new()
186 .part(
187 "file",
188 reqwest::multipart::Part::bytes(request.file)
189 .file_name("audio.mp3")
190 .mime_str("audio/mpeg")?,
191 )
192 .text("model", request.model);
193
194 let form = if let Some(language) = request.language {
195 form.text("language", language)
196 } else {
197 form
198 };
199
200 let form = if let Some(prompt) = request.prompt {
201 form.text("prompt", prompt)
202 } else {
203 form
204 };
205
206 let response = self
207 .client
208 .post_multipart("/audio/transcriptions", form)
209 .await?;
210 if !response.status().is_success() {
211 return Err(map_error_response(response).await);
212 }
213 let audio_response: AudioResponse = response.json().await?;
214 Ok(audio_response)
215 }
216
217 async fn text_to_speech(
218 &self,
219 mut request: SpeechRequest,
220 ) -> Result<SpeechResponse, ProviderError> {
221 request.model = self.map_model(&request.model);
222
223 let response = self.client.post_json_raw("/audio/speech", &request).await?;
224 if !response.status().is_success() {
225 return Err(map_error_response(response).await);
226 }
227
228 let content_type = response
229 .headers()
230 .get("content-type")
231 .and_then(|ct| ct.to_str().ok())
232 .unwrap_or("audio/mpeg")
233 .to_string();
234
235 let audio_bytes = response.bytes().await?;
236
237 Ok(SpeechResponse {
238 audio: audio_bytes.to_vec(),
239 content_type,
240 })
241 }
242
243 async fn health_check(&self) -> Result<ProviderHealth, ProviderError> {
244 let start = Instant::now();
245 let response = self.client.get_json::<serde_json::Value>("/models").await;
246 let latency_ms = start.elapsed().as_millis() as u64;
247
248 match response {
249 Ok(_) => Ok(ProviderHealth {
250 status: HealthStatus::Healthy,
251 latency_ms: Some(latency_ms),
252 error_rate: 0.0,
253 last_check: chrono::Utc::now(),
254 details: HashMap::new(),
255 }),
256 Err(e) => {
257 let mut details = HashMap::new();
258 details.insert("error".to_string(), e.to_string());
259 Ok(ProviderHealth {
260 status: HealthStatus::Degraded,
261 latency_ms: Some(latency_ms),
262 error_rate: 1.0,
263 last_check: chrono::Utc::now(),
264 details,
265 })
266 }
267 }
268 }
269}