ultrafast_models_sdk/providers/
openrouter.rs

1use crate::error::ProviderError;
2use crate::models::{
3    AudioRequest, AudioResponse, ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
4    ImageRequest, ImageResponse, SpeechRequest, SpeechResponse, StreamChunk,
5};
6use crate::providers::{HealthStatus, Provider, ProviderConfig, ProviderHealth, StreamResult};
7use async_stream::stream;
8use std::collections::HashMap;
9use std::time::Instant;
10
11use super::http_client::{map_error_response, AuthStrategy, HttpProviderClient};
12
13/// OpenRouter provider implementation (OpenAI-compatible API)
14pub struct OpenRouterProvider {
15    client: HttpProviderClient,
16    config: ProviderConfig,
17}
18
19impl OpenRouterProvider {
20    pub fn new(config: ProviderConfig) -> Result<Self, ProviderError> {
21        // Allow custom headers like HTTP-Referer, X-Title to be passed via config.headers
22        let client = HttpProviderClient::new(
23            config.timeout,
24            config.base_url.clone(),
25            "https://openrouter.ai/api/v1",
26            &config.headers,
27            AuthStrategy::Bearer {
28                token: config.api_key.clone(),
29            },
30        )?;
31
32        Ok(Self { client, config })
33    }
34
35    fn map_model(&self, model: &str) -> String {
36        self.config
37            .model_mapping
38            .get(model)
39            .cloned()
40            .unwrap_or_else(|| model.to_string())
41    }
42}
43
44#[async_trait::async_trait]
45impl Provider for OpenRouterProvider {
46    fn name(&self) -> &str {
47        "openrouter"
48    }
49
50    fn supports_streaming(&self) -> bool {
51        true
52    }
53
54    fn supports_function_calling(&self) -> bool {
55        true
56    }
57
58    fn supported_models(&self) -> Vec<String> {
59        // Leave generic; OpenRouter aggregates many models. Users can override mapping.
60        vec![
61            "openrouter/gpt-4o".to_string(),
62            "openrouter/gpt-4-turbo".to_string(),
63        ]
64    }
65
66    async fn chat_completion(
67        &self,
68        mut request: ChatRequest,
69    ) -> Result<ChatResponse, ProviderError> {
70        request.model = self.map_model(&request.model);
71        let chat_response: ChatResponse =
72            self.client.post_json("/chat/completions", &request).await?;
73        Ok(chat_response)
74    }
75
76    async fn stream_chat_completion(
77        &self,
78        mut request: ChatRequest,
79    ) -> Result<StreamResult, ProviderError> {
80        request.model = self.map_model(&request.model);
81        request.stream = Some(true);
82
83        let response = self
84            .client
85            .post_json_raw("/chat/completions", &request)
86            .await?;
87        if !response.status().is_success() {
88            return Err(map_error_response(response).await);
89        }
90
91        let stream = Box::pin(stream! {
92            let mut bytes_stream = response.bytes_stream();
93            let mut buffer = String::new();
94
95            while let Some(chunk_result) = futures::StreamExt::next(&mut bytes_stream).await {
96                match chunk_result {
97                    Ok(chunk) => {
98                        let chunk_str = String::from_utf8_lossy(&chunk);
99                        buffer.push_str(&chunk_str);
100
101                        while let Some(line_end) = buffer.find('\n') {
102                            let line = buffer[..line_end].trim().to_string();
103                            buffer = buffer[line_end + 1..].to_string();
104
105                            if let Some(json_str) = line.strip_prefix("data: ") {
106                                if json_str == "[DONE]" {
107                                    return;
108                                }
109
110                                match serde_json::from_str::<StreamChunk>(json_str) {
111                                    Ok(stream_chunk) => yield Ok(stream_chunk),
112                                    Err(e) => yield Err(ProviderError::Serialization(e)),
113                                }
114                            }
115                        }
116                    }
117                    Err(e) => yield Err(ProviderError::Http(e)),
118                }
119            }
120        });
121
122        Ok(stream)
123    }
124
125    async fn embedding(
126        &self,
127        mut request: EmbeddingRequest,
128    ) -> Result<EmbeddingResponse, ProviderError> {
129        request.model = self.map_model(&request.model);
130        // OpenRouter expects OpenAI-style embeddings endpoint, but some upstream models may not support it
131        let resp = self.client.post_json_raw("/embeddings", &request).await?;
132        let status = resp.status();
133        if !status.is_success() {
134            return Err(map_error_response(resp).await);
135        }
136        let text = resp.text().await?;
137        match serde_json::from_str::<EmbeddingResponse>(&text) {
138            Ok(er) => Ok(er),
139            Err(_) => Err(ProviderError::Api {
140                code: status.as_u16(),
141                message: text,
142            }),
143        }
144    }
145
146    async fn image_generation(
147        &self,
148        mut request: ImageRequest,
149    ) -> Result<ImageResponse, ProviderError> {
150        if let Some(ref model) = request.model {
151            request.model = Some(self.map_model(model));
152        }
153
154        // Fall back to 405 mapping if not supported
155        let resp = self
156            .client
157            .post_json_raw("/images/generations", &request)
158            .await?;
159        if resp.status().as_u16() == 405 {
160            return Err(ProviderError::Configuration {
161                message: "Image generation not supported by OpenRouter for selected model"
162                    .to_string(),
163            });
164        }
165        let status = resp.status();
166        if !status.is_success() {
167            return Err(map_error_response(resp).await);
168        }
169        let text = resp.text().await?;
170        match serde_json::from_str::<ImageResponse>(&text) {
171            Ok(er) => Ok(er),
172            Err(_) => Err(ProviderError::Api {
173                code: status.as_u16(),
174                message: text,
175            }),
176        }
177    }
178
179    async fn audio_transcription(
180        &self,
181        mut request: AudioRequest,
182    ) -> Result<AudioResponse, ProviderError> {
183        request.model = self.map_model(&request.model);
184
185        let form = reqwest::multipart::Form::new()
186            .part(
187                "file",
188                reqwest::multipart::Part::bytes(request.file)
189                    .file_name("audio.mp3")
190                    .mime_str("audio/mpeg")?,
191            )
192            .text("model", request.model);
193
194        let form = if let Some(language) = request.language {
195            form.text("language", language)
196        } else {
197            form
198        };
199
200        let form = if let Some(prompt) = request.prompt {
201            form.text("prompt", prompt)
202        } else {
203            form
204        };
205
206        let response = self
207            .client
208            .post_multipart("/audio/transcriptions", form)
209            .await?;
210        if !response.status().is_success() {
211            return Err(map_error_response(response).await);
212        }
213        let audio_response: AudioResponse = response.json().await?;
214        Ok(audio_response)
215    }
216
217    async fn text_to_speech(
218        &self,
219        mut request: SpeechRequest,
220    ) -> Result<SpeechResponse, ProviderError> {
221        request.model = self.map_model(&request.model);
222
223        let response = self.client.post_json_raw("/audio/speech", &request).await?;
224        if !response.status().is_success() {
225            return Err(map_error_response(response).await);
226        }
227
228        let content_type = response
229            .headers()
230            .get("content-type")
231            .and_then(|ct| ct.to_str().ok())
232            .unwrap_or("audio/mpeg")
233            .to_string();
234
235        let audio_bytes = response.bytes().await?;
236
237        Ok(SpeechResponse {
238            audio: audio_bytes.to_vec(),
239            content_type,
240        })
241    }
242
243    async fn health_check(&self) -> Result<ProviderHealth, ProviderError> {
244        let start = Instant::now();
245        let response = self.client.get_json::<serde_json::Value>("/models").await;
246        let latency_ms = start.elapsed().as_millis() as u64;
247
248        match response {
249            Ok(_) => Ok(ProviderHealth {
250                status: HealthStatus::Healthy,
251                latency_ms: Some(latency_ms),
252                error_rate: 0.0,
253                last_check: chrono::Utc::now(),
254                details: HashMap::new(),
255            }),
256            Err(e) => {
257                let mut details = HashMap::new();
258                details.insert("error".to_string(), e.to_string());
259                Ok(ProviderHealth {
260                    status: HealthStatus::Degraded,
261                    latency_ms: Some(latency_ms),
262                    error_rate: 1.0,
263                    last_check: chrono::Utc::now(),
264                    details,
265                })
266            }
267        }
268    }
269}