Skip to main content

ruvector_core/
embeddings.rs

1//! Text Embedding Providers
2//!
3//! This module provides a pluggable embedding system for AgenticDB.
4//!
5//! ## Available Providers
6//!
7//! - **HashEmbedding**: Fast hash-based placeholder (default, not semantic)
8//! - **CandleEmbedding**: Real embeddings using candle-transformers (feature: `real-embeddings`)
9//! - **ApiEmbedding**: External API calls (OpenAI, Anthropic, Cohere, etc.)
10//!
11//! ## Usage
12//!
13//! ```rust,no_run
14//! use ruvector_core::embeddings::{EmbeddingProvider, HashEmbedding, ApiEmbedding};
15//! use ruvector_core::AgenticDB;
16//!
17//! // Default: Hash-based (fast, but not semantic)
18//! let hash_provider = HashEmbedding::new(384);
19//! let embedding = hash_provider.embed("hello world")?;
20//!
21//! // API-based (requires API key)
22//! let api_provider = ApiEmbedding::openai("sk-...", "text-embedding-3-small");
23//! let embedding = api_provider.embed("hello world")?;
24//! # Ok::<(), Box<dyn std::error::Error>>(())
25//! ```
26
27use crate::error::Result;
28#[cfg(any(feature = "real-embeddings", feature = "api-embeddings"))]
29use crate::error::RuvectorError;
30use std::sync::Arc;
31
32/// Trait for text embedding providers
33pub trait EmbeddingProvider: Send + Sync {
34    /// Generate embedding vector for the given text
35    fn embed(&self, text: &str) -> Result<Vec<f32>>;
36
37    /// Get the dimensionality of embeddings produced by this provider
38    fn dimensions(&self) -> usize;
39
40    /// Get a description of this provider (for logging/debugging)
41    fn name(&self) -> &str;
42}
43
44/// Hash-based embedding provider (placeholder, not semantic)
45///
46/// ⚠️ **WARNING**: This does NOT produce semantic embeddings!
47/// - "dog" and "cat" will NOT be similar
48/// - "dog" and "god" WILL be similar (same characters)
49///
50/// Use this only for:
51/// - Testing
52/// - Prototyping
53/// - When semantic similarity is not required
54#[derive(Debug, Clone)]
55pub struct HashEmbedding {
56    dimensions: usize,
57}
58
59impl HashEmbedding {
60    /// Create a new hash-based embedding provider
61    pub fn new(dimensions: usize) -> Self {
62        Self { dimensions }
63    }
64}
65
66impl EmbeddingProvider for HashEmbedding {
67    fn embed(&self, text: &str) -> Result<Vec<f32>> {
68        let mut embedding = vec![0.0; self.dimensions];
69        let bytes = text.as_bytes();
70
71        for (i, byte) in bytes.iter().enumerate() {
72            embedding[i % self.dimensions] += (*byte as f32) / 255.0;
73        }
74
75        // Normalize
76        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
77        if norm > 0.0 {
78            for val in &mut embedding {
79                *val /= norm;
80            }
81        }
82
83        Ok(embedding)
84    }
85
86    fn dimensions(&self) -> usize {
87        self.dimensions
88    }
89
90    fn name(&self) -> &str {
91        "HashEmbedding (placeholder)"
92    }
93}
94
95/// Real embeddings using candle-transformers
96///
97/// Requires feature flag: `real-embeddings`
98///
99/// ⚠️ **Note**: Full candle integration is complex and model-specific.
100/// For production use, we recommend:
101/// 1. Using the API-based providers (simpler, always up-to-date)
102/// 2. Using ONNX Runtime with pre-exported models
103/// 3. Implementing your own candle wrapper for your specific model
104///
105/// This is a stub implementation showing the structure.
106/// Users should implement `EmbeddingProvider` trait for their specific models.
107#[cfg(feature = "real-embeddings")]
108pub mod candle {
109    use super::*;
110
111    /// Candle-based embedding provider stub
112    ///
113    /// This is a placeholder. For real implementation:
114    /// 1. Add candle dependencies for your specific model type
115    /// 2. Implement model loading and inference
116    /// 3. Handle tokenization appropriately
117    ///
118    /// Example structure:
119    /// ```rust,ignore
120    /// pub struct CandleEmbedding {
121    ///     model: YourModelType,
122    ///     tokenizer: Tokenizer,
123    ///     device: Device,
124    ///     dimensions: usize,
125    /// }
126    /// ```
127    pub struct CandleEmbedding {
128        dimensions: usize,
129        model_id: String,
130    }
131
132    impl CandleEmbedding {
133        /// Create a stub candle embedding provider
134        ///
135        /// **This is not a real implementation!**
136        /// For production, implement with actual model loading.
137        ///
138        /// # Example
139        /// ```rust,no_run
140        /// # #[cfg(feature = "real-embeddings")]
141        /// # {
142        /// use ruvector_core::embeddings::candle::CandleEmbedding;
143        ///
144        /// // This returns an error - real implementation required
145        /// let result = CandleEmbedding::from_pretrained(
146        ///     "sentence-transformers/all-MiniLM-L6-v2",
147        ///     false
148        /// );
149        /// assert!(result.is_err());
150        /// # }
151        /// ```
152        pub fn from_pretrained(model_id: &str, _use_gpu: bool) -> Result<Self> {
153            Err(RuvectorError::ModelLoadError(format!(
154                "Candle embedding support is a stub. Please:\n\
155                     1. Use ApiEmbedding for production (recommended)\n\
156                     2. Or implement CandleEmbedding for model: {}\n\
157                     3. See docs for ONNX Runtime integration examples",
158                model_id
159            )))
160        }
161    }
162
163    impl EmbeddingProvider for CandleEmbedding {
164        fn embed(&self, _text: &str) -> Result<Vec<f32>> {
165            Err(RuvectorError::ModelInferenceError(
166                "Candle embedding not implemented - use ApiEmbedding instead".to_string(),
167            ))
168        }
169
170        fn dimensions(&self) -> usize {
171            self.dimensions
172        }
173
174        fn name(&self) -> &str {
175            "CandleEmbedding (stub - not implemented)"
176        }
177    }
178}
179
180#[cfg(feature = "real-embeddings")]
181pub use candle::CandleEmbedding;
182
183/// API-based embedding provider (OpenAI, Anthropic, Cohere, etc.)
184///
185/// Supports any API that accepts JSON and returns embeddings in a standard format.
186///
187/// # Example (OpenAI)
188/// ```rust,no_run
189/// use ruvector_core::embeddings::{EmbeddingProvider, ApiEmbedding};
190///
191/// let provider = ApiEmbedding::openai("sk-...", "text-embedding-3-small");
192/// let embedding = provider.embed("hello world")?;
193/// # Ok::<(), Box<dyn std::error::Error>>(())
194/// ```
195#[cfg(feature = "api-embeddings")]
196#[derive(Clone)]
197pub struct ApiEmbedding {
198    api_key: String,
199    endpoint: String,
200    model: String,
201    dimensions: usize,
202    client: reqwest::blocking::Client,
203}
204
205#[cfg(feature = "api-embeddings")]
206impl ApiEmbedding {
207    /// Create a new API embedding provider
208    ///
209    /// # Arguments
210    /// * `api_key` - API key for authentication
211    /// * `endpoint` - API endpoint URL
212    /// * `model` - Model identifier
213    /// * `dimensions` - Expected embedding dimensions
214    pub fn new(api_key: String, endpoint: String, model: String, dimensions: usize) -> Self {
215        Self {
216            api_key,
217            endpoint,
218            model,
219            dimensions,
220            client: reqwest::blocking::Client::new(),
221        }
222    }
223
224    /// Create OpenAI embedding provider
225    ///
226    /// # Models
227    /// - `text-embedding-3-small` - 1536 dimensions, $0.02/1M tokens
228    /// - `text-embedding-3-large` - 3072 dimensions, $0.13/1M tokens
229    /// - `text-embedding-ada-002` - 1536 dimensions (legacy)
230    pub fn openai(api_key: &str, model: &str) -> Self {
231        let dimensions = match model {
232            "text-embedding-3-large" => 3072,
233            _ => 1536, // text-embedding-3-small and ada-002
234        };
235
236        Self::new(
237            api_key.to_string(),
238            "https://api.openai.com/v1/embeddings".to_string(),
239            model.to_string(),
240            dimensions,
241        )
242    }
243
244    /// Create Cohere embedding provider
245    ///
246    /// # Models
247    /// - `embed-english-v3.0` - 1024 dimensions
248    /// - `embed-multilingual-v3.0` - 1024 dimensions
249    pub fn cohere(api_key: &str, model: &str) -> Self {
250        Self::new(
251            api_key.to_string(),
252            "https://api.cohere.ai/v1/embed".to_string(),
253            model.to_string(),
254            1024,
255        )
256    }
257
258    /// Create Voyage AI embedding provider
259    ///
260    /// # Models
261    /// - `voyage-2` - 1024 dimensions
262    /// - `voyage-large-2` - 1536 dimensions
263    pub fn voyage(api_key: &str, model: &str) -> Self {
264        let dimensions = if model.contains("large") { 1536 } else { 1024 };
265
266        Self::new(
267            api_key.to_string(),
268            "https://api.voyageai.com/v1/embeddings".to_string(),
269            model.to_string(),
270            dimensions,
271        )
272    }
273}
274
275#[cfg(feature = "api-embeddings")]
276impl EmbeddingProvider for ApiEmbedding {
277    fn embed(&self, text: &str) -> Result<Vec<f32>> {
278        let request_body = serde_json::json!({
279            "input": text,
280            "model": self.model,
281        });
282
283        let response = self
284            .client
285            .post(&self.endpoint)
286            .header("Authorization", format!("Bearer {}", self.api_key))
287            .header("Content-Type", "application/json")
288            .json(&request_body)
289            .send()
290            .map_err(|e| {
291                RuvectorError::ModelInferenceError(format!("API request failed: {}", e))
292            })?;
293
294        if !response.status().is_success() {
295            let status = response.status();
296            let error_text = response
297                .text()
298                .unwrap_or_else(|_| "Unknown error".to_string());
299            return Err(RuvectorError::ModelInferenceError(format!(
300                "API returned error {}: {}",
301                status, error_text
302            )));
303        }
304
305        let response_json: serde_json::Value = response.json().map_err(|e| {
306            RuvectorError::ModelInferenceError(format!("Failed to parse response: {}", e))
307        })?;
308
309        // Handle different API response formats
310        let embedding = if let Some(data) = response_json.get("data") {
311            // OpenAI format: {"data": [{"embedding": [...]}]}
312            data.as_array()
313                .and_then(|arr| arr.first())
314                .and_then(|obj| obj.get("embedding"))
315                .and_then(|emb| emb.as_array())
316                .ok_or_else(|| {
317                    RuvectorError::ModelInferenceError("Invalid OpenAI response format".to_string())
318                })?
319        } else if let Some(embeddings) = response_json.get("embeddings") {
320            // Cohere format: {"embeddings": [[...]]}
321            embeddings
322                .as_array()
323                .and_then(|arr| arr.first())
324                .and_then(|emb| emb.as_array())
325                .ok_or_else(|| {
326                    RuvectorError::ModelInferenceError("Invalid Cohere response format".to_string())
327                })?
328        } else {
329            return Err(RuvectorError::ModelInferenceError(
330                "Unknown API response format".to_string(),
331            ));
332        };
333
334        let embedding_vec: Result<Vec<f32>> = embedding
335            .iter()
336            .map(|v| {
337                v.as_f64().map(|f| f as f32).ok_or_else(|| {
338                    RuvectorError::ModelInferenceError("Invalid embedding value".to_string())
339                })
340            })
341            .collect();
342
343        embedding_vec
344    }
345
346    fn dimensions(&self) -> usize {
347        self.dimensions
348    }
349
350    fn name(&self) -> &str {
351        "ApiEmbedding"
352    }
353}
354
355/// Type-erased embedding provider for dynamic dispatch
356pub type BoxedEmbeddingProvider = Arc<dyn EmbeddingProvider>;
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    #[test]
363    fn test_hash_embedding() {
364        let provider = HashEmbedding::new(128);
365
366        let emb1 = provider.embed("hello world").unwrap();
367        let emb2 = provider.embed("hello world").unwrap();
368
369        assert_eq!(emb1.len(), 128);
370        assert_eq!(emb1, emb2, "Same text should produce same embedding");
371
372        // Check normalization
373        let norm: f32 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt();
374        assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
375    }
376
377    #[test]
378    fn test_hash_embedding_different_text() {
379        let provider = HashEmbedding::new(128);
380
381        let emb1 = provider.embed("hello").unwrap();
382        let emb2 = provider.embed("world").unwrap();
383
384        assert_ne!(
385            emb1, emb2,
386            "Different text should produce different embeddings"
387        );
388    }
389
390    #[cfg(feature = "real-embeddings")]
391    #[test]
392    #[ignore] // Requires model download
393    fn test_candle_embedding() {
394        let provider =
395            CandleEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2", false)
396                .unwrap();
397
398        let embedding = provider.embed("hello world").unwrap();
399        assert_eq!(embedding.len(), 384);
400
401        // Check normalization
402        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
403        assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
404    }
405
406    #[test]
407    #[ignore] // Requires API key
408    fn test_api_embedding_openai() {
409        let api_key = std::env::var("OPENAI_API_KEY").unwrap();
410        let provider = ApiEmbedding::openai(&api_key, "text-embedding-3-small");
411
412        let embedding = provider.embed("hello world").unwrap();
413        assert_eq!(embedding.len(), 1536);
414    }
415}