1use crate::error::Result;
36#[cfg(any(feature = "real-embeddings", feature = "api-embeddings"))]
37use crate::error::RuvectorError;
38use std::sync::Arc;
39
40pub trait EmbeddingProvider: Send + Sync {
42 fn embed(&self, text: &str) -> Result<Vec<f32>>;
44
45 fn dimensions(&self) -> usize;
47
48 fn name(&self) -> &str;
50}
51
52#[derive(Debug, Clone)]
63pub struct HashEmbedding {
64 dimensions: usize,
65}
66
67impl HashEmbedding {
68 pub fn new(dimensions: usize) -> Self {
70 Self { dimensions }
71 }
72}
73
74impl EmbeddingProvider for HashEmbedding {
75 fn embed(&self, text: &str) -> Result<Vec<f32>> {
76 let mut embedding = vec![0.0; self.dimensions];
77 let bytes = text.as_bytes();
78
79 for (i, byte) in bytes.iter().enumerate() {
80 embedding[i % self.dimensions] += (*byte as f32) / 255.0;
81 }
82
83 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
85 if norm > 0.0 {
86 for val in &mut embedding {
87 *val /= norm;
88 }
89 }
90
91 Ok(embedding)
92 }
93
94 fn dimensions(&self) -> usize {
95 self.dimensions
96 }
97
98 fn name(&self) -> &str {
99 "HashEmbedding (placeholder)"
100 }
101}
102
103#[cfg(feature = "real-embeddings")]
116pub mod candle {
117 use super::*;
118
119 pub struct CandleEmbedding {
136 dimensions: usize,
137 model_id: String,
138 }
139
140 impl CandleEmbedding {
141 pub fn from_pretrained(model_id: &str, _use_gpu: bool) -> Result<Self> {
161 Err(RuvectorError::ModelLoadError(format!(
162 "Candle embedding support is a stub. Please:\n\
163 1. Use ApiEmbedding for production (recommended)\n\
164 2. Or implement CandleEmbedding for model: {}\n\
165 3. See docs for ONNX Runtime integration examples",
166 model_id
167 )))
168 }
169 }
170
171 impl EmbeddingProvider for CandleEmbedding {
172 fn embed(&self, _text: &str) -> Result<Vec<f32>> {
173 Err(RuvectorError::ModelInferenceError(
174 "Candle embedding not implemented - use ApiEmbedding instead".to_string(),
175 ))
176 }
177
178 fn dimensions(&self) -> usize {
179 self.dimensions
180 }
181
182 fn name(&self) -> &str {
183 "CandleEmbedding (stub - not implemented)"
184 }
185 }
186}
187
188#[cfg(feature = "real-embeddings")]
189pub use candle::CandleEmbedding;
190
191#[cfg(feature = "api-embeddings")]
204#[derive(Clone)]
205pub struct ApiEmbedding {
206 api_key: String,
207 endpoint: String,
208 model: String,
209 dimensions: usize,
210 client: reqwest::blocking::Client,
211}
212
213#[cfg(feature = "api-embeddings")]
214impl ApiEmbedding {
215 pub fn new(api_key: String, endpoint: String, model: String, dimensions: usize) -> Self {
223 Self {
224 api_key,
225 endpoint,
226 model,
227 dimensions,
228 client: reqwest::blocking::Client::new(),
229 }
230 }
231
232 pub fn openai(api_key: &str, model: &str) -> Self {
239 let dimensions = match model {
240 "text-embedding-3-large" => 3072,
241 _ => 1536, };
243
244 Self::new(
245 api_key.to_string(),
246 "https://api.openai.com/v1/embeddings".to_string(),
247 model.to_string(),
248 dimensions,
249 )
250 }
251
252 pub fn cohere(api_key: &str, model: &str) -> Self {
258 Self::new(
259 api_key.to_string(),
260 "https://api.cohere.ai/v1/embed".to_string(),
261 model.to_string(),
262 1024,
263 )
264 }
265
266 pub fn voyage(api_key: &str, model: &str) -> Self {
272 let dimensions = if model.contains("large") { 1536 } else { 1024 };
273
274 Self::new(
275 api_key.to_string(),
276 "https://api.voyageai.com/v1/embeddings".to_string(),
277 model.to_string(),
278 dimensions,
279 )
280 }
281}
282
283#[cfg(feature = "api-embeddings")]
284impl EmbeddingProvider for ApiEmbedding {
285 fn embed(&self, text: &str) -> Result<Vec<f32>> {
286 let request_body = serde_json::json!({
287 "input": text,
288 "model": self.model,
289 });
290
291 let response = self
292 .client
293 .post(&self.endpoint)
294 .header("Authorization", format!("Bearer {}", self.api_key))
295 .header("Content-Type", "application/json")
296 .json(&request_body)
297 .send()
298 .map_err(|e| {
299 RuvectorError::ModelInferenceError(format!("API request failed: {}", e))
300 })?;
301
302 if !response.status().is_success() {
303 let status = response.status();
304 let error_text = response
305 .text()
306 .unwrap_or_else(|_| "Unknown error".to_string());
307 return Err(RuvectorError::ModelInferenceError(format!(
308 "API returned error {}: {}",
309 status, error_text
310 )));
311 }
312
313 let response_json: serde_json::Value = response.json().map_err(|e| {
314 RuvectorError::ModelInferenceError(format!("Failed to parse response: {}", e))
315 })?;
316
317 let embedding = if let Some(data) = response_json.get("data") {
319 data.as_array()
321 .and_then(|arr| arr.first())
322 .and_then(|obj| obj.get("embedding"))
323 .and_then(|emb| emb.as_array())
324 .ok_or_else(|| {
325 RuvectorError::ModelInferenceError("Invalid OpenAI response format".to_string())
326 })?
327 } else if let Some(embeddings) = response_json.get("embeddings") {
328 embeddings
330 .as_array()
331 .and_then(|arr| arr.first())
332 .and_then(|emb| emb.as_array())
333 .ok_or_else(|| {
334 RuvectorError::ModelInferenceError("Invalid Cohere response format".to_string())
335 })?
336 } else {
337 return Err(RuvectorError::ModelInferenceError(
338 "Unknown API response format".to_string(),
339 ));
340 };
341
342 let embedding_vec: Result<Vec<f32>> = embedding
343 .iter()
344 .map(|v| {
345 v.as_f64().map(|f| f as f32).ok_or_else(|| {
346 RuvectorError::ModelInferenceError("Invalid embedding value".to_string())
347 })
348 })
349 .collect();
350
351 embedding_vec
352 }
353
354 fn dimensions(&self) -> usize {
355 self.dimensions
356 }
357
358 fn name(&self) -> &str {
359 "ApiEmbedding"
360 }
361}
362
363#[cfg(feature = "onnx-embeddings")]
394pub mod onnx {
395 use super::*;
396 use crate::error::RuvectorError;
397 use ort::session::Session;
398 use ort::value::{Tensor, ValueType};
399 use parking_lot::RwLock;
400 use std::path::PathBuf;
401 use tokenizers::Tokenizer;
402
403 pub struct OnnxEmbedding {
405 session: RwLock<Session>,
406 tokenizer: RwLock<Tokenizer>,
407 dimensions: usize,
408 model_id: String,
409 #[allow(dead_code)]
410 max_length: usize,
411 }
412
413 impl OnnxEmbedding {
414 pub fn from_pretrained(model_id: &str) -> Result<Self> {
426 let api = hf_hub::api::sync::Api::new().map_err(|e| {
427 RuvectorError::ModelLoadError(format!("Failed to create HuggingFace API: {}", e))
428 })?;
429
430 let repo = api.model(model_id.to_string());
431
432 let model_path = repo.get("model.onnx").or_else(|_| {
434 repo.get("onnx/model.onnx")
436 }).map_err(|e| {
437 RuvectorError::ModelLoadError(format!(
438 "Failed to download ONNX model from {}: {}. \
439 Make sure the model has an ONNX export available.",
440 model_id, e
441 ))
442 })?;
443
444 let tokenizer_path = repo.get("tokenizer.json").map_err(|e| {
445 RuvectorError::ModelLoadError(format!(
446 "Failed to download tokenizer from {}: {}",
447 model_id, e
448 ))
449 })?;
450
451 Self::from_files(&model_path, &tokenizer_path, model_id)
452 }
453
454 pub fn from_files(
461 model_path: &PathBuf,
462 tokenizer_path: &PathBuf,
463 model_id: &str,
464 ) -> Result<Self> {
465 let _ = ort::init().commit();
467
468 let session = Session::builder()
470 .map_err(|e| {
471 RuvectorError::ModelLoadError(format!("Failed to create session builder: {}", e))
472 })?
473 .with_intra_threads(4)
474 .map_err(|e| {
475 RuvectorError::ModelLoadError(format!("Failed to set thread count: {}", e))
476 })?
477 .commit_from_file(model_path)
478 .map_err(|e| {
479 RuvectorError::ModelLoadError(format!("Failed to load ONNX model: {}", e))
480 })?;
481
482 let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(|e| {
484 RuvectorError::ModelLoadError(format!("Failed to load tokenizer: {}", e))
485 })?;
486
487 let dimensions = Self::infer_dimensions(&session, model_id)?;
489
490 let max_length = 512;
492
493 tracing::info!(
494 "Loaded ONNX embedding model: {} ({}D)",
495 model_id,
496 dimensions
497 );
498
499 Ok(Self {
500 session: RwLock::new(session),
501 tokenizer: RwLock::new(tokenizer),
502 dimensions,
503 model_id: model_id.to_string(),
504 max_length,
505 })
506 }
507
508 fn infer_dimensions(session: &Session, model_id: &str) -> Result<usize> {
509 let dimensions = match model_id {
511 id if id.contains("all-MiniLM-L6") => 384,
512 id if id.contains("all-mpnet-base") => 768,
513 id if id.contains("bge-small") => 384,
514 id if id.contains("bge-base") => 768,
515 id if id.contains("bge-large") => 1024,
516 id if id.contains("e5-small") => 384,
517 id if id.contains("e5-base") => 768,
518 id if id.contains("e5-large") => 1024,
519 _ => {
520 if let Some(output) = session.outputs().first() {
522 if let ValueType::Tensor { shape, .. } = output.dtype() {
523 let dims: Vec<i64> = shape.iter().copied().collect();
524 if dims.len() >= 2 {
525 let last_dim = dims[dims.len() - 1];
526 if last_dim > 0 {
527 return Ok(last_dim as usize);
528 }
529 }
530 }
531 }
532 384
534 }
535 };
536
537 Ok(dimensions)
538 }
539
540 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
542 texts.iter().map(|text| self.embed(text)).collect()
543 }
544
545 fn mean_pooling(
546 token_embeddings: &[f32],
547 attention_mask: &[i64],
548 seq_len: usize,
549 hidden_size: usize,
550 ) -> Vec<f32> {
551 let mut pooled = vec![0.0f32; hidden_size];
552 let mut mask_sum = 0.0f32;
553
554 for i in 0..seq_len {
555 let mask = attention_mask[i] as f32;
556 mask_sum += mask;
557 for j in 0..hidden_size {
558 pooled[j] += token_embeddings[i * hidden_size + j] * mask;
559 }
560 }
561
562 if mask_sum > 0.0 {
564 for val in &mut pooled {
565 *val /= mask_sum;
566 }
567 }
568
569 let norm: f32 = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
571 if norm > 0.0 {
572 for val in &mut pooled {
573 *val /= norm;
574 }
575 }
576
577 pooled
578 }
579 }
580
581 impl EmbeddingProvider for OnnxEmbedding {
582 fn embed(&self, text: &str) -> Result<Vec<f32>> {
583 let encoding = {
585 let tokenizer = self.tokenizer.read();
586 tokenizer
587 .encode(text, true)
588 .map_err(|e| {
589 RuvectorError::ModelInferenceError(format!("Tokenization failed: {}", e))
590 })?
591 };
592
593 let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
595 let attention_mask: Vec<i64> = encoding
596 .get_attention_mask()
597 .iter()
598 .map(|&x| x as i64)
599 .collect();
600 let token_type_ids: Vec<i64> = encoding
601 .get_type_ids()
602 .iter()
603 .map(|&x| x as i64)
604 .collect();
605
606 let seq_len = input_ids.len();
607
608 let input_ids_tensor = Tensor::<i64>::from_array(([1, seq_len], input_ids.clone().into_boxed_slice()))
611 .map_err(|e| {
612 RuvectorError::ModelInferenceError(format!(
613 "Failed to create input_ids tensor: {}",
614 e
615 ))
616 })?;
617
618 let attention_mask_tensor =
619 Tensor::<i64>::from_array(([1, seq_len], attention_mask.clone().into_boxed_slice())).map_err(|e| {
620 RuvectorError::ModelInferenceError(format!(
621 "Failed to create attention_mask tensor: {}",
622 e
623 ))
624 })?;
625
626 let token_type_ids_tensor =
627 Tensor::<i64>::from_array(([1, seq_len], token_type_ids.into_boxed_slice())).map_err(|e| {
628 RuvectorError::ModelInferenceError(format!(
629 "Failed to create token_type_ids tensor: {}",
630 e
631 ))
632 })?;
633
634 let (output_data, output_shape_vec) = {
637 let mut session = self.session.write();
638 let outputs = session
639 .run(ort::inputs![
640 "input_ids" => input_ids_tensor,
641 "attention_mask" => attention_mask_tensor,
642 "token_type_ids" => token_type_ids_tensor,
643 ])
644 .map_err(|e| {
645 RuvectorError::ModelInferenceError(format!("ONNX inference failed: {}", e))
646 })?;
647
648 let output_value = &outputs[0];
651
652 let output_array = output_value.try_extract_array::<f32>().map_err(|e| {
654 RuvectorError::ModelInferenceError(format!("Failed to extract output tensor: {}", e))
655 })?;
656
657 let output_shape_vec: Vec<usize> = output_array.shape().to_vec();
658 let output_data_vec: Vec<f32> = output_array.iter().copied().collect();
659
660 (output_data_vec, output_shape_vec)
661 };
662
663 let embedding = if output_shape_vec.len() == 3 {
665 let hidden_size = output_shape_vec[2];
667 Self::mean_pooling(&output_data, &attention_mask, seq_len, hidden_size)
668 } else if output_shape_vec.len() == 2 {
669 let mut emb = output_data;
671 let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
673 if norm > 0.0 {
674 for val in &mut emb {
675 *val /= norm;
676 }
677 }
678 emb
679 } else {
680 return Err(RuvectorError::ModelInferenceError(format!(
681 "Unexpected output shape: {:?}",
682 output_shape_vec
683 )));
684 };
685
686 Ok(embedding)
687 }
688
689 fn dimensions(&self) -> usize {
690 self.dimensions
691 }
692
693 fn name(&self) -> &str {
694 &self.model_id
695 }
696 }
697
698 impl std::fmt::Debug for OnnxEmbedding {
699 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
700 f.debug_struct("OnnxEmbedding")
701 .field("model_id", &self.model_id)
702 .field("dimensions", &self.dimensions)
703 .field("max_length", &self.max_length)
704 .finish()
705 }
706 }
707}
708
709#[cfg(feature = "onnx-embeddings")]
710pub use onnx::OnnxEmbedding;
711
712pub type BoxedEmbeddingProvider = Arc<dyn EmbeddingProvider>;
714
715#[cfg(test)]
716mod tests {
717 use super::*;
718
719 #[test]
720 fn test_hash_embedding() {
721 let provider = HashEmbedding::new(128);
722
723 let emb1 = provider.embed("hello world").unwrap();
724 let emb2 = provider.embed("hello world").unwrap();
725
726 assert_eq!(emb1.len(), 128);
727 assert_eq!(emb1, emb2, "Same text should produce same embedding");
728
729 let norm: f32 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt();
731 assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
732 }
733
734 #[test]
735 fn test_hash_embedding_different_text() {
736 let provider = HashEmbedding::new(128);
737
738 let emb1 = provider.embed("hello").unwrap();
739 let emb2 = provider.embed("world").unwrap();
740
741 assert_ne!(
742 emb1, emb2,
743 "Different text should produce different embeddings"
744 );
745 }
746
747 #[cfg(feature = "real-embeddings")]
748 #[test]
749 #[ignore] fn test_candle_embedding() {
751 let provider =
752 CandleEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2", false)
753 .unwrap();
754
755 let embedding = provider.embed("hello world").unwrap();
756 assert_eq!(embedding.len(), 384);
757
758 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
760 assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
761 }
762
763 #[test]
764 #[ignore] fn test_api_embedding_openai() {
766 let api_key = std::env::var("OPENAI_API_KEY").unwrap();
767 let provider = ApiEmbedding::openai(&api_key, "text-embedding-3-small");
768
769 let embedding = provider.embed("hello world").unwrap();
770 assert_eq!(embedding.len(), 1536);
771 }
772
773 #[cfg(feature = "onnx-embeddings")]
774 mod onnx_tests {
775 use super::*;
776
777 #[test]
778 #[ignore] fn test_onnx_embedding_minilm() {
780 let provider =
781 OnnxEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2").unwrap();
782
783 let embedding = provider.embed("hello world").unwrap();
784 assert_eq!(embedding.len(), 384);
785
786 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
788 assert!(
789 (norm - 1.0).abs() < 1e-4,
790 "Embedding should be normalized, got norm={}",
791 norm
792 );
793 }
794
795 #[test]
796 #[ignore] fn test_onnx_semantic_similarity() {
798 let provider =
799 OnnxEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2").unwrap();
800
801 let emb_dog = provider.embed("dog").unwrap();
802 let emb_cat = provider.embed("cat").unwrap();
803 let emb_car = provider.embed("car").unwrap();
804
805 let sim_dog_cat: f32 = emb_dog.iter().zip(&emb_cat).map(|(a, b)| a * b).sum();
807 let sim_dog_car: f32 = emb_dog.iter().zip(&emb_car).map(|(a, b)| a * b).sum();
808
809 assert!(
811 sim_dog_cat > sim_dog_car,
812 "Expected dog-cat similarity ({}) > dog-car similarity ({})",
813 sim_dog_cat,
814 sim_dog_car
815 );
816 }
817
818 #[test]
819 #[ignore] fn test_onnx_batch_embedding() {
821 let provider =
822 OnnxEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2").unwrap();
823
824 let texts = vec!["hello world", "goodbye world", "rust programming"];
825 let embeddings = provider.embed_batch(&texts).unwrap();
826
827 assert_eq!(embeddings.len(), 3);
828 for emb in &embeddings {
829 assert_eq!(emb.len(), 384);
830 }
831 }
832 }
833}