1use super::trait_def::AttentionScores;
7use crate::dag::QueryDag;
8use std::collections::{hash_map::DefaultHasher, HashMap};
9use std::hash::{Hash, Hasher};
10use std::time::{Duration, Instant};
11
12#[derive(Debug, Clone)]
13pub struct CacheConfig {
14 pub capacity: usize,
16 pub ttl: Option<Duration>,
18}
19
20impl Default for CacheConfig {
21 fn default() -> Self {
22 Self {
23 capacity: 1000,
24 ttl: Some(Duration::from_secs(300)), }
26 }
27}
28
29#[derive(Debug)]
30struct CacheEntry {
31 scores: AttentionScores,
32 timestamp: Instant,
33 access_count: usize,
34}
35
36pub struct AttentionCache {
37 config: CacheConfig,
38 cache: HashMap<u64, CacheEntry>,
39 access_order: Vec<u64>,
40 hits: usize,
41 misses: usize,
42}
43
44impl AttentionCache {
45 pub fn new(config: CacheConfig) -> Self {
46 Self {
47 cache: HashMap::with_capacity(config.capacity),
48 access_order: Vec::with_capacity(config.capacity),
49 config,
50 hits: 0,
51 misses: 0,
52 }
53 }
54
55 fn hash_dag(dag: &QueryDag, mechanism: &str) -> u64 {
57 let mut hasher = DefaultHasher::new();
58
59 mechanism.hash(&mut hasher);
61
62 dag.node_count().hash(&mut hasher);
64
65 let mut edge_list: Vec<(usize, usize)> = Vec::new();
67 for node_id in dag.node_ids() {
68 for &child in dag.children(node_id) {
69 edge_list.push((node_id, child));
70 }
71 }
72 edge_list.sort_unstable();
73
74 for (from, to) in edge_list {
75 from.hash(&mut hasher);
76 to.hash(&mut hasher);
77 }
78
79 hasher.finish()
80 }
81
82 fn is_expired(&self, entry: &CacheEntry) -> bool {
84 if let Some(ttl) = self.config.ttl {
85 entry.timestamp.elapsed() > ttl
86 } else {
87 false
88 }
89 }
90
91 pub fn get(&mut self, dag: &QueryDag, mechanism: &str) -> Option<AttentionScores> {
93 let key = Self::hash_dag(dag, mechanism);
94
95 let is_expired = self
97 .cache
98 .get(&key)
99 .map(|entry| self.is_expired(entry))
100 .unwrap_or(true);
101
102 if is_expired {
103 self.cache.remove(&key);
104 self.access_order.retain(|&k| k != key);
105 self.misses += 1;
106 return None;
107 }
108
109 if let Some(entry) = self.cache.get_mut(&key) {
111 self.access_order.retain(|&k| k != key);
113 self.access_order.push(key);
114 entry.access_count += 1;
115 self.hits += 1;
116
117 Some(entry.scores.clone())
118 } else {
119 self.misses += 1;
120 None
121 }
122 }
123
124 pub fn insert(&mut self, dag: &QueryDag, mechanism: &str, scores: AttentionScores) {
126 let key = Self::hash_dag(dag, mechanism);
127
128 while self.cache.len() >= self.config.capacity && !self.access_order.is_empty() {
130 if let Some(oldest) = self.access_order.first().copied() {
131 self.cache.remove(&oldest);
132 self.access_order.remove(0);
133 }
134 }
135
136 let entry = CacheEntry {
137 scores,
138 timestamp: Instant::now(),
139 access_count: 0,
140 };
141
142 self.cache.insert(key, entry);
143 self.access_order.push(key);
144 }
145
146 pub fn clear(&mut self) {
148 self.cache.clear();
149 self.access_order.clear();
150 self.hits = 0;
151 self.misses = 0;
152 }
153
154 pub fn evict_expired(&mut self) {
156 let expired_keys: Vec<u64> = self
157 .cache
158 .iter()
159 .filter(|(_, entry)| self.is_expired(entry))
160 .map(|(k, _)| *k)
161 .collect();
162
163 for key in expired_keys {
164 self.cache.remove(&key);
165 self.access_order.retain(|&k| k != key);
166 }
167 }
168
169 pub fn stats(&self) -> CacheStats {
171 CacheStats {
172 size: self.cache.len(),
173 capacity: self.config.capacity,
174 hits: self.hits,
175 misses: self.misses,
176 hit_rate: if self.hits + self.misses > 0 {
177 self.hits as f64 / (self.hits + self.misses) as f64
178 } else {
179 0.0
180 },
181 }
182 }
183
184 pub fn most_accessed(&self) -> Option<(&u64, usize)> {
186 self.cache
187 .iter()
188 .max_by_key(|(_, entry)| entry.access_count)
189 .map(|(k, entry)| (k, entry.access_count))
190 }
191}
192
193#[derive(Debug, Clone)]
194pub struct CacheStats {
195 pub size: usize,
196 pub capacity: usize,
197 pub hits: usize,
198 pub misses: usize,
199 pub hit_rate: f64,
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::dag::{OperatorNode, OperatorType};
206
207 fn create_test_dag(n: usize) -> QueryDag {
208 let mut dag = QueryDag::new();
209 for i in 0..n {
210 let mut node = OperatorNode::new(i, OperatorType::Scan);
211 node.estimated_cost = (i + 1) as f64;
212 dag.add_node(node);
213 }
214 if n > 1 {
215 let _ = dag.add_edge(0, 1);
216 }
217 dag
218 }
219
220 #[test]
221 fn test_cache_insert_and_get() {
222 let mut cache = AttentionCache::new(CacheConfig::default());
223 let dag = create_test_dag(3);
224
225 let scores = AttentionScores::new(vec![0.5, 0.3, 0.2]);
226 let expected_scores = scores.scores.clone();
227 cache.insert(&dag, "test_mechanism", scores);
228
229 let retrieved = cache.get(&dag, "test_mechanism").unwrap();
230 assert_eq!(retrieved.scores, expected_scores);
231 }
232
233 #[test]
234 fn test_cache_miss() {
235 let mut cache = AttentionCache::new(CacheConfig::default());
236 let dag = create_test_dag(3);
237
238 let result = cache.get(&dag, "nonexistent");
239 assert!(result.is_none());
240 }
241
242 #[test]
243 fn test_lru_eviction() {
244 let mut cache = AttentionCache::new(CacheConfig {
245 capacity: 2,
246 ttl: None,
247 });
248
249 let dag1 = create_test_dag(1);
250 let dag2 = create_test_dag(2);
251 let dag3 = create_test_dag(3);
252
253 cache.insert(&dag1, "mech", AttentionScores::new(vec![0.5]));
254 cache.insert(&dag2, "mech", AttentionScores::new(vec![0.3, 0.7]));
255 cache.insert(&dag3, "mech", AttentionScores::new(vec![0.2, 0.3, 0.5]));
256
257 let result1 = cache.get(&dag1, "mech");
259 let result2 = cache.get(&dag2, "mech");
260 let result3 = cache.get(&dag3, "mech");
261
262 assert!(result1.is_none());
263 assert!(result2.is_some());
264 assert!(result3.is_some());
265 }
266
267 #[test]
268 fn test_cache_stats() {
269 let mut cache = AttentionCache::new(CacheConfig::default());
270 let dag = create_test_dag(2);
271
272 cache.insert(&dag, "mech", AttentionScores::new(vec![0.5, 0.5]));
273
274 cache.get(&dag, "mech"); cache.get(&dag, "nonexistent"); let stats = cache.stats();
278 assert_eq!(stats.hits, 1);
279 assert_eq!(stats.misses, 1);
280 assert!((stats.hit_rate - 0.5).abs() < 0.01);
281 }
282
283 #[test]
284 fn test_ttl_expiration() {
285 let mut cache = AttentionCache::new(CacheConfig {
286 capacity: 100,
287 ttl: Some(Duration::from_millis(50)),
288 });
289
290 let dag = create_test_dag(2);
291 cache.insert(&dag, "mech", AttentionScores::new(vec![0.5, 0.5]));
292
293 assert!(cache.get(&dag, "mech").is_some());
295
296 std::thread::sleep(Duration::from_millis(60));
298
299 assert!(cache.get(&dag, "mech").is_none());
301 }
302
303 #[test]
304 fn test_hash_consistency() {
305 let dag = create_test_dag(3);
306
307 let hash1 = AttentionCache::hash_dag(&dag, "mechanism");
308 let hash2 = AttentionCache::hash_dag(&dag, "mechanism");
309
310 assert_eq!(hash1, hash2);
311 }
312
313 #[test]
314 fn test_hash_different_mechanisms() {
315 let dag = create_test_dag(3);
316
317 let hash1 = AttentionCache::hash_dag(&dag, "mechanism1");
318 let hash2 = AttentionCache::hash_dag(&dag, "mechanism2");
319
320 assert_ne!(hash1, hash2);
321 }
322}