ultrafast_models_sdk/providers/
openai.rs

1use crate::error::ProviderError;
2use crate::models::{
3    AudioRequest, AudioResponse, ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
4    ImageRequest, ImageResponse, SpeechRequest, SpeechResponse, StreamChunk,
5};
6use crate::providers::{HealthStatus, Provider, ProviderConfig, ProviderHealth, StreamResult};
7use async_stream::stream;
8
9use super::http_client::{map_error_response, AuthStrategy, HttpProviderClient};
10
11use std::collections::HashMap;
12use std::time::Instant;
13
14pub struct OpenAIProvider {
15    client: HttpProviderClient,
16    config: ProviderConfig,
17}
18
19impl OpenAIProvider {
20    pub fn new(config: ProviderConfig) -> Result<Self, ProviderError> {
21        let client = HttpProviderClient::new(
22            config.timeout,
23            config.base_url.clone(),
24            "https://api.openai.com/v1",
25            &config.headers,
26            AuthStrategy::Bearer {
27                token: config.api_key.clone(),
28            },
29        )?;
30
31        Ok(Self { client, config })
32    }
33
34    fn map_model(&self, model: &str) -> String {
35        self.config
36            .model_mapping
37            .get(model)
38            .cloned()
39            .unwrap_or_else(|| model.to_string())
40    }
41}
42
43#[async_trait::async_trait]
44impl Provider for OpenAIProvider {
45    fn name(&self) -> &str {
46        "openai"
47    }
48
49    fn supports_streaming(&self) -> bool {
50        true
51    }
52
53    fn supports_function_calling(&self) -> bool {
54        true
55    }
56
57    fn supported_models(&self) -> Vec<String> {
58        vec![
59            "gpt-4".to_string(),
60            "gpt-4-turbo".to_string(),
61            "gpt-4-turbo-preview".to_string(),
62            "gpt-3.5-turbo".to_string(),
63            "gpt-3.5-turbo-16k".to_string(),
64            "text-embedding-ada-002".to_string(),
65            "text-embedding-3-small".to_string(),
66            "text-embedding-3-large".to_string(),
67            "dall-e-2".to_string(),
68            "dall-e-3".to_string(),
69            "whisper-1".to_string(),
70            "tts-1".to_string(),
71            "tts-1-hd".to_string(),
72        ]
73    }
74
75    async fn chat_completion(
76        &self,
77        mut request: ChatRequest,
78    ) -> Result<ChatResponse, ProviderError> {
79        request.model = self.map_model(&request.model);
80
81        let chat_response: ChatResponse =
82            self.client.post_json("/chat/completions", &request).await?;
83        Ok(chat_response)
84    }
85
86    async fn stream_chat_completion(
87        &self,
88        mut request: ChatRequest,
89    ) -> Result<StreamResult, ProviderError> {
90        request.model = self.map_model(&request.model);
91        request.stream = Some(true);
92
93        let response = self
94            .client
95            .post_json_raw("/chat/completions", &request)
96            .await?;
97        if !response.status().is_success() {
98            return Err(map_error_response(response).await);
99        }
100
101        let stream = Box::pin(stream! {
102            let mut bytes_stream = response.bytes_stream();
103            let mut buffer = String::new();
104
105            while let Some(chunk_result) = futures::StreamExt::next(&mut bytes_stream).await {
106                match chunk_result {
107                    Ok(chunk) => {
108                        let chunk_str = String::from_utf8_lossy(&chunk);
109                        buffer.push_str(&chunk_str);
110
111                        while let Some(line_end) = buffer.find('\n') {
112                            let line = buffer[..line_end].trim().to_string();
113                            buffer = buffer[line_end + 1..].to_string();
114
115                            if let Some(json_str) = line.strip_prefix("data: ") {
116                                if json_str == "[DONE]" {
117                                    return;
118                                }
119
120                                match serde_json::from_str::<StreamChunk>(json_str) {
121                                    Ok(stream_chunk) => yield Ok(stream_chunk),
122                                    Err(e) => yield Err(ProviderError::Serialization(e)),
123                                }
124                            }
125                        }
126                    }
127                    Err(e) => yield Err(ProviderError::Http(e)),
128                }
129            }
130        });
131
132        Ok(stream)
133    }
134
135    async fn embedding(
136        &self,
137        mut request: EmbeddingRequest,
138    ) -> Result<EmbeddingResponse, ProviderError> {
139        request.model = self.map_model(&request.model);
140
141        let embedding_response: EmbeddingResponse =
142            self.client.post_json("/embeddings", &request).await?;
143        Ok(embedding_response)
144    }
145
146    async fn image_generation(
147        &self,
148        mut request: ImageRequest,
149    ) -> Result<ImageResponse, ProviderError> {
150        if let Some(ref model) = request.model {
151            request.model = Some(self.map_model(model));
152        }
153
154        let image_response: ImageResponse = self
155            .client
156            .post_json("/images/generations", &request)
157            .await?;
158        Ok(image_response)
159    }
160
161    async fn audio_transcription(
162        &self,
163        mut request: AudioRequest,
164    ) -> Result<AudioResponse, ProviderError> {
165        request.model = self.map_model(&request.model);
166
167        let form = reqwest::multipart::Form::new()
168            .part(
169                "file",
170                reqwest::multipart::Part::bytes(request.file)
171                    .file_name("audio.mp3")
172                    .mime_str("audio/mpeg")?,
173            )
174            .text("model", request.model);
175
176        let form = if let Some(language) = request.language {
177            form.text("language", language)
178        } else {
179            form
180        };
181
182        let form = if let Some(prompt) = request.prompt {
183            form.text("prompt", prompt)
184        } else {
185            form
186        };
187
188        let response = self
189            .client
190            .post_multipart("/audio/transcriptions", form)
191            .await?;
192        if !response.status().is_success() {
193            return Err(map_error_response(response).await);
194        }
195        let audio_response: AudioResponse = response.json().await?;
196        Ok(audio_response)
197    }
198
199    async fn text_to_speech(
200        &self,
201        mut request: SpeechRequest,
202    ) -> Result<SpeechResponse, ProviderError> {
203        request.model = self.map_model(&request.model);
204
205        let response = self.client.post_json_raw("/audio/speech", &request).await?;
206        if !response.status().is_success() {
207            return Err(map_error_response(response).await);
208        }
209
210        let content_type = response
211            .headers()
212            .get("content-type")
213            .and_then(|ct| ct.to_str().ok())
214            .unwrap_or("audio/mpeg")
215            .to_string();
216
217        let audio_bytes = response.bytes().await?;
218
219        Ok(SpeechResponse {
220            audio: audio_bytes.to_vec(),
221            content_type,
222        })
223    }
224
225    async fn health_check(&self) -> Result<ProviderHealth, ProviderError> {
226        let start = Instant::now();
227
228        let response = self.client.get_json::<serde_json::Value>("/models").await;
229
230        let latency_ms = start.elapsed().as_millis() as u64;
231
232        match response {
233            Ok(_json) => Ok(ProviderHealth {
234                status: HealthStatus::Healthy,
235                latency_ms: Some(latency_ms),
236                error_rate: 0.0,
237                last_check: chrono::Utc::now(),
238                details: HashMap::new(),
239            }),
240            Err(e) => {
241                let mut details = HashMap::new();
242                details.insert("error".to_string(), e.to_string());
243
244                Ok(ProviderHealth {
245                    status: HealthStatus::Degraded,
246                    latency_ms: Some(latency_ms),
247                    error_rate: 1.0,
248                    last_check: chrono::Utc::now(),
249                    details,
250                })
251            }
252        }
253    }
254}