1use crate::{PredicateInfo, SymbolTable};
7use std::collections::{HashMap, VecDeque};
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12pub enum CacheKey {
13 PredicateByName(String),
15 PredicatesByArity(usize),
17 PredicatesByDomain(String),
19 PredicatesBySignature(Vec<String>),
21 PredicatesByPattern(String),
23 DomainUsageCount(String),
25 AllDomainNames,
27 AllPredicateNames,
29 Custom(String),
31}
32
33#[derive(Debug, Clone)]
35pub struct CachedResult<T> {
36 pub value: T,
38 pub created_at: Instant,
40 pub last_accessed: Instant,
42 pub access_count: u64,
44 pub ttl: Option<Duration>,
46}
47
48impl<T> CachedResult<T> {
49 pub fn new(value: T, ttl: Option<Duration>) -> Self {
51 let now = Instant::now();
52 Self {
53 value,
54 created_at: now,
55 last_accessed: now,
56 access_count: 1,
57 ttl,
58 }
59 }
60
61 pub fn is_expired(&self) -> bool {
63 if let Some(ttl) = self.ttl {
64 self.created_at.elapsed() > ttl
65 } else {
66 false
67 }
68 }
69
70 pub fn update_access(&mut self) {
72 self.last_accessed = Instant::now();
73 self.access_count += 1;
74 }
75
76 pub fn age(&self) -> Duration {
78 self.created_at.elapsed()
79 }
80}
81
82#[derive(Debug, Clone)]
84pub struct CacheConfig {
85 pub max_entries: usize,
87 pub default_ttl: Option<Duration>,
89 pub enable_lru: bool,
91 pub enable_stats: bool,
93}
94
95impl Default for CacheConfig {
96 fn default() -> Self {
97 Self {
98 max_entries: 1000,
99 default_ttl: Some(Duration::from_secs(300)), enable_lru: true,
101 enable_stats: true,
102 }
103 }
104}
105
106impl CacheConfig {
107 pub fn small() -> Self {
109 Self {
110 max_entries: 100,
111 default_ttl: Some(Duration::from_secs(60)),
112 enable_lru: true,
113 enable_stats: true,
114 }
115 }
116
117 pub fn large() -> Self {
119 Self {
120 max_entries: 10000,
121 default_ttl: Some(Duration::from_secs(600)),
122 enable_lru: true,
123 enable_stats: true,
124 }
125 }
126
127 pub fn no_ttl() -> Self {
129 Self {
130 max_entries: 1000,
131 default_ttl: None,
132 enable_lru: true,
133 enable_stats: true,
134 }
135 }
136}
137
138#[derive(Debug, Clone, Default)]
140pub struct QueryCacheStats {
141 pub hits: u64,
143 pub misses: u64,
145 pub evictions: u64,
147 pub expirations: u64,
149 pub invalidations: u64,
151}
152
153impl QueryCacheStats {
154 pub fn hit_rate(&self) -> f64 {
156 let total = self.hits + self.misses;
157 if total == 0 {
158 0.0
159 } else {
160 self.hits as f64 / total as f64
161 }
162 }
163
164 pub fn miss_rate(&self) -> f64 {
166 1.0 - self.hit_rate()
167 }
168
169 pub fn total_accesses(&self) -> u64 {
171 self.hits + self.misses
172 }
173}
174
175pub struct QueryCache<T> {
177 cache: HashMap<CacheKey, CachedResult<T>>,
179 lru_queue: VecDeque<CacheKey>,
181 config: CacheConfig,
183 stats: QueryCacheStats,
185}
186
187impl<T: Clone> QueryCache<T> {
188 pub fn new() -> Self {
190 Self::with_config(CacheConfig::default())
191 }
192
193 pub fn with_config(config: CacheConfig) -> Self {
195 Self {
196 cache: HashMap::new(),
197 lru_queue: VecDeque::new(),
198 config,
199 stats: QueryCacheStats::default(),
200 }
201 }
202
203 pub fn get(&mut self, key: &CacheKey) -> Option<T> {
205 let is_expired = self
207 .cache
208 .get(key)
209 .map(|entry| entry.is_expired())
210 .unwrap_or(false);
211
212 if is_expired {
213 self.cache.remove(key);
214 if self.config.enable_stats {
215 self.stats.expirations += 1;
216 self.stats.misses += 1;
217 }
218 return None;
219 }
220
221 if let Some(entry) = self.cache.get_mut(key) {
223 entry.update_access();
225 if self.config.enable_stats {
226 self.stats.hits += 1;
227 }
228
229 let value = entry.value.clone();
230
231 if self.config.enable_lru {
233 self.update_lru(key);
234 }
235
236 Some(value)
237 } else {
238 if self.config.enable_stats {
239 self.stats.misses += 1;
240 }
241 None
242 }
243 }
244
245 pub fn insert(&mut self, key: CacheKey, value: T) {
247 self.insert_with_ttl(key, value, self.config.default_ttl);
248 }
249
250 pub fn insert_with_ttl(&mut self, key: CacheKey, value: T, ttl: Option<Duration>) {
252 if self.cache.len() >= self.config.max_entries {
254 self.evict_one();
255 }
256
257 let entry = CachedResult::new(value, ttl);
259 self.cache.insert(key.clone(), entry);
260
261 if self.config.enable_lru {
263 self.lru_queue.push_back(key);
264 }
265 }
266
267 pub fn invalidate(&mut self, key: &CacheKey) -> bool {
269 if self.cache.remove(key).is_some() {
270 if self.config.enable_stats {
271 self.stats.invalidations += 1;
272 }
273 if self.config.enable_lru {
275 self.lru_queue.retain(|k| k != key);
276 }
277 true
278 } else {
279 false
280 }
281 }
282
283 pub fn clear(&mut self) {
285 self.cache.clear();
286 self.lru_queue.clear();
287 }
288
289 pub fn cleanup_expired(&mut self) -> usize {
291 let mut removed = 0;
292 let expired_keys: Vec<CacheKey> = self
293 .cache
294 .iter()
295 .filter(|(_, v)| v.is_expired())
296 .map(|(k, _)| k.clone())
297 .collect();
298
299 for key in expired_keys {
300 self.cache.remove(&key);
301 self.lru_queue.retain(|k| k != &key);
302 removed += 1;
303 }
304
305 if self.config.enable_stats {
306 self.stats.expirations += removed as u64;
307 }
308
309 removed
310 }
311
312 pub fn stats(&self) -> &QueryCacheStats {
314 &self.stats
315 }
316
317 pub fn len(&self) -> usize {
319 self.cache.len()
320 }
321
322 pub fn is_empty(&self) -> bool {
324 self.cache.is_empty()
325 }
326
327 pub fn config(&self) -> &CacheConfig {
329 &self.config
330 }
331
332 fn update_lru(&mut self, key: &CacheKey) {
334 self.lru_queue.retain(|k| k != key);
336 self.lru_queue.push_back(key.clone());
338 }
339
340 fn evict_one(&mut self) {
342 if let Some(key) = self.lru_queue.pop_front() {
343 self.cache.remove(&key);
344 if self.config.enable_stats {
345 self.stats.evictions += 1;
346 }
347 }
348 }
349}
350
351impl<T: Clone> Default for QueryCache<T> {
352 fn default() -> Self {
353 Self::new()
354 }
355}
356
357pub struct SymbolTableCache {
359 predicate_cache: QueryCache<Vec<PredicateInfo>>,
361 domain_cache: QueryCache<Vec<String>>,
363 scalar_cache: QueryCache<usize>,
365}
366
367impl SymbolTableCache {
368 pub fn new() -> Self {
370 Self {
371 predicate_cache: QueryCache::new(),
372 domain_cache: QueryCache::new(),
373 scalar_cache: QueryCache::new(),
374 }
375 }
376
377 pub fn with_config(config: CacheConfig) -> Self {
379 Self {
380 predicate_cache: QueryCache::with_config(config.clone()),
381 domain_cache: QueryCache::with_config(config.clone()),
382 scalar_cache: QueryCache::with_config(config),
383 }
384 }
385
386 pub fn get_predicates_by_arity(
388 &mut self,
389 table: &SymbolTable,
390 arity: usize,
391 ) -> Vec<PredicateInfo> {
392 let key = CacheKey::PredicatesByArity(arity);
393
394 if let Some(result) = self.predicate_cache.get(&key) {
395 return result;
396 }
397
398 let result: Vec<PredicateInfo> = table
400 .predicates
401 .values()
402 .filter(|p| p.arg_domains.len() == arity)
403 .cloned()
404 .collect();
405
406 self.predicate_cache.insert(key, result.clone());
407 result
408 }
409
410 pub fn get_predicates_by_domain(
412 &mut self,
413 table: &SymbolTable,
414 domain: &str,
415 ) -> Vec<PredicateInfo> {
416 let key = CacheKey::PredicatesByDomain(domain.to_string());
417
418 if let Some(result) = self.predicate_cache.get(&key) {
419 return result;
420 }
421
422 let result: Vec<PredicateInfo> = table
424 .predicates
425 .values()
426 .filter(|p| p.arg_domains.contains(&domain.to_string()))
427 .cloned()
428 .collect();
429
430 self.predicate_cache.insert(key, result.clone());
431 result
432 }
433
434 pub fn get_domain_names(&mut self, table: &SymbolTable) -> Vec<String> {
436 let key = CacheKey::AllDomainNames;
437
438 if let Some(result) = self.domain_cache.get(&key) {
439 return result;
440 }
441
442 let mut result: Vec<String> = table.domains.keys().cloned().collect();
444 result.sort();
445
446 self.domain_cache.insert(key, result.clone());
447 result
448 }
449
450 pub fn get_domain_usage_count(&mut self, table: &SymbolTable, domain: &str) -> usize {
452 let key = CacheKey::DomainUsageCount(domain.to_string());
453
454 if let Some(result) = self.scalar_cache.get(&key) {
455 return result;
456 }
457
458 let mut count = 0;
460 for predicate in table.predicates.values() {
461 count += predicate
462 .arg_domains
463 .iter()
464 .filter(|d| d.as_str() == domain)
465 .count();
466 }
467
468 for var_domain in table.variables.values() {
469 if var_domain == domain {
470 count += 1;
471 }
472 }
473
474 self.scalar_cache.insert(key, count);
475 count
476 }
477
478 pub fn invalidate_all(&mut self) {
480 self.predicate_cache.clear();
481 self.domain_cache.clear();
482 self.scalar_cache.clear();
483 }
484
485 pub fn invalidate_domain(&mut self, domain: &str) {
487 self.predicate_cache
488 .invalidate(&CacheKey::PredicatesByDomain(domain.to_string()));
489 self.scalar_cache
490 .invalidate(&CacheKey::DomainUsageCount(domain.to_string()));
491 self.domain_cache.invalidate(&CacheKey::AllDomainNames);
492 }
493
494 pub fn invalidate_predicates(&mut self) {
496 self.predicate_cache.clear();
497 }
498
499 pub fn combined_stats(&self) -> QueryCacheStats {
501 let pred_stats = self.predicate_cache.stats();
502 let domain_stats = self.domain_cache.stats();
503 let scalar_stats = self.scalar_cache.stats();
504
505 QueryCacheStats {
506 hits: pred_stats.hits + domain_stats.hits + scalar_stats.hits,
507 misses: pred_stats.misses + domain_stats.misses + scalar_stats.misses,
508 evictions: pred_stats.evictions + domain_stats.evictions + scalar_stats.evictions,
509 expirations: pred_stats.expirations
510 + domain_stats.expirations
511 + scalar_stats.expirations,
512 invalidations: pred_stats.invalidations
513 + domain_stats.invalidations
514 + scalar_stats.invalidations,
515 }
516 }
517
518 pub fn cleanup_expired(&mut self) -> usize {
520 self.predicate_cache.cleanup_expired()
521 + self.domain_cache.cleanup_expired()
522 + self.scalar_cache.cleanup_expired()
523 }
524}
525
526impl Default for SymbolTableCache {
527 fn default() -> Self {
528 Self::new()
529 }
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535 use crate::DomainInfo;
536
537 #[test]
538 fn test_cache_basic_operations() {
539 let mut cache: QueryCache<String> = QueryCache::new();
540 let key = CacheKey::Custom("test".to_string());
541
542 cache.insert(key.clone(), "value".to_string());
544 assert_eq!(cache.get(&key), Some("value".to_string()));
545
546 assert_eq!(cache.stats().hits, 1);
548 assert_eq!(cache.stats().misses, 0);
549 }
550
551 #[test]
552 fn test_cache_miss() {
553 let mut cache: QueryCache<String> = QueryCache::new();
554 let key = CacheKey::Custom("nonexistent".to_string());
555
556 assert_eq!(cache.get(&key), None);
557 assert_eq!(cache.stats().misses, 1);
558 }
559
560 #[test]
561 fn test_cache_invalidation() {
562 let mut cache: QueryCache<String> = QueryCache::new();
563 let key = CacheKey::Custom("test".to_string());
564
565 cache.insert(key.clone(), "value".to_string());
566 assert!(cache.invalidate(&key));
567 assert_eq!(cache.get(&key), None);
568 }
569
570 #[test]
571 fn test_cache_expiration() {
572 let config = CacheConfig {
573 default_ttl: Some(Duration::from_millis(10)),
574 ..Default::default()
575 };
576 let mut cache: QueryCache<String> = QueryCache::with_config(config);
577 let key = CacheKey::Custom("test".to_string());
578
579 cache.insert(key.clone(), "value".to_string());
580 std::thread::sleep(Duration::from_millis(20));
581
582 assert_eq!(cache.get(&key), None);
584 assert_eq!(cache.stats().expirations, 1);
585 }
586
587 #[test]
588 fn test_cache_eviction() {
589 let config = CacheConfig {
590 max_entries: 2,
591 enable_lru: true,
592 ..Default::default()
593 };
594 let mut cache: QueryCache<String> = QueryCache::with_config(config);
595
596 cache.insert(CacheKey::Custom("key1".to_string()), "value1".to_string());
597 cache.insert(CacheKey::Custom("key2".to_string()), "value2".to_string());
598 cache.insert(CacheKey::Custom("key3".to_string()), "value3".to_string());
599
600 assert_eq!(cache.len(), 2);
602 assert_eq!(cache.stats().evictions, 1);
603 }
604
605 #[test]
606 fn test_symbol_table_cache() {
607 let mut table = SymbolTable::new();
608 table.add_domain(DomainInfo::new("Person", 100)).unwrap();
609 table
610 .add_predicate(PredicateInfo::new(
611 "knows",
612 vec!["Person".to_string(), "Person".to_string()],
613 ))
614 .unwrap();
615 table
616 .add_predicate(PredicateInfo::new("age", vec!["Person".to_string()]))
617 .unwrap();
618
619 let mut cache = SymbolTableCache::new();
620
621 let predicates = cache.get_predicates_by_arity(&table, 2);
623 assert_eq!(predicates.len(), 1);
624 assert_eq!(cache.predicate_cache.stats().misses, 1);
625
626 let predicates = cache.get_predicates_by_arity(&table, 2);
628 assert_eq!(predicates.len(), 1);
629 assert_eq!(cache.predicate_cache.stats().hits, 1);
630 }
631
632 #[test]
633 fn test_cache_config_presets() {
634 let small = CacheConfig::small();
635 assert_eq!(small.max_entries, 100);
636
637 let large = CacheConfig::large();
638 assert_eq!(large.max_entries, 10000);
639
640 let no_ttl = CacheConfig::no_ttl();
641 assert!(no_ttl.default_ttl.is_none());
642 }
643
644 #[test]
645 fn test_cache_stats() {
646 let mut cache: QueryCache<String> = QueryCache::new();
647 let key1 = CacheKey::Custom("key1".to_string());
648 let key2 = CacheKey::Custom("key2".to_string());
649
650 cache.insert(key1.clone(), "value1".to_string());
651 cache.get(&key1); cache.get(&key2); let stats = cache.stats();
655 assert_eq!(stats.hit_rate(), 0.5);
656 assert_eq!(stats.miss_rate(), 0.5);
657 assert_eq!(stats.total_accesses(), 2);
658 }
659
660 #[test]
661 fn test_cleanup_expired() {
662 let config = CacheConfig {
663 default_ttl: Some(Duration::from_millis(10)),
664 ..Default::default()
665 };
666 let mut cache: QueryCache<String> = QueryCache::with_config(config);
667
668 cache.insert(CacheKey::Custom("key1".to_string()), "value1".to_string());
669 cache.insert(CacheKey::Custom("key2".to_string()), "value2".to_string());
670
671 std::thread::sleep(Duration::from_millis(20));
672
673 let removed = cache.cleanup_expired();
674 assert_eq!(removed, 2);
675 assert!(cache.is_empty());
676 }
677}