ruvector_core/
embeddings.rs

1//! Text Embedding Providers
2//!
3//! This module provides a pluggable embedding system for AgenticDB.
4//!
5//! ## Available Providers
6//!
7//! - **HashEmbedding**: Fast hash-based placeholder (default, not semantic)
8//! - **CandleEmbedding**: Real embeddings using candle-transformers (feature: `real-embeddings`)
9//! - **ApiEmbedding**: External API calls (OpenAI, Anthropic, Cohere, etc.)
10//!
11//! ## Usage
12//!
13//! ```rust,no_run
14//! use ruvector_core::embeddings::{EmbeddingProvider, HashEmbedding, ApiEmbedding};
15//! use ruvector_core::AgenticDB;
16//!
17//! // Default: Hash-based (fast, but not semantic)
18//! let hash_provider = HashEmbedding::new(384);
19//! let embedding = hash_provider.embed("hello world")?;
20//!
21//! // API-based (requires API key)
22//! let api_provider = ApiEmbedding::openai("sk-...", "text-embedding-3-small");
23//! let embedding = api_provider.embed("hello world")?;
24//! # Ok::<(), Box<dyn std::error::Error>>(())
25//! ```
26
27use crate::error::{Result, RuvectorError};
28use std::sync::Arc;
29
30/// Trait for text embedding providers
31pub trait EmbeddingProvider: Send + Sync {
32    /// Generate embedding vector for the given text
33    fn embed(&self, text: &str) -> Result<Vec<f32>>;
34
35    /// Get the dimensionality of embeddings produced by this provider
36    fn dimensions(&self) -> usize;
37
38    /// Get a description of this provider (for logging/debugging)
39    fn name(&self) -> &str;
40}
41
42/// Hash-based embedding provider (placeholder, not semantic)
43///
44/// ⚠️ **WARNING**: This does NOT produce semantic embeddings!
45/// - "dog" and "cat" will NOT be similar
46/// - "dog" and "god" WILL be similar (same characters)
47///
48/// Use this only for:
49/// - Testing
50/// - Prototyping
51/// - When semantic similarity is not required
52#[derive(Debug, Clone)]
53pub struct HashEmbedding {
54    dimensions: usize,
55}
56
57impl HashEmbedding {
58    /// Create a new hash-based embedding provider
59    pub fn new(dimensions: usize) -> Self {
60        Self { dimensions }
61    }
62}
63
64impl EmbeddingProvider for HashEmbedding {
65    fn embed(&self, text: &str) -> Result<Vec<f32>> {
66        let mut embedding = vec![0.0; self.dimensions];
67        let bytes = text.as_bytes();
68
69        for (i, byte) in bytes.iter().enumerate() {
70            embedding[i % self.dimensions] += (*byte as f32) / 255.0;
71        }
72
73        // Normalize
74        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
75        if norm > 0.0 {
76            for val in &mut embedding {
77                *val /= norm;
78            }
79        }
80
81        Ok(embedding)
82    }
83
84    fn dimensions(&self) -> usize {
85        self.dimensions
86    }
87
88    fn name(&self) -> &str {
89        "HashEmbedding (placeholder)"
90    }
91}
92
93/// Real embeddings using candle-transformers
94///
95/// Requires feature flag: `real-embeddings`
96///
97/// ⚠️ **Note**: Full candle integration is complex and model-specific.
98/// For production use, we recommend:
99/// 1. Using the API-based providers (simpler, always up-to-date)
100/// 2. Using ONNX Runtime with pre-exported models
101/// 3. Implementing your own candle wrapper for your specific model
102///
103/// This is a stub implementation showing the structure.
104/// Users should implement `EmbeddingProvider` trait for their specific models.
105#[cfg(feature = "real-embeddings")]
106pub mod candle {
107    use super::*;
108
109    /// Candle-based embedding provider stub
110    ///
111    /// This is a placeholder. For real implementation:
112    /// 1. Add candle dependencies for your specific model type
113    /// 2. Implement model loading and inference
114    /// 3. Handle tokenization appropriately
115    ///
116    /// Example structure:
117    /// ```rust,ignore
118    /// pub struct CandleEmbedding {
119    ///     model: YourModelType,
120    ///     tokenizer: Tokenizer,
121    ///     device: Device,
122    ///     dimensions: usize,
123    /// }
124    /// ```
125    pub struct CandleEmbedding {
126        dimensions: usize,
127        model_id: String,
128    }
129
130    impl CandleEmbedding {
131        /// Create a stub candle embedding provider
132        ///
133        /// **This is not a real implementation!**
134        /// For production, implement with actual model loading.
135        ///
136        /// # Example
137        /// ```rust,no_run
138        /// # #[cfg(feature = "real-embeddings")]
139        /// # {
140        /// use ruvector_core::embeddings::candle::CandleEmbedding;
141        ///
142        /// // This returns an error - real implementation required
143        /// let result = CandleEmbedding::from_pretrained(
144        ///     "sentence-transformers/all-MiniLM-L6-v2",
145        ///     false
146        /// );
147        /// assert!(result.is_err());
148        /// # }
149        /// ```
150        pub fn from_pretrained(model_id: &str, _use_gpu: bool) -> Result<Self> {
151            Err(RuvectorError::ModelLoadError(
152                format!(
153                    "Candle embedding support is a stub. Please:\n\
154                     1. Use ApiEmbedding for production (recommended)\n\
155                     2. Or implement CandleEmbedding for model: {}\n\
156                     3. See docs for ONNX Runtime integration examples",
157                    model_id
158                )
159            ))
160        }
161    }
162
163    impl EmbeddingProvider for CandleEmbedding {
164        fn embed(&self, _text: &str) -> Result<Vec<f32>> {
165            Err(RuvectorError::ModelInferenceError(
166                "Candle embedding not implemented - use ApiEmbedding instead".to_string()
167            ))
168        }
169
170        fn dimensions(&self) -> usize {
171            self.dimensions
172        }
173
174        fn name(&self) -> &str {
175            "CandleEmbedding (stub - not implemented)"
176        }
177    }
178}
179
180#[cfg(feature = "real-embeddings")]
181pub use candle::CandleEmbedding;
182
183/// API-based embedding provider (OpenAI, Anthropic, Cohere, etc.)
184///
185/// Supports any API that accepts JSON and returns embeddings in a standard format.
186///
187/// # Example (OpenAI)
188/// ```rust,no_run
189/// use ruvector_core::embeddings::{EmbeddingProvider, ApiEmbedding};
190///
191/// let provider = ApiEmbedding::openai("sk-...", "text-embedding-3-small");
192/// let embedding = provider.embed("hello world")?;
193/// # Ok::<(), Box<dyn std::error::Error>>(())
194/// ```
195#[cfg(feature = "api-embeddings")]
196#[derive(Clone)]
197pub struct ApiEmbedding {
198    api_key: String,
199    endpoint: String,
200    model: String,
201    dimensions: usize,
202    client: reqwest::blocking::Client,
203}
204
205#[cfg(feature = "api-embeddings")]
206impl ApiEmbedding {
207    /// Create a new API embedding provider
208    ///
209    /// # Arguments
210    /// * `api_key` - API key for authentication
211    /// * `endpoint` - API endpoint URL
212    /// * `model` - Model identifier
213    /// * `dimensions` - Expected embedding dimensions
214    pub fn new(api_key: String, endpoint: String, model: String, dimensions: usize) -> Self {
215        Self {
216            api_key,
217            endpoint,
218            model,
219            dimensions,
220            client: reqwest::blocking::Client::new(),
221        }
222    }
223
224    /// Create OpenAI embedding provider
225    ///
226    /// # Models
227    /// - `text-embedding-3-small` - 1536 dimensions, $0.02/1M tokens
228    /// - `text-embedding-3-large` - 3072 dimensions, $0.13/1M tokens
229    /// - `text-embedding-ada-002` - 1536 dimensions (legacy)
230    pub fn openai(api_key: &str, model: &str) -> Self {
231        let dimensions = match model {
232            "text-embedding-3-large" => 3072,
233            _ => 1536, // text-embedding-3-small and ada-002
234        };
235
236        Self::new(
237            api_key.to_string(),
238            "https://api.openai.com/v1/embeddings".to_string(),
239            model.to_string(),
240            dimensions,
241        )
242    }
243
244    /// Create Cohere embedding provider
245    ///
246    /// # Models
247    /// - `embed-english-v3.0` - 1024 dimensions
248    /// - `embed-multilingual-v3.0` - 1024 dimensions
249    pub fn cohere(api_key: &str, model: &str) -> Self {
250        Self::new(
251            api_key.to_string(),
252            "https://api.cohere.ai/v1/embed".to_string(),
253            model.to_string(),
254            1024,
255        )
256    }
257
258    /// Create Voyage AI embedding provider
259    ///
260    /// # Models
261    /// - `voyage-2` - 1024 dimensions
262    /// - `voyage-large-2` - 1536 dimensions
263    pub fn voyage(api_key: &str, model: &str) -> Self {
264        let dimensions = if model.contains("large") { 1536 } else { 1024 };
265
266        Self::new(
267            api_key.to_string(),
268            "https://api.voyageai.com/v1/embeddings".to_string(),
269            model.to_string(),
270            dimensions,
271        )
272    }
273}
274
275#[cfg(feature = "api-embeddings")]
276impl EmbeddingProvider for ApiEmbedding {
277    fn embed(&self, text: &str) -> Result<Vec<f32>> {
278        let request_body = serde_json::json!({
279            "input": text,
280            "model": self.model,
281        });
282
283        let response = self.client
284            .post(&self.endpoint)
285            .header("Authorization", format!("Bearer {}", self.api_key))
286            .header("Content-Type", "application/json")
287            .json(&request_body)
288            .send()
289            .map_err(|e| RuvectorError::ModelInferenceError(format!("API request failed: {}", e)))?;
290
291        if !response.status().is_success() {
292            let status = response.status();
293            let error_text = response.text().unwrap_or_else(|_| "Unknown error".to_string());
294            return Err(RuvectorError::ModelInferenceError(
295                format!("API returned error {}: {}", status, error_text)
296            ));
297        }
298
299        let response_json: serde_json::Value = response.json()
300            .map_err(|e| RuvectorError::ModelInferenceError(format!("Failed to parse response: {}", e)))?;
301
302        // Handle different API response formats
303        let embedding = if let Some(data) = response_json.get("data") {
304            // OpenAI format: {"data": [{"embedding": [...]}]}
305            data.as_array()
306                .and_then(|arr| arr.first())
307                .and_then(|obj| obj.get("embedding"))
308                .and_then(|emb| emb.as_array())
309                .ok_or_else(|| RuvectorError::ModelInferenceError(
310                    "Invalid OpenAI response format".to_string()
311                ))?
312        } else if let Some(embeddings) = response_json.get("embeddings") {
313            // Cohere format: {"embeddings": [[...]]}
314            embeddings.as_array()
315                .and_then(|arr| arr.first())
316                .and_then(|emb| emb.as_array())
317                .ok_or_else(|| RuvectorError::ModelInferenceError(
318                    "Invalid Cohere response format".to_string()
319                ))?
320        } else {
321            return Err(RuvectorError::ModelInferenceError(
322                "Unknown API response format".to_string()
323            ));
324        };
325
326        let embedding_vec: Result<Vec<f32>> = embedding
327            .iter()
328            .map(|v| v.as_f64()
329                .map(|f| f as f32)
330                .ok_or_else(|| RuvectorError::ModelInferenceError(
331                    "Invalid embedding value".to_string()
332                ))
333            )
334            .collect();
335
336        embedding_vec
337    }
338
339    fn dimensions(&self) -> usize {
340        self.dimensions
341    }
342
343    fn name(&self) -> &str {
344        "ApiEmbedding"
345    }
346}
347
348/// Type-erased embedding provider for dynamic dispatch
349pub type BoxedEmbeddingProvider = Arc<dyn EmbeddingProvider>;
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    #[test]
356    fn test_hash_embedding() {
357        let provider = HashEmbedding::new(128);
358
359        let emb1 = provider.embed("hello world").unwrap();
360        let emb2 = provider.embed("hello world").unwrap();
361
362        assert_eq!(emb1.len(), 128);
363        assert_eq!(emb1, emb2, "Same text should produce same embedding");
364
365        // Check normalization
366        let norm: f32 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt();
367        assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
368    }
369
370    #[test]
371    fn test_hash_embedding_different_text() {
372        let provider = HashEmbedding::new(128);
373
374        let emb1 = provider.embed("hello").unwrap();
375        let emb2 = provider.embed("world").unwrap();
376
377        assert_ne!(emb1, emb2, "Different text should produce different embeddings");
378    }
379
380    #[cfg(feature = "real-embeddings")]
381    #[test]
382    #[ignore] // Requires model download
383    fn test_candle_embedding() {
384        let provider = CandleEmbedding::from_pretrained(
385            "sentence-transformers/all-MiniLM-L6-v2",
386            false
387        ).unwrap();
388
389        let embedding = provider.embed("hello world").unwrap();
390        assert_eq!(embedding.len(), 384);
391
392        // Check normalization
393        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
394        assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
395    }
396
397    #[test]
398    #[ignore] // Requires API key
399    fn test_api_embedding_openai() {
400        let api_key = std::env::var("OPENAI_API_KEY").unwrap();
401        let provider = ApiEmbedding::openai(&api_key, "text-embedding-3-small");
402
403        let embedding = provider.embed("hello world").unwrap();
404        assert_eq!(embedding.len(), 1536);
405    }
406}