Skip to main content

synaptic_embeddings/
cached.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use synaptic_core::SynapseError;
6use tokio::sync::RwLock;
7
8use crate::Embeddings;
9
10/// An embeddings wrapper that caches results in memory.
11///
12/// Previously computed embeddings are stored in an in-memory hash map keyed
13/// by the input text. On subsequent calls, cached embeddings are returned
14/// directly, and only uncached texts are sent to the inner embeddings provider.
15pub struct CacheBackedEmbeddings {
16    inner: Arc<dyn Embeddings>,
17    cache: Arc<RwLock<HashMap<String, Vec<f32>>>>,
18}
19
20impl CacheBackedEmbeddings {
21    /// Create a new cached embeddings wrapper around the given embeddings provider.
22    pub fn new(inner: Arc<dyn Embeddings>) -> Self {
23        Self {
24            inner,
25            cache: Arc::new(RwLock::new(HashMap::new())),
26        }
27    }
28}
29
30#[async_trait]
31impl Embeddings for CacheBackedEmbeddings {
32    async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapseError> {
33        // Determine which texts need embedding
34        let cache = self.cache.read().await;
35        let mut results: Vec<Option<Vec<f32>>> = Vec::with_capacity(texts.len());
36        let mut uncached_indices: Vec<usize> = Vec::new();
37        let mut uncached_texts: Vec<&str> = Vec::new();
38
39        for (i, text) in texts.iter().enumerate() {
40            if let Some(cached) = cache.get(*text) {
41                results.push(Some(cached.clone()));
42            } else {
43                results.push(None);
44                uncached_indices.push(i);
45                uncached_texts.push(text);
46            }
47        }
48        drop(cache);
49
50        // Embed uncached texts
51        if !uncached_texts.is_empty() {
52            let new_embeddings = self.inner.embed_documents(&uncached_texts).await?;
53
54            // Store new embeddings in cache
55            let mut cache = self.cache.write().await;
56            for (idx, embedding) in uncached_indices.iter().zip(new_embeddings.into_iter()) {
57                cache.insert(texts[*idx].to_string(), embedding.clone());
58                results[*idx] = Some(embedding);
59            }
60        }
61
62        // All results should now be Some
63        Ok(results.into_iter().map(|r| r.unwrap()).collect())
64    }
65
66    async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapseError> {
67        // Check cache
68        {
69            let cache = self.cache.read().await;
70            if let Some(cached) = cache.get(text) {
71                return Ok(cached.clone());
72            }
73        }
74
75        // Cache miss: compute embedding
76        let embedding = self.inner.embed_query(text).await?;
77
78        // Store in cache
79        {
80            let mut cache = self.cache.write().await;
81            cache.insert(text.to_string(), embedding.clone());
82        }
83
84        Ok(embedding)
85    }
86}