Skip to main content

ruvector_core/
embeddings.rs

1//! Text Embedding Providers
2//!
3//! This module provides a pluggable embedding system for AgenticDB.
4//!
5//! ## Available Providers
6//!
7//! - **HashEmbedding**: Fast hash-based placeholder (default, not semantic)
8//! - **OnnxEmbedding**: Real semantic embeddings using ONNX Runtime (feature: `onnx-embeddings`) ✅ RECOMMENDED
9//! - **CandleEmbedding**: Real embeddings using candle-transformers (feature: `real-embeddings`)
10//! - **ApiEmbedding**: External API calls (OpenAI, Anthropic, Cohere, etc.)
11//!
12//! ## Usage
13//!
14//! ```rust,no_run
15//! use ruvector_core::embeddings::{EmbeddingProvider, HashEmbedding};
16//!
17//! // Default: Hash-based (fast, but not semantic)
18//! let hash_provider = HashEmbedding::new(384);
19//! let embedding = hash_provider.embed("hello world")?;
20//!
21//! # Ok::<(), Box<dyn std::error::Error>>(())
22//! ```
23//!
24//! ## ONNX Embeddings (Recommended for Production)
25//!
26//! ```rust,ignore
27//! use ruvector_core::embeddings::{EmbeddingProvider, OnnxEmbedding};
28//!
29//! // Real semantic embeddings using all-MiniLM-L6-v2
30//! let provider = OnnxEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2")?;
31//! let embedding = provider.embed("hello world")?;
32//! // "dog" and "cat" WILL be similar (semantic understanding!)
33//! ```
34
35use crate::error::Result;
36#[cfg(any(feature = "real-embeddings", feature = "api-embeddings"))]
37use crate::error::RuvectorError;
38use std::sync::Arc;
39
40/// Trait for text embedding providers
41pub trait EmbeddingProvider: Send + Sync {
42    /// Generate embedding vector for the given text
43    fn embed(&self, text: &str) -> Result<Vec<f32>>;
44
45    /// Get the dimensionality of embeddings produced by this provider
46    fn dimensions(&self) -> usize;
47
48    /// Get a description of this provider (for logging/debugging)
49    fn name(&self) -> &str;
50}
51
52/// Hash-based embedding provider (placeholder, not semantic)
53///
54/// ⚠️ **WARNING**: This does NOT produce semantic embeddings!
55/// - "dog" and "cat" will NOT be similar
56/// - "dog" and "god" WILL be similar (same characters)
57///
58/// Use this only for:
59/// - Testing
60/// - Prototyping
61/// - When semantic similarity is not required
62#[derive(Debug, Clone)]
63pub struct HashEmbedding {
64    dimensions: usize,
65}
66
67impl HashEmbedding {
68    /// Create a new hash-based embedding provider
69    pub fn new(dimensions: usize) -> Self {
70        Self { dimensions }
71    }
72}
73
74impl EmbeddingProvider for HashEmbedding {
75    fn embed(&self, text: &str) -> Result<Vec<f32>> {
76        let mut embedding = vec![0.0; self.dimensions];
77        let bytes = text.as_bytes();
78
79        for (i, byte) in bytes.iter().enumerate() {
80            embedding[i % self.dimensions] += (*byte as f32) / 255.0;
81        }
82
83        // Normalize
84        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
85        if norm > 0.0 {
86            for val in &mut embedding {
87                *val /= norm;
88            }
89        }
90
91        Ok(embedding)
92    }
93
94    fn dimensions(&self) -> usize {
95        self.dimensions
96    }
97
98    fn name(&self) -> &str {
99        "HashEmbedding (placeholder)"
100    }
101}
102
103/// Real embeddings using candle-transformers
104///
105/// Requires feature flag: `real-embeddings`
106///
107/// ⚠️ **Note**: Full candle integration is complex and model-specific.
108/// For production use, we recommend:
109/// 1. Using the API-based providers (simpler, always up-to-date)
110/// 2. Using ONNX Runtime with pre-exported models
111/// 3. Implementing your own candle wrapper for your specific model
112///
113/// This is a stub implementation showing the structure.
114/// Users should implement `EmbeddingProvider` trait for their specific models.
115#[cfg(feature = "real-embeddings")]
116pub mod candle {
117    use super::*;
118
119    /// Candle-based embedding provider stub
120    ///
121    /// This is a placeholder. For real implementation:
122    /// 1. Add candle dependencies for your specific model type
123    /// 2. Implement model loading and inference
124    /// 3. Handle tokenization appropriately
125    ///
126    /// Example structure:
127    /// ```rust,ignore
128    /// pub struct CandleEmbedding {
129    ///     model: YourModelType,
130    ///     tokenizer: Tokenizer,
131    ///     device: Device,
132    ///     dimensions: usize,
133    /// }
134    /// ```
135    pub struct CandleEmbedding {
136        dimensions: usize,
137        model_id: String,
138    }
139
140    impl CandleEmbedding {
141        /// Create a stub candle embedding provider
142        ///
143        /// **This is not a real implementation!**
144        /// For production, implement with actual model loading.
145        ///
146        /// # Example
147        /// ```rust,no_run
148        /// # #[cfg(feature = "real-embeddings")]
149        /// # {
150        /// use ruvector_core::embeddings::candle::CandleEmbedding;
151        ///
152        /// // This returns an error - real implementation required
153        /// let result = CandleEmbedding::from_pretrained(
154        ///     "sentence-transformers/all-MiniLM-L6-v2",
155        ///     false
156        /// );
157        /// assert!(result.is_err());
158        /// # }
159        /// ```
160        pub fn from_pretrained(model_id: &str, _use_gpu: bool) -> Result<Self> {
161            Err(RuvectorError::ModelLoadError(format!(
162                "Candle embedding support is a stub. Please:\n\
163                     1. Use ApiEmbedding for production (recommended)\n\
164                     2. Or implement CandleEmbedding for model: {}\n\
165                     3. See docs for ONNX Runtime integration examples",
166                model_id
167            )))
168        }
169    }
170
171    impl EmbeddingProvider for CandleEmbedding {
172        fn embed(&self, _text: &str) -> Result<Vec<f32>> {
173            Err(RuvectorError::ModelInferenceError(
174                "Candle embedding not implemented - use ApiEmbedding instead".to_string(),
175            ))
176        }
177
178        fn dimensions(&self) -> usize {
179            self.dimensions
180        }
181
182        fn name(&self) -> &str {
183            "CandleEmbedding (stub - not implemented)"
184        }
185    }
186}
187
188#[cfg(feature = "real-embeddings")]
189pub use candle::CandleEmbedding;
190
191/// API-based embedding provider (OpenAI, Anthropic, Cohere, etc.)
192///
193/// Supports any API that accepts JSON and returns embeddings in a standard format.
194///
195/// # Example (OpenAI)
196/// ```rust,no_run
197/// use ruvector_core::embeddings::{EmbeddingProvider, ApiEmbedding};
198///
199/// let provider = ApiEmbedding::openai("sk-...", "text-embedding-3-small");
200/// let embedding = provider.embed("hello world")?;
201/// # Ok::<(), Box<dyn std::error::Error>>(())
202/// ```
203#[cfg(feature = "api-embeddings")]
204#[derive(Clone)]
205pub struct ApiEmbedding {
206    api_key: String,
207    endpoint: String,
208    model: String,
209    dimensions: usize,
210    client: reqwest::blocking::Client,
211}
212
213#[cfg(feature = "api-embeddings")]
214impl ApiEmbedding {
215    /// Create a new API embedding provider
216    ///
217    /// # Arguments
218    /// * `api_key` - API key for authentication
219    /// * `endpoint` - API endpoint URL
220    /// * `model` - Model identifier
221    /// * `dimensions` - Expected embedding dimensions
222    pub fn new(api_key: String, endpoint: String, model: String, dimensions: usize) -> Self {
223        Self {
224            api_key,
225            endpoint,
226            model,
227            dimensions,
228            client: reqwest::blocking::Client::new(),
229        }
230    }
231
232    /// Create OpenAI embedding provider
233    ///
234    /// # Models
235    /// - `text-embedding-3-small` - 1536 dimensions, $0.02/1M tokens
236    /// - `text-embedding-3-large` - 3072 dimensions, $0.13/1M tokens
237    /// - `text-embedding-ada-002` - 1536 dimensions (legacy)
238    pub fn openai(api_key: &str, model: &str) -> Self {
239        let dimensions = match model {
240            "text-embedding-3-large" => 3072,
241            _ => 1536, // text-embedding-3-small and ada-002
242        };
243
244        Self::new(
245            api_key.to_string(),
246            "https://api.openai.com/v1/embeddings".to_string(),
247            model.to_string(),
248            dimensions,
249        )
250    }
251
252    /// Create Cohere embedding provider
253    ///
254    /// # Models
255    /// - `embed-english-v3.0` - 1024 dimensions
256    /// - `embed-multilingual-v3.0` - 1024 dimensions
257    pub fn cohere(api_key: &str, model: &str) -> Self {
258        Self::new(
259            api_key.to_string(),
260            "https://api.cohere.ai/v1/embed".to_string(),
261            model.to_string(),
262            1024,
263        )
264    }
265
266    /// Create Voyage AI embedding provider
267    ///
268    /// # Models
269    /// - `voyage-2` - 1024 dimensions
270    /// - `voyage-large-2` - 1536 dimensions
271    pub fn voyage(api_key: &str, model: &str) -> Self {
272        let dimensions = if model.contains("large") { 1536 } else { 1024 };
273
274        Self::new(
275            api_key.to_string(),
276            "https://api.voyageai.com/v1/embeddings".to_string(),
277            model.to_string(),
278            dimensions,
279        )
280    }
281}
282
283#[cfg(feature = "api-embeddings")]
284impl EmbeddingProvider for ApiEmbedding {
285    fn embed(&self, text: &str) -> Result<Vec<f32>> {
286        let request_body = serde_json::json!({
287            "input": text,
288            "model": self.model,
289        });
290
291        let response = self
292            .client
293            .post(&self.endpoint)
294            .header("Authorization", format!("Bearer {}", self.api_key))
295            .header("Content-Type", "application/json")
296            .json(&request_body)
297            .send()
298            .map_err(|e| {
299                RuvectorError::ModelInferenceError(format!("API request failed: {}", e))
300            })?;
301
302        if !response.status().is_success() {
303            let status = response.status();
304            let error_text = response
305                .text()
306                .unwrap_or_else(|_| "Unknown error".to_string());
307            return Err(RuvectorError::ModelInferenceError(format!(
308                "API returned error {}: {}",
309                status, error_text
310            )));
311        }
312
313        let response_json: serde_json::Value = response.json().map_err(|e| {
314            RuvectorError::ModelInferenceError(format!("Failed to parse response: {}", e))
315        })?;
316
317        // Handle different API response formats
318        let embedding = if let Some(data) = response_json.get("data") {
319            // OpenAI format: {"data": [{"embedding": [...]}]}
320            data.as_array()
321                .and_then(|arr| arr.first())
322                .and_then(|obj| obj.get("embedding"))
323                .and_then(|emb| emb.as_array())
324                .ok_or_else(|| {
325                    RuvectorError::ModelInferenceError("Invalid OpenAI response format".to_string())
326                })?
327        } else if let Some(embeddings) = response_json.get("embeddings") {
328            // Cohere format: {"embeddings": [[...]]}
329            embeddings
330                .as_array()
331                .and_then(|arr| arr.first())
332                .and_then(|emb| emb.as_array())
333                .ok_or_else(|| {
334                    RuvectorError::ModelInferenceError("Invalid Cohere response format".to_string())
335                })?
336        } else {
337            return Err(RuvectorError::ModelInferenceError(
338                "Unknown API response format".to_string(),
339            ));
340        };
341
342        let embedding_vec: Result<Vec<f32>> = embedding
343            .iter()
344            .map(|v| {
345                v.as_f64().map(|f| f as f32).ok_or_else(|| {
346                    RuvectorError::ModelInferenceError("Invalid embedding value".to_string())
347                })
348            })
349            .collect();
350
351        embedding_vec
352    }
353
354    fn dimensions(&self) -> usize {
355        self.dimensions
356    }
357
358    fn name(&self) -> &str {
359        "ApiEmbedding"
360    }
361}
362
363// ============================================================================
364// ONNX Embeddings (Recommended for Production)
365// ============================================================================
366
367/// ONNX-based embedding provider using ONNX Runtime
368///
369/// Provides **real semantic embeddings** using transformer models like all-MiniLM-L6-v2.
370/// This is the **recommended** embedding provider for production use.
371///
372/// Requires feature flag: `onnx-embeddings`
373///
374/// ## Features
375/// - Real semantic understanding ("dog" and "cat" ARE similar)
376/// - Local inference (no API calls, works offline)
377/// - Fast inference (5-50ms per embedding)
378/// - Automatic model download from HuggingFace
379///
380/// ## Supported Models
381/// - `sentence-transformers/all-MiniLM-L6-v2` (384 dims, recommended)
382/// - `sentence-transformers/all-mpnet-base-v2` (768 dims)
383/// - `BAAI/bge-small-en-v1.5` (384 dims)
384///
385/// # Example
386/// ```rust,ignore
387/// use ruvector_core::embeddings::{EmbeddingProvider, OnnxEmbedding};
388///
389/// let provider = OnnxEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2")?;
390/// let embedding = provider.embed("hello world")?;
391/// assert_eq!(embedding.len(), 384);
392/// ```
393#[cfg(feature = "onnx-embeddings")]
394pub mod onnx {
395    use super::*;
396    use crate::error::RuvectorError;
397    use ort::session::Session;
398    use ort::value::{Tensor, ValueType};
399    use parking_lot::RwLock;
400    use std::path::PathBuf;
401    use tokenizers::Tokenizer;
402
403    /// ONNX-based embedding provider
404    pub struct OnnxEmbedding {
405        session: RwLock<Session>,
406        tokenizer: RwLock<Tokenizer>,
407        dimensions: usize,
408        model_id: String,
409        #[allow(dead_code)]
410        max_length: usize,
411    }
412
413    impl OnnxEmbedding {
414        /// Load a pre-trained embedding model from HuggingFace
415        ///
416        /// The model will be downloaded and cached automatically.
417        ///
418        /// # Arguments
419        /// * `model_id` - HuggingFace model identifier (e.g., "sentence-transformers/all-MiniLM-L6-v2")
420        ///
421        /// # Example
422        /// ```rust,ignore
423        /// let provider = OnnxEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2")?;
424        /// ```
425        pub fn from_pretrained(model_id: &str) -> Result<Self> {
426            let api = hf_hub::api::sync::Api::new().map_err(|e| {
427                RuvectorError::ModelLoadError(format!("Failed to create HuggingFace API: {}", e))
428            })?;
429
430            let repo = api.model(model_id.to_string());
431
432            // Download model files
433            let model_path = repo.get("model.onnx").or_else(|_| {
434                // Try alternative path for some models
435                repo.get("onnx/model.onnx")
436            }).map_err(|e| {
437                RuvectorError::ModelLoadError(format!(
438                    "Failed to download ONNX model from {}: {}. \
439                     Make sure the model has an ONNX export available.",
440                    model_id, e
441                ))
442            })?;
443
444            let tokenizer_path = repo.get("tokenizer.json").map_err(|e| {
445                RuvectorError::ModelLoadError(format!(
446                    "Failed to download tokenizer from {}: {}",
447                    model_id, e
448                ))
449            })?;
450
451            Self::from_files(&model_path, &tokenizer_path, model_id)
452        }
453
454        /// Load from local files
455        ///
456        /// # Arguments
457        /// * `model_path` - Path to the ONNX model file
458        /// * `tokenizer_path` - Path to the tokenizer.json file
459        /// * `model_id` - Model identifier for logging
460        pub fn from_files(
461            model_path: &PathBuf,
462            tokenizer_path: &PathBuf,
463            model_id: &str,
464        ) -> Result<Self> {
465            // Initialize ONNX Runtime (returns bool, true = first init)
466            let _ = ort::init().commit();
467
468            // Load the ONNX session
469            let session = Session::builder()
470                .map_err(|e| {
471                    RuvectorError::ModelLoadError(format!("Failed to create session builder: {}", e))
472                })?
473                .with_intra_threads(4)
474                .map_err(|e| {
475                    RuvectorError::ModelLoadError(format!("Failed to set thread count: {}", e))
476                })?
477                .commit_from_file(model_path)
478                .map_err(|e| {
479                    RuvectorError::ModelLoadError(format!("Failed to load ONNX model: {}", e))
480                })?;
481
482            // Load tokenizer
483            let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(|e| {
484                RuvectorError::ModelLoadError(format!("Failed to load tokenizer: {}", e))
485            })?;
486
487            // Determine dimensions from model output
488            let dimensions = Self::infer_dimensions(&session, model_id)?;
489
490            // Determine max_length from model (default to 512 for sentence transformers)
491            let max_length = 512;
492
493            tracing::info!(
494                "Loaded ONNX embedding model: {} ({}D)",
495                model_id,
496                dimensions
497            );
498
499            Ok(Self {
500                session: RwLock::new(session),
501                tokenizer: RwLock::new(tokenizer),
502                dimensions,
503                model_id: model_id.to_string(),
504                max_length,
505            })
506        }
507
508        fn infer_dimensions(session: &Session, model_id: &str) -> Result<usize> {
509            // Common dimensions for known models
510            let dimensions = match model_id {
511                id if id.contains("all-MiniLM-L6") => 384,
512                id if id.contains("all-mpnet-base") => 768,
513                id if id.contains("bge-small") => 384,
514                id if id.contains("bge-base") => 768,
515                id if id.contains("bge-large") => 1024,
516                id if id.contains("e5-small") => 384,
517                id if id.contains("e5-base") => 768,
518                id if id.contains("e5-large") => 1024,
519                _ => {
520                    // Try to infer from output shape via session.outputs() method
521                    if let Some(output) = session.outputs().first() {
522                        if let ValueType::Tensor { shape, .. } = output.dtype() {
523                            let dims: Vec<i64> = shape.iter().copied().collect();
524                            if dims.len() >= 2 {
525                                let last_dim = dims[dims.len() - 1];
526                                if last_dim > 0 {
527                                    return Ok(last_dim as usize);
528                                }
529                            }
530                        }
531                    }
532                    // Default to 384 (most common)
533                    384
534                }
535            };
536
537            Ok(dimensions)
538        }
539
540        /// Embed multiple texts in a batch (more efficient than individual calls)
541        pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
542            texts.iter().map(|text| self.embed(text)).collect()
543        }
544
545        fn mean_pooling(
546            token_embeddings: &[f32],
547            attention_mask: &[i64],
548            seq_len: usize,
549            hidden_size: usize,
550        ) -> Vec<f32> {
551            let mut pooled = vec![0.0f32; hidden_size];
552            let mut mask_sum = 0.0f32;
553
554            for i in 0..seq_len {
555                let mask = attention_mask[i] as f32;
556                mask_sum += mask;
557                for j in 0..hidden_size {
558                    pooled[j] += token_embeddings[i * hidden_size + j] * mask;
559                }
560            }
561
562            // Avoid division by zero
563            if mask_sum > 0.0 {
564                for val in &mut pooled {
565                    *val /= mask_sum;
566                }
567            }
568
569            // L2 normalize
570            let norm: f32 = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
571            if norm > 0.0 {
572                for val in &mut pooled {
573                    *val /= norm;
574                }
575            }
576
577            pooled
578        }
579    }
580
581    impl EmbeddingProvider for OnnxEmbedding {
582        fn embed(&self, text: &str) -> Result<Vec<f32>> {
583            // Tokenize
584            let encoding = {
585                let tokenizer = self.tokenizer.read();
586                tokenizer
587                    .encode(text, true)
588                    .map_err(|e| {
589                        RuvectorError::ModelInferenceError(format!("Tokenization failed: {}", e))
590                    })?
591            };
592
593            // Prepare inputs
594            let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
595            let attention_mask: Vec<i64> = encoding
596                .get_attention_mask()
597                .iter()
598                .map(|&x| x as i64)
599                .collect();
600            let token_type_ids: Vec<i64> = encoding
601                .get_type_ids()
602                .iter()
603                .map(|&x| x as i64)
604                .collect();
605
606            let seq_len = input_ids.len();
607
608            // Create ONNX tensors using ort 2.0 API (batch_size=1)
609            // Tensor::from_array takes (shape, owned_data)
610            let input_ids_tensor = Tensor::<i64>::from_array(([1, seq_len], input_ids.clone().into_boxed_slice()))
611                .map_err(|e| {
612                    RuvectorError::ModelInferenceError(format!(
613                        "Failed to create input_ids tensor: {}",
614                        e
615                    ))
616                })?;
617
618            let attention_mask_tensor =
619                Tensor::<i64>::from_array(([1, seq_len], attention_mask.clone().into_boxed_slice())).map_err(|e| {
620                    RuvectorError::ModelInferenceError(format!(
621                        "Failed to create attention_mask tensor: {}",
622                        e
623                    ))
624                })?;
625
626            let token_type_ids_tensor =
627                Tensor::<i64>::from_array(([1, seq_len], token_type_ids.into_boxed_slice())).map_err(|e| {
628                    RuvectorError::ModelInferenceError(format!(
629                        "Failed to create token_type_ids tensor: {}",
630                        e
631                    ))
632                })?;
633
634            // Run inference and extract output (needs mutable access to session)
635            // We must extract all data while holding the lock since SessionOutputs has a lifetime
636            let (output_data, output_shape_vec) = {
637                let mut session = self.session.write();
638                let outputs = session
639                    .run(ort::inputs![
640                        "input_ids" => input_ids_tensor,
641                        "attention_mask" => attention_mask_tensor,
642                        "token_type_ids" => token_type_ids_tensor,
643                    ])
644                    .map_err(|e| {
645                        RuvectorError::ModelInferenceError(format!("ONNX inference failed: {}", e))
646                    })?;
647
648                // Extract output using indexing (ort 2.0 API)
649                // Sentence transformers output shape: [batch_size, seq_len, hidden_size]
650                let output_value = &outputs[0];
651
652                // Extract as ndarray view
653                let output_array = output_value.try_extract_array::<f32>().map_err(|e| {
654                    RuvectorError::ModelInferenceError(format!("Failed to extract output tensor: {}", e))
655                })?;
656
657                let output_shape_vec: Vec<usize> = output_array.shape().to_vec();
658                let output_data_vec: Vec<f32> = output_array.iter().copied().collect();
659
660                (output_data_vec, output_shape_vec)
661            };
662
663            // Determine if we need pooling based on output shape
664            let embedding = if output_shape_vec.len() == 3 {
665                // Shape: [batch_size, seq_len, hidden_size] - needs pooling
666                let hidden_size = output_shape_vec[2];
667                Self::mean_pooling(&output_data, &attention_mask, seq_len, hidden_size)
668            } else if output_shape_vec.len() == 2 {
669                // Shape: [batch_size, hidden_size] - already pooled
670                let mut emb = output_data;
671                // L2 normalize
672                let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
673                if norm > 0.0 {
674                    for val in &mut emb {
675                        *val /= norm;
676                    }
677                }
678                emb
679            } else {
680                return Err(RuvectorError::ModelInferenceError(format!(
681                    "Unexpected output shape: {:?}",
682                    output_shape_vec
683                )));
684            };
685
686            Ok(embedding)
687        }
688
689        fn dimensions(&self) -> usize {
690            self.dimensions
691        }
692
693        fn name(&self) -> &str {
694            &self.model_id
695        }
696    }
697
698    impl std::fmt::Debug for OnnxEmbedding {
699        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
700            f.debug_struct("OnnxEmbedding")
701                .field("model_id", &self.model_id)
702                .field("dimensions", &self.dimensions)
703                .field("max_length", &self.max_length)
704                .finish()
705        }
706    }
707}
708
709#[cfg(feature = "onnx-embeddings")]
710pub use onnx::OnnxEmbedding;
711
712/// Type-erased embedding provider for dynamic dispatch
713pub type BoxedEmbeddingProvider = Arc<dyn EmbeddingProvider>;
714
715#[cfg(test)]
716mod tests {
717    use super::*;
718
719    #[test]
720    fn test_hash_embedding() {
721        let provider = HashEmbedding::new(128);
722
723        let emb1 = provider.embed("hello world").unwrap();
724        let emb2 = provider.embed("hello world").unwrap();
725
726        assert_eq!(emb1.len(), 128);
727        assert_eq!(emb1, emb2, "Same text should produce same embedding");
728
729        // Check normalization
730        let norm: f32 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt();
731        assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
732    }
733
734    #[test]
735    fn test_hash_embedding_different_text() {
736        let provider = HashEmbedding::new(128);
737
738        let emb1 = provider.embed("hello").unwrap();
739        let emb2 = provider.embed("world").unwrap();
740
741        assert_ne!(
742            emb1, emb2,
743            "Different text should produce different embeddings"
744        );
745    }
746
747    #[cfg(feature = "real-embeddings")]
748    #[test]
749    #[ignore] // Requires model download
750    fn test_candle_embedding() {
751        let provider =
752            CandleEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2", false)
753                .unwrap();
754
755        let embedding = provider.embed("hello world").unwrap();
756        assert_eq!(embedding.len(), 384);
757
758        // Check normalization
759        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
760        assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
761    }
762
763    #[test]
764    #[ignore] // Requires API key
765    fn test_api_embedding_openai() {
766        let api_key = std::env::var("OPENAI_API_KEY").unwrap();
767        let provider = ApiEmbedding::openai(&api_key, "text-embedding-3-small");
768
769        let embedding = provider.embed("hello world").unwrap();
770        assert_eq!(embedding.len(), 1536);
771    }
772
773    #[cfg(feature = "onnx-embeddings")]
774    mod onnx_tests {
775        use super::*;
776
777        #[test]
778        #[ignore] // Requires model download (~90MB)
779        fn test_onnx_embedding_minilm() {
780            let provider =
781                OnnxEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2").unwrap();
782
783            let embedding = provider.embed("hello world").unwrap();
784            assert_eq!(embedding.len(), 384);
785
786            // Check normalization
787            let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
788            assert!(
789                (norm - 1.0).abs() < 1e-4,
790                "Embedding should be normalized, got norm={}",
791                norm
792            );
793        }
794
795        #[test]
796        #[ignore] // Requires model download
797        fn test_onnx_semantic_similarity() {
798            let provider =
799                OnnxEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2").unwrap();
800
801            let emb_dog = provider.embed("dog").unwrap();
802            let emb_cat = provider.embed("cat").unwrap();
803            let emb_car = provider.embed("car").unwrap();
804
805            // Cosine similarity (embeddings are normalized, so dot product = cosine)
806            let sim_dog_cat: f32 = emb_dog.iter().zip(&emb_cat).map(|(a, b)| a * b).sum();
807            let sim_dog_car: f32 = emb_dog.iter().zip(&emb_car).map(|(a, b)| a * b).sum();
808
809            // dog and cat should be more similar than dog and car
810            assert!(
811                sim_dog_cat > sim_dog_car,
812                "Expected dog-cat similarity ({}) > dog-car similarity ({})",
813                sim_dog_cat,
814                sim_dog_car
815            );
816        }
817
818        #[test]
819        #[ignore] // Requires model download
820        fn test_onnx_batch_embedding() {
821            let provider =
822                OnnxEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2").unwrap();
823
824            let texts = vec!["hello world", "goodbye world", "rust programming"];
825            let embeddings = provider.embed_batch(&texts).unwrap();
826
827            assert_eq!(embeddings.len(), 3);
828            for emb in &embeddings {
829                assert_eq!(emb.len(), 384);
830            }
831        }
832    }
833}