ultrafast_models_sdk/providers/
cohere.rs1use crate::error::ProviderError;
2use crate::models::{
3 AudioRequest, AudioResponse, ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
4 ImageRequest, ImageResponse, SpeechRequest, SpeechResponse, StreamChunk,
5};
6use crate::providers::{HealthStatus, Provider, ProviderConfig, ProviderHealth, StreamResult};
7use async_stream::stream;
8use serde_json::json;
9
10use super::http_client::{map_error_response, AuthStrategy, HttpProviderClient};
11
12use std::collections::HashMap;
13use std::time::Instant;
14
15pub struct CohereProvider {
16 http: HttpProviderClient,
17 config: ProviderConfig,
18}
19
20impl CohereProvider {
21 pub fn new(config: ProviderConfig) -> Result<Self, ProviderError> {
22 let http = HttpProviderClient::new(
23 config.timeout,
24 config.base_url.clone(),
25 "https://api.cohere.ai/v1",
26 &config.headers,
27 AuthStrategy::Bearer {
28 token: config.api_key.clone(),
29 },
30 )?;
31 Ok(Self { http, config })
32 }
33
34 fn map_model(&self, model: &str) -> String {
35 self.config
36 .model_mapping
37 .get(model)
38 .cloned()
39 .unwrap_or_else(|| model.to_string())
40 }
41
42 #[allow(dead_code)]
43 async fn handle_error_response(&self, response: reqwest::Response) -> ProviderError {
44 let status = response.status();
45
46 match response.text().await {
47 Ok(body) => {
48 if let Ok(error_json) = serde_json::from_str::<serde_json::Value>(&body) {
49 let message = error_json
50 .get("message")
51 .and_then(|m| m.as_str())
52 .unwrap_or("Unknown API error")
53 .to_string();
54
55 match status.as_u16() {
56 401 => ProviderError::InvalidApiKey,
57 404 => ProviderError::ModelNotFound {
58 model: "unknown".to_string(),
59 },
60 429 => ProviderError::RateLimit,
61 _ => ProviderError::Api {
62 code: status.as_u16(),
63 message,
64 },
65 }
66 } else {
67 ProviderError::Api {
68 code: status.as_u16(),
69 message: body,
70 }
71 }
72 }
73 Err(_) => ProviderError::Api {
74 code: status.as_u16(),
75 message: "Failed to read error response".to_string(),
76 },
77 }
78 }
79
80 #[allow(dead_code)]
81 fn convert_messages_to_cohere_format(
82 &self,
83 messages: &[crate::models::Message],
84 ) -> Vec<serde_json::Value> {
85 messages
86 .iter()
87 .map(|msg| {
88 let role = match msg.role {
89 crate::models::Role::User => "user",
90 crate::models::Role::Assistant => "assistant",
91 crate::models::Role::System => "system",
92 crate::models::Role::Tool => "user", };
94
95 json!({
96 "role": role,
97 "content": msg.content
98 })
99 })
100 .collect()
101 }
102}
103
104#[async_trait::async_trait]
105impl Provider for CohereProvider {
106 fn name(&self) -> &str {
107 "cohere"
108 }
109
110 fn supports_streaming(&self) -> bool {
111 true
112 }
113
114 fn supports_function_calling(&self) -> bool {
115 false }
117
118 fn supported_models(&self) -> Vec<String> {
119 vec![
120 "command".to_string(),
121 "command-light".to_string(),
122 "command-nightly".to_string(),
123 "command-light-nightly".to_string(),
124 "embed-english-v3.0".to_string(),
125 "embed-multilingual-v3.0".to_string(),
126 ]
127 }
128
129 async fn chat_completion(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
130 let model = self.map_model(&request.model);
131
132 let cohere_request = json!({
134 "model": model,
135 "message": request.messages.last().map(|m| m.content.clone()).unwrap_or_default(),
136 "chat_history": request.messages[..request.messages.len()-1].iter().map(|m| {
137 json!({
138 "role": match m.role {
139 crate::models::Role::User => "user",
140 crate::models::Role::Assistant => "assistant",
141 crate::models::Role::System => "system",
142 crate::models::Role::Tool => "user",
143 },
144 "message": m.content
145 })
146 }).collect::<Vec<_>>(),
147 "temperature": request.temperature.unwrap_or(0.7),
148 "max_tokens": request.max_tokens,
149 "stream": false,
150 });
151
152 let cohere_response: serde_json::Value =
153 self.http.post_json("/chat", &cohere_request).await?;
154
155 let chat_response = ChatResponse {
157 id: cohere_response["response_id"]
158 .as_str()
159 .unwrap_or("")
160 .to_string(),
161 object: "chat.completion".to_string(),
162 created: chrono::Utc::now().timestamp() as u64,
163 model,
164 choices: vec![crate::models::Choice {
165 index: 0,
166 message: crate::models::Message {
167 role: crate::models::Role::Assistant,
168 content: cohere_response["text"].as_str().unwrap_or("").to_string(),
169 name: None,
170 tool_calls: None,
171 tool_call_id: None,
172 },
173 finish_reason: Some("stop".to_string()),
174 logprobs: None,
175 }],
176 usage: Some(crate::models::Usage {
177 prompt_tokens: cohere_response["meta"]["billed_units"]["input_tokens"]
178 .as_u64()
179 .unwrap_or(0) as u32,
180 completion_tokens: cohere_response["meta"]["billed_units"]["output_tokens"]
181 .as_u64()
182 .unwrap_or(0) as u32,
183 total_tokens: cohere_response["meta"]["billed_units"]["input_tokens"]
184 .as_u64()
185 .unwrap_or(0) as u32
186 + cohere_response["meta"]["billed_units"]["output_tokens"]
187 .as_u64()
188 .unwrap_or(0) as u32,
189 }),
190 system_fingerprint: None,
191 };
192
193 Ok(chat_response)
194 }
195
196 async fn stream_chat_completion(
197 &self,
198 request: ChatRequest,
199 ) -> Result<StreamResult, ProviderError> {
200 let model = self.map_model(&request.model);
201
202 let cohere_request = json!({
203 "model": model,
204 "message": request.messages.last().map(|m| m.content.clone()).unwrap_or_default(),
205 "chat_history": request.messages[..request.messages.len()-1].iter().map(|m| {
206 json!({
207 "role": match m.role {
208 crate::models::Role::User => "user",
209 crate::models::Role::Assistant => "assistant",
210 crate::models::Role::System => "system",
211 crate::models::Role::Tool => "user",
212 },
213 "message": m.content
214 })
215 }).collect::<Vec<_>>(),
216 "temperature": request.temperature.unwrap_or(0.7),
217 "max_tokens": request.max_tokens,
218 "stream": true,
219 });
220
221 let response = self.http.post_json_raw("/chat", &cohere_request).await?;
222 if !response.status().is_success() {
223 return Err(map_error_response(response).await);
224 }
225
226 let stream = Box::pin(stream! {
227 let mut bytes_stream = response.bytes_stream();
228 let mut buffer = String::new();
229
230 while let Some(chunk_result) = futures::StreamExt::next(&mut bytes_stream).await {
231 match chunk_result {
232 Ok(chunk) => {
233 let chunk_str = String::from_utf8_lossy(&chunk);
234 buffer.push_str(&chunk_str);
235
236 while let Some(line_end) = buffer.find('\n') {
237 let line = buffer[..line_end].trim().to_string();
238 buffer = buffer[line_end + 1..].to_string();
239
240 if let Some(json_str) = line.strip_prefix("data: ") {
241 if json_str == "[DONE]" {
242 return;
243 }
244
245 if let Ok(cohere_chunk) = serde_json::from_str::<serde_json::Value>(json_str) {
247 if let Some(text) = cohere_chunk["text"].as_str() {
248 let stream_chunk = StreamChunk {
249 id: "cohere-stream".to_string(),
250 object: "chat.completion.chunk".to_string(),
251 created: chrono::Utc::now().timestamp() as u64,
252 model: model.clone(),
253 choices: vec![crate::models::StreamChoice {
254 index: 0,
255 delta: crate::models::Delta {
256 role: None,
257 content: Some(text.to_string()),
258 tool_calls: None,
259 },
260 finish_reason: None,
261 }],
262 };
263 yield Ok(stream_chunk);
264 }
265 }
266 }
267 }
268 }
269 Err(e) => yield Err(ProviderError::Http(e)),
270 }
271 }
272 });
273
274 Ok(stream)
275 }
276
277 async fn embedding(
278 &self,
279 request: EmbeddingRequest,
280 ) -> Result<EmbeddingResponse, ProviderError> {
281 let model = self.map_model(&request.model);
282
283 let input = match &request.input {
284 crate::models::EmbeddingInput::String(s) => vec![s.clone()],
285 crate::models::EmbeddingInput::StringArray(arr) => arr.clone(),
286 _ => {
287 return Err(ProviderError::Configuration {
288 message: "Unsupported embedding input format".to_string(),
289 })
290 }
291 };
292
293 let cohere_request = json!({
294 "model": model,
295 "texts": input,
296 "input_type": "search_document",
297 });
298
299 let cohere_response: serde_json::Value =
300 self.http.post_json("/embed", &cohere_request).await?;
301
302 let embeddings = cohere_response["embeddings"]
304 .as_array()
305 .unwrap_or(&vec![])
306 .iter()
307 .enumerate()
308 .map(|(i, embedding)| {
309 let embedding_vec = embedding["values"]
310 .as_array()
311 .unwrap_or(&vec![])
312 .iter()
313 .filter_map(|v| v.as_f64().map(|f| f as f32))
314 .collect::<Vec<f32>>();
315
316 crate::models::Embedding {
317 object: "embedding".to_string(),
318 embedding: embedding_vec,
319 index: i as u32,
320 }
321 })
322 .collect();
323
324 let embedding_response = EmbeddingResponse {
325 object: "list".to_string(),
326 data: embeddings,
327 model,
328 usage: crate::models::Usage {
329 prompt_tokens: cohere_response["meta"]["billed_units"]["input_tokens"]
330 .as_u64()
331 .unwrap_or(0) as u32,
332 completion_tokens: 0,
333 total_tokens: cohere_response["meta"]["billed_units"]["input_tokens"]
334 .as_u64()
335 .unwrap_or(0) as u32,
336 },
337 };
338
339 Ok(embedding_response)
340 }
341
342 async fn image_generation(
343 &self,
344 _request: ImageRequest,
345 ) -> Result<ImageResponse, ProviderError> {
346 Err(ProviderError::Configuration {
347 message: "Image generation not supported by Cohere".to_string(),
348 })
349 }
350
351 async fn audio_transcription(
352 &self,
353 _request: AudioRequest,
354 ) -> Result<AudioResponse, ProviderError> {
355 Err(ProviderError::Configuration {
356 message: "Audio transcription not supported by Cohere".to_string(),
357 })
358 }
359
360 async fn text_to_speech(
361 &self,
362 _request: SpeechRequest,
363 ) -> Result<SpeechResponse, ProviderError> {
364 Err(ProviderError::Configuration {
365 message: "Text-to-speech not supported by Cohere".to_string(),
366 })
367 }
368
369 async fn health_check(&self) -> Result<ProviderHealth, ProviderError> {
370 let start = Instant::now();
371
372 let response = self.http.get_json::<serde_json::Value>("/models").await;
373
374 let latency_ms = start.elapsed().as_millis() as u64;
375
376 match response {
377 Ok(_) => Ok(ProviderHealth {
378 status: HealthStatus::Healthy,
379 latency_ms: Some(latency_ms),
380 error_rate: 0.0,
381 last_check: chrono::Utc::now(),
382 details: HashMap::new(),
383 }),
384 Err(e) => {
385 let mut details = HashMap::new();
386 details.insert("error".to_string(), e.to_string());
387
388 Ok(ProviderHealth {
389 status: HealthStatus::Degraded,
390 latency_ms: Some(latency_ms),
391 error_rate: 1.0,
392 last_check: chrono::Utc::now(),
393 details,
394 })
395 }
396 }
397 }
398}