1use std::sync::Arc;
49use moka::sync::Cache;
50
51#[derive(Debug, Clone)]
57pub enum EmbeddingError {
58 ModelNotAvailable(String),
60 TextTooLong { max_length: usize, actual: usize },
62 DimensionMismatch { expected: usize, actual: usize },
64 ProviderError(String),
66 CacheError(String),
68}
69
70impl std::fmt::Display for EmbeddingError {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 match self {
73 Self::ModelNotAvailable(model) => write!(f, "Embedding model not available: {}", model),
74 Self::TextTooLong { max_length, actual } => {
75 write!(f, "Text too long: {} > {} max", actual, max_length)
76 }
77 Self::DimensionMismatch { expected, actual } => {
78 write!(f, "Dimension mismatch: expected {}, got {}", expected, actual)
79 }
80 Self::ProviderError(msg) => write!(f, "Provider error: {}", msg),
81 Self::CacheError(msg) => write!(f, "Cache error: {}", msg),
82 }
83 }
84}
85
86impl std::error::Error for EmbeddingError {}
87
88pub type EmbeddingResult<T> = Result<T, EmbeddingError>;
90
91pub trait EmbeddingProvider: Send + Sync {
93 fn model_name(&self) -> &str;
95
96 fn dimension(&self) -> usize;
98
99 fn max_length(&self) -> usize;
101
102 fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>>;
104
105 fn embed_batch(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
107 texts.iter().map(|t| self.embed(t)).collect()
109 }
110
111 fn normalize(&self, embedding: &mut [f32]) {
113 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
114 if norm > 1e-10 {
115 for x in embedding.iter_mut() {
116 *x /= norm;
117 }
118 }
119 }
120}
121
122#[derive(Debug, Clone)]
128pub struct EmbeddingConfig {
129 pub model: String,
131
132 pub model_path: Option<String>,
134
135 pub dimension: usize,
137
138 pub max_length: usize,
140
141 pub normalize: bool,
143
144 pub batch_size: usize,
146
147 pub cache_size: usize,
149
150 pub cache_ttl_secs: u64,
152}
153
154impl Default for EmbeddingConfig {
155 fn default() -> Self {
156 Self {
157 model: "all-MiniLM-L6-v2".to_string(),
158 model_path: None,
159 dimension: 384, max_length: 512,
161 normalize: true,
162 batch_size: 32,
163 cache_size: 10_000,
164 cache_ttl_secs: 3600, }
166 }
167}
168
169impl EmbeddingConfig {
170 pub fn sentence_transformer(model: &str) -> Self {
172 let dimension = match model {
173 "all-MiniLM-L6-v2" => 384,
174 "all-MiniLM-L12-v2" => 384,
175 "all-mpnet-base-v2" => 768,
176 "paraphrase-MiniLM-L6-v2" => 384,
177 "multi-qa-MiniLM-L6-cos-v1" => 384,
178 _ => 384, };
180
181 Self {
182 model: model.to_string(),
183 dimension,
184 ..Default::default()
185 }
186 }
187
188 pub fn openai(model: &str) -> Self {
190 let dimension = match model {
191 "text-embedding-ada-002" => 1536,
192 "text-embedding-3-small" => 1536,
193 "text-embedding-3-large" => 3072,
194 _ => 1536,
195 };
196
197 Self {
198 model: model.to_string(),
199 dimension,
200 max_length: 8192,
201 ..Default::default()
202 }
203 }
204}
205
206pub struct MockEmbeddingProvider {
212 config: EmbeddingConfig,
213 use_hash: bool,
215}
216
217impl MockEmbeddingProvider {
218 pub fn new(dimension: usize) -> Self {
220 Self {
221 config: EmbeddingConfig {
222 model: "mock".to_string(),
223 dimension,
224 ..Default::default()
225 },
226 use_hash: true,
227 }
228 }
229
230 pub fn with_config(config: EmbeddingConfig) -> Self {
232 Self {
233 config,
234 use_hash: true,
235 }
236 }
237
238 fn hash_embed(&self, text: &str) -> Vec<f32> {
240 use std::hash::{Hash, Hasher};
241 use std::collections::hash_map::DefaultHasher;
242
243 let mut embedding = Vec::with_capacity(self.config.dimension);
244
245 for i in 0..self.config.dimension {
247 let mut hasher = DefaultHasher::new();
248 text.hash(&mut hasher);
249 i.hash(&mut hasher);
250 let hash = hasher.finish();
251
252 let value = ((hash as f64) / (u64::MAX as f64) * 2.0 - 1.0) as f32;
254 embedding.push(value);
255 }
256
257 embedding
258 }
259}
260
261impl EmbeddingProvider for MockEmbeddingProvider {
262 fn model_name(&self) -> &str {
263 &self.config.model
264 }
265
266 fn dimension(&self) -> usize {
267 self.config.dimension
268 }
269
270 fn max_length(&self) -> usize {
271 self.config.max_length
272 }
273
274 fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
275 if text.len() > self.config.max_length {
276 return Err(EmbeddingError::TextTooLong {
277 max_length: self.config.max_length,
278 actual: text.len(),
279 });
280 }
281
282 let mut embedding = if self.use_hash {
283 self.hash_embed(text)
284 } else {
285 vec![0.0; self.config.dimension]
286 };
287
288 if self.config.normalize {
289 self.normalize(&mut embedding);
290 }
291
292 Ok(embedding)
293 }
294}
295
296pub struct CachedEmbeddingProvider<P: EmbeddingProvider> {
302 inner: P,
304
305 cache: Cache<u64, Vec<f32>>,
307
308 stats: Arc<CacheStats>,
310}
311
312#[derive(Debug, Default)]
314pub struct CacheStats {
315 pub hits: std::sync::atomic::AtomicUsize,
317 pub misses: std::sync::atomic::AtomicUsize,
319 pub size: std::sync::atomic::AtomicUsize,
321}
322
323impl CacheStats {
324 pub fn hit_rate(&self) -> f64 {
326 let hits = self.hits.load(std::sync::atomic::Ordering::Relaxed);
327 let misses = self.misses.load(std::sync::atomic::Ordering::Relaxed);
328 let total = hits + misses;
329 if total == 0 {
330 0.0
331 } else {
332 hits as f64 / total as f64
333 }
334 }
335}
336
337impl<P: EmbeddingProvider> CachedEmbeddingProvider<P> {
338 pub fn new(inner: P, cache_size: usize) -> Self {
340 Self {
341 inner,
342 cache: Cache::new(cache_size as u64),
343 stats: Arc::new(CacheStats::default()),
344 }
345 }
346
347 pub fn with_ttl(inner: P, cache_size: usize, ttl_secs: u64) -> Self {
349 let cache = Cache::builder()
350 .max_capacity(cache_size as u64)
351 .time_to_live(std::time::Duration::from_secs(ttl_secs))
352 .build();
353
354 Self {
355 inner,
356 cache,
357 stats: Arc::new(CacheStats::default()),
358 }
359 }
360
361 pub fn stats(&self) -> &Arc<CacheStats> {
363 &self.stats
364 }
365
366 fn text_hash(text: &str) -> u64 {
368 use std::hash::{Hash, Hasher};
369 use std::collections::hash_map::DefaultHasher;
370
371 let mut hasher = DefaultHasher::new();
372 text.hash(&mut hasher);
373 hasher.finish()
374 }
375}
376
377impl<P: EmbeddingProvider> EmbeddingProvider for CachedEmbeddingProvider<P> {
378 fn model_name(&self) -> &str {
379 self.inner.model_name()
380 }
381
382 fn dimension(&self) -> usize {
383 self.inner.dimension()
384 }
385
386 fn max_length(&self) -> usize {
387 self.inner.max_length()
388 }
389
390 fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
391 let hash = Self::text_hash(text);
392
393 if let Some(cached) = self.cache.get(&hash) {
395 self.stats.hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
396 return Ok(cached);
397 }
398
399 self.stats.misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
400
401 let embedding = self.inner.embed(text)?;
403
404 self.cache.insert(hash, embedding.clone());
406 self.stats.size.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
407
408 Ok(embedding)
409 }
410
411 fn embed_batch(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
412 let mut results = Vec::with_capacity(texts.len());
413 let mut uncached: Vec<(usize, &str)> = Vec::new();
414
415 for (i, text) in texts.iter().enumerate() {
417 let hash = Self::text_hash(text);
418 if let Some(cached) = self.cache.get(&hash) {
419 self.stats.hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
420 results.push((i, cached));
421 } else {
422 self.stats.misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
423 uncached.push((i, *text));
424 }
425 }
426
427 if !uncached.is_empty() {
429 let uncached_texts: Vec<&str> = uncached.iter().map(|(_, t)| *t).collect();
430 let embeddings = self.inner.embed_batch(&uncached_texts)?;
431
432 for ((i, text), embedding) in uncached.iter().zip(embeddings.into_iter()) {
433 let hash = Self::text_hash(text);
434 self.cache.insert(hash, embedding.clone());
435 self.stats.size.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
436 results.push((*i, embedding));
437 }
438 }
439
440 results.sort_by_key(|(i, _)| *i);
442 Ok(results.into_iter().map(|(_, e)| e).collect())
443 }
444}
445
446#[derive(Debug)]
457pub struct LocalOnnxProvider {
458 config: EmbeddingConfig,
459 #[allow(dead_code)]
461 model_loaded: bool,
462}
463
464impl LocalOnnxProvider {
465 pub fn new(config: EmbeddingConfig) -> EmbeddingResult<Self> {
467 Ok(Self {
469 config,
470 model_loaded: false,
471 })
472 }
473
474 pub fn load_pretrained(model_name: &str) -> EmbeddingResult<Self> {
476 let config = EmbeddingConfig::sentence_transformer(model_name);
477 Self::new(config)
478 }
479}
480
481impl EmbeddingProvider for LocalOnnxProvider {
482 fn model_name(&self) -> &str {
483 &self.config.model
484 }
485
486 fn dimension(&self) -> usize {
487 self.config.dimension
488 }
489
490 fn max_length(&self) -> usize {
491 self.config.max_length
492 }
493
494 fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
495 let mock = MockEmbeddingProvider::with_config(self.config.clone());
498 mock.embed(text)
499 }
500}
501
502pub struct EmbeddingVectorIndex<V, P>
508where
509 V: crate::context_query::VectorIndex,
510 P: EmbeddingProvider,
511{
512 index: Arc<V>,
514
515 provider: Arc<P>,
517}
518
519impl<V, P> EmbeddingVectorIndex<V, P>
520where
521 V: crate::context_query::VectorIndex,
522 P: EmbeddingProvider,
523{
524 pub fn new(index: Arc<V>, provider: Arc<P>) -> Self {
526 Self { index, provider }
527 }
528
529 pub fn search_text(
531 &self,
532 collection: &str,
533 text: &str,
534 k: usize,
535 min_score: Option<f32>,
536 ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
537 let embedding = self.provider.embed(text)
539 .map_err(|e| e.to_string())?;
540
541 self.index.search_by_embedding(collection, &embedding, k, min_score)
543 }
544
545 pub fn search_embedding(
547 &self,
548 collection: &str,
549 embedding: &[f32],
550 k: usize,
551 min_score: Option<f32>,
552 ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
553 if embedding.len() != self.provider.dimension() {
555 return Err(format!(
556 "Embedding dimension mismatch: expected {}, got {}",
557 self.provider.dimension(),
558 embedding.len()
559 ));
560 }
561
562 self.index.search_by_embedding(collection, embedding, k, min_score)
563 }
564
565 pub fn provider(&self) -> &Arc<P> {
567 &self.provider
568 }
569
570 pub fn index(&self) -> &Arc<V> {
572 &self.index
573 }
574}
575
576impl<V, P> crate::context_query::VectorIndex for EmbeddingVectorIndex<V, P>
577where
578 V: crate::context_query::VectorIndex,
579 P: EmbeddingProvider,
580{
581 fn search_by_embedding(
582 &self,
583 collection: &str,
584 embedding: &[f32],
585 k: usize,
586 min_score: Option<f32>,
587 ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
588 self.search_embedding(collection, embedding, k, min_score)
589 }
590
591 fn search_by_text(
592 &self,
593 collection: &str,
594 text: &str,
595 k: usize,
596 min_score: Option<f32>,
597 ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
598 self.search_text(collection, text, k, min_score)
599 }
600
601 fn stats(&self, collection: &str) -> Option<crate::context_query::VectorIndexStats> {
602 self.index.stats(collection)
603 }
604}
605
606pub fn create_mock_provider(dimension: usize, cache_size: usize) -> CachedEmbeddingProvider<MockEmbeddingProvider> {
612 let mock = MockEmbeddingProvider::new(dimension);
613 CachedEmbeddingProvider::new(mock, cache_size)
614}
615
616pub fn create_embedding_index<V: crate::context_query::VectorIndex>(
618 index: Arc<V>,
619 dimension: usize,
620) -> EmbeddingVectorIndex<V, CachedEmbeddingProvider<MockEmbeddingProvider>> {
621 let provider = Arc::new(create_mock_provider(dimension, 10_000));
622 EmbeddingVectorIndex::new(index, provider)
623}
624
625#[cfg(test)]
630mod tests {
631 use super::*;
632
633 #[test]
634 fn test_mock_embedding_deterministic() {
635 let provider = MockEmbeddingProvider::new(384);
636
637 let emb1 = provider.embed("hello world").unwrap();
638 let emb2 = provider.embed("hello world").unwrap();
639
640 assert_eq!(emb1, emb2);
641 assert_eq!(emb1.len(), 384);
642 }
643
644 #[test]
645 fn test_mock_embedding_different_texts() {
646 let provider = MockEmbeddingProvider::new(384);
647
648 let emb1 = provider.embed("hello").unwrap();
649 let emb2 = provider.embed("world").unwrap();
650
651 assert_ne!(emb1, emb2);
652 }
653
654 #[test]
655 fn test_cached_provider() {
656 let mock = MockEmbeddingProvider::new(128);
657 let cached = CachedEmbeddingProvider::new(mock, 100);
658
659 let _ = cached.embed("test text").unwrap();
661 assert_eq!(cached.stats().hits.load(std::sync::atomic::Ordering::Relaxed), 0);
662 assert_eq!(cached.stats().misses.load(std::sync::atomic::Ordering::Relaxed), 1);
663
664 let _ = cached.embed("test text").unwrap();
666 assert_eq!(cached.stats().hits.load(std::sync::atomic::Ordering::Relaxed), 1);
667 assert_eq!(cached.stats().misses.load(std::sync::atomic::Ordering::Relaxed), 1);
668
669 assert!(cached.stats().hit_rate() > 0.4);
670 }
671
672 #[test]
673 fn test_batch_embedding() {
674 let mock = MockEmbeddingProvider::new(128);
675 let cached = CachedEmbeddingProvider::new(mock, 100);
676
677 let texts = vec!["hello", "world", "test"];
678 let embeddings = cached.embed_batch(&texts).unwrap();
679
680 assert_eq!(embeddings.len(), 3);
681 for emb in &embeddings {
682 assert_eq!(emb.len(), 128);
683 }
684 }
685
686 #[test]
687 fn test_normalization() {
688 let provider = MockEmbeddingProvider::new(3);
689 let emb = provider.embed("test").unwrap();
690
691 let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
693 assert!((norm - 1.0).abs() < 1e-5);
694 }
695
696 #[test]
697 fn test_text_too_long() {
698 let config = EmbeddingConfig {
699 max_length: 10,
700 ..Default::default()
701 };
702 let provider = MockEmbeddingProvider::with_config(config);
703
704 let result = provider.embed("this is a very long text that exceeds the limit");
705 assert!(matches!(result, Err(EmbeddingError::TextTooLong { .. })));
706 }
707}