1use moka::sync::Cache;
52use std::sync::Arc;
53
54#[derive(Debug, Clone)]
60pub enum EmbeddingError {
61 ModelNotAvailable(String),
63 TextTooLong { max_length: usize, actual: usize },
65 DimensionMismatch { expected: usize, actual: usize },
67 ProviderError(String),
69 CacheError(String),
71}
72
73impl std::fmt::Display for EmbeddingError {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 match self {
76 Self::ModelNotAvailable(model) => write!(f, "Embedding model not available: {}", model),
77 Self::TextTooLong { max_length, actual } => {
78 write!(f, "Text too long: {} > {} max", actual, max_length)
79 }
80 Self::DimensionMismatch { expected, actual } => {
81 write!(
82 f,
83 "Dimension mismatch: expected {}, got {}",
84 expected, actual
85 )
86 }
87 Self::ProviderError(msg) => write!(f, "Provider error: {}", msg),
88 Self::CacheError(msg) => write!(f, "Cache error: {}", msg),
89 }
90 }
91}
92
93impl std::error::Error for EmbeddingError {}
94
95pub type EmbeddingResult<T> = Result<T, EmbeddingError>;
97
98pub trait EmbeddingProvider: Send + Sync {
100 fn model_name(&self) -> &str;
102
103 fn dimension(&self) -> usize;
105
106 fn max_length(&self) -> usize;
108
109 fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>>;
111
112 fn embed_batch(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
114 texts.iter().map(|t| self.embed(t)).collect()
116 }
117
118 fn normalize(&self, embedding: &mut [f32]) {
120 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
121 if norm > 1e-10 {
122 for x in embedding.iter_mut() {
123 *x /= norm;
124 }
125 }
126 }
127}
128
129#[derive(Debug, Clone)]
135pub struct EmbeddingConfig {
136 pub model: String,
138
139 pub model_path: Option<String>,
141
142 pub dimension: usize,
144
145 pub max_length: usize,
147
148 pub normalize: bool,
150
151 pub batch_size: usize,
153
154 pub cache_size: usize,
156
157 pub cache_ttl_secs: u64,
159}
160
161impl Default for EmbeddingConfig {
162 fn default() -> Self {
163 Self {
164 model: "all-MiniLM-L6-v2".to_string(),
165 model_path: None,
166 dimension: 384, max_length: 512,
168 normalize: true,
169 batch_size: 32,
170 cache_size: 10_000,
171 cache_ttl_secs: 3600, }
173 }
174}
175
176impl EmbeddingConfig {
177 pub fn sentence_transformer(model: &str) -> Self {
179 let dimension = match model {
180 "all-MiniLM-L6-v2" => 384,
181 "all-MiniLM-L12-v2" => 384,
182 "all-mpnet-base-v2" => 768,
183 "paraphrase-MiniLM-L6-v2" => 384,
184 "multi-qa-MiniLM-L6-cos-v1" => 384,
185 _ => 384, };
187
188 Self {
189 model: model.to_string(),
190 dimension,
191 ..Default::default()
192 }
193 }
194
195 pub fn openai(model: &str) -> Self {
197 let dimension = match model {
198 "text-embedding-ada-002" => 1536,
199 "text-embedding-3-small" => 1536,
200 "text-embedding-3-large" => 3072,
201 _ => 1536,
202 };
203
204 Self {
205 model: model.to_string(),
206 dimension,
207 max_length: 8192,
208 ..Default::default()
209 }
210 }
211}
212
213pub struct MockEmbeddingProvider {
219 config: EmbeddingConfig,
220 use_hash: bool,
222}
223
224impl MockEmbeddingProvider {
225 pub fn new(dimension: usize) -> Self {
227 Self {
228 config: EmbeddingConfig {
229 model: "mock".to_string(),
230 dimension,
231 ..Default::default()
232 },
233 use_hash: true,
234 }
235 }
236
237 pub fn with_config(config: EmbeddingConfig) -> Self {
239 Self {
240 config,
241 use_hash: true,
242 }
243 }
244
245 fn hash_embed(&self, text: &str) -> Vec<f32> {
247 use std::collections::hash_map::DefaultHasher;
248 use std::hash::{Hash, Hasher};
249
250 let mut embedding = Vec::with_capacity(self.config.dimension);
251
252 for i in 0..self.config.dimension {
254 let mut hasher = DefaultHasher::new();
255 text.hash(&mut hasher);
256 i.hash(&mut hasher);
257 let hash = hasher.finish();
258
259 let value = ((hash as f64) / (u64::MAX as f64) * 2.0 - 1.0) as f32;
261 embedding.push(value);
262 }
263
264 embedding
265 }
266}
267
268impl EmbeddingProvider for MockEmbeddingProvider {
269 fn model_name(&self) -> &str {
270 &self.config.model
271 }
272
273 fn dimension(&self) -> usize {
274 self.config.dimension
275 }
276
277 fn max_length(&self) -> usize {
278 self.config.max_length
279 }
280
281 fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
282 if text.len() > self.config.max_length {
283 return Err(EmbeddingError::TextTooLong {
284 max_length: self.config.max_length,
285 actual: text.len(),
286 });
287 }
288
289 let mut embedding = if self.use_hash {
290 self.hash_embed(text)
291 } else {
292 vec![0.0; self.config.dimension]
293 };
294
295 if self.config.normalize {
296 self.normalize(&mut embedding);
297 }
298
299 Ok(embedding)
300 }
301}
302
303pub struct CachedEmbeddingProvider<P: EmbeddingProvider> {
309 inner: P,
311
312 cache: Cache<u64, Vec<f32>>,
314
315 stats: Arc<CacheStats>,
317}
318
319#[derive(Debug, Default)]
321pub struct CacheStats {
322 pub hits: std::sync::atomic::AtomicUsize,
324 pub misses: std::sync::atomic::AtomicUsize,
326 pub size: std::sync::atomic::AtomicUsize,
328}
329
330impl CacheStats {
331 pub fn hit_rate(&self) -> f64 {
333 let hits = self.hits.load(std::sync::atomic::Ordering::Relaxed);
334 let misses = self.misses.load(std::sync::atomic::Ordering::Relaxed);
335 let total = hits + misses;
336 if total == 0 {
337 0.0
338 } else {
339 hits as f64 / total as f64
340 }
341 }
342}
343
344impl<P: EmbeddingProvider> CachedEmbeddingProvider<P> {
345 pub fn new(inner: P, cache_size: usize) -> Self {
347 Self {
348 inner,
349 cache: Cache::new(cache_size as u64),
350 stats: Arc::new(CacheStats::default()),
351 }
352 }
353
354 pub fn with_ttl(inner: P, cache_size: usize, ttl_secs: u64) -> Self {
356 let cache = Cache::builder()
357 .max_capacity(cache_size as u64)
358 .time_to_live(std::time::Duration::from_secs(ttl_secs))
359 .build();
360
361 Self {
362 inner,
363 cache,
364 stats: Arc::new(CacheStats::default()),
365 }
366 }
367
368 pub fn stats(&self) -> &Arc<CacheStats> {
370 &self.stats
371 }
372
373 fn text_hash(text: &str) -> u64 {
375 use std::collections::hash_map::DefaultHasher;
376 use std::hash::{Hash, Hasher};
377
378 let mut hasher = DefaultHasher::new();
379 text.hash(&mut hasher);
380 hasher.finish()
381 }
382}
383
384impl<P: EmbeddingProvider> EmbeddingProvider for CachedEmbeddingProvider<P> {
385 fn model_name(&self) -> &str {
386 self.inner.model_name()
387 }
388
389 fn dimension(&self) -> usize {
390 self.inner.dimension()
391 }
392
393 fn max_length(&self) -> usize {
394 self.inner.max_length()
395 }
396
397 fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
398 let hash = Self::text_hash(text);
399
400 if let Some(cached) = self.cache.get(&hash) {
402 self.stats
403 .hits
404 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
405 return Ok(cached);
406 }
407
408 self.stats
409 .misses
410 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
411
412 let embedding = self.inner.embed(text)?;
414
415 self.cache.insert(hash, embedding.clone());
417 self.stats
418 .size
419 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
420
421 Ok(embedding)
422 }
423
424 fn embed_batch(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
425 let mut results = Vec::with_capacity(texts.len());
426 let mut uncached: Vec<(usize, &str)> = Vec::new();
427
428 for (i, text) in texts.iter().enumerate() {
430 let hash = Self::text_hash(text);
431 if let Some(cached) = self.cache.get(&hash) {
432 self.stats
433 .hits
434 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
435 results.push((i, cached));
436 } else {
437 self.stats
438 .misses
439 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
440 uncached.push((i, *text));
441 }
442 }
443
444 if !uncached.is_empty() {
446 let uncached_texts: Vec<&str> = uncached.iter().map(|(_, t)| *t).collect();
447 let embeddings = self.inner.embed_batch(&uncached_texts)?;
448
449 for ((i, text), embedding) in uncached.iter().zip(embeddings.into_iter()) {
450 let hash = Self::text_hash(text);
451 self.cache.insert(hash, embedding.clone());
452 self.stats
453 .size
454 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
455 results.push((*i, embedding));
456 }
457 }
458
459 results.sort_by_key(|(i, _)| *i);
461 Ok(results.into_iter().map(|(_, e)| e).collect())
462 }
463}
464
465#[derive(Debug)]
476pub struct LocalOnnxProvider {
477 config: EmbeddingConfig,
478 #[allow(dead_code)]
480 model_loaded: bool,
481}
482
483impl LocalOnnxProvider {
484 pub fn new(config: EmbeddingConfig) -> EmbeddingResult<Self> {
486 Ok(Self {
488 config,
489 model_loaded: false,
490 })
491 }
492
493 pub fn load_pretrained(model_name: &str) -> EmbeddingResult<Self> {
495 let config = EmbeddingConfig::sentence_transformer(model_name);
496 Self::new(config)
497 }
498}
499
500impl EmbeddingProvider for LocalOnnxProvider {
501 fn model_name(&self) -> &str {
502 &self.config.model
503 }
504
505 fn dimension(&self) -> usize {
506 self.config.dimension
507 }
508
509 fn max_length(&self) -> usize {
510 self.config.max_length
511 }
512
513 fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
514 let mock = MockEmbeddingProvider::with_config(self.config.clone());
517 mock.embed(text)
518 }
519}
520
521pub struct EmbeddingVectorIndex<V, P>
527where
528 V: crate::context_query::VectorIndex,
529 P: EmbeddingProvider,
530{
531 index: Arc<V>,
533
534 provider: Arc<P>,
536}
537
538impl<V, P> EmbeddingVectorIndex<V, P>
539where
540 V: crate::context_query::VectorIndex,
541 P: EmbeddingProvider,
542{
543 pub fn new(index: Arc<V>, provider: Arc<P>) -> Self {
545 Self { index, provider }
546 }
547
548 pub fn search_text(
550 &self,
551 collection: &str,
552 text: &str,
553 k: usize,
554 min_score: Option<f32>,
555 ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
556 let embedding = self.provider.embed(text).map_err(|e| e.to_string())?;
558
559 self.index
561 .search_by_embedding(collection, &embedding, k, min_score)
562 }
563
564 pub fn search_embedding(
566 &self,
567 collection: &str,
568 embedding: &[f32],
569 k: usize,
570 min_score: Option<f32>,
571 ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
572 if embedding.len() != self.provider.dimension() {
574 return Err(format!(
575 "Embedding dimension mismatch: expected {}, got {}",
576 self.provider.dimension(),
577 embedding.len()
578 ));
579 }
580
581 self.index
582 .search_by_embedding(collection, embedding, k, min_score)
583 }
584
585 pub fn provider(&self) -> &Arc<P> {
587 &self.provider
588 }
589
590 pub fn index(&self) -> &Arc<V> {
592 &self.index
593 }
594}
595
596impl<V, P> crate::context_query::VectorIndex for EmbeddingVectorIndex<V, P>
597where
598 V: crate::context_query::VectorIndex,
599 P: EmbeddingProvider,
600{
601 fn search_by_embedding(
602 &self,
603 collection: &str,
604 embedding: &[f32],
605 k: usize,
606 min_score: Option<f32>,
607 ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
608 self.search_embedding(collection, embedding, k, min_score)
609 }
610
611 fn search_by_text(
612 &self,
613 collection: &str,
614 text: &str,
615 k: usize,
616 min_score: Option<f32>,
617 ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
618 self.search_text(collection, text, k, min_score)
619 }
620
621 fn stats(&self, collection: &str) -> Option<crate::context_query::VectorIndexStats> {
622 self.index.stats(collection)
623 }
624}
625
626pub fn create_mock_provider(
632 dimension: usize,
633 cache_size: usize,
634) -> CachedEmbeddingProvider<MockEmbeddingProvider> {
635 let mock = MockEmbeddingProvider::new(dimension);
636 CachedEmbeddingProvider::new(mock, cache_size)
637}
638
639pub fn create_embedding_index<V: crate::context_query::VectorIndex>(
641 index: Arc<V>,
642 dimension: usize,
643) -> EmbeddingVectorIndex<V, CachedEmbeddingProvider<MockEmbeddingProvider>> {
644 let provider = Arc::new(create_mock_provider(dimension, 10_000));
645 EmbeddingVectorIndex::new(index, provider)
646}
647
648#[cfg(test)]
653mod tests {
654 use super::*;
655
656 #[test]
657 fn test_mock_embedding_deterministic() {
658 let provider = MockEmbeddingProvider::new(384);
659
660 let emb1 = provider.embed("hello world").unwrap();
661 let emb2 = provider.embed("hello world").unwrap();
662
663 assert_eq!(emb1, emb2);
664 assert_eq!(emb1.len(), 384);
665 }
666
667 #[test]
668 fn test_mock_embedding_different_texts() {
669 let provider = MockEmbeddingProvider::new(384);
670
671 let emb1 = provider.embed("hello").unwrap();
672 let emb2 = provider.embed("world").unwrap();
673
674 assert_ne!(emb1, emb2);
675 }
676
677 #[test]
678 fn test_cached_provider() {
679 let mock = MockEmbeddingProvider::new(128);
680 let cached = CachedEmbeddingProvider::new(mock, 100);
681
682 let _ = cached.embed("test text").unwrap();
684 assert_eq!(
685 cached
686 .stats()
687 .hits
688 .load(std::sync::atomic::Ordering::Relaxed),
689 0
690 );
691 assert_eq!(
692 cached
693 .stats()
694 .misses
695 .load(std::sync::atomic::Ordering::Relaxed),
696 1
697 );
698
699 let _ = cached.embed("test text").unwrap();
701 assert_eq!(
702 cached
703 .stats()
704 .hits
705 .load(std::sync::atomic::Ordering::Relaxed),
706 1
707 );
708 assert_eq!(
709 cached
710 .stats()
711 .misses
712 .load(std::sync::atomic::Ordering::Relaxed),
713 1
714 );
715
716 assert!(cached.stats().hit_rate() > 0.4);
717 }
718
719 #[test]
720 fn test_batch_embedding() {
721 let mock = MockEmbeddingProvider::new(128);
722 let cached = CachedEmbeddingProvider::new(mock, 100);
723
724 let texts = vec!["hello", "world", "test"];
725 let embeddings = cached.embed_batch(&texts).unwrap();
726
727 assert_eq!(embeddings.len(), 3);
728 for emb in &embeddings {
729 assert_eq!(emb.len(), 128);
730 }
731 }
732
733 #[test]
734 fn test_normalization() {
735 let provider = MockEmbeddingProvider::new(3);
736 let emb = provider.embed("test").unwrap();
737
738 let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
740 assert!((norm - 1.0).abs() < 1e-5);
741 }
742
743 #[test]
744 fn test_text_too_long() {
745 let config = EmbeddingConfig {
746 max_length: 10,
747 ..Default::default()
748 };
749 let provider = MockEmbeddingProvider::with_config(config);
750
751 let result = provider.embed("this is a very long text that exceeds the limit");
752 assert!(matches!(result, Err(EmbeddingError::TextTooLong { .. })));
753 }
754}