ultrafast_models_sdk/providers/
custom.rs

1use crate::error::ProviderError;
2use crate::models::{
3    AudioRequest, AudioResponse, ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
4    ImageRequest, ImageResponse, SpeechRequest, SpeechResponse, StreamChunk,
5};
6use crate::providers::{HealthStatus, Provider, ProviderConfig, ProviderHealth, StreamResult};
7use async_stream::stream;
8use serde_json::json;
9
10use super::http_client::{map_error_response, AuthStrategy, HttpProviderClient};
11
12use std::collections::HashMap;
13use std::time::Instant;
14
15#[derive(Debug, Clone)]
16pub struct CustomProviderConfig {
17    pub chat_endpoint: String,
18    pub embedding_endpoint: Option<String>,
19    pub image_endpoint: Option<String>,
20    pub audio_endpoint: Option<String>,
21    pub speech_endpoint: Option<String>,
22    pub request_format: RequestFormat,
23    pub response_format: ResponseFormat,
24    pub auth_type: AuthType,
25}
26
27#[derive(Debug, Clone)]
28pub enum RequestFormat {
29    OpenAI,
30    Anthropic,
31    Custom { template: String },
32}
33
34#[derive(Debug, Clone)]
35pub enum ResponseFormat {
36    OpenAI,
37    Anthropic,
38    Custom { template: String },
39}
40
41#[derive(Debug, Clone)]
42pub enum AuthType {
43    Bearer,
44    ApiKey,
45    Custom { header: String },
46    None,
47}
48
49pub struct CustomProvider {
50    http: HttpProviderClient,
51    config: ProviderConfig,
52    custom_config: CustomProviderConfig,
53}
54
55impl CustomProvider {
56    pub fn new(
57        config: ProviderConfig,
58        custom_config: CustomProviderConfig,
59    ) -> Result<Self, ProviderError> {
60        let auth = match &custom_config.auth_type {
61            AuthType::Bearer => AuthStrategy::Bearer {
62                token: config.api_key.clone(),
63            },
64            AuthType::ApiKey => AuthStrategy::Header {
65                name: "X-API-Key".to_string(),
66                value: config.api_key.clone(),
67            },
68            AuthType::Custom { header } => AuthStrategy::Header {
69                name: header.clone(),
70                value: config.api_key.clone(),
71            },
72            AuthType::None => AuthStrategy::None,
73        };
74
75        let http = HttpProviderClient::new(
76            config.timeout,
77            config.base_url.clone(),
78            "http://localhost:8080",
79            &config.headers,
80            auth,
81        )?;
82
83        Ok(Self {
84            http,
85            config,
86            custom_config,
87        })
88    }
89
90    fn map_model(&self, model: &str) -> String {
91        self.config
92            .model_mapping
93            .get(model)
94            .cloned()
95            .unwrap_or_else(|| model.to_string())
96    }
97
98    #[allow(dead_code)]
99    async fn handle_error_response(&self, response: reqwest::Response) -> ProviderError {
100        let status = response.status();
101
102        match response.text().await {
103            Ok(body) => {
104                if let Ok(error_json) = serde_json::from_str::<serde_json::Value>(&body) {
105                    let message = error_json
106                        .get("error")
107                        .and_then(|e| e.get("message"))
108                        .and_then(|m| m.as_str())
109                        .unwrap_or("Unknown API error")
110                        .to_string();
111
112                    match status.as_u16() {
113                        401 => ProviderError::InvalidApiKey,
114                        404 => ProviderError::ModelNotFound {
115                            model: "unknown".to_string(),
116                        },
117                        429 => ProviderError::RateLimit,
118                        _ => ProviderError::Api {
119                            code: status.as_u16(),
120                            message,
121                        },
122                    }
123                } else {
124                    ProviderError::Api {
125                        code: status.as_u16(),
126                        message: body,
127                    }
128                }
129            }
130            Err(_) => ProviderError::Api {
131                code: status.as_u16(),
132                message: "Failed to read error response".to_string(),
133            },
134        }
135    }
136
137    fn format_request(&self, request: &ChatRequest) -> Result<serde_json::Value, ProviderError> {
138        match &self.custom_config.request_format {
139            RequestFormat::OpenAI => Ok(json!({
140                "model": self.map_model(&request.model),
141                "messages": request.messages,
142                "temperature": request.temperature,
143                "max_tokens": request.max_tokens,
144                "stream": request.stream,
145            })),
146            RequestFormat::Anthropic => {
147                let messages = request
148                    .messages
149                    .iter()
150                    .map(|msg| {
151                        json!({
152                            "role": match msg.role {
153                                crate::models::Role::User => "user",
154                                crate::models::Role::Assistant => "assistant",
155                                crate::models::Role::System => "system",
156                                crate::models::Role::Tool => "user",
157                            },
158                            "content": msg.content
159                        })
160                    })
161                    .collect::<Vec<_>>();
162
163                Ok(json!({
164                    "model": self.map_model(&request.model),
165                    "messages": messages,
166                    "temperature": request.temperature,
167                    "max_tokens": request.max_tokens,
168                    "stream": request.stream,
169                }))
170            }
171            RequestFormat::Custom { template } => {
172                // Simple template substitution - in a real implementation, you'd want a proper templating engine
173                let mut formatted = template.clone();
174                formatted = formatted.replace("{{model}}", &self.map_model(&request.model));
175                formatted = formatted.replace(
176                    "{{temperature}}",
177                    &request.temperature.unwrap_or(0.7).to_string(),
178                );
179                formatted = formatted.replace(
180                    "{{max_tokens}}",
181                    &request.max_tokens.unwrap_or(100).to_string(),
182                );
183
184                serde_json::from_str(&formatted).map_err(|e| ProviderError::Configuration {
185                    message: format!("Invalid custom request template: {e}"),
186                })
187            }
188        }
189    }
190
191    fn parse_response(&self, response: serde_json::Value) -> Result<ChatResponse, ProviderError> {
192        match &self.custom_config.response_format {
193            ResponseFormat::OpenAI => {
194                let chat_response: ChatResponse =
195                    serde_json::from_value(response).map_err(ProviderError::Serialization)?;
196                Ok(chat_response)
197            }
198            ResponseFormat::Anthropic => {
199                // Convert Anthropic format to OpenAI format
200                let chat_response = ChatResponse {
201                    id: response["id"].as_str().unwrap_or("").to_string(),
202                    object: "chat.completion".to_string(),
203                    created: chrono::Utc::now().timestamp() as u64,
204                    model: response["model"].as_str().unwrap_or("").to_string(),
205                    choices: vec![crate::models::Choice {
206                        index: 0,
207                        message: crate::models::Message {
208                            role: crate::models::Role::Assistant,
209                            content: response["content"][0]["text"]
210                                .as_str()
211                                .unwrap_or("")
212                                .to_string(),
213                            name: None,
214                            tool_calls: None,
215                            tool_call_id: None,
216                        },
217                        finish_reason: Some("stop".to_string()),
218                        logprobs: None,
219                    }],
220                    usage: Some(crate::models::Usage {
221                        prompt_tokens: response["usage"]["input_tokens"].as_u64().unwrap_or(0)
222                            as u32,
223                        completion_tokens: response["usage"]["output_tokens"].as_u64().unwrap_or(0)
224                            as u32,
225                        total_tokens: response["usage"]["input_tokens"].as_u64().unwrap_or(0)
226                            as u32
227                            + response["usage"]["output_tokens"].as_u64().unwrap_or(0) as u32,
228                    }),
229                    system_fingerprint: None,
230                };
231                Ok(chat_response)
232            }
233            ResponseFormat::Custom { template } => {
234                // Simple template parsing - in a real implementation, you'd want a proper templating engine
235                let response_str =
236                    serde_json::to_string(&response).map_err(ProviderError::Serialization)?;
237
238                let mut formatted = template.clone();
239                formatted = formatted.replace("{{response}}", &response_str);
240
241                serde_json::from_str(&formatted).map_err(|e| ProviderError::Configuration {
242                    message: format!("Invalid custom response template: {e}"),
243                })
244            }
245        }
246    }
247}
248
249#[async_trait::async_trait]
250impl Provider for CustomProvider {
251    fn name(&self) -> &str {
252        "custom"
253    }
254
255    fn supports_streaming(&self) -> bool {
256        true
257    }
258
259    fn supports_function_calling(&self) -> bool {
260        false // Custom providers don't support function calling by default
261    }
262
263    fn supported_models(&self) -> Vec<String> {
264        vec!["custom-model".to_string()]
265    }
266
267    async fn chat_completion(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
268        let formatted_request = self.format_request(&request)?;
269
270        let url = self.custom_config.chat_endpoint.to_string();
271        let response_json: serde_json::Value =
272            self.http.post_json(&url, &formatted_request).await?;
273        let chat_response = self.parse_response(response_json)?;
274        Ok(chat_response)
275    }
276
277    async fn stream_chat_completion(
278        &self,
279        request: ChatRequest,
280    ) -> Result<StreamResult, ProviderError> {
281        let mut formatted_request = self.format_request(&request)?;
282        formatted_request["stream"] = serde_json::Value::Bool(true);
283
284        let url = self.custom_config.chat_endpoint.to_string();
285        let response = self.http.post_json_raw(&url, &formatted_request).await?;
286        if !response.status().is_success() {
287            return Err(map_error_response(response).await);
288        }
289
290        let stream = Box::pin(stream! {
291            let mut bytes_stream = response.bytes_stream();
292            let mut buffer = String::new();
293
294            while let Some(chunk_result) = futures::StreamExt::next(&mut bytes_stream).await {
295                match chunk_result {
296                    Ok(chunk) => {
297                        let chunk_str = String::from_utf8_lossy(&chunk);
298                        buffer.push_str(&chunk_str);
299
300                        while let Some(line_end) = buffer.find('\n') {
301                            let line = buffer[..line_end].trim().to_string();
302                            buffer = buffer[line_end + 1..].to_string();
303
304                            if let Some(json_str) = line.strip_prefix("data: ") {
305                                if json_str == "[DONE]" {
306                                    return;
307                                }
308
309                                match serde_json::from_str::<StreamChunk>(json_str) {
310                                    Ok(stream_chunk) => yield Ok(stream_chunk),
311                                    Err(e) => yield Err(ProviderError::Serialization(e)),
312                                }
313                            }
314                        }
315                    }
316                    Err(e) => yield Err(ProviderError::Http(e)),
317                }
318            }
319        });
320
321        Ok(stream)
322    }
323
324    async fn embedding(
325        &self,
326        request: EmbeddingRequest,
327    ) -> Result<EmbeddingResponse, ProviderError> {
328        if let Some(embedding_endpoint) = &self.custom_config.embedding_endpoint {
329            let model = self.map_model(&request.model);
330
331            let input = match &request.input {
332                crate::models::EmbeddingInput::String(s) => vec![s.clone()],
333                crate::models::EmbeddingInput::StringArray(arr) => arr.clone(),
334                _ => {
335                    return Err(ProviderError::Configuration {
336                        message: "Unsupported embedding input format".to_string(),
337                    })
338                }
339            };
340
341            let embedding_request = json!({
342                "model": model,
343                "input": input,
344            });
345
346            let url = embedding_endpoint.to_string();
347            let embedding_response: EmbeddingResponse =
348                self.http.post_json(&url, &embedding_request).await?;
349            Ok(embedding_response)
350        } else {
351            Err(ProviderError::Configuration {
352                message: "Embeddings not supported by this custom provider".to_string(),
353            })
354        }
355    }
356
357    async fn image_generation(
358        &self,
359        _request: ImageRequest,
360    ) -> Result<ImageResponse, ProviderError> {
361        Err(ProviderError::Configuration {
362            message: "Image generation not supported by custom providers".to_string(),
363        })
364    }
365
366    async fn audio_transcription(
367        &self,
368        _request: AudioRequest,
369    ) -> Result<AudioResponse, ProviderError> {
370        Err(ProviderError::Configuration {
371            message: "Audio transcription not supported by custom providers".to_string(),
372        })
373    }
374
375    async fn text_to_speech(
376        &self,
377        _request: SpeechRequest,
378    ) -> Result<SpeechResponse, ProviderError> {
379        Err(ProviderError::Configuration {
380            message: "Text-to-speech not supported by custom providers".to_string(),
381        })
382    }
383
384    async fn health_check(&self) -> Result<ProviderHealth, ProviderError> {
385        let start = Instant::now();
386
387        let response = self.http.get_json::<serde_json::Value>("/health").await;
388
389        let latency_ms = start.elapsed().as_millis() as u64;
390
391        match response {
392            Ok(_) => Ok(ProviderHealth {
393                status: HealthStatus::Healthy,
394                latency_ms: Some(latency_ms),
395                error_rate: 0.0,
396                last_check: chrono::Utc::now(),
397                details: HashMap::new(),
398            }),
399            Err(e) => {
400                let mut details = HashMap::new();
401                details.insert("error".to_string(), e.to_string());
402
403                Ok(ProviderHealth {
404                    status: HealthStatus::Degraded,
405                    latency_ms: Some(latency_ms),
406                    error_rate: 1.0,
407                    last_check: chrono::Utc::now(),
408                    details,
409                })
410            }
411        }
412    }
413}