1use crate::graph::VertexId;
12use std::collections::{HashMap, VecDeque};
13use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
14use std::sync::RwLock;
15
16#[derive(Debug, Clone)]
18pub struct CacheConfig {
19 pub max_entries: usize,
21 pub enable_prefetch: bool,
23 pub prefetch_history_size: usize,
25 pub prefetch_lookahead: usize,
27}
28
29impl Default for CacheConfig {
30 fn default() -> Self {
31 Self {
32 max_entries: 10_000,
33 enable_prefetch: true,
34 prefetch_history_size: 100,
35 prefetch_lookahead: 4,
36 }
37 }
38}
39
40#[derive(Debug, Clone, Default)]
42pub struct CacheStats {
43 pub hits: u64,
45 pub misses: u64,
47 pub size: usize,
49 pub prefetch_hits: u64,
51 pub evictions: u64,
53}
54
55impl CacheStats {
56 pub fn hit_rate(&self) -> f64 {
58 let total = self.hits + self.misses;
59 if total > 0 {
60 self.hits as f64 / total as f64
61 } else {
62 0.0
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
69pub struct PrefetchHint {
70 pub source: VertexId,
72 pub targets: Vec<VertexId>,
74 pub confidence: f64,
76}
77
78#[derive(Debug, Clone)]
80struct CacheEntry {
81 source: VertexId,
83 target: VertexId,
85 distance: f64,
87 last_access: u64,
89 prefetched: bool,
91}
92
93#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
95struct CacheKey {
96 source: VertexId,
97 target: VertexId,
98}
99
100impl CacheKey {
101 fn new(source: VertexId, target: VertexId) -> Self {
102 if source <= target {
104 Self { source, target }
105 } else {
106 Self {
107 source: target,
108 target: source,
109 }
110 }
111 }
112}
113
114pub struct PathDistanceCache {
116 config: CacheConfig,
117 cache: RwLock<HashMap<CacheKey, CacheEntry>>,
119 lru_order: RwLock<VecDeque<CacheKey>>,
121 access_counter: AtomicU64,
123 hits: AtomicU64,
125 misses: AtomicU64,
126 prefetch_hits: AtomicU64,
127 evictions: AtomicU64,
128 query_history: RwLock<VecDeque<CacheKey>>,
130 predicted_queries: RwLock<Vec<CacheKey>>,
132}
133
134impl PathDistanceCache {
135 pub fn new() -> Self {
137 Self::with_config(CacheConfig::default())
138 }
139
140 pub fn with_config(config: CacheConfig) -> Self {
142 Self {
143 config,
144 cache: RwLock::new(HashMap::new()),
145 lru_order: RwLock::new(VecDeque::new()),
146 access_counter: AtomicU64::new(0),
147 hits: AtomicU64::new(0),
148 misses: AtomicU64::new(0),
149 prefetch_hits: AtomicU64::new(0),
150 evictions: AtomicU64::new(0),
151 query_history: RwLock::new(VecDeque::new()),
152 predicted_queries: RwLock::new(Vec::new()),
153 }
154 }
155
156 pub fn get(&self, source: VertexId, target: VertexId) -> Option<f64> {
158 let key = CacheKey::new(source, target);
159
160 let cache = self.cache.read().unwrap();
162 if let Some(entry) = cache.get(&key) {
163 self.hits.fetch_add(1, Ordering::Relaxed);
164 if entry.prefetched {
165 self.prefetch_hits.fetch_add(1, Ordering::Relaxed);
166 }
167
168 if self.config.enable_prefetch {
170 self.record_query(key);
171 }
172
173 return Some(entry.distance);
174 }
175 drop(cache);
176
177 self.misses.fetch_add(1, Ordering::Relaxed);
178
179 if self.config.enable_prefetch {
181 self.record_query(key);
182 }
183
184 None
185 }
186
187 pub fn insert(&self, source: VertexId, target: VertexId, distance: f64) {
189 let key = CacheKey::new(source, target);
190 let timestamp = self.access_counter.fetch_add(1, Ordering::Relaxed);
191
192 let entry = CacheEntry {
193 source,
194 target,
195 distance,
196 last_access: timestamp,
197 prefetched: false,
198 };
199
200 self.insert_entry(key, entry);
201 }
202
203 pub fn insert_prefetch(&self, source: VertexId, target: VertexId, distance: f64) {
205 let key = CacheKey::new(source, target);
206 let timestamp = self.access_counter.fetch_add(1, Ordering::Relaxed);
207
208 let entry = CacheEntry {
209 source,
210 target,
211 distance,
212 last_access: timestamp,
213 prefetched: true,
214 };
215
216 self.insert_entry(key, entry);
217 }
218
219 fn insert_entry(&self, key: CacheKey, entry: CacheEntry) {
221 let mut cache = self.cache.write().unwrap();
222 let mut lru = self.lru_order.write().unwrap();
223
224 while cache.len() >= self.config.max_entries {
226 if let Some(evict_key) = lru.pop_front() {
227 cache.remove(&evict_key);
228 self.evictions.fetch_add(1, Ordering::Relaxed);
229 } else {
230 break;
231 }
232 }
233
234 cache.insert(key, entry);
236 lru.push_back(key);
237 }
238
239 pub fn insert_batch(&self, entries: &[(VertexId, VertexId, f64)]) {
241 let mut cache = self.cache.write().unwrap();
242 let mut lru = self.lru_order.write().unwrap();
243
244 for &(source, target, distance) in entries {
245 let key = CacheKey::new(source, target);
246 let timestamp = self.access_counter.fetch_add(1, Ordering::Relaxed);
247
248 let entry = CacheEntry {
249 source,
250 target,
251 distance,
252 last_access: timestamp,
253 prefetched: false,
254 };
255
256 while cache.len() >= self.config.max_entries {
258 if let Some(evict_key) = lru.pop_front() {
259 cache.remove(&evict_key);
260 self.evictions.fetch_add(1, Ordering::Relaxed);
261 } else {
262 break;
263 }
264 }
265
266 cache.insert(key, entry);
267 lru.push_back(key);
268 }
269 }
270
271 pub fn invalidate_vertex(&self, vertex: VertexId) {
273 let mut cache = self.cache.write().unwrap();
274 let mut lru = self.lru_order.write().unwrap();
275
276 let keys_to_remove: Vec<CacheKey> = cache
277 .keys()
278 .filter(|k| k.source == vertex || k.target == vertex)
279 .copied()
280 .collect();
281
282 for key in keys_to_remove {
283 cache.remove(&key);
284 lru.retain(|k| *k != key);
285 }
286 }
287
288 pub fn clear(&self) {
290 let mut cache = self.cache.write().unwrap();
291 let mut lru = self.lru_order.write().unwrap();
292 cache.clear();
293 lru.clear();
294 }
295
296 fn record_query(&self, key: CacheKey) {
298 if let Ok(mut history) = self.query_history.try_write() {
299 history.push_back(key);
300 while history.len() > self.config.prefetch_history_size {
301 history.pop_front();
302 }
303
304 if history.len() % 10 == 0 {
306 self.update_predictions(&history);
307 }
308 }
309 }
310
311 fn update_predictions(&self, history: &VecDeque<CacheKey>) {
313 if history.len() < 10 {
314 return;
315 }
316
317 let mut vertex_frequency: HashMap<VertexId, usize> = HashMap::new();
319 for key in history.iter() {
320 *vertex_frequency.entry(key.source).or_insert(0) += 1;
321 *vertex_frequency.entry(key.target).or_insert(0) += 1;
322 }
323
324 let recent: Vec<_> = history.iter().rev().take(5).collect();
326 let mut predictions = Vec::new();
327
328 for key in recent {
329 for (vertex, &freq) in &vertex_frequency {
331 if freq > 2 && *vertex != key.source && *vertex != key.target {
332 predictions.push(CacheKey::new(key.source, *vertex));
333 if predictions.len() >= self.config.prefetch_lookahead {
334 break;
335 }
336 }
337 }
338 if predictions.len() >= self.config.prefetch_lookahead {
339 break;
340 }
341 }
342
343 if let Ok(mut pred) = self.predicted_queries.try_write() {
344 *pred = predictions;
345 }
346 }
347
348 pub fn get_prefetch_hints(&self) -> Vec<PrefetchHint> {
350 let history = self.query_history.read().unwrap();
351 if history.is_empty() {
352 return Vec::new();
353 }
354
355 let mut source_freq: HashMap<VertexId, Vec<VertexId>> = HashMap::new();
357 for key in history.iter() {
358 source_freq.entry(key.source).or_default().push(key.target);
359 source_freq.entry(key.target).or_default().push(key.source);
360 }
361
362 source_freq
364 .into_iter()
365 .filter(|(_, targets)| targets.len() > 2)
366 .map(|(source, targets)| {
367 let confidence = (targets.len() as f64 / history.len() as f64).min(1.0);
368 PrefetchHint {
369 source,
370 targets,
371 confidence,
372 }
373 })
374 .collect()
375 }
376
377 pub fn get_predicted_queries(&self) -> Vec<(VertexId, VertexId)> {
379 let pred = self.predicted_queries.read().unwrap();
380 pred.iter().map(|key| (key.source, key.target)).collect()
381 }
382
383 pub fn stats(&self) -> CacheStats {
385 let cache = self.cache.read().unwrap();
386 CacheStats {
387 hits: self.hits.load(Ordering::Relaxed),
388 misses: self.misses.load(Ordering::Relaxed),
389 size: cache.len(),
390 prefetch_hits: self.prefetch_hits.load(Ordering::Relaxed),
391 evictions: self.evictions.load(Ordering::Relaxed),
392 }
393 }
394
395 pub fn len(&self) -> usize {
397 self.cache.read().unwrap().len()
398 }
399
400 pub fn is_empty(&self) -> bool {
402 self.cache.read().unwrap().is_empty()
403 }
404}
405
406impl Default for PathDistanceCache {
407 fn default() -> Self {
408 Self::new()
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_basic_cache_operations() {
418 let cache = PathDistanceCache::new();
419
420 cache.insert(1, 2, 10.0);
422 assert_eq!(cache.get(1, 2), Some(10.0));
423
424 assert_eq!(cache.get(2, 1), Some(10.0));
426
427 assert_eq!(cache.get(1, 3), None);
429 }
430
431 #[test]
432 fn test_lru_eviction() {
433 let cache = PathDistanceCache::with_config(CacheConfig {
434 max_entries: 3,
435 ..Default::default()
436 });
437
438 cache.insert(1, 2, 1.0);
439 cache.insert(2, 3, 2.0);
440 cache.insert(3, 4, 3.0);
441
442 assert_eq!(cache.len(), 3);
444
445 cache.insert(4, 5, 4.0);
447
448 assert_eq!(cache.len(), 3);
449 assert_eq!(cache.get(1, 2), None); assert_eq!(cache.get(4, 5), Some(4.0)); }
452
453 #[test]
454 fn test_batch_insert() {
455 let cache = PathDistanceCache::new();
456
457 let entries = vec![(1, 2, 1.0), (2, 3, 2.0), (3, 4, 3.0)];
458
459 cache.insert_batch(&entries);
460
461 assert_eq!(cache.len(), 3);
462 assert_eq!(cache.get(1, 2), Some(1.0));
463 assert_eq!(cache.get(2, 3), Some(2.0));
464 assert_eq!(cache.get(3, 4), Some(3.0));
465 }
466
467 #[test]
468 fn test_invalidate_vertex() {
469 let cache = PathDistanceCache::new();
470
471 cache.insert(1, 2, 1.0);
472 cache.insert(1, 3, 2.0);
473 cache.insert(2, 3, 3.0);
474
475 cache.invalidate_vertex(1);
476
477 assert_eq!(cache.get(1, 2), None);
478 assert_eq!(cache.get(1, 3), None);
479 assert_eq!(cache.get(2, 3), Some(3.0));
480 }
481
482 #[test]
483 fn test_statistics() {
484 let cache = PathDistanceCache::new();
485
486 cache.insert(1, 2, 1.0);
487
488 cache.get(1, 2);
490 cache.get(1, 2);
491
492 cache.get(3, 4);
494
495 let stats = cache.stats();
496 assert_eq!(stats.hits, 2);
497 assert_eq!(stats.misses, 1);
498 assert_eq!(stats.size, 1);
499 assert!(stats.hit_rate() > 0.5);
500 }
501
502 #[test]
503 fn test_prefetch_hints() {
504 let cache = PathDistanceCache::with_config(CacheConfig {
505 enable_prefetch: true,
506 prefetch_history_size: 50,
507 ..Default::default()
508 });
509
510 for i in 0..20 {
512 cache.insert(1, i as u64, i as f64);
513 let _ = cache.get(1, i as u64);
514 }
515
516 let hints = cache.get_prefetch_hints();
517 assert!(!hints.is_empty() || cache.stats().hits > 0);
519 }
520
521 #[test]
522 fn test_clear() {
523 let cache = PathDistanceCache::new();
524
525 cache.insert(1, 2, 1.0);
526 cache.insert(2, 3, 2.0);
527
528 assert_eq!(cache.len(), 2);
529
530 cache.clear();
531
532 assert_eq!(cache.len(), 0);
533 assert!(cache.is_empty());
534 }
535}