ultrafast_models_sdk/providers/
openai.rs1use crate::error::ProviderError;
2use crate::models::{
3 AudioRequest, AudioResponse, ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
4 ImageRequest, ImageResponse, SpeechRequest, SpeechResponse, StreamChunk,
5};
6use crate::providers::{HealthStatus, Provider, ProviderConfig, ProviderHealth, StreamResult};
7use async_stream::stream;
8
9use super::http_client::{map_error_response, AuthStrategy, HttpProviderClient};
10
11use std::collections::HashMap;
12use std::time::Instant;
13
14pub struct OpenAIProvider {
15 client: HttpProviderClient,
16 config: ProviderConfig,
17}
18
19impl OpenAIProvider {
20 pub fn new(config: ProviderConfig) -> Result<Self, ProviderError> {
21 let client = HttpProviderClient::new(
22 config.timeout,
23 config.base_url.clone(),
24 "https://api.openai.com/v1",
25 &config.headers,
26 AuthStrategy::Bearer {
27 token: config.api_key.clone(),
28 },
29 )?;
30
31 Ok(Self { client, config })
32 }
33
34 fn map_model(&self, model: &str) -> String {
35 self.config
36 .model_mapping
37 .get(model)
38 .cloned()
39 .unwrap_or_else(|| model.to_string())
40 }
41}
42
43#[async_trait::async_trait]
44impl Provider for OpenAIProvider {
45 fn name(&self) -> &str {
46 "openai"
47 }
48
49 fn supports_streaming(&self) -> bool {
50 true
51 }
52
53 fn supports_function_calling(&self) -> bool {
54 true
55 }
56
57 fn supported_models(&self) -> Vec<String> {
58 vec![
59 "gpt-4".to_string(),
60 "gpt-4-turbo".to_string(),
61 "gpt-4-turbo-preview".to_string(),
62 "gpt-3.5-turbo".to_string(),
63 "gpt-3.5-turbo-16k".to_string(),
64 "text-embedding-ada-002".to_string(),
65 "text-embedding-3-small".to_string(),
66 "text-embedding-3-large".to_string(),
67 "dall-e-2".to_string(),
68 "dall-e-3".to_string(),
69 "whisper-1".to_string(),
70 "tts-1".to_string(),
71 "tts-1-hd".to_string(),
72 ]
73 }
74
75 async fn chat_completion(
76 &self,
77 mut request: ChatRequest,
78 ) -> Result<ChatResponse, ProviderError> {
79 request.model = self.map_model(&request.model);
80
81 let chat_response: ChatResponse =
82 self.client.post_json("/chat/completions", &request).await?;
83 Ok(chat_response)
84 }
85
86 async fn stream_chat_completion(
87 &self,
88 mut request: ChatRequest,
89 ) -> Result<StreamResult, ProviderError> {
90 request.model = self.map_model(&request.model);
91 request.stream = Some(true);
92
93 let response = self
94 .client
95 .post_json_raw("/chat/completions", &request)
96 .await?;
97 if !response.status().is_success() {
98 return Err(map_error_response(response).await);
99 }
100
101 let stream = Box::pin(stream! {
102 let mut bytes_stream = response.bytes_stream();
103 let mut buffer = String::new();
104
105 while let Some(chunk_result) = futures::StreamExt::next(&mut bytes_stream).await {
106 match chunk_result {
107 Ok(chunk) => {
108 let chunk_str = String::from_utf8_lossy(&chunk);
109 buffer.push_str(&chunk_str);
110
111 while let Some(line_end) = buffer.find('\n') {
112 let line = buffer[..line_end].trim().to_string();
113 buffer = buffer[line_end + 1..].to_string();
114
115 if let Some(json_str) = line.strip_prefix("data: ") {
116 if json_str == "[DONE]" {
117 return;
118 }
119
120 match serde_json::from_str::<StreamChunk>(json_str) {
121 Ok(stream_chunk) => yield Ok(stream_chunk),
122 Err(e) => yield Err(ProviderError::Serialization(e)),
123 }
124 }
125 }
126 }
127 Err(e) => yield Err(ProviderError::Http(e)),
128 }
129 }
130 });
131
132 Ok(stream)
133 }
134
135 async fn embedding(
136 &self,
137 mut request: EmbeddingRequest,
138 ) -> Result<EmbeddingResponse, ProviderError> {
139 request.model = self.map_model(&request.model);
140
141 let embedding_response: EmbeddingResponse =
142 self.client.post_json("/embeddings", &request).await?;
143 Ok(embedding_response)
144 }
145
146 async fn image_generation(
147 &self,
148 mut request: ImageRequest,
149 ) -> Result<ImageResponse, ProviderError> {
150 if let Some(ref model) = request.model {
151 request.model = Some(self.map_model(model));
152 }
153
154 let image_response: ImageResponse = self
155 .client
156 .post_json("/images/generations", &request)
157 .await?;
158 Ok(image_response)
159 }
160
161 async fn audio_transcription(
162 &self,
163 mut request: AudioRequest,
164 ) -> Result<AudioResponse, ProviderError> {
165 request.model = self.map_model(&request.model);
166
167 let form = reqwest::multipart::Form::new()
168 .part(
169 "file",
170 reqwest::multipart::Part::bytes(request.file)
171 .file_name("audio.mp3")
172 .mime_str("audio/mpeg")?,
173 )
174 .text("model", request.model);
175
176 let form = if let Some(language) = request.language {
177 form.text("language", language)
178 } else {
179 form
180 };
181
182 let form = if let Some(prompt) = request.prompt {
183 form.text("prompt", prompt)
184 } else {
185 form
186 };
187
188 let response = self
189 .client
190 .post_multipart("/audio/transcriptions", form)
191 .await?;
192 if !response.status().is_success() {
193 return Err(map_error_response(response).await);
194 }
195 let audio_response: AudioResponse = response.json().await?;
196 Ok(audio_response)
197 }
198
199 async fn text_to_speech(
200 &self,
201 mut request: SpeechRequest,
202 ) -> Result<SpeechResponse, ProviderError> {
203 request.model = self.map_model(&request.model);
204
205 let response = self.client.post_json_raw("/audio/speech", &request).await?;
206 if !response.status().is_success() {
207 return Err(map_error_response(response).await);
208 }
209
210 let content_type = response
211 .headers()
212 .get("content-type")
213 .and_then(|ct| ct.to_str().ok())
214 .unwrap_or("audio/mpeg")
215 .to_string();
216
217 let audio_bytes = response.bytes().await?;
218
219 Ok(SpeechResponse {
220 audio: audio_bytes.to_vec(),
221 content_type,
222 })
223 }
224
225 async fn health_check(&self) -> Result<ProviderHealth, ProviderError> {
226 let start = Instant::now();
227
228 let response = self.client.get_json::<serde_json::Value>("/models").await;
229
230 let latency_ms = start.elapsed().as_millis() as u64;
231
232 match response {
233 Ok(_json) => Ok(ProviderHealth {
234 status: HealthStatus::Healthy,
235 latency_ms: Some(latency_ms),
236 error_rate: 0.0,
237 last_check: chrono::Utc::now(),
238 details: HashMap::new(),
239 }),
240 Err(e) => {
241 let mut details = HashMap::new();
242 details.insert("error".to_string(), e.to_string());
243
244 Ok(ProviderHealth {
245 status: HealthStatus::Degraded,
246 latency_ms: Some(latency_ms),
247 error_rate: 1.0,
248 last_check: chrono::Utc::now(),
249 details,
250 })
251 }
252 }
253 }
254}