synaptic_cohere/
embeddings.rs1use async_trait::async_trait;
8use serde_json::json;
9use synaptic_core::{Embeddings, SynapticError};
10
11#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum CohereInputType {
16 SearchDocument,
18 SearchQuery,
20 Classification,
22 Clustering,
24}
25
26impl CohereInputType {
27 pub fn as_str(&self) -> &str {
28 match self {
29 CohereInputType::SearchDocument => "search_document",
30 CohereInputType::SearchQuery => "search_query",
31 CohereInputType::Classification => "classification",
32 CohereInputType::Clustering => "clustering",
33 }
34 }
35}
36
37#[derive(Debug, Clone)]
39pub struct CohereEmbeddingsConfig {
40 pub api_key: String,
41 pub model: String,
43 pub input_type: CohereInputType,
45 pub query_input_type: CohereInputType,
47 pub base_url: String,
49}
50
51impl CohereEmbeddingsConfig {
52 pub fn new(api_key: impl Into<String>) -> Self {
53 Self {
54 api_key: api_key.into(),
55 model: "embed-english-v3.0".to_string(),
56 input_type: CohereInputType::SearchDocument,
57 query_input_type: CohereInputType::SearchQuery,
58 base_url: "https://api.cohere.ai/v2".to_string(),
59 }
60 }
61
62 pub fn with_model(mut self, model: impl Into<String>) -> Self {
63 self.model = model.into();
64 self
65 }
66
67 pub fn with_input_type(mut self, input_type: CohereInputType) -> Self {
68 self.input_type = input_type;
69 self
70 }
71
72 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
73 self.base_url = base_url.into();
74 self
75 }
76}
77
78pub struct CohereEmbeddings {
83 config: CohereEmbeddingsConfig,
84 client: reqwest::Client,
85}
86
87impl CohereEmbeddings {
88 pub fn new(config: CohereEmbeddingsConfig) -> Self {
89 Self {
90 config,
91 client: reqwest::Client::new(),
92 }
93 }
94
95 pub fn with_client(config: CohereEmbeddingsConfig, client: reqwest::Client) -> Self {
96 Self { config, client }
97 }
98
99 async fn embed_with_type(
100 &self,
101 texts: &[&str],
102 input_type: &CohereInputType,
103 ) -> Result<Vec<Vec<f32>>, SynapticError> {
104 if texts.is_empty() {
105 return Ok(Vec::new());
106 }
107
108 let body = json!({
109 "model": self.config.model,
110 "texts": texts,
111 "input_type": input_type.as_str(),
112 "embedding_types": ["float"],
113 });
114
115 let response = self
116 .client
117 .post(format!("{}/embed", self.config.base_url))
118 .header("Authorization", format!("Bearer {}", self.config.api_key))
119 .header("Content-Type", "application/json")
120 .json(&body)
121 .send()
122 .await
123 .map_err(|e| SynapticError::Embedding(format!("Cohere embed request: {e}")))?;
124
125 if !response.status().is_success() {
126 let status = response.status().as_u16();
127 let text = response.text().await.unwrap_or_default();
128 return Err(SynapticError::Embedding(format!(
129 "Cohere embed API error ({status}): {text}"
130 )));
131 }
132
133 let resp_body: serde_json::Value = response
134 .json()
135 .await
136 .map_err(|e| SynapticError::Embedding(format!("Cohere embed parse: {e}")))?;
137
138 let float_embeddings = resp_body["embeddings"]["float"]
139 .as_array()
140 .ok_or_else(|| SynapticError::Embedding("missing embeddings.float".to_string()))?;
141
142 let mut result = Vec::with_capacity(float_embeddings.len());
143 for embedding in float_embeddings {
144 let vec = embedding
145 .as_array()
146 .ok_or_else(|| SynapticError::Embedding("embedding is not array".to_string()))?
147 .iter()
148 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
149 .collect();
150 result.push(vec);
151 }
152
153 Ok(result)
154 }
155}
156
157#[async_trait]
158impl Embeddings for CohereEmbeddings {
159 async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
160 self.embed_with_type(texts, &self.config.input_type).await
161 }
162
163 async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError> {
164 let mut results = self
165 .embed_with_type(&[text], &self.config.query_input_type)
166 .await?;
167 results
168 .pop()
169 .ok_or_else(|| SynapticError::Embedding("empty response".to_string()))
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[test]
178 fn config_defaults() {
179 let config = CohereEmbeddingsConfig::new("test-key");
180 assert_eq!(config.model, "embed-english-v3.0");
181 assert_eq!(config.input_type, CohereInputType::SearchDocument);
182 assert_eq!(config.query_input_type, CohereInputType::SearchQuery);
183 }
184
185 #[test]
186 fn config_builder() {
187 let config = CohereEmbeddingsConfig::new("key")
188 .with_model("embed-multilingual-v3.0")
189 .with_input_type(CohereInputType::Clustering);
190 assert_eq!(config.model, "embed-multilingual-v3.0");
191 assert_eq!(config.input_type, CohereInputType::Clustering);
192 }
193}