ruvector_core/
embeddings.rs

1//! Text Embedding Providers
2//!
3//! This module provides a pluggable embedding system for AgenticDB.
4//!
5//! ## Available Providers
6//!
7//! - **HashEmbedding**: Fast hash-based placeholder (default, not semantic)
8//! - **CandleEmbedding**: Real embeddings using candle-transformers (feature: `real-embeddings`)
9//! - **ApiEmbedding**: External API calls (OpenAI, Anthropic, Cohere, etc.)
10//!
11//! ## Usage
12//!
13//! ```rust,no_run
14//! use ruvector_core::embeddings::{EmbeddingProvider, HashEmbedding, ApiEmbedding};
15//! use ruvector_core::AgenticDB;
16//!
17//! // Default: Hash-based (fast, but not semantic)
18//! let hash_provider = HashEmbedding::new(384);
19//! let embedding = hash_provider.embed("hello world")?;
20//!
21//! // API-based (requires API key)
22//! let api_provider = ApiEmbedding::openai("sk-...", "text-embedding-3-small");
23//! let embedding = api_provider.embed("hello world")?;
24//! # Ok::<(), Box<dyn std::error::Error>>(())
25//! ```
26
27use crate::error::{Result, RuvectorError};
28use std::sync::Arc;
29
30/// Trait for text embedding providers
31pub trait EmbeddingProvider: Send + Sync {
32    /// Generate embedding vector for the given text
33    fn embed(&self, text: &str) -> Result<Vec<f32>>;
34
35    /// Get the dimensionality of embeddings produced by this provider
36    fn dimensions(&self) -> usize;
37
38    /// Get a description of this provider (for logging/debugging)
39    fn name(&self) -> &str;
40}
41
42/// Hash-based embedding provider (placeholder, not semantic)
43///
44/// ⚠️ **WARNING**: This does NOT produce semantic embeddings!
45/// - "dog" and "cat" will NOT be similar
46/// - "dog" and "god" WILL be similar (same characters)
47///
48/// Use this only for:
49/// - Testing
50/// - Prototyping
51/// - When semantic similarity is not required
52#[derive(Debug, Clone)]
53pub struct HashEmbedding {
54    dimensions: usize,
55}
56
57impl HashEmbedding {
58    /// Create a new hash-based embedding provider
59    pub fn new(dimensions: usize) -> Self {
60        Self { dimensions }
61    }
62}
63
64impl EmbeddingProvider for HashEmbedding {
65    fn embed(&self, text: &str) -> Result<Vec<f32>> {
66        let mut embedding = vec![0.0; self.dimensions];
67        let bytes = text.as_bytes();
68
69        for (i, byte) in bytes.iter().enumerate() {
70            embedding[i % self.dimensions] += (*byte as f32) / 255.0;
71        }
72
73        // Normalize
74        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
75        if norm > 0.0 {
76            for val in &mut embedding {
77                *val /= norm;
78            }
79        }
80
81        Ok(embedding)
82    }
83
84    fn dimensions(&self) -> usize {
85        self.dimensions
86    }
87
88    fn name(&self) -> &str {
89        "HashEmbedding (placeholder)"
90    }
91}
92
93/// Real embeddings using candle-transformers
94///
95/// Requires feature flag: `real-embeddings`
96///
97/// ⚠️ **Note**: Full candle integration is complex and model-specific.
98/// For production use, we recommend:
99/// 1. Using the API-based providers (simpler, always up-to-date)
100/// 2. Using ONNX Runtime with pre-exported models
101/// 3. Implementing your own candle wrapper for your specific model
102///
103/// This is a stub implementation showing the structure.
104/// Users should implement `EmbeddingProvider` trait for their specific models.
105#[cfg(feature = "real-embeddings")]
106pub mod candle {
107    use super::*;
108
109    /// Candle-based embedding provider stub
110    ///
111    /// This is a placeholder. For real implementation:
112    /// 1. Add candle dependencies for your specific model type
113    /// 2. Implement model loading and inference
114    /// 3. Handle tokenization appropriately
115    ///
116    /// Example structure:
117    /// ```rust,ignore
118    /// pub struct CandleEmbedding {
119    ///     model: YourModelType,
120    ///     tokenizer: Tokenizer,
121    ///     device: Device,
122    ///     dimensions: usize,
123    /// }
124    /// ```
125    pub struct CandleEmbedding {
126        dimensions: usize,
127        model_id: String,
128    }
129
130    impl CandleEmbedding {
131        /// Create a stub candle embedding provider
132        ///
133        /// **This is not a real implementation!**
134        /// For production, implement with actual model loading.
135        ///
136        /// # Example
137        /// ```rust,no_run
138        /// # #[cfg(feature = "real-embeddings")]
139        /// # {
140        /// use ruvector_core::embeddings::candle::CandleEmbedding;
141        ///
142        /// // This returns an error - real implementation required
143        /// let result = CandleEmbedding::from_pretrained(
144        ///     "sentence-transformers/all-MiniLM-L6-v2",
145        ///     false
146        /// );
147        /// assert!(result.is_err());
148        /// # }
149        /// ```
150        pub fn from_pretrained(model_id: &str, _use_gpu: bool) -> Result<Self> {
151            Err(RuvectorError::ModelLoadError(
152                format!(
153                    "Candle embedding support is a stub. Please:\n\
154                     1. Use ApiEmbedding for production (recommended)\n\
155                     2. Or implement CandleEmbedding for model: {}\n\
156                     3. See docs for ONNX Runtime integration examples",
157                    model_id
158                )
159            ))
160        }
161    }
162
163    impl EmbeddingProvider for CandleEmbedding {
164        fn embed(&self, _text: &str) -> Result<Vec<f32>> {
165            Err(RuvectorError::ModelInferenceError(
166                "Candle embedding not implemented - use ApiEmbedding instead".to_string()
167            ))
168        }
169
170        fn dimensions(&self) -> usize {
171            self.dimensions
172        }
173
174        fn name(&self) -> &str {
175            "CandleEmbedding (stub - not implemented)"
176        }
177    }
178}
179
180#[cfg(feature = "real-embeddings")]
181pub use candle::CandleEmbedding;
182
183/// API-based embedding provider (OpenAI, Anthropic, Cohere, etc.)
184///
185/// Supports any API that accepts JSON and returns embeddings in a standard format.
186///
187/// # Example (OpenAI)
188/// ```rust,no_run
189/// use ruvector_core::embeddings::{EmbeddingProvider, ApiEmbedding};
190///
191/// let provider = ApiEmbedding::openai("sk-...", "text-embedding-3-small");
192/// let embedding = provider.embed("hello world")?;
193/// # Ok::<(), Box<dyn std::error::Error>>(())
194/// ```
195#[derive(Clone)]
196pub struct ApiEmbedding {
197    api_key: String,
198    endpoint: String,
199    model: String,
200    dimensions: usize,
201    client: reqwest::blocking::Client,
202}
203
204impl ApiEmbedding {
205    /// Create a new API embedding provider
206    ///
207    /// # Arguments
208    /// * `api_key` - API key for authentication
209    /// * `endpoint` - API endpoint URL
210    /// * `model` - Model identifier
211    /// * `dimensions` - Expected embedding dimensions
212    pub fn new(api_key: String, endpoint: String, model: String, dimensions: usize) -> Self {
213        Self {
214            api_key,
215            endpoint,
216            model,
217            dimensions,
218            client: reqwest::blocking::Client::new(),
219        }
220    }
221
222    /// Create OpenAI embedding provider
223    ///
224    /// # Models
225    /// - `text-embedding-3-small` - 1536 dimensions, $0.02/1M tokens
226    /// - `text-embedding-3-large` - 3072 dimensions, $0.13/1M tokens
227    /// - `text-embedding-ada-002` - 1536 dimensions (legacy)
228    pub fn openai(api_key: &str, model: &str) -> Self {
229        let dimensions = match model {
230            "text-embedding-3-large" => 3072,
231            _ => 1536, // text-embedding-3-small and ada-002
232        };
233
234        Self::new(
235            api_key.to_string(),
236            "https://api.openai.com/v1/embeddings".to_string(),
237            model.to_string(),
238            dimensions,
239        )
240    }
241
242    /// Create Cohere embedding provider
243    ///
244    /// # Models
245    /// - `embed-english-v3.0` - 1024 dimensions
246    /// - `embed-multilingual-v3.0` - 1024 dimensions
247    pub fn cohere(api_key: &str, model: &str) -> Self {
248        Self::new(
249            api_key.to_string(),
250            "https://api.cohere.ai/v1/embed".to_string(),
251            model.to_string(),
252            1024,
253        )
254    }
255
256    /// Create Voyage AI embedding provider
257    ///
258    /// # Models
259    /// - `voyage-2` - 1024 dimensions
260    /// - `voyage-large-2` - 1536 dimensions
261    pub fn voyage(api_key: &str, model: &str) -> Self {
262        let dimensions = if model.contains("large") { 1536 } else { 1024 };
263
264        Self::new(
265            api_key.to_string(),
266            "https://api.voyageai.com/v1/embeddings".to_string(),
267            model.to_string(),
268            dimensions,
269        )
270    }
271}
272
273impl EmbeddingProvider for ApiEmbedding {
274    fn embed(&self, text: &str) -> Result<Vec<f32>> {
275        let request_body = serde_json::json!({
276            "input": text,
277            "model": self.model,
278        });
279
280        let response = self.client
281            .post(&self.endpoint)
282            .header("Authorization", format!("Bearer {}", self.api_key))
283            .header("Content-Type", "application/json")
284            .json(&request_body)
285            .send()
286            .map_err(|e| RuvectorError::ModelInferenceError(format!("API request failed: {}", e)))?;
287
288        if !response.status().is_success() {
289            let status = response.status();
290            let error_text = response.text().unwrap_or_else(|_| "Unknown error".to_string());
291            return Err(RuvectorError::ModelInferenceError(
292                format!("API returned error {}: {}", status, error_text)
293            ));
294        }
295
296        let response_json: serde_json::Value = response.json()
297            .map_err(|e| RuvectorError::ModelInferenceError(format!("Failed to parse response: {}", e)))?;
298
299        // Handle different API response formats
300        let embedding = if let Some(data) = response_json.get("data") {
301            // OpenAI format: {"data": [{"embedding": [...]}]}
302            data.as_array()
303                .and_then(|arr| arr.first())
304                .and_then(|obj| obj.get("embedding"))
305                .and_then(|emb| emb.as_array())
306                .ok_or_else(|| RuvectorError::ModelInferenceError(
307                    "Invalid OpenAI response format".to_string()
308                ))?
309        } else if let Some(embeddings) = response_json.get("embeddings") {
310            // Cohere format: {"embeddings": [[...]]}
311            embeddings.as_array()
312                .and_then(|arr| arr.first())
313                .and_then(|emb| emb.as_array())
314                .ok_or_else(|| RuvectorError::ModelInferenceError(
315                    "Invalid Cohere response format".to_string()
316                ))?
317        } else {
318            return Err(RuvectorError::ModelInferenceError(
319                "Unknown API response format".to_string()
320            ));
321        };
322
323        let embedding_vec: Result<Vec<f32>> = embedding
324            .iter()
325            .map(|v| v.as_f64()
326                .map(|f| f as f32)
327                .ok_or_else(|| RuvectorError::ModelInferenceError(
328                    "Invalid embedding value".to_string()
329                ))
330            )
331            .collect();
332
333        embedding_vec
334    }
335
336    fn dimensions(&self) -> usize {
337        self.dimensions
338    }
339
340    fn name(&self) -> &str {
341        "ApiEmbedding"
342    }
343}
344
345/// Type-erased embedding provider for dynamic dispatch
346pub type BoxedEmbeddingProvider = Arc<dyn EmbeddingProvider>;
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn test_hash_embedding() {
354        let provider = HashEmbedding::new(128);
355
356        let emb1 = provider.embed("hello world").unwrap();
357        let emb2 = provider.embed("hello world").unwrap();
358
359        assert_eq!(emb1.len(), 128);
360        assert_eq!(emb1, emb2, "Same text should produce same embedding");
361
362        // Check normalization
363        let norm: f32 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt();
364        assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
365    }
366
367    #[test]
368    fn test_hash_embedding_different_text() {
369        let provider = HashEmbedding::new(128);
370
371        let emb1 = provider.embed("hello").unwrap();
372        let emb2 = provider.embed("world").unwrap();
373
374        assert_ne!(emb1, emb2, "Different text should produce different embeddings");
375    }
376
377    #[cfg(feature = "real-embeddings")]
378    #[test]
379    #[ignore] // Requires model download
380    fn test_candle_embedding() {
381        let provider = CandleEmbedding::from_pretrained(
382            "sentence-transformers/all-MiniLM-L6-v2",
383            false
384        ).unwrap();
385
386        let embedding = provider.embed("hello world").unwrap();
387        assert_eq!(embedding.len(), 384);
388
389        // Check normalization
390        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
391        assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
392    }
393
394    #[test]
395    #[ignore] // Requires API key
396    fn test_api_embedding_openai() {
397        let api_key = std::env::var("OPENAI_API_KEY").unwrap();
398        let provider = ApiEmbedding::openai(&api_key, "text-embedding-3-small");
399
400        let embedding = provider.embed("hello world").unwrap();
401        assert_eq!(embedding.len(), 1536);
402    }
403}