Skip to main content

ruvector_dag/attention/
cache.rs

1//! Attention Cache: LRU cache for computed attention scores
2//!
3//! Caches attention scores to avoid redundant computation for identical DAGs.
4//! Uses LRU eviction policy to manage memory usage.
5
6use super::trait_def::AttentionScores;
7use crate::dag::QueryDag;
8use std::collections::{hash_map::DefaultHasher, HashMap};
9use std::hash::{Hash, Hasher};
10use std::time::{Duration, Instant};
11
12#[derive(Debug, Clone)]
13pub struct CacheConfig {
14    /// Maximum number of entries
15    pub capacity: usize,
16    /// Time-to-live for entries
17    pub ttl: Option<Duration>,
18}
19
20impl Default for CacheConfig {
21    fn default() -> Self {
22        Self {
23            capacity: 1000,
24            ttl: Some(Duration::from_secs(300)), // 5 minutes
25        }
26    }
27}
28
29#[derive(Debug)]
30struct CacheEntry {
31    scores: AttentionScores,
32    timestamp: Instant,
33    access_count: usize,
34}
35
36pub struct AttentionCache {
37    config: CacheConfig,
38    cache: HashMap<u64, CacheEntry>,
39    access_order: Vec<u64>,
40    hits: usize,
41    misses: usize,
42}
43
44impl AttentionCache {
45    pub fn new(config: CacheConfig) -> Self {
46        Self {
47            cache: HashMap::with_capacity(config.capacity),
48            access_order: Vec::with_capacity(config.capacity),
49            config,
50            hits: 0,
51            misses: 0,
52        }
53    }
54
55    /// Hash a DAG for cache key
56    fn hash_dag(dag: &QueryDag, mechanism: &str) -> u64 {
57        let mut hasher = DefaultHasher::new();
58
59        // Hash mechanism name
60        mechanism.hash(&mut hasher);
61
62        // Hash number of nodes
63        dag.node_count().hash(&mut hasher);
64
65        // Hash edges structure
66        let mut edge_list: Vec<(usize, usize)> = Vec::new();
67        for node_id in dag.node_ids() {
68            for &child in dag.children(node_id) {
69                edge_list.push((node_id, child));
70            }
71        }
72        edge_list.sort_unstable();
73
74        for (from, to) in edge_list {
75            from.hash(&mut hasher);
76            to.hash(&mut hasher);
77        }
78
79        hasher.finish()
80    }
81
82    /// Check if entry is expired
83    fn is_expired(&self, entry: &CacheEntry) -> bool {
84        if let Some(ttl) = self.config.ttl {
85            entry.timestamp.elapsed() > ttl
86        } else {
87            false
88        }
89    }
90
91    /// Get cached scores for a DAG and mechanism
92    pub fn get(&mut self, dag: &QueryDag, mechanism: &str) -> Option<AttentionScores> {
93        let key = Self::hash_dag(dag, mechanism);
94
95        // Check if key exists and is not expired
96        let is_expired = self
97            .cache
98            .get(&key)
99            .map(|entry| self.is_expired(entry))
100            .unwrap_or(true);
101
102        if is_expired {
103            self.cache.remove(&key);
104            self.access_order.retain(|&k| k != key);
105            self.misses += 1;
106            return None;
107        }
108
109        // Update access and return clone
110        if let Some(entry) = self.cache.get_mut(&key) {
111            // Update access order (move to end = most recently used)
112            self.access_order.retain(|&k| k != key);
113            self.access_order.push(key);
114            entry.access_count += 1;
115            self.hits += 1;
116
117            Some(entry.scores.clone())
118        } else {
119            self.misses += 1;
120            None
121        }
122    }
123
124    /// Insert scores into cache
125    pub fn insert(&mut self, dag: &QueryDag, mechanism: &str, scores: AttentionScores) {
126        let key = Self::hash_dag(dag, mechanism);
127
128        // Evict if at capacity
129        while self.cache.len() >= self.config.capacity && !self.access_order.is_empty() {
130            if let Some(oldest) = self.access_order.first().copied() {
131                self.cache.remove(&oldest);
132                self.access_order.remove(0);
133            }
134        }
135
136        let entry = CacheEntry {
137            scores,
138            timestamp: Instant::now(),
139            access_count: 0,
140        };
141
142        self.cache.insert(key, entry);
143        self.access_order.push(key);
144    }
145
146    /// Clear all entries
147    pub fn clear(&mut self) {
148        self.cache.clear();
149        self.access_order.clear();
150        self.hits = 0;
151        self.misses = 0;
152    }
153
154    /// Remove expired entries
155    pub fn evict_expired(&mut self) {
156        let expired_keys: Vec<u64> = self
157            .cache
158            .iter()
159            .filter(|(_, entry)| self.is_expired(entry))
160            .map(|(k, _)| *k)
161            .collect();
162
163        for key in expired_keys {
164            self.cache.remove(&key);
165            self.access_order.retain(|&k| k != key);
166        }
167    }
168
169    /// Get cache statistics
170    pub fn stats(&self) -> CacheStats {
171        CacheStats {
172            size: self.cache.len(),
173            capacity: self.config.capacity,
174            hits: self.hits,
175            misses: self.misses,
176            hit_rate: if self.hits + self.misses > 0 {
177                self.hits as f64 / (self.hits + self.misses) as f64
178            } else {
179                0.0
180            },
181        }
182    }
183
184    /// Get entry with most accesses
185    pub fn most_accessed(&self) -> Option<(&u64, usize)> {
186        self.cache
187            .iter()
188            .max_by_key(|(_, entry)| entry.access_count)
189            .map(|(k, entry)| (k, entry.access_count))
190    }
191}
192
193#[derive(Debug, Clone)]
194pub struct CacheStats {
195    pub size: usize,
196    pub capacity: usize,
197    pub hits: usize,
198    pub misses: usize,
199    pub hit_rate: f64,
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use crate::dag::{OperatorNode, OperatorType};
206
207    fn create_test_dag(n: usize) -> QueryDag {
208        let mut dag = QueryDag::new();
209        for i in 0..n {
210            let mut node = OperatorNode::new(i, OperatorType::Scan);
211            node.estimated_cost = (i + 1) as f64;
212            dag.add_node(node);
213        }
214        if n > 1 {
215            let _ = dag.add_edge(0, 1);
216        }
217        dag
218    }
219
220    #[test]
221    fn test_cache_insert_and_get() {
222        let mut cache = AttentionCache::new(CacheConfig::default());
223        let dag = create_test_dag(3);
224
225        let scores = AttentionScores::new(vec![0.5, 0.3, 0.2]);
226        let expected_scores = scores.scores.clone();
227        cache.insert(&dag, "test_mechanism", scores);
228
229        let retrieved = cache.get(&dag, "test_mechanism").unwrap();
230        assert_eq!(retrieved.scores, expected_scores);
231    }
232
233    #[test]
234    fn test_cache_miss() {
235        let mut cache = AttentionCache::new(CacheConfig::default());
236        let dag = create_test_dag(3);
237
238        let result = cache.get(&dag, "nonexistent");
239        assert!(result.is_none());
240    }
241
242    #[test]
243    fn test_lru_eviction() {
244        let mut cache = AttentionCache::new(CacheConfig {
245            capacity: 2,
246            ttl: None,
247        });
248
249        let dag1 = create_test_dag(1);
250        let dag2 = create_test_dag(2);
251        let dag3 = create_test_dag(3);
252
253        cache.insert(&dag1, "mech", AttentionScores::new(vec![0.5]));
254        cache.insert(&dag2, "mech", AttentionScores::new(vec![0.3, 0.7]));
255        cache.insert(&dag3, "mech", AttentionScores::new(vec![0.2, 0.3, 0.5]));
256
257        // dag1 should be evicted (LRU), dag2 and dag3 should still be present
258        let result1 = cache.get(&dag1, "mech");
259        let result2 = cache.get(&dag2, "mech");
260        let result3 = cache.get(&dag3, "mech");
261
262        assert!(result1.is_none());
263        assert!(result2.is_some());
264        assert!(result3.is_some());
265    }
266
267    #[test]
268    fn test_cache_stats() {
269        let mut cache = AttentionCache::new(CacheConfig::default());
270        let dag = create_test_dag(2);
271
272        cache.insert(&dag, "mech", AttentionScores::new(vec![0.5, 0.5]));
273
274        cache.get(&dag, "mech"); // hit
275        cache.get(&dag, "nonexistent"); // miss
276
277        let stats = cache.stats();
278        assert_eq!(stats.hits, 1);
279        assert_eq!(stats.misses, 1);
280        assert!((stats.hit_rate - 0.5).abs() < 0.01);
281    }
282
283    #[test]
284    fn test_ttl_expiration() {
285        let mut cache = AttentionCache::new(CacheConfig {
286            capacity: 100,
287            ttl: Some(Duration::from_millis(50)),
288        });
289
290        let dag = create_test_dag(2);
291        cache.insert(&dag, "mech", AttentionScores::new(vec![0.5, 0.5]));
292
293        // Should be present immediately
294        assert!(cache.get(&dag, "mech").is_some());
295
296        // Wait for expiration
297        std::thread::sleep(Duration::from_millis(60));
298
299        // Should be expired
300        assert!(cache.get(&dag, "mech").is_none());
301    }
302
303    #[test]
304    fn test_hash_consistency() {
305        let dag = create_test_dag(3);
306
307        let hash1 = AttentionCache::hash_dag(&dag, "mechanism");
308        let hash2 = AttentionCache::hash_dag(&dag, "mechanism");
309
310        assert_eq!(hash1, hash2);
311    }
312
313    #[test]
314    fn test_hash_different_mechanisms() {
315        let dag = create_test_dag(3);
316
317        let hash1 = AttentionCache::hash_dag(&dag, "mechanism1");
318        let hash2 = AttentionCache::hash_dag(&dag, "mechanism2");
319
320        assert_ne!(hash1, hash2);
321    }
322}