1use std::sync::Arc;
52use moka::sync::Cache;
53
54#[derive(Debug, Clone)]
60pub enum EmbeddingError {
61 ModelNotAvailable(String),
63 TextTooLong { max_length: usize, actual: usize },
65 DimensionMismatch { expected: usize, actual: usize },
67 ProviderError(String),
69 CacheError(String),
71}
72
73impl std::fmt::Display for EmbeddingError {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 match self {
76 Self::ModelNotAvailable(model) => write!(f, "Embedding model not available: {}", model),
77 Self::TextTooLong { max_length, actual } => {
78 write!(f, "Text too long: {} > {} max", actual, max_length)
79 }
80 Self::DimensionMismatch { expected, actual } => {
81 write!(f, "Dimension mismatch: expected {}, got {}", expected, actual)
82 }
83 Self::ProviderError(msg) => write!(f, "Provider error: {}", msg),
84 Self::CacheError(msg) => write!(f, "Cache error: {}", msg),
85 }
86 }
87}
88
89impl std::error::Error for EmbeddingError {}
90
91pub type EmbeddingResult<T> = Result<T, EmbeddingError>;
93
94pub trait EmbeddingProvider: Send + Sync {
96 fn model_name(&self) -> &str;
98
99 fn dimension(&self) -> usize;
101
102 fn max_length(&self) -> usize;
104
105 fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>>;
107
108 fn embed_batch(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
110 texts.iter().map(|t| self.embed(t)).collect()
112 }
113
114 fn normalize(&self, embedding: &mut [f32]) {
116 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
117 if norm > 1e-10 {
118 for x in embedding.iter_mut() {
119 *x /= norm;
120 }
121 }
122 }
123}
124
125#[derive(Debug, Clone)]
131pub struct EmbeddingConfig {
132 pub model: String,
134
135 pub model_path: Option<String>,
137
138 pub dimension: usize,
140
141 pub max_length: usize,
143
144 pub normalize: bool,
146
147 pub batch_size: usize,
149
150 pub cache_size: usize,
152
153 pub cache_ttl_secs: u64,
155}
156
157impl Default for EmbeddingConfig {
158 fn default() -> Self {
159 Self {
160 model: "all-MiniLM-L6-v2".to_string(),
161 model_path: None,
162 dimension: 384, max_length: 512,
164 normalize: true,
165 batch_size: 32,
166 cache_size: 10_000,
167 cache_ttl_secs: 3600, }
169 }
170}
171
172impl EmbeddingConfig {
173 pub fn sentence_transformer(model: &str) -> Self {
175 let dimension = match model {
176 "all-MiniLM-L6-v2" => 384,
177 "all-MiniLM-L12-v2" => 384,
178 "all-mpnet-base-v2" => 768,
179 "paraphrase-MiniLM-L6-v2" => 384,
180 "multi-qa-MiniLM-L6-cos-v1" => 384,
181 _ => 384, };
183
184 Self {
185 model: model.to_string(),
186 dimension,
187 ..Default::default()
188 }
189 }
190
191 pub fn openai(model: &str) -> Self {
193 let dimension = match model {
194 "text-embedding-ada-002" => 1536,
195 "text-embedding-3-small" => 1536,
196 "text-embedding-3-large" => 3072,
197 _ => 1536,
198 };
199
200 Self {
201 model: model.to_string(),
202 dimension,
203 max_length: 8192,
204 ..Default::default()
205 }
206 }
207}
208
209pub struct MockEmbeddingProvider {
215 config: EmbeddingConfig,
216 use_hash: bool,
218}
219
220impl MockEmbeddingProvider {
221 pub fn new(dimension: usize) -> Self {
223 Self {
224 config: EmbeddingConfig {
225 model: "mock".to_string(),
226 dimension,
227 ..Default::default()
228 },
229 use_hash: true,
230 }
231 }
232
233 pub fn with_config(config: EmbeddingConfig) -> Self {
235 Self {
236 config,
237 use_hash: true,
238 }
239 }
240
241 fn hash_embed(&self, text: &str) -> Vec<f32> {
243 use std::hash::{Hash, Hasher};
244 use std::collections::hash_map::DefaultHasher;
245
246 let mut embedding = Vec::with_capacity(self.config.dimension);
247
248 for i in 0..self.config.dimension {
250 let mut hasher = DefaultHasher::new();
251 text.hash(&mut hasher);
252 i.hash(&mut hasher);
253 let hash = hasher.finish();
254
255 let value = ((hash as f64) / (u64::MAX as f64) * 2.0 - 1.0) as f32;
257 embedding.push(value);
258 }
259
260 embedding
261 }
262}
263
264impl EmbeddingProvider for MockEmbeddingProvider {
265 fn model_name(&self) -> &str {
266 &self.config.model
267 }
268
269 fn dimension(&self) -> usize {
270 self.config.dimension
271 }
272
273 fn max_length(&self) -> usize {
274 self.config.max_length
275 }
276
277 fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
278 if text.len() > self.config.max_length {
279 return Err(EmbeddingError::TextTooLong {
280 max_length: self.config.max_length,
281 actual: text.len(),
282 });
283 }
284
285 let mut embedding = if self.use_hash {
286 self.hash_embed(text)
287 } else {
288 vec![0.0; self.config.dimension]
289 };
290
291 if self.config.normalize {
292 self.normalize(&mut embedding);
293 }
294
295 Ok(embedding)
296 }
297}
298
299pub struct CachedEmbeddingProvider<P: EmbeddingProvider> {
305 inner: P,
307
308 cache: Cache<u64, Vec<f32>>,
310
311 stats: Arc<CacheStats>,
313}
314
315#[derive(Debug, Default)]
317pub struct CacheStats {
318 pub hits: std::sync::atomic::AtomicUsize,
320 pub misses: std::sync::atomic::AtomicUsize,
322 pub size: std::sync::atomic::AtomicUsize,
324}
325
326impl CacheStats {
327 pub fn hit_rate(&self) -> f64 {
329 let hits = self.hits.load(std::sync::atomic::Ordering::Relaxed);
330 let misses = self.misses.load(std::sync::atomic::Ordering::Relaxed);
331 let total = hits + misses;
332 if total == 0 {
333 0.0
334 } else {
335 hits as f64 / total as f64
336 }
337 }
338}
339
340impl<P: EmbeddingProvider> CachedEmbeddingProvider<P> {
341 pub fn new(inner: P, cache_size: usize) -> Self {
343 Self {
344 inner,
345 cache: Cache::new(cache_size as u64),
346 stats: Arc::new(CacheStats::default()),
347 }
348 }
349
350 pub fn with_ttl(inner: P, cache_size: usize, ttl_secs: u64) -> Self {
352 let cache = Cache::builder()
353 .max_capacity(cache_size as u64)
354 .time_to_live(std::time::Duration::from_secs(ttl_secs))
355 .build();
356
357 Self {
358 inner,
359 cache,
360 stats: Arc::new(CacheStats::default()),
361 }
362 }
363
364 pub fn stats(&self) -> &Arc<CacheStats> {
366 &self.stats
367 }
368
369 fn text_hash(text: &str) -> u64 {
371 use std::hash::{Hash, Hasher};
372 use std::collections::hash_map::DefaultHasher;
373
374 let mut hasher = DefaultHasher::new();
375 text.hash(&mut hasher);
376 hasher.finish()
377 }
378}
379
380impl<P: EmbeddingProvider> EmbeddingProvider for CachedEmbeddingProvider<P> {
381 fn model_name(&self) -> &str {
382 self.inner.model_name()
383 }
384
385 fn dimension(&self) -> usize {
386 self.inner.dimension()
387 }
388
389 fn max_length(&self) -> usize {
390 self.inner.max_length()
391 }
392
393 fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
394 let hash = Self::text_hash(text);
395
396 if let Some(cached) = self.cache.get(&hash) {
398 self.stats.hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
399 return Ok(cached);
400 }
401
402 self.stats.misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
403
404 let embedding = self.inner.embed(text)?;
406
407 self.cache.insert(hash, embedding.clone());
409 self.stats.size.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
410
411 Ok(embedding)
412 }
413
414 fn embed_batch(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
415 let mut results = Vec::with_capacity(texts.len());
416 let mut uncached: Vec<(usize, &str)> = Vec::new();
417
418 for (i, text) in texts.iter().enumerate() {
420 let hash = Self::text_hash(text);
421 if let Some(cached) = self.cache.get(&hash) {
422 self.stats.hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
423 results.push((i, cached));
424 } else {
425 self.stats.misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
426 uncached.push((i, *text));
427 }
428 }
429
430 if !uncached.is_empty() {
432 let uncached_texts: Vec<&str> = uncached.iter().map(|(_, t)| *t).collect();
433 let embeddings = self.inner.embed_batch(&uncached_texts)?;
434
435 for ((i, text), embedding) in uncached.iter().zip(embeddings.into_iter()) {
436 let hash = Self::text_hash(text);
437 self.cache.insert(hash, embedding.clone());
438 self.stats.size.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
439 results.push((*i, embedding));
440 }
441 }
442
443 results.sort_by_key(|(i, _)| *i);
445 Ok(results.into_iter().map(|(_, e)| e).collect())
446 }
447}
448
449#[derive(Debug)]
460pub struct LocalOnnxProvider {
461 config: EmbeddingConfig,
462 #[allow(dead_code)]
464 model_loaded: bool,
465}
466
467impl LocalOnnxProvider {
468 pub fn new(config: EmbeddingConfig) -> EmbeddingResult<Self> {
470 Ok(Self {
472 config,
473 model_loaded: false,
474 })
475 }
476
477 pub fn load_pretrained(model_name: &str) -> EmbeddingResult<Self> {
479 let config = EmbeddingConfig::sentence_transformer(model_name);
480 Self::new(config)
481 }
482}
483
484impl EmbeddingProvider for LocalOnnxProvider {
485 fn model_name(&self) -> &str {
486 &self.config.model
487 }
488
489 fn dimension(&self) -> usize {
490 self.config.dimension
491 }
492
493 fn max_length(&self) -> usize {
494 self.config.max_length
495 }
496
497 fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
498 let mock = MockEmbeddingProvider::with_config(self.config.clone());
501 mock.embed(text)
502 }
503}
504
505pub struct EmbeddingVectorIndex<V, P>
511where
512 V: crate::context_query::VectorIndex,
513 P: EmbeddingProvider,
514{
515 index: Arc<V>,
517
518 provider: Arc<P>,
520}
521
522impl<V, P> EmbeddingVectorIndex<V, P>
523where
524 V: crate::context_query::VectorIndex,
525 P: EmbeddingProvider,
526{
527 pub fn new(index: Arc<V>, provider: Arc<P>) -> Self {
529 Self { index, provider }
530 }
531
532 pub fn search_text(
534 &self,
535 collection: &str,
536 text: &str,
537 k: usize,
538 min_score: Option<f32>,
539 ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
540 let embedding = self.provider.embed(text)
542 .map_err(|e| e.to_string())?;
543
544 self.index.search_by_embedding(collection, &embedding, k, min_score)
546 }
547
548 pub fn search_embedding(
550 &self,
551 collection: &str,
552 embedding: &[f32],
553 k: usize,
554 min_score: Option<f32>,
555 ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
556 if embedding.len() != self.provider.dimension() {
558 return Err(format!(
559 "Embedding dimension mismatch: expected {}, got {}",
560 self.provider.dimension(),
561 embedding.len()
562 ));
563 }
564
565 self.index.search_by_embedding(collection, embedding, k, min_score)
566 }
567
568 pub fn provider(&self) -> &Arc<P> {
570 &self.provider
571 }
572
573 pub fn index(&self) -> &Arc<V> {
575 &self.index
576 }
577}
578
579impl<V, P> crate::context_query::VectorIndex for EmbeddingVectorIndex<V, P>
580where
581 V: crate::context_query::VectorIndex,
582 P: EmbeddingProvider,
583{
584 fn search_by_embedding(
585 &self,
586 collection: &str,
587 embedding: &[f32],
588 k: usize,
589 min_score: Option<f32>,
590 ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
591 self.search_embedding(collection, embedding, k, min_score)
592 }
593
594 fn search_by_text(
595 &self,
596 collection: &str,
597 text: &str,
598 k: usize,
599 min_score: Option<f32>,
600 ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
601 self.search_text(collection, text, k, min_score)
602 }
603
604 fn stats(&self, collection: &str) -> Option<crate::context_query::VectorIndexStats> {
605 self.index.stats(collection)
606 }
607}
608
609pub fn create_mock_provider(dimension: usize, cache_size: usize) -> CachedEmbeddingProvider<MockEmbeddingProvider> {
615 let mock = MockEmbeddingProvider::new(dimension);
616 CachedEmbeddingProvider::new(mock, cache_size)
617}
618
619pub fn create_embedding_index<V: crate::context_query::VectorIndex>(
621 index: Arc<V>,
622 dimension: usize,
623) -> EmbeddingVectorIndex<V, CachedEmbeddingProvider<MockEmbeddingProvider>> {
624 let provider = Arc::new(create_mock_provider(dimension, 10_000));
625 EmbeddingVectorIndex::new(index, provider)
626}
627
628#[cfg(test)]
633mod tests {
634 use super::*;
635
636 #[test]
637 fn test_mock_embedding_deterministic() {
638 let provider = MockEmbeddingProvider::new(384);
639
640 let emb1 = provider.embed("hello world").unwrap();
641 let emb2 = provider.embed("hello world").unwrap();
642
643 assert_eq!(emb1, emb2);
644 assert_eq!(emb1.len(), 384);
645 }
646
647 #[test]
648 fn test_mock_embedding_different_texts() {
649 let provider = MockEmbeddingProvider::new(384);
650
651 let emb1 = provider.embed("hello").unwrap();
652 let emb2 = provider.embed("world").unwrap();
653
654 assert_ne!(emb1, emb2);
655 }
656
657 #[test]
658 fn test_cached_provider() {
659 let mock = MockEmbeddingProvider::new(128);
660 let cached = CachedEmbeddingProvider::new(mock, 100);
661
662 let _ = cached.embed("test text").unwrap();
664 assert_eq!(cached.stats().hits.load(std::sync::atomic::Ordering::Relaxed), 0);
665 assert_eq!(cached.stats().misses.load(std::sync::atomic::Ordering::Relaxed), 1);
666
667 let _ = cached.embed("test text").unwrap();
669 assert_eq!(cached.stats().hits.load(std::sync::atomic::Ordering::Relaxed), 1);
670 assert_eq!(cached.stats().misses.load(std::sync::atomic::Ordering::Relaxed), 1);
671
672 assert!(cached.stats().hit_rate() > 0.4);
673 }
674
675 #[test]
676 fn test_batch_embedding() {
677 let mock = MockEmbeddingProvider::new(128);
678 let cached = CachedEmbeddingProvider::new(mock, 100);
679
680 let texts = vec!["hello", "world", "test"];
681 let embeddings = cached.embed_batch(&texts).unwrap();
682
683 assert_eq!(embeddings.len(), 3);
684 for emb in &embeddings {
685 assert_eq!(emb.len(), 128);
686 }
687 }
688
689 #[test]
690 fn test_normalization() {
691 let provider = MockEmbeddingProvider::new(3);
692 let emb = provider.embed("test").unwrap();
693
694 let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
696 assert!((norm - 1.0).abs() < 1e-5);
697 }
698
699 #[test]
700 fn test_text_too_long() {
701 let config = EmbeddingConfig {
702 max_length: 10,
703 ..Default::default()
704 };
705 let provider = MockEmbeddingProvider::with_config(config);
706
707 let result = provider.embed("this is a very long text that exceeds the limit");
708 assert!(matches!(result, Err(EmbeddingError::TextTooLong { .. })));
709 }
710}