Skip to main content

ruvector_core/
embeddings.rs

1//! Text Embedding Providers
2//!
3//! This module provides a pluggable embedding system for AgenticDB.
4//!
5//! ## Available Providers
6//!
7//! - **HashEmbedding**: Fast hash-based placeholder (default, not semantic)
8//! - **OnnxEmbedding**: Real semantic embeddings using ONNX Runtime (feature: `onnx-embeddings`) ✅ RECOMMENDED
9//! - **CandleEmbedding**: Real embeddings using candle-transformers (feature: `real-embeddings`)
10//! - **ApiEmbedding**: External API calls (OpenAI, Anthropic, Cohere, etc.)
11//!
12//! ## Usage
13//!
14//! ```rust,no_run
15//! use ruvector_core::embeddings::{EmbeddingProvider, HashEmbedding};
16//!
17//! // Default: Hash-based (fast, but not semantic)
18//! let hash_provider = HashEmbedding::new(384);
19//! let embedding = hash_provider.embed("hello world")?;
20//!
21//! # Ok::<(), Box<dyn std::error::Error>>(())
22//! ```
23//!
24//! ## ONNX Embeddings (Recommended for Production)
25//!
26//! ```rust,ignore
27//! use ruvector_core::embeddings::{EmbeddingProvider, OnnxEmbedding};
28//!
29//! // Real semantic embeddings using all-MiniLM-L6-v2
30//! let provider = OnnxEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2")?;
31//! let embedding = provider.embed("hello world")?;
32//! // "dog" and "cat" WILL be similar (semantic understanding!)
33//! ```
34
35use crate::error::Result;
36#[cfg(any(feature = "real-embeddings", feature = "api-embeddings"))]
37use crate::error::RuvectorError;
38use std::sync::Arc;
39
40/// Trait for text embedding providers
41pub trait EmbeddingProvider: Send + Sync {
42    /// Generate embedding vector for the given text
43    fn embed(&self, text: &str) -> Result<Vec<f32>>;
44
45    /// Get the dimensionality of embeddings produced by this provider
46    fn dimensions(&self) -> usize;
47
48    /// Get a description of this provider (for logging/debugging)
49    fn name(&self) -> &str;
50}
51
52/// Hash-based embedding provider (placeholder, not semantic)
53///
54/// ⚠️ **WARNING**: This does NOT produce semantic embeddings!
55/// - "dog" and "cat" will NOT be similar
56/// - "dog" and "god" WILL be similar (same characters)
57///
58/// Use this only for:
59/// - Testing
60/// - Prototyping
61/// - When semantic similarity is not required
62#[derive(Debug, Clone)]
63pub struct HashEmbedding {
64    dimensions: usize,
65}
66
67impl HashEmbedding {
68    /// Create a new hash-based embedding provider
69    pub fn new(dimensions: usize) -> Self {
70        Self { dimensions }
71    }
72}
73
74impl EmbeddingProvider for HashEmbedding {
75    fn embed(&self, text: &str) -> Result<Vec<f32>> {
76        let mut embedding = vec![0.0; self.dimensions];
77        let bytes = text.as_bytes();
78
79        for (i, byte) in bytes.iter().enumerate() {
80            embedding[i % self.dimensions] += (*byte as f32) / 255.0;
81        }
82
83        // Normalize
84        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
85        if norm > 0.0 {
86            for val in &mut embedding {
87                *val /= norm;
88            }
89        }
90
91        Ok(embedding)
92    }
93
94    fn dimensions(&self) -> usize {
95        self.dimensions
96    }
97
98    fn name(&self) -> &str {
99        "HashEmbedding (placeholder)"
100    }
101}
102
103/// Real embeddings using candle-transformers
104///
105/// Requires feature flag: `real-embeddings`
106///
107/// ⚠️ **Note**: Full candle integration is complex and model-specific.
108/// For production use, we recommend:
109/// 1. Using the API-based providers (simpler, always up-to-date)
110/// 2. Using ONNX Runtime with pre-exported models
111/// 3. Implementing your own candle wrapper for your specific model
112///
113/// This is a stub implementation showing the structure.
114/// Users should implement `EmbeddingProvider` trait for their specific models.
115#[cfg(feature = "real-embeddings")]
116pub mod candle {
117    use super::*;
118
119    /// Candle-based embedding provider stub
120    ///
121    /// This is a placeholder. For real implementation:
122    /// 1. Add candle dependencies for your specific model type
123    /// 2. Implement model loading and inference
124    /// 3. Handle tokenization appropriately
125    ///
126    /// Example structure:
127    /// ```rust,ignore
128    /// pub struct CandleEmbedding {
129    ///     model: YourModelType,
130    ///     tokenizer: Tokenizer,
131    ///     device: Device,
132    ///     dimensions: usize,
133    /// }
134    /// ```
135    pub struct CandleEmbedding {
136        dimensions: usize,
137        model_id: String,
138    }
139
140    impl CandleEmbedding {
141        /// Create a stub candle embedding provider
142        ///
143        /// **This is not a real implementation!**
144        /// For production, implement with actual model loading.
145        ///
146        /// # Example
147        /// ```rust,no_run
148        /// # #[cfg(feature = "real-embeddings")]
149        /// # {
150        /// use ruvector_core::embeddings::candle::CandleEmbedding;
151        ///
152        /// // This returns an error - real implementation required
153        /// let result = CandleEmbedding::from_pretrained(
154        ///     "sentence-transformers/all-MiniLM-L6-v2",
155        ///     false
156        /// );
157        /// assert!(result.is_err());
158        /// # }
159        /// ```
160        pub fn from_pretrained(model_id: &str, _use_gpu: bool) -> Result<Self> {
161            Err(RuvectorError::ModelLoadError(format!(
162                "Candle embedding support is a stub. Please:\n\
163                     1. Use ApiEmbedding for production (recommended)\n\
164                     2. Or implement CandleEmbedding for model: {}\n\
165                     3. See docs for ONNX Runtime integration examples",
166                model_id
167            )))
168        }
169    }
170
171    impl EmbeddingProvider for CandleEmbedding {
172        fn embed(&self, _text: &str) -> Result<Vec<f32>> {
173            Err(RuvectorError::ModelInferenceError(
174                "Candle embedding not implemented - use ApiEmbedding instead".to_string(),
175            ))
176        }
177
178        fn dimensions(&self) -> usize {
179            self.dimensions
180        }
181
182        fn name(&self) -> &str {
183            "CandleEmbedding (stub - not implemented)"
184        }
185    }
186}
187
188#[cfg(feature = "real-embeddings")]
189pub use candle::CandleEmbedding;
190
191/// API-based embedding provider (OpenAI, Anthropic, Cohere, etc.)
192///
193/// Supports any API that accepts JSON and returns embeddings in a standard format.
194///
195/// # Example (OpenAI)
196/// ```rust,no_run
197/// use ruvector_core::embeddings::{EmbeddingProvider, ApiEmbedding};
198///
199/// let provider = ApiEmbedding::openai("sk-...", "text-embedding-3-small");
200/// let embedding = provider.embed("hello world")?;
201/// # Ok::<(), Box<dyn std::error::Error>>(())
202/// ```
203#[cfg(feature = "api-embeddings")]
204#[derive(Clone)]
205pub struct ApiEmbedding {
206    api_key: String,
207    endpoint: String,
208    model: String,
209    dimensions: usize,
210    client: reqwest::blocking::Client,
211}
212
213#[cfg(feature = "api-embeddings")]
214impl ApiEmbedding {
215    /// Create a new API embedding provider
216    ///
217    /// # Arguments
218    /// * `api_key` - API key for authentication
219    /// * `endpoint` - API endpoint URL
220    /// * `model` - Model identifier
221    /// * `dimensions` - Expected embedding dimensions
222    pub fn new(api_key: String, endpoint: String, model: String, dimensions: usize) -> Self {
223        Self {
224            api_key,
225            endpoint,
226            model,
227            dimensions,
228            client: reqwest::blocking::Client::new(),
229        }
230    }
231
232    /// Create OpenAI embedding provider
233    ///
234    /// # Models
235    /// - `text-embedding-3-small` - 1536 dimensions, $0.02/1M tokens
236    /// - `text-embedding-3-large` - 3072 dimensions, $0.13/1M tokens
237    /// - `text-embedding-ada-002` - 1536 dimensions (legacy)
238    pub fn openai(api_key: &str, model: &str) -> Self {
239        let dimensions = match model {
240            "text-embedding-3-large" => 3072,
241            _ => 1536, // text-embedding-3-small and ada-002
242        };
243
244        Self::new(
245            api_key.to_string(),
246            "https://api.openai.com/v1/embeddings".to_string(),
247            model.to_string(),
248            dimensions,
249        )
250    }
251
252    /// Create Cohere embedding provider
253    ///
254    /// # Models
255    /// - `embed-english-v3.0` - 1024 dimensions
256    /// - `embed-multilingual-v3.0` - 1024 dimensions
257    pub fn cohere(api_key: &str, model: &str) -> Self {
258        Self::new(
259            api_key.to_string(),
260            "https://api.cohere.ai/v1/embed".to_string(),
261            model.to_string(),
262            1024,
263        )
264    }
265
266    /// Create Voyage AI embedding provider
267    ///
268    /// # Models
269    /// - `voyage-2` - 1024 dimensions
270    /// - `voyage-large-2` - 1536 dimensions
271    pub fn voyage(api_key: &str, model: &str) -> Self {
272        let dimensions = if model.contains("large") { 1536 } else { 1024 };
273
274        Self::new(
275            api_key.to_string(),
276            "https://api.voyageai.com/v1/embeddings".to_string(),
277            model.to_string(),
278            dimensions,
279        )
280    }
281}
282
283#[cfg(feature = "api-embeddings")]
284impl EmbeddingProvider for ApiEmbedding {
285    fn embed(&self, text: &str) -> Result<Vec<f32>> {
286        let request_body = serde_json::json!({
287            "input": text,
288            "model": self.model,
289        });
290
291        let response = self
292            .client
293            .post(&self.endpoint)
294            .header("Authorization", format!("Bearer {}", self.api_key))
295            .header("Content-Type", "application/json")
296            .json(&request_body)
297            .send()
298            .map_err(|e| {
299                RuvectorError::ModelInferenceError(format!("API request failed: {}", e))
300            })?;
301
302        if !response.status().is_success() {
303            let status = response.status();
304            let error_text = response
305                .text()
306                .unwrap_or_else(|_| "Unknown error".to_string());
307            return Err(RuvectorError::ModelInferenceError(format!(
308                "API returned error {}: {}",
309                status, error_text
310            )));
311        }
312
313        let response_json: serde_json::Value = response.json().map_err(|e| {
314            RuvectorError::ModelInferenceError(format!("Failed to parse response: {}", e))
315        })?;
316
317        // Handle different API response formats
318        let embedding = if let Some(data) = response_json.get("data") {
319            // OpenAI format: {"data": [{"embedding": [...]}]}
320            data.as_array()
321                .and_then(|arr| arr.first())
322                .and_then(|obj| obj.get("embedding"))
323                .and_then(|emb| emb.as_array())
324                .ok_or_else(|| {
325                    RuvectorError::ModelInferenceError("Invalid OpenAI response format".to_string())
326                })?
327        } else if let Some(embeddings) = response_json.get("embeddings") {
328            // Cohere format: {"embeddings": [[...]]}
329            embeddings
330                .as_array()
331                .and_then(|arr| arr.first())
332                .and_then(|emb| emb.as_array())
333                .ok_or_else(|| {
334                    RuvectorError::ModelInferenceError("Invalid Cohere response format".to_string())
335                })?
336        } else {
337            return Err(RuvectorError::ModelInferenceError(
338                "Unknown API response format".to_string(),
339            ));
340        };
341
342        let embedding_vec: Result<Vec<f32>> = embedding
343            .iter()
344            .map(|v| {
345                v.as_f64().map(|f| f as f32).ok_or_else(|| {
346                    RuvectorError::ModelInferenceError("Invalid embedding value".to_string())
347                })
348            })
349            .collect();
350
351        embedding_vec
352    }
353
354    fn dimensions(&self) -> usize {
355        self.dimensions
356    }
357
358    fn name(&self) -> &str {
359        "ApiEmbedding"
360    }
361}
362
363// ============================================================================
364// ONNX Embeddings (Recommended for Production)
365// ============================================================================
366
367/// ONNX-based embedding provider using ONNX Runtime
368///
369/// Provides **real semantic embeddings** using transformer models like all-MiniLM-L6-v2.
370/// This is the **recommended** embedding provider for production use.
371///
372/// Requires feature flag: `onnx-embeddings`
373///
374/// ## Features
375/// - Real semantic understanding ("dog" and "cat" ARE similar)
376/// - Local inference (no API calls, works offline)
377/// - Fast inference (5-50ms per embedding)
378/// - Automatic model download from HuggingFace
379///
380/// ## Supported Models
381/// - `sentence-transformers/all-MiniLM-L6-v2` (384 dims, recommended)
382/// - `sentence-transformers/all-mpnet-base-v2` (768 dims)
383/// - `BAAI/bge-small-en-v1.5` (384 dims)
384///
385/// # Example
386/// ```rust,ignore
387/// use ruvector_core::embeddings::{EmbeddingProvider, OnnxEmbedding};
388///
389/// let provider = OnnxEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2")?;
390/// let embedding = provider.embed("hello world")?;
391/// assert_eq!(embedding.len(), 384);
392/// ```
393#[cfg(feature = "onnx-embeddings")]
394pub mod onnx {
395    use super::*;
396    use crate::error::RuvectorError;
397    use ort::session::Session;
398    use ort::value::{Tensor, ValueType};
399    use parking_lot::RwLock;
400    use std::path::PathBuf;
401    use tokenizers::Tokenizer;
402
403    /// ONNX-based embedding provider
404    pub struct OnnxEmbedding {
405        session: RwLock<Session>,
406        tokenizer: RwLock<Tokenizer>,
407        dimensions: usize,
408        model_id: String,
409        #[allow(dead_code)]
410        max_length: usize,
411    }
412
413    impl OnnxEmbedding {
414        /// Load a pre-trained embedding model from HuggingFace
415        ///
416        /// The model will be downloaded and cached automatically.
417        ///
418        /// # Arguments
419        /// * `model_id` - HuggingFace model identifier (e.g., "sentence-transformers/all-MiniLM-L6-v2")
420        ///
421        /// # Example
422        /// ```rust,ignore
423        /// let provider = OnnxEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2")?;
424        /// ```
425        pub fn from_pretrained(model_id: &str) -> Result<Self> {
426            let api = hf_hub::api::sync::Api::new().map_err(|e| {
427                RuvectorError::ModelLoadError(format!("Failed to create HuggingFace API: {}", e))
428            })?;
429
430            let repo = api.model(model_id.to_string());
431
432            // Download model files
433            let model_path = repo
434                .get("model.onnx")
435                .or_else(|_| {
436                    // Try alternative path for some models
437                    repo.get("onnx/model.onnx")
438                })
439                .map_err(|e| {
440                    RuvectorError::ModelLoadError(format!(
441                        "Failed to download ONNX model from {}: {}. \
442                     Make sure the model has an ONNX export available.",
443                        model_id, e
444                    ))
445                })?;
446
447            let tokenizer_path = repo.get("tokenizer.json").map_err(|e| {
448                RuvectorError::ModelLoadError(format!(
449                    "Failed to download tokenizer from {}: {}",
450                    model_id, e
451                ))
452            })?;
453
454            Self::from_files(&model_path, &tokenizer_path, model_id)
455        }
456
457        /// Load from local files
458        ///
459        /// # Arguments
460        /// * `model_path` - Path to the ONNX model file
461        /// * `tokenizer_path` - Path to the tokenizer.json file
462        /// * `model_id` - Model identifier for logging
463        pub fn from_files(
464            model_path: &PathBuf,
465            tokenizer_path: &PathBuf,
466            model_id: &str,
467        ) -> Result<Self> {
468            // Initialize ONNX Runtime (returns bool, true = first init)
469            let _ = ort::init().commit();
470
471            // Load the ONNX session
472            let session = Session::builder()
473                .map_err(|e| {
474                    RuvectorError::ModelLoadError(format!(
475                        "Failed to create session builder: {}",
476                        e
477                    ))
478                })?
479                .with_intra_threads(4)
480                .map_err(|e| {
481                    RuvectorError::ModelLoadError(format!("Failed to set thread count: {}", e))
482                })?
483                .commit_from_file(model_path)
484                .map_err(|e| {
485                    RuvectorError::ModelLoadError(format!("Failed to load ONNX model: {}", e))
486                })?;
487
488            // Load tokenizer
489            let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(|e| {
490                RuvectorError::ModelLoadError(format!("Failed to load tokenizer: {}", e))
491            })?;
492
493            // Determine dimensions from model output
494            let dimensions = Self::infer_dimensions(&session, model_id)?;
495
496            // Determine max_length from model (default to 512 for sentence transformers)
497            let max_length = 512;
498
499            tracing::info!(
500                "Loaded ONNX embedding model: {} ({}D)",
501                model_id,
502                dimensions
503            );
504
505            Ok(Self {
506                session: RwLock::new(session),
507                tokenizer: RwLock::new(tokenizer),
508                dimensions,
509                model_id: model_id.to_string(),
510                max_length,
511            })
512        }
513
514        fn infer_dimensions(session: &Session, model_id: &str) -> Result<usize> {
515            // Common dimensions for known models
516            let dimensions = match model_id {
517                id if id.contains("all-MiniLM-L6") => 384,
518                id if id.contains("all-mpnet-base") => 768,
519                id if id.contains("bge-small") => 384,
520                id if id.contains("bge-base") => 768,
521                id if id.contains("bge-large") => 1024,
522                id if id.contains("e5-small") => 384,
523                id if id.contains("e5-base") => 768,
524                id if id.contains("e5-large") => 1024,
525                _ => {
526                    // Try to infer from output shape via session.outputs() method
527                    if let Some(output) = session.outputs().first() {
528                        if let ValueType::Tensor { shape, .. } = output.dtype() {
529                            let dims: Vec<i64> = shape.iter().copied().collect();
530                            if dims.len() >= 2 {
531                                let last_dim = dims[dims.len() - 1];
532                                if last_dim > 0 {
533                                    return Ok(last_dim as usize);
534                                }
535                            }
536                        }
537                    }
538                    // Default to 384 (most common)
539                    384
540                }
541            };
542
543            Ok(dimensions)
544        }
545
546        /// Embed multiple texts in a batch (more efficient than individual calls)
547        pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
548            texts.iter().map(|text| self.embed(text)).collect()
549        }
550
551        fn mean_pooling(
552            token_embeddings: &[f32],
553            attention_mask: &[i64],
554            seq_len: usize,
555            hidden_size: usize,
556        ) -> Vec<f32> {
557            let mut pooled = vec![0.0f32; hidden_size];
558            let mut mask_sum = 0.0f32;
559
560            for i in 0..seq_len {
561                let mask = attention_mask[i] as f32;
562                mask_sum += mask;
563                for j in 0..hidden_size {
564                    pooled[j] += token_embeddings[i * hidden_size + j] * mask;
565                }
566            }
567
568            // Avoid division by zero
569            if mask_sum > 0.0 {
570                for val in &mut pooled {
571                    *val /= mask_sum;
572                }
573            }
574
575            // L2 normalize
576            let norm: f32 = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
577            if norm > 0.0 {
578                for val in &mut pooled {
579                    *val /= norm;
580                }
581            }
582
583            pooled
584        }
585    }
586
587    impl EmbeddingProvider for OnnxEmbedding {
588        fn embed(&self, text: &str) -> Result<Vec<f32>> {
589            // Tokenize
590            let encoding = {
591                let tokenizer = self.tokenizer.read();
592                tokenizer.encode(text, true).map_err(|e| {
593                    RuvectorError::ModelInferenceError(format!("Tokenization failed: {}", e))
594                })?
595            };
596
597            // Prepare inputs
598            let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
599            let attention_mask: Vec<i64> = encoding
600                .get_attention_mask()
601                .iter()
602                .map(|&x| x as i64)
603                .collect();
604            let token_type_ids: Vec<i64> =
605                encoding.get_type_ids().iter().map(|&x| x as i64).collect();
606
607            let seq_len = input_ids.len();
608
609            // Create ONNX tensors using ort 2.0 API (batch_size=1)
610            // Tensor::from_array takes (shape, owned_data)
611            let input_ids_tensor =
612                Tensor::<i64>::from_array(([1, seq_len], input_ids.clone().into_boxed_slice()))
613                    .map_err(|e| {
614                        RuvectorError::ModelInferenceError(format!(
615                            "Failed to create input_ids tensor: {}",
616                            e
617                        ))
618                    })?;
619
620            let attention_mask_tensor = Tensor::<i64>::from_array((
621                [1, seq_len],
622                attention_mask.clone().into_boxed_slice(),
623            ))
624            .map_err(|e| {
625                RuvectorError::ModelInferenceError(format!(
626                    "Failed to create attention_mask tensor: {}",
627                    e
628                ))
629            })?;
630
631            let token_type_ids_tensor =
632                Tensor::<i64>::from_array(([1, seq_len], token_type_ids.into_boxed_slice()))
633                    .map_err(|e| {
634                        RuvectorError::ModelInferenceError(format!(
635                            "Failed to create token_type_ids tensor: {}",
636                            e
637                        ))
638                    })?;
639
640            // Run inference and extract output (needs mutable access to session)
641            // We must extract all data while holding the lock since SessionOutputs has a lifetime
642            let (output_data, output_shape_vec) = {
643                let mut session = self.session.write();
644                let outputs = session
645                    .run(ort::inputs![
646                        "input_ids" => input_ids_tensor,
647                        "attention_mask" => attention_mask_tensor,
648                        "token_type_ids" => token_type_ids_tensor,
649                    ])
650                    .map_err(|e| {
651                        RuvectorError::ModelInferenceError(format!("ONNX inference failed: {}", e))
652                    })?;
653
654                // Extract output using indexing (ort 2.0 API)
655                // Sentence transformers output shape: [batch_size, seq_len, hidden_size]
656                let output_value = &outputs[0];
657
658                // Extract as ndarray view
659                let output_array = output_value.try_extract_array::<f32>().map_err(|e| {
660                    RuvectorError::ModelInferenceError(format!(
661                        "Failed to extract output tensor: {}",
662                        e
663                    ))
664                })?;
665
666                let output_shape_vec: Vec<usize> = output_array.shape().to_vec();
667                let output_data_vec: Vec<f32> = output_array.iter().copied().collect();
668
669                (output_data_vec, output_shape_vec)
670            };
671
672            // Determine if we need pooling based on output shape
673            let embedding = if output_shape_vec.len() == 3 {
674                // Shape: [batch_size, seq_len, hidden_size] - needs pooling
675                let hidden_size = output_shape_vec[2];
676                Self::mean_pooling(&output_data, &attention_mask, seq_len, hidden_size)
677            } else if output_shape_vec.len() == 2 {
678                // Shape: [batch_size, hidden_size] - already pooled
679                let mut emb = output_data;
680                // L2 normalize
681                let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
682                if norm > 0.0 {
683                    for val in &mut emb {
684                        *val /= norm;
685                    }
686                }
687                emb
688            } else {
689                return Err(RuvectorError::ModelInferenceError(format!(
690                    "Unexpected output shape: {:?}",
691                    output_shape_vec
692                )));
693            };
694
695            Ok(embedding)
696        }
697
698        fn dimensions(&self) -> usize {
699            self.dimensions
700        }
701
702        fn name(&self) -> &str {
703            &self.model_id
704        }
705    }
706
707    impl std::fmt::Debug for OnnxEmbedding {
708        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
709            f.debug_struct("OnnxEmbedding")
710                .field("model_id", &self.model_id)
711                .field("dimensions", &self.dimensions)
712                .field("max_length", &self.max_length)
713                .finish()
714        }
715    }
716}
717
718#[cfg(feature = "onnx-embeddings")]
719pub use onnx::OnnxEmbedding;
720
721/// Type-erased embedding provider for dynamic dispatch
722pub type BoxedEmbeddingProvider = Arc<dyn EmbeddingProvider>;
723
724#[cfg(test)]
725mod tests {
726    use super::*;
727
728    #[test]
729    fn test_hash_embedding() {
730        let provider = HashEmbedding::new(128);
731
732        let emb1 = provider.embed("hello world").unwrap();
733        let emb2 = provider.embed("hello world").unwrap();
734
735        assert_eq!(emb1.len(), 128);
736        assert_eq!(emb1, emb2, "Same text should produce same embedding");
737
738        // Check normalization
739        let norm: f32 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt();
740        assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
741    }
742
743    #[test]
744    fn test_hash_embedding_different_text() {
745        let provider = HashEmbedding::new(128);
746
747        let emb1 = provider.embed("hello").unwrap();
748        let emb2 = provider.embed("world").unwrap();
749
750        assert_ne!(
751            emb1, emb2,
752            "Different text should produce different embeddings"
753        );
754    }
755
756    #[cfg(feature = "real-embeddings")]
757    #[test]
758    #[ignore] // Requires model download
759    fn test_candle_embedding() {
760        let provider =
761            CandleEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2", false)
762                .unwrap();
763
764        let embedding = provider.embed("hello world").unwrap();
765        assert_eq!(embedding.len(), 384);
766
767        // Check normalization
768        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
769        assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
770    }
771
772    #[test]
773    #[ignore] // Requires API key
774    fn test_api_embedding_openai() {
775        let api_key = std::env::var("OPENAI_API_KEY").unwrap();
776        let provider = ApiEmbedding::openai(&api_key, "text-embedding-3-small");
777
778        let embedding = provider.embed("hello world").unwrap();
779        assert_eq!(embedding.len(), 1536);
780    }
781
782    #[cfg(feature = "onnx-embeddings")]
783    mod onnx_tests {
784        use super::*;
785
786        #[test]
787        #[ignore] // Requires model download (~90MB)
788        fn test_onnx_embedding_minilm() {
789            let provider =
790                OnnxEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2").unwrap();
791
792            let embedding = provider.embed("hello world").unwrap();
793            assert_eq!(embedding.len(), 384);
794
795            // Check normalization
796            let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
797            assert!(
798                (norm - 1.0).abs() < 1e-4,
799                "Embedding should be normalized, got norm={}",
800                norm
801            );
802        }
803
804        #[test]
805        #[ignore] // Requires model download
806        fn test_onnx_semantic_similarity() {
807            let provider =
808                OnnxEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2").unwrap();
809
810            let emb_dog = provider.embed("dog").unwrap();
811            let emb_cat = provider.embed("cat").unwrap();
812            let emb_car = provider.embed("car").unwrap();
813
814            // Cosine similarity (embeddings are normalized, so dot product = cosine)
815            let sim_dog_cat: f32 = emb_dog.iter().zip(&emb_cat).map(|(a, b)| a * b).sum();
816            let sim_dog_car: f32 = emb_dog.iter().zip(&emb_car).map(|(a, b)| a * b).sum();
817
818            // dog and cat should be more similar than dog and car
819            assert!(
820                sim_dog_cat > sim_dog_car,
821                "Expected dog-cat similarity ({}) > dog-car similarity ({})",
822                sim_dog_cat,
823                sim_dog_car
824            );
825        }
826
827        #[test]
828        #[ignore] // Requires model download
829        fn test_onnx_batch_embedding() {
830            let provider =
831                OnnxEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2").unwrap();
832
833            let texts = vec!["hello world", "goodbye world", "rust programming"];
834            let embeddings = provider.embed_batch(&texts).unwrap();
835
836            assert_eq!(embeddings.len(), 3);
837            for emb in &embeddings {
838                assert_eq!(emb.len(), 384);
839            }
840        }
841    }
842}