Skip to main content

synaptic_cohere/
embeddings.rs

1//! Cohere Embeddings implementation using the native Cohere v2 API.
2//!
3//! Unlike the OpenAI-compatible endpoint, this implementation supports Cohere's
4//! `input_type` parameter, which is required for optimal retrieval performance:
5//! use `search_document` when embedding documents and `search_query` when embedding queries.
6
7use async_trait::async_trait;
8use serde_json::json;
9use synaptic_core::{Embeddings, SynapticError};
10
11/// Input type for Cohere embeddings.
12///
13/// Using the correct input type is important for retrieval quality.
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum CohereInputType {
16    /// For embedding documents to be stored in a vector database.
17    SearchDocument,
18    /// For embedding queries used to search the vector database.
19    SearchQuery,
20    /// For classification tasks.
21    Classification,
22    /// For clustering tasks.
23    Clustering,
24}
25
26impl CohereInputType {
27    pub fn as_str(&self) -> &str {
28        match self {
29            CohereInputType::SearchDocument => "search_document",
30            CohereInputType::SearchQuery => "search_query",
31            CohereInputType::Classification => "classification",
32            CohereInputType::Clustering => "clustering",
33        }
34    }
35}
36
37/// Configuration for [`CohereEmbeddings`].
38#[derive(Debug, Clone)]
39pub struct CohereEmbeddingsConfig {
40    pub api_key: String,
41    /// Model name (default: `"embed-english-v3.0"`).
42    pub model: String,
43    /// Input type for document embedding (default: `SearchDocument`).
44    pub input_type: CohereInputType,
45    /// Query input type (default: `SearchQuery`).
46    pub query_input_type: CohereInputType,
47    /// Base URL (default: `"https://api.cohere.ai/v2"`).
48    pub base_url: String,
49}
50
51impl CohereEmbeddingsConfig {
52    pub fn new(api_key: impl Into<String>) -> Self {
53        Self {
54            api_key: api_key.into(),
55            model: "embed-english-v3.0".to_string(),
56            input_type: CohereInputType::SearchDocument,
57            query_input_type: CohereInputType::SearchQuery,
58            base_url: "https://api.cohere.ai/v2".to_string(),
59        }
60    }
61
62    pub fn with_model(mut self, model: impl Into<String>) -> Self {
63        self.model = model.into();
64        self
65    }
66
67    pub fn with_input_type(mut self, input_type: CohereInputType) -> Self {
68        self.input_type = input_type;
69        self
70    }
71
72    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
73        self.base_url = base_url.into();
74        self
75    }
76}
77
78/// Embeddings backed by the Cohere Embed API.
79///
80/// Supports all Cohere embedding models including `embed-english-v3.0` (1024-dim)
81/// and `embed-multilingual-v3.0` (1024-dim).
82pub struct CohereEmbeddings {
83    config: CohereEmbeddingsConfig,
84    client: reqwest::Client,
85}
86
87impl CohereEmbeddings {
88    pub fn new(config: CohereEmbeddingsConfig) -> Self {
89        Self {
90            config,
91            client: reqwest::Client::new(),
92        }
93    }
94
95    pub fn with_client(config: CohereEmbeddingsConfig, client: reqwest::Client) -> Self {
96        Self { config, client }
97    }
98
99    async fn embed_with_type(
100        &self,
101        texts: &[&str],
102        input_type: &CohereInputType,
103    ) -> Result<Vec<Vec<f32>>, SynapticError> {
104        if texts.is_empty() {
105            return Ok(Vec::new());
106        }
107
108        let body = json!({
109            "model": self.config.model,
110            "texts": texts,
111            "input_type": input_type.as_str(),
112            "embedding_types": ["float"],
113        });
114
115        let response = self
116            .client
117            .post(format!("{}/embed", self.config.base_url))
118            .header("Authorization", format!("Bearer {}", self.config.api_key))
119            .header("Content-Type", "application/json")
120            .json(&body)
121            .send()
122            .await
123            .map_err(|e| SynapticError::Embedding(format!("Cohere embed request: {e}")))?;
124
125        if !response.status().is_success() {
126            let status = response.status().as_u16();
127            let text = response.text().await.unwrap_or_default();
128            return Err(SynapticError::Embedding(format!(
129                "Cohere embed API error ({status}): {text}"
130            )));
131        }
132
133        let resp_body: serde_json::Value = response
134            .json()
135            .await
136            .map_err(|e| SynapticError::Embedding(format!("Cohere embed parse: {e}")))?;
137
138        let float_embeddings = resp_body["embeddings"]["float"]
139            .as_array()
140            .ok_or_else(|| SynapticError::Embedding("missing embeddings.float".to_string()))?;
141
142        let mut result = Vec::with_capacity(float_embeddings.len());
143        for embedding in float_embeddings {
144            let vec = embedding
145                .as_array()
146                .ok_or_else(|| SynapticError::Embedding("embedding is not array".to_string()))?
147                .iter()
148                .map(|v| v.as_f64().unwrap_or(0.0) as f32)
149                .collect();
150            result.push(vec);
151        }
152
153        Ok(result)
154    }
155}
156
157#[async_trait]
158impl Embeddings for CohereEmbeddings {
159    async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
160        self.embed_with_type(texts, &self.config.input_type).await
161    }
162
163    async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError> {
164        let mut results = self
165            .embed_with_type(&[text], &self.config.query_input_type)
166            .await?;
167        results
168            .pop()
169            .ok_or_else(|| SynapticError::Embedding("empty response".to_string()))
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn config_defaults() {
179        let config = CohereEmbeddingsConfig::new("test-key");
180        assert_eq!(config.model, "embed-english-v3.0");
181        assert_eq!(config.input_type, CohereInputType::SearchDocument);
182        assert_eq!(config.query_input_type, CohereInputType::SearchQuery);
183    }
184
185    #[test]
186    fn config_builder() {
187        let config = CohereEmbeddingsConfig::new("key")
188            .with_model("embed-multilingual-v3.0")
189            .with_input_type(CohereInputType::Clustering);
190        assert_eq!(config.model, "embed-multilingual-v3.0");
191        assert_eq!(config.input_type, CohereInputType::Clustering);
192    }
193}