ultrafast_models_sdk/providers/
cohere.rs

1use crate::error::ProviderError;
2use crate::models::{
3    AudioRequest, AudioResponse, ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
4    ImageRequest, ImageResponse, SpeechRequest, SpeechResponse, StreamChunk,
5};
6use crate::providers::{HealthStatus, Provider, ProviderConfig, ProviderHealth, StreamResult};
7use async_stream::stream;
8use serde_json::json;
9
10use super::http_client::{map_error_response, AuthStrategy, HttpProviderClient};
11
12use std::collections::HashMap;
13use std::time::Instant;
14
15pub struct CohereProvider {
16    http: HttpProviderClient,
17    config: ProviderConfig,
18}
19
20impl CohereProvider {
21    pub fn new(config: ProviderConfig) -> Result<Self, ProviderError> {
22        let http = HttpProviderClient::new(
23            config.timeout,
24            config.base_url.clone(),
25            "https://api.cohere.ai/v1",
26            &config.headers,
27            AuthStrategy::Bearer {
28                token: config.api_key.clone(),
29            },
30        )?;
31        Ok(Self { http, config })
32    }
33
34    fn map_model(&self, model: &str) -> String {
35        self.config
36            .model_mapping
37            .get(model)
38            .cloned()
39            .unwrap_or_else(|| model.to_string())
40    }
41
42    #[allow(dead_code)]
43    async fn handle_error_response(&self, response: reqwest::Response) -> ProviderError {
44        let status = response.status();
45
46        match response.text().await {
47            Ok(body) => {
48                if let Ok(error_json) = serde_json::from_str::<serde_json::Value>(&body) {
49                    let message = error_json
50                        .get("message")
51                        .and_then(|m| m.as_str())
52                        .unwrap_or("Unknown API error")
53                        .to_string();
54
55                    match status.as_u16() {
56                        401 => ProviderError::InvalidApiKey,
57                        404 => ProviderError::ModelNotFound {
58                            model: "unknown".to_string(),
59                        },
60                        429 => ProviderError::RateLimit,
61                        _ => ProviderError::Api {
62                            code: status.as_u16(),
63                            message,
64                        },
65                    }
66                } else {
67                    ProviderError::Api {
68                        code: status.as_u16(),
69                        message: body,
70                    }
71                }
72            }
73            Err(_) => ProviderError::Api {
74                code: status.as_u16(),
75                message: "Failed to read error response".to_string(),
76            },
77        }
78    }
79
80    #[allow(dead_code)]
81    fn convert_messages_to_cohere_format(
82        &self,
83        messages: &[crate::models::Message],
84    ) -> Vec<serde_json::Value> {
85        messages
86            .iter()
87            .map(|msg| {
88                let role = match msg.role {
89                    crate::models::Role::User => "user",
90                    crate::models::Role::Assistant => "assistant",
91                    crate::models::Role::System => "system",
92                    crate::models::Role::Tool => "user", // Cohere doesn't have tool role
93                };
94
95                json!({
96                    "role": role,
97                    "content": msg.content
98                })
99            })
100            .collect()
101    }
102}
103
104#[async_trait::async_trait]
105impl Provider for CohereProvider {
106    fn name(&self) -> &str {
107        "cohere"
108    }
109
110    fn supports_streaming(&self) -> bool {
111        true
112    }
113
114    fn supports_function_calling(&self) -> bool {
115        false // Cohere doesn't support function calling yet
116    }
117
118    fn supported_models(&self) -> Vec<String> {
119        vec![
120            "command".to_string(),
121            "command-light".to_string(),
122            "command-nightly".to_string(),
123            "command-light-nightly".to_string(),
124            "embed-english-v3.0".to_string(),
125            "embed-multilingual-v3.0".to_string(),
126        ]
127    }
128
129    async fn chat_completion(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
130        let model = self.map_model(&request.model);
131
132        // Convert OpenAI format to Cohere format
133        let cohere_request = json!({
134            "model": model,
135            "message": request.messages.last().map(|m| m.content.clone()).unwrap_or_default(),
136            "chat_history": request.messages[..request.messages.len()-1].iter().map(|m| {
137                json!({
138                    "role": match m.role {
139                        crate::models::Role::User => "user",
140                        crate::models::Role::Assistant => "assistant",
141                        crate::models::Role::System => "system",
142                        crate::models::Role::Tool => "user",
143                    },
144                    "message": m.content
145                })
146            }).collect::<Vec<_>>(),
147            "temperature": request.temperature.unwrap_or(0.7),
148            "max_tokens": request.max_tokens,
149            "stream": false,
150        });
151
152        let cohere_response: serde_json::Value =
153            self.http.post_json("/chat", &cohere_request).await?;
154
155        // Convert Cohere response to OpenAI format
156        let chat_response = ChatResponse {
157            id: cohere_response["response_id"]
158                .as_str()
159                .unwrap_or("")
160                .to_string(),
161            object: "chat.completion".to_string(),
162            created: chrono::Utc::now().timestamp() as u64,
163            model,
164            choices: vec![crate::models::Choice {
165                index: 0,
166                message: crate::models::Message {
167                    role: crate::models::Role::Assistant,
168                    content: cohere_response["text"].as_str().unwrap_or("").to_string(),
169                    name: None,
170                    tool_calls: None,
171                    tool_call_id: None,
172                },
173                finish_reason: Some("stop".to_string()),
174                logprobs: None,
175            }],
176            usage: Some(crate::models::Usage {
177                prompt_tokens: cohere_response["meta"]["billed_units"]["input_tokens"]
178                    .as_u64()
179                    .unwrap_or(0) as u32,
180                completion_tokens: cohere_response["meta"]["billed_units"]["output_tokens"]
181                    .as_u64()
182                    .unwrap_or(0) as u32,
183                total_tokens: cohere_response["meta"]["billed_units"]["input_tokens"]
184                    .as_u64()
185                    .unwrap_or(0) as u32
186                    + cohere_response["meta"]["billed_units"]["output_tokens"]
187                        .as_u64()
188                        .unwrap_or(0) as u32,
189            }),
190            system_fingerprint: None,
191        };
192
193        Ok(chat_response)
194    }
195
196    async fn stream_chat_completion(
197        &self,
198        request: ChatRequest,
199    ) -> Result<StreamResult, ProviderError> {
200        let model = self.map_model(&request.model);
201
202        let cohere_request = json!({
203            "model": model,
204            "message": request.messages.last().map(|m| m.content.clone()).unwrap_or_default(),
205            "chat_history": request.messages[..request.messages.len()-1].iter().map(|m| {
206                json!({
207                    "role": match m.role {
208                        crate::models::Role::User => "user",
209                        crate::models::Role::Assistant => "assistant",
210                        crate::models::Role::System => "system",
211                        crate::models::Role::Tool => "user",
212                    },
213                    "message": m.content
214                })
215            }).collect::<Vec<_>>(),
216            "temperature": request.temperature.unwrap_or(0.7),
217            "max_tokens": request.max_tokens,
218            "stream": true,
219        });
220
221        let response = self.http.post_json_raw("/chat", &cohere_request).await?;
222        if !response.status().is_success() {
223            return Err(map_error_response(response).await);
224        }
225
226        let stream = Box::pin(stream! {
227            let mut bytes_stream = response.bytes_stream();
228            let mut buffer = String::new();
229
230            while let Some(chunk_result) = futures::StreamExt::next(&mut bytes_stream).await {
231                match chunk_result {
232                    Ok(chunk) => {
233                        let chunk_str = String::from_utf8_lossy(&chunk);
234                        buffer.push_str(&chunk_str);
235
236                        while let Some(line_end) = buffer.find('\n') {
237                            let line = buffer[..line_end].trim().to_string();
238                            buffer = buffer[line_end + 1..].to_string();
239
240                            if let Some(json_str) = line.strip_prefix("data: ") {
241                                if json_str == "[DONE]" {
242                                    return;
243                                }
244
245                                // Parse Cohere streaming format and convert to OpenAI format
246                                if let Ok(cohere_chunk) = serde_json::from_str::<serde_json::Value>(json_str) {
247                                    if let Some(text) = cohere_chunk["text"].as_str() {
248                                        let stream_chunk = StreamChunk {
249                                            id: "cohere-stream".to_string(),
250                                            object: "chat.completion.chunk".to_string(),
251                                            created: chrono::Utc::now().timestamp() as u64,
252                                            model: model.clone(),
253                                            choices: vec![crate::models::StreamChoice {
254                                                index: 0,
255                                                delta: crate::models::Delta {
256                                                    role: None,
257                                                    content: Some(text.to_string()),
258                                                    tool_calls: None,
259                                                },
260                                                finish_reason: None,
261                                            }],
262                                        };
263                                        yield Ok(stream_chunk);
264                                    }
265                                }
266                            }
267                        }
268                    }
269                    Err(e) => yield Err(ProviderError::Http(e)),
270                }
271            }
272        });
273
274        Ok(stream)
275    }
276
277    async fn embedding(
278        &self,
279        request: EmbeddingRequest,
280    ) -> Result<EmbeddingResponse, ProviderError> {
281        let model = self.map_model(&request.model);
282
283        let input = match &request.input {
284            crate::models::EmbeddingInput::String(s) => vec![s.clone()],
285            crate::models::EmbeddingInput::StringArray(arr) => arr.clone(),
286            _ => {
287                return Err(ProviderError::Configuration {
288                    message: "Unsupported embedding input format".to_string(),
289                })
290            }
291        };
292
293        let cohere_request = json!({
294            "model": model,
295            "texts": input,
296            "input_type": "search_document",
297        });
298
299        let cohere_response: serde_json::Value =
300            self.http.post_json("/embed", &cohere_request).await?;
301
302        // Convert Cohere response to OpenAI format
303        let embeddings = cohere_response["embeddings"]
304            .as_array()
305            .unwrap_or(&vec![])
306            .iter()
307            .enumerate()
308            .map(|(i, embedding)| {
309                let embedding_vec = embedding["values"]
310                    .as_array()
311                    .unwrap_or(&vec![])
312                    .iter()
313                    .filter_map(|v| v.as_f64().map(|f| f as f32))
314                    .collect::<Vec<f32>>();
315
316                crate::models::Embedding {
317                    object: "embedding".to_string(),
318                    embedding: embedding_vec,
319                    index: i as u32,
320                }
321            })
322            .collect();
323
324        let embedding_response = EmbeddingResponse {
325            object: "list".to_string(),
326            data: embeddings,
327            model,
328            usage: crate::models::Usage {
329                prompt_tokens: cohere_response["meta"]["billed_units"]["input_tokens"]
330                    .as_u64()
331                    .unwrap_or(0) as u32,
332                completion_tokens: 0,
333                total_tokens: cohere_response["meta"]["billed_units"]["input_tokens"]
334                    .as_u64()
335                    .unwrap_or(0) as u32,
336            },
337        };
338
339        Ok(embedding_response)
340    }
341
342    async fn image_generation(
343        &self,
344        _request: ImageRequest,
345    ) -> Result<ImageResponse, ProviderError> {
346        Err(ProviderError::Configuration {
347            message: "Image generation not supported by Cohere".to_string(),
348        })
349    }
350
351    async fn audio_transcription(
352        &self,
353        _request: AudioRequest,
354    ) -> Result<AudioResponse, ProviderError> {
355        Err(ProviderError::Configuration {
356            message: "Audio transcription not supported by Cohere".to_string(),
357        })
358    }
359
360    async fn text_to_speech(
361        &self,
362        _request: SpeechRequest,
363    ) -> Result<SpeechResponse, ProviderError> {
364        Err(ProviderError::Configuration {
365            message: "Text-to-speech not supported by Cohere".to_string(),
366        })
367    }
368
369    async fn health_check(&self) -> Result<ProviderHealth, ProviderError> {
370        let start = Instant::now();
371
372        let response = self.http.get_json::<serde_json::Value>("/models").await;
373
374        let latency_ms = start.elapsed().as_millis() as u64;
375
376        match response {
377            Ok(_) => Ok(ProviderHealth {
378                status: HealthStatus::Healthy,
379                latency_ms: Some(latency_ms),
380                error_rate: 0.0,
381                last_check: chrono::Utc::now(),
382                details: HashMap::new(),
383            }),
384            Err(e) => {
385                let mut details = HashMap::new();
386                details.insert("error".to_string(), e.to_string());
387
388                Ok(ProviderHealth {
389                    status: HealthStatus::Degraded,
390                    latency_ms: Some(latency_ms),
391                    error_rate: 1.0,
392                    last_check: chrono::Utc::now(),
393                    details,
394                })
395            }
396        }
397    }
398}