1use crate::error::Result;
36#[cfg(any(feature = "real-embeddings", feature = "api-embeddings"))]
37use crate::error::RuvectorError;
38use std::sync::Arc;
39
40pub trait EmbeddingProvider: Send + Sync {
42 fn embed(&self, text: &str) -> Result<Vec<f32>>;
44
45 fn dimensions(&self) -> usize;
47
48 fn name(&self) -> &str;
50}
51
52#[derive(Debug, Clone)]
63pub struct HashEmbedding {
64 dimensions: usize,
65}
66
67impl HashEmbedding {
68 pub fn new(dimensions: usize) -> Self {
70 Self { dimensions }
71 }
72}
73
74impl EmbeddingProvider for HashEmbedding {
75 fn embed(&self, text: &str) -> Result<Vec<f32>> {
76 let mut embedding = vec![0.0; self.dimensions];
77 let bytes = text.as_bytes();
78
79 for (i, byte) in bytes.iter().enumerate() {
80 embedding[i % self.dimensions] += (*byte as f32) / 255.0;
81 }
82
83 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
85 if norm > 0.0 {
86 for val in &mut embedding {
87 *val /= norm;
88 }
89 }
90
91 Ok(embedding)
92 }
93
94 fn dimensions(&self) -> usize {
95 self.dimensions
96 }
97
98 fn name(&self) -> &str {
99 "HashEmbedding (placeholder)"
100 }
101}
102
103#[cfg(feature = "real-embeddings")]
116pub mod candle {
117 use super::*;
118
119 pub struct CandleEmbedding {
136 dimensions: usize,
137 model_id: String,
138 }
139
140 impl CandleEmbedding {
141 pub fn from_pretrained(model_id: &str, _use_gpu: bool) -> Result<Self> {
161 Err(RuvectorError::ModelLoadError(format!(
162 "Candle embedding support is a stub. Please:\n\
163 1. Use ApiEmbedding for production (recommended)\n\
164 2. Or implement CandleEmbedding for model: {}\n\
165 3. See docs for ONNX Runtime integration examples",
166 model_id
167 )))
168 }
169 }
170
171 impl EmbeddingProvider for CandleEmbedding {
172 fn embed(&self, _text: &str) -> Result<Vec<f32>> {
173 Err(RuvectorError::ModelInferenceError(
174 "Candle embedding not implemented - use ApiEmbedding instead".to_string(),
175 ))
176 }
177
178 fn dimensions(&self) -> usize {
179 self.dimensions
180 }
181
182 fn name(&self) -> &str {
183 "CandleEmbedding (stub - not implemented)"
184 }
185 }
186}
187
188#[cfg(feature = "real-embeddings")]
189pub use candle::CandleEmbedding;
190
191#[cfg(feature = "api-embeddings")]
204#[derive(Clone)]
205pub struct ApiEmbedding {
206 api_key: String,
207 endpoint: String,
208 model: String,
209 dimensions: usize,
210 client: reqwest::blocking::Client,
211}
212
213#[cfg(feature = "api-embeddings")]
214impl ApiEmbedding {
215 pub fn new(api_key: String, endpoint: String, model: String, dimensions: usize) -> Self {
223 Self {
224 api_key,
225 endpoint,
226 model,
227 dimensions,
228 client: reqwest::blocking::Client::new(),
229 }
230 }
231
232 pub fn openai(api_key: &str, model: &str) -> Self {
239 let dimensions = match model {
240 "text-embedding-3-large" => 3072,
241 _ => 1536, };
243
244 Self::new(
245 api_key.to_string(),
246 "https://api.openai.com/v1/embeddings".to_string(),
247 model.to_string(),
248 dimensions,
249 )
250 }
251
252 pub fn cohere(api_key: &str, model: &str) -> Self {
258 Self::new(
259 api_key.to_string(),
260 "https://api.cohere.ai/v1/embed".to_string(),
261 model.to_string(),
262 1024,
263 )
264 }
265
266 pub fn voyage(api_key: &str, model: &str) -> Self {
272 let dimensions = if model.contains("large") { 1536 } else { 1024 };
273
274 Self::new(
275 api_key.to_string(),
276 "https://api.voyageai.com/v1/embeddings".to_string(),
277 model.to_string(),
278 dimensions,
279 )
280 }
281}
282
283#[cfg(feature = "api-embeddings")]
284impl EmbeddingProvider for ApiEmbedding {
285 fn embed(&self, text: &str) -> Result<Vec<f32>> {
286 let request_body = serde_json::json!({
287 "input": text,
288 "model": self.model,
289 });
290
291 let response = self
292 .client
293 .post(&self.endpoint)
294 .header("Authorization", format!("Bearer {}", self.api_key))
295 .header("Content-Type", "application/json")
296 .json(&request_body)
297 .send()
298 .map_err(|e| {
299 RuvectorError::ModelInferenceError(format!("API request failed: {}", e))
300 })?;
301
302 if !response.status().is_success() {
303 let status = response.status();
304 let error_text = response
305 .text()
306 .unwrap_or_else(|_| "Unknown error".to_string());
307 return Err(RuvectorError::ModelInferenceError(format!(
308 "API returned error {}: {}",
309 status, error_text
310 )));
311 }
312
313 let response_json: serde_json::Value = response.json().map_err(|e| {
314 RuvectorError::ModelInferenceError(format!("Failed to parse response: {}", e))
315 })?;
316
317 let embedding = if let Some(data) = response_json.get("data") {
319 data.as_array()
321 .and_then(|arr| arr.first())
322 .and_then(|obj| obj.get("embedding"))
323 .and_then(|emb| emb.as_array())
324 .ok_or_else(|| {
325 RuvectorError::ModelInferenceError("Invalid OpenAI response format".to_string())
326 })?
327 } else if let Some(embeddings) = response_json.get("embeddings") {
328 embeddings
330 .as_array()
331 .and_then(|arr| arr.first())
332 .and_then(|emb| emb.as_array())
333 .ok_or_else(|| {
334 RuvectorError::ModelInferenceError("Invalid Cohere response format".to_string())
335 })?
336 } else {
337 return Err(RuvectorError::ModelInferenceError(
338 "Unknown API response format".to_string(),
339 ));
340 };
341
342 let embedding_vec: Result<Vec<f32>> = embedding
343 .iter()
344 .map(|v| {
345 v.as_f64().map(|f| f as f32).ok_or_else(|| {
346 RuvectorError::ModelInferenceError("Invalid embedding value".to_string())
347 })
348 })
349 .collect();
350
351 embedding_vec
352 }
353
354 fn dimensions(&self) -> usize {
355 self.dimensions
356 }
357
358 fn name(&self) -> &str {
359 "ApiEmbedding"
360 }
361}
362
363#[cfg(feature = "onnx-embeddings")]
394pub mod onnx {
395 use super::*;
396 use crate::error::RuvectorError;
397 use ort::session::Session;
398 use ort::value::{Tensor, ValueType};
399 use parking_lot::RwLock;
400 use std::path::PathBuf;
401 use tokenizers::Tokenizer;
402
403 pub struct OnnxEmbedding {
405 session: RwLock<Session>,
406 tokenizer: RwLock<Tokenizer>,
407 dimensions: usize,
408 model_id: String,
409 #[allow(dead_code)]
410 max_length: usize,
411 }
412
413 impl OnnxEmbedding {
414 pub fn from_pretrained(model_id: &str) -> Result<Self> {
426 let api = hf_hub::api::sync::Api::new().map_err(|e| {
427 RuvectorError::ModelLoadError(format!("Failed to create HuggingFace API: {}", e))
428 })?;
429
430 let repo = api.model(model_id.to_string());
431
432 let model_path = repo
434 .get("model.onnx")
435 .or_else(|_| {
436 repo.get("onnx/model.onnx")
438 })
439 .map_err(|e| {
440 RuvectorError::ModelLoadError(format!(
441 "Failed to download ONNX model from {}: {}. \
442 Make sure the model has an ONNX export available.",
443 model_id, e
444 ))
445 })?;
446
447 let tokenizer_path = repo.get("tokenizer.json").map_err(|e| {
448 RuvectorError::ModelLoadError(format!(
449 "Failed to download tokenizer from {}: {}",
450 model_id, e
451 ))
452 })?;
453
454 Self::from_files(&model_path, &tokenizer_path, model_id)
455 }
456
457 pub fn from_files(
464 model_path: &PathBuf,
465 tokenizer_path: &PathBuf,
466 model_id: &str,
467 ) -> Result<Self> {
468 let _ = ort::init().commit();
470
471 let session = Session::builder()
473 .map_err(|e| {
474 RuvectorError::ModelLoadError(format!(
475 "Failed to create session builder: {}",
476 e
477 ))
478 })?
479 .with_intra_threads(4)
480 .map_err(|e| {
481 RuvectorError::ModelLoadError(format!("Failed to set thread count: {}", e))
482 })?
483 .commit_from_file(model_path)
484 .map_err(|e| {
485 RuvectorError::ModelLoadError(format!("Failed to load ONNX model: {}", e))
486 })?;
487
488 let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(|e| {
490 RuvectorError::ModelLoadError(format!("Failed to load tokenizer: {}", e))
491 })?;
492
493 let dimensions = Self::infer_dimensions(&session, model_id)?;
495
496 let max_length = 512;
498
499 tracing::info!(
500 "Loaded ONNX embedding model: {} ({}D)",
501 model_id,
502 dimensions
503 );
504
505 Ok(Self {
506 session: RwLock::new(session),
507 tokenizer: RwLock::new(tokenizer),
508 dimensions,
509 model_id: model_id.to_string(),
510 max_length,
511 })
512 }
513
514 fn infer_dimensions(session: &Session, model_id: &str) -> Result<usize> {
515 let dimensions = match model_id {
517 id if id.contains("all-MiniLM-L6") => 384,
518 id if id.contains("all-mpnet-base") => 768,
519 id if id.contains("bge-small") => 384,
520 id if id.contains("bge-base") => 768,
521 id if id.contains("bge-large") => 1024,
522 id if id.contains("e5-small") => 384,
523 id if id.contains("e5-base") => 768,
524 id if id.contains("e5-large") => 1024,
525 _ => {
526 if let Some(output) = session.outputs().first() {
528 if let ValueType::Tensor { shape, .. } = output.dtype() {
529 let dims: Vec<i64> = shape.iter().copied().collect();
530 if dims.len() >= 2 {
531 let last_dim = dims[dims.len() - 1];
532 if last_dim > 0 {
533 return Ok(last_dim as usize);
534 }
535 }
536 }
537 }
538 384
540 }
541 };
542
543 Ok(dimensions)
544 }
545
546 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
548 texts.iter().map(|text| self.embed(text)).collect()
549 }
550
551 fn mean_pooling(
552 token_embeddings: &[f32],
553 attention_mask: &[i64],
554 seq_len: usize,
555 hidden_size: usize,
556 ) -> Vec<f32> {
557 let mut pooled = vec![0.0f32; hidden_size];
558 let mut mask_sum = 0.0f32;
559
560 for i in 0..seq_len {
561 let mask = attention_mask[i] as f32;
562 mask_sum += mask;
563 for j in 0..hidden_size {
564 pooled[j] += token_embeddings[i * hidden_size + j] * mask;
565 }
566 }
567
568 if mask_sum > 0.0 {
570 for val in &mut pooled {
571 *val /= mask_sum;
572 }
573 }
574
575 let norm: f32 = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
577 if norm > 0.0 {
578 for val in &mut pooled {
579 *val /= norm;
580 }
581 }
582
583 pooled
584 }
585 }
586
587 impl EmbeddingProvider for OnnxEmbedding {
588 fn embed(&self, text: &str) -> Result<Vec<f32>> {
589 let encoding = {
591 let tokenizer = self.tokenizer.read();
592 tokenizer.encode(text, true).map_err(|e| {
593 RuvectorError::ModelInferenceError(format!("Tokenization failed: {}", e))
594 })?
595 };
596
597 let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
599 let attention_mask: Vec<i64> = encoding
600 .get_attention_mask()
601 .iter()
602 .map(|&x| x as i64)
603 .collect();
604 let token_type_ids: Vec<i64> =
605 encoding.get_type_ids().iter().map(|&x| x as i64).collect();
606
607 let seq_len = input_ids.len();
608
609 let input_ids_tensor =
612 Tensor::<i64>::from_array(([1, seq_len], input_ids.clone().into_boxed_slice()))
613 .map_err(|e| {
614 RuvectorError::ModelInferenceError(format!(
615 "Failed to create input_ids tensor: {}",
616 e
617 ))
618 })?;
619
620 let attention_mask_tensor = Tensor::<i64>::from_array((
621 [1, seq_len],
622 attention_mask.clone().into_boxed_slice(),
623 ))
624 .map_err(|e| {
625 RuvectorError::ModelInferenceError(format!(
626 "Failed to create attention_mask tensor: {}",
627 e
628 ))
629 })?;
630
631 let token_type_ids_tensor =
632 Tensor::<i64>::from_array(([1, seq_len], token_type_ids.into_boxed_slice()))
633 .map_err(|e| {
634 RuvectorError::ModelInferenceError(format!(
635 "Failed to create token_type_ids tensor: {}",
636 e
637 ))
638 })?;
639
640 let (output_data, output_shape_vec) = {
643 let mut session = self.session.write();
644 let outputs = session
645 .run(ort::inputs![
646 "input_ids" => input_ids_tensor,
647 "attention_mask" => attention_mask_tensor,
648 "token_type_ids" => token_type_ids_tensor,
649 ])
650 .map_err(|e| {
651 RuvectorError::ModelInferenceError(format!("ONNX inference failed: {}", e))
652 })?;
653
654 let output_value = &outputs[0];
657
658 let output_array = output_value.try_extract_array::<f32>().map_err(|e| {
660 RuvectorError::ModelInferenceError(format!(
661 "Failed to extract output tensor: {}",
662 e
663 ))
664 })?;
665
666 let output_shape_vec: Vec<usize> = output_array.shape().to_vec();
667 let output_data_vec: Vec<f32> = output_array.iter().copied().collect();
668
669 (output_data_vec, output_shape_vec)
670 };
671
672 let embedding = if output_shape_vec.len() == 3 {
674 let hidden_size = output_shape_vec[2];
676 Self::mean_pooling(&output_data, &attention_mask, seq_len, hidden_size)
677 } else if output_shape_vec.len() == 2 {
678 let mut emb = output_data;
680 let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
682 if norm > 0.0 {
683 for val in &mut emb {
684 *val /= norm;
685 }
686 }
687 emb
688 } else {
689 return Err(RuvectorError::ModelInferenceError(format!(
690 "Unexpected output shape: {:?}",
691 output_shape_vec
692 )));
693 };
694
695 Ok(embedding)
696 }
697
698 fn dimensions(&self) -> usize {
699 self.dimensions
700 }
701
702 fn name(&self) -> &str {
703 &self.model_id
704 }
705 }
706
707 impl std::fmt::Debug for OnnxEmbedding {
708 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
709 f.debug_struct("OnnxEmbedding")
710 .field("model_id", &self.model_id)
711 .field("dimensions", &self.dimensions)
712 .field("max_length", &self.max_length)
713 .finish()
714 }
715 }
716}
717
718#[cfg(feature = "onnx-embeddings")]
719pub use onnx::OnnxEmbedding;
720
721pub type BoxedEmbeddingProvider = Arc<dyn EmbeddingProvider>;
723
724#[cfg(test)]
725mod tests {
726 use super::*;
727
728 #[test]
729 fn test_hash_embedding() {
730 let provider = HashEmbedding::new(128);
731
732 let emb1 = provider.embed("hello world").unwrap();
733 let emb2 = provider.embed("hello world").unwrap();
734
735 assert_eq!(emb1.len(), 128);
736 assert_eq!(emb1, emb2, "Same text should produce same embedding");
737
738 let norm: f32 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt();
740 assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
741 }
742
743 #[test]
744 fn test_hash_embedding_different_text() {
745 let provider = HashEmbedding::new(128);
746
747 let emb1 = provider.embed("hello").unwrap();
748 let emb2 = provider.embed("world").unwrap();
749
750 assert_ne!(
751 emb1, emb2,
752 "Different text should produce different embeddings"
753 );
754 }
755
756 #[cfg(feature = "real-embeddings")]
757 #[test]
758 #[ignore] fn test_candle_embedding() {
760 let provider =
761 CandleEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2", false)
762 .unwrap();
763
764 let embedding = provider.embed("hello world").unwrap();
765 assert_eq!(embedding.len(), 384);
766
767 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
769 assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
770 }
771
772 #[test]
773 #[ignore] fn test_api_embedding_openai() {
775 let api_key = std::env::var("OPENAI_API_KEY").unwrap();
776 let provider = ApiEmbedding::openai(&api_key, "text-embedding-3-small");
777
778 let embedding = provider.embed("hello world").unwrap();
779 assert_eq!(embedding.len(), 1536);
780 }
781
782 #[cfg(feature = "onnx-embeddings")]
783 mod onnx_tests {
784 use super::*;
785
786 #[test]
787 #[ignore] fn test_onnx_embedding_minilm() {
789 let provider =
790 OnnxEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2").unwrap();
791
792 let embedding = provider.embed("hello world").unwrap();
793 assert_eq!(embedding.len(), 384);
794
795 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
797 assert!(
798 (norm - 1.0).abs() < 1e-4,
799 "Embedding should be normalized, got norm={}",
800 norm
801 );
802 }
803
804 #[test]
805 #[ignore] fn test_onnx_semantic_similarity() {
807 let provider =
808 OnnxEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2").unwrap();
809
810 let emb_dog = provider.embed("dog").unwrap();
811 let emb_cat = provider.embed("cat").unwrap();
812 let emb_car = provider.embed("car").unwrap();
813
814 let sim_dog_cat: f32 = emb_dog.iter().zip(&emb_cat).map(|(a, b)| a * b).sum();
816 let sim_dog_car: f32 = emb_dog.iter().zip(&emb_car).map(|(a, b)| a * b).sum();
817
818 assert!(
820 sim_dog_cat > sim_dog_car,
821 "Expected dog-cat similarity ({}) > dog-car similarity ({})",
822 sim_dog_cat,
823 sim_dog_car
824 );
825 }
826
827 #[test]
828 #[ignore] fn test_onnx_batch_embedding() {
830 let provider =
831 OnnxEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2").unwrap();
832
833 let texts = vec!["hello world", "goodbye world", "rust programming"];
834 let embeddings = provider.embed_batch(&texts).unwrap();
835
836 assert_eq!(embeddings.len(), 3);
837 for emb in &embeddings {
838 assert_eq!(emb.len(), 384);
839 }
840 }
841 }
842}