1use crate::error::ProviderError;
2use crate::models::{
3 AudioRequest, AudioResponse, ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
4 ImageRequest, ImageResponse, SpeechRequest, SpeechResponse, StreamChunk,
5};
6use crate::providers::{HealthStatus, Provider, ProviderConfig, ProviderHealth, StreamResult};
7use async_stream::stream;
8use serde_json::json;
9
10use super::http_client::{map_error_response, AuthStrategy, HttpProviderClient};
11
12use std::collections::HashMap;
13use std::time::Instant;
14
15#[derive(Debug, Clone)]
16pub struct CustomProviderConfig {
17 pub chat_endpoint: String,
18 pub embedding_endpoint: Option<String>,
19 pub image_endpoint: Option<String>,
20 pub audio_endpoint: Option<String>,
21 pub speech_endpoint: Option<String>,
22 pub request_format: RequestFormat,
23 pub response_format: ResponseFormat,
24 pub auth_type: AuthType,
25}
26
27#[derive(Debug, Clone)]
28pub enum RequestFormat {
29 OpenAI,
30 Anthropic,
31 Custom { template: String },
32}
33
34#[derive(Debug, Clone)]
35pub enum ResponseFormat {
36 OpenAI,
37 Anthropic,
38 Custom { template: String },
39}
40
41#[derive(Debug, Clone)]
42pub enum AuthType {
43 Bearer,
44 ApiKey,
45 Custom { header: String },
46 None,
47}
48
49pub struct CustomProvider {
50 http: HttpProviderClient,
51 config: ProviderConfig,
52 custom_config: CustomProviderConfig,
53}
54
55impl CustomProvider {
56 pub fn new(
57 config: ProviderConfig,
58 custom_config: CustomProviderConfig,
59 ) -> Result<Self, ProviderError> {
60 let auth = match &custom_config.auth_type {
61 AuthType::Bearer => AuthStrategy::Bearer {
62 token: config.api_key.clone(),
63 },
64 AuthType::ApiKey => AuthStrategy::Header {
65 name: "X-API-Key".to_string(),
66 value: config.api_key.clone(),
67 },
68 AuthType::Custom { header } => AuthStrategy::Header {
69 name: header.clone(),
70 value: config.api_key.clone(),
71 },
72 AuthType::None => AuthStrategy::None,
73 };
74
75 let http = HttpProviderClient::new(
76 config.timeout,
77 config.base_url.clone(),
78 "http://localhost:8080",
79 &config.headers,
80 auth,
81 )?;
82
83 Ok(Self {
84 http,
85 config,
86 custom_config,
87 })
88 }
89
90 fn map_model(&self, model: &str) -> String {
91 self.config
92 .model_mapping
93 .get(model)
94 .cloned()
95 .unwrap_or_else(|| model.to_string())
96 }
97
98 #[allow(dead_code)]
99 async fn handle_error_response(&self, response: reqwest::Response) -> ProviderError {
100 let status = response.status();
101
102 match response.text().await {
103 Ok(body) => {
104 if let Ok(error_json) = serde_json::from_str::<serde_json::Value>(&body) {
105 let message = error_json
106 .get("error")
107 .and_then(|e| e.get("message"))
108 .and_then(|m| m.as_str())
109 .unwrap_or("Unknown API error")
110 .to_string();
111
112 match status.as_u16() {
113 401 => ProviderError::InvalidApiKey,
114 404 => ProviderError::ModelNotFound {
115 model: "unknown".to_string(),
116 },
117 429 => ProviderError::RateLimit,
118 _ => ProviderError::Api {
119 code: status.as_u16(),
120 message,
121 },
122 }
123 } else {
124 ProviderError::Api {
125 code: status.as_u16(),
126 message: body,
127 }
128 }
129 }
130 Err(_) => ProviderError::Api {
131 code: status.as_u16(),
132 message: "Failed to read error response".to_string(),
133 },
134 }
135 }
136
137 fn format_request(&self, request: &ChatRequest) -> Result<serde_json::Value, ProviderError> {
138 match &self.custom_config.request_format {
139 RequestFormat::OpenAI => Ok(json!({
140 "model": self.map_model(&request.model),
141 "messages": request.messages,
142 "temperature": request.temperature,
143 "max_tokens": request.max_tokens,
144 "stream": request.stream,
145 })),
146 RequestFormat::Anthropic => {
147 let messages = request
148 .messages
149 .iter()
150 .map(|msg| {
151 json!({
152 "role": match msg.role {
153 crate::models::Role::User => "user",
154 crate::models::Role::Assistant => "assistant",
155 crate::models::Role::System => "system",
156 crate::models::Role::Tool => "user",
157 },
158 "content": msg.content
159 })
160 })
161 .collect::<Vec<_>>();
162
163 Ok(json!({
164 "model": self.map_model(&request.model),
165 "messages": messages,
166 "temperature": request.temperature,
167 "max_tokens": request.max_tokens,
168 "stream": request.stream,
169 }))
170 }
171 RequestFormat::Custom { template } => {
172 let mut formatted = template.clone();
174 formatted = formatted.replace("{{model}}", &self.map_model(&request.model));
175 formatted = formatted.replace(
176 "{{temperature}}",
177 &request.temperature.unwrap_or(0.7).to_string(),
178 );
179 formatted = formatted.replace(
180 "{{max_tokens}}",
181 &request.max_tokens.unwrap_or(100).to_string(),
182 );
183
184 serde_json::from_str(&formatted).map_err(|e| ProviderError::Configuration {
185 message: format!("Invalid custom request template: {e}"),
186 })
187 }
188 }
189 }
190
191 fn parse_response(&self, response: serde_json::Value) -> Result<ChatResponse, ProviderError> {
192 match &self.custom_config.response_format {
193 ResponseFormat::OpenAI => {
194 let chat_response: ChatResponse =
195 serde_json::from_value(response).map_err(ProviderError::Serialization)?;
196 Ok(chat_response)
197 }
198 ResponseFormat::Anthropic => {
199 let chat_response = ChatResponse {
201 id: response["id"].as_str().unwrap_or("").to_string(),
202 object: "chat.completion".to_string(),
203 created: chrono::Utc::now().timestamp() as u64,
204 model: response["model"].as_str().unwrap_or("").to_string(),
205 choices: vec![crate::models::Choice {
206 index: 0,
207 message: crate::models::Message {
208 role: crate::models::Role::Assistant,
209 content: response["content"][0]["text"]
210 .as_str()
211 .unwrap_or("")
212 .to_string(),
213 name: None,
214 tool_calls: None,
215 tool_call_id: None,
216 },
217 finish_reason: Some("stop".to_string()),
218 logprobs: None,
219 }],
220 usage: Some(crate::models::Usage {
221 prompt_tokens: response["usage"]["input_tokens"].as_u64().unwrap_or(0)
222 as u32,
223 completion_tokens: response["usage"]["output_tokens"].as_u64().unwrap_or(0)
224 as u32,
225 total_tokens: response["usage"]["input_tokens"].as_u64().unwrap_or(0)
226 as u32
227 + response["usage"]["output_tokens"].as_u64().unwrap_or(0) as u32,
228 }),
229 system_fingerprint: None,
230 };
231 Ok(chat_response)
232 }
233 ResponseFormat::Custom { template } => {
234 let response_str =
236 serde_json::to_string(&response).map_err(ProviderError::Serialization)?;
237
238 let mut formatted = template.clone();
239 formatted = formatted.replace("{{response}}", &response_str);
240
241 serde_json::from_str(&formatted).map_err(|e| ProviderError::Configuration {
242 message: format!("Invalid custom response template: {e}"),
243 })
244 }
245 }
246 }
247}
248
249#[async_trait::async_trait]
250impl Provider for CustomProvider {
251 fn name(&self) -> &str {
252 "custom"
253 }
254
255 fn supports_streaming(&self) -> bool {
256 true
257 }
258
259 fn supports_function_calling(&self) -> bool {
260 false }
262
263 fn supported_models(&self) -> Vec<String> {
264 vec!["custom-model".to_string()]
265 }
266
267 async fn chat_completion(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
268 let formatted_request = self.format_request(&request)?;
269
270 let url = self.custom_config.chat_endpoint.to_string();
271 let response_json: serde_json::Value =
272 self.http.post_json(&url, &formatted_request).await?;
273 let chat_response = self.parse_response(response_json)?;
274 Ok(chat_response)
275 }
276
277 async fn stream_chat_completion(
278 &self,
279 request: ChatRequest,
280 ) -> Result<StreamResult, ProviderError> {
281 let mut formatted_request = self.format_request(&request)?;
282 formatted_request["stream"] = serde_json::Value::Bool(true);
283
284 let url = self.custom_config.chat_endpoint.to_string();
285 let response = self.http.post_json_raw(&url, &formatted_request).await?;
286 if !response.status().is_success() {
287 return Err(map_error_response(response).await);
288 }
289
290 let stream = Box::pin(stream! {
291 let mut bytes_stream = response.bytes_stream();
292 let mut buffer = String::new();
293
294 while let Some(chunk_result) = futures::StreamExt::next(&mut bytes_stream).await {
295 match chunk_result {
296 Ok(chunk) => {
297 let chunk_str = String::from_utf8_lossy(&chunk);
298 buffer.push_str(&chunk_str);
299
300 while let Some(line_end) = buffer.find('\n') {
301 let line = buffer[..line_end].trim().to_string();
302 buffer = buffer[line_end + 1..].to_string();
303
304 if let Some(json_str) = line.strip_prefix("data: ") {
305 if json_str == "[DONE]" {
306 return;
307 }
308
309 match serde_json::from_str::<StreamChunk>(json_str) {
310 Ok(stream_chunk) => yield Ok(stream_chunk),
311 Err(e) => yield Err(ProviderError::Serialization(e)),
312 }
313 }
314 }
315 }
316 Err(e) => yield Err(ProviderError::Http(e)),
317 }
318 }
319 });
320
321 Ok(stream)
322 }
323
324 async fn embedding(
325 &self,
326 request: EmbeddingRequest,
327 ) -> Result<EmbeddingResponse, ProviderError> {
328 if let Some(embedding_endpoint) = &self.custom_config.embedding_endpoint {
329 let model = self.map_model(&request.model);
330
331 let input = match &request.input {
332 crate::models::EmbeddingInput::String(s) => vec![s.clone()],
333 crate::models::EmbeddingInput::StringArray(arr) => arr.clone(),
334 _ => {
335 return Err(ProviderError::Configuration {
336 message: "Unsupported embedding input format".to_string(),
337 })
338 }
339 };
340
341 let embedding_request = json!({
342 "model": model,
343 "input": input,
344 });
345
346 let url = embedding_endpoint.to_string();
347 let embedding_response: EmbeddingResponse =
348 self.http.post_json(&url, &embedding_request).await?;
349 Ok(embedding_response)
350 } else {
351 Err(ProviderError::Configuration {
352 message: "Embeddings not supported by this custom provider".to_string(),
353 })
354 }
355 }
356
357 async fn image_generation(
358 &self,
359 _request: ImageRequest,
360 ) -> Result<ImageResponse, ProviderError> {
361 Err(ProviderError::Configuration {
362 message: "Image generation not supported by custom providers".to_string(),
363 })
364 }
365
366 async fn audio_transcription(
367 &self,
368 _request: AudioRequest,
369 ) -> Result<AudioResponse, ProviderError> {
370 Err(ProviderError::Configuration {
371 message: "Audio transcription not supported by custom providers".to_string(),
372 })
373 }
374
375 async fn text_to_speech(
376 &self,
377 _request: SpeechRequest,
378 ) -> Result<SpeechResponse, ProviderError> {
379 Err(ProviderError::Configuration {
380 message: "Text-to-speech not supported by custom providers".to_string(),
381 })
382 }
383
384 async fn health_check(&self) -> Result<ProviderHealth, ProviderError> {
385 let start = Instant::now();
386
387 let response = self.http.get_json::<serde_json::Value>("/health").await;
388
389 let latency_ms = start.elapsed().as_millis() as u64;
390
391 match response {
392 Ok(_) => Ok(ProviderHealth {
393 status: HealthStatus::Healthy,
394 latency_ms: Some(latency_ms),
395 error_rate: 0.0,
396 last_check: chrono::Utc::now(),
397 details: HashMap::new(),
398 }),
399 Err(e) => {
400 let mut details = HashMap::new();
401 details.insert("error".to_string(), e.to_string());
402
403 Ok(ProviderHealth {
404 status: HealthStatus::Degraded,
405 latency_ms: Some(latency_ms),
406 error_rate: 1.0,
407 last_check: chrono::Utc::now(),
408 details,
409 })
410 }
411 }
412 }
413}