Skip to main content

tensorlogic_adapters/
query_cache.rs

1//! Advanced query result caching system for performance optimization.
2//!
3//! This module provides sophisticated caching mechanisms for expensive query operations,
4//! including TTL-based expiration, size limits, and cache statistics tracking.
5
6use crate::{PredicateInfo, SymbolTable};
7use std::collections::{HashMap, VecDeque};
8use std::time::{Duration, Instant};
9
10/// A cache key for query results.
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12pub enum CacheKey {
13    /// Query by predicate name
14    PredicateByName(String),
15    /// Query by arity
16    PredicatesByArity(usize),
17    /// Query by domain
18    PredicatesByDomain(String),
19    /// Query by signature
20    PredicatesBySignature(Vec<String>),
21    /// Query by pattern (wildcard matching)
22    PredicatesByPattern(String),
23    /// Domain usage count
24    DomainUsageCount(String),
25    /// All domain names
26    AllDomainNames,
27    /// All predicate names
28    AllPredicateNames,
29    /// Custom query key
30    Custom(String),
31}
32
33/// A cached query result with metadata.
34#[derive(Debug, Clone)]
35pub struct CachedResult<T> {
36    /// The cached value
37    pub value: T,
38    /// When this entry was created
39    pub created_at: Instant,
40    /// When this entry was last accessed
41    pub last_accessed: Instant,
42    /// Number of times this entry has been accessed
43    pub access_count: u64,
44    /// Time-to-live for this entry
45    pub ttl: Option<Duration>,
46}
47
48impl<T> CachedResult<T> {
49    /// Create a new cached result.
50    pub fn new(value: T, ttl: Option<Duration>) -> Self {
51        let now = Instant::now();
52        Self {
53            value,
54            created_at: now,
55            last_accessed: now,
56            access_count: 1,
57            ttl,
58        }
59    }
60
61    /// Check if this cached result has expired.
62    pub fn is_expired(&self) -> bool {
63        if let Some(ttl) = self.ttl {
64            self.created_at.elapsed() > ttl
65        } else {
66            false
67        }
68    }
69
70    /// Update access statistics.
71    pub fn update_access(&mut self) {
72        self.last_accessed = Instant::now();
73        self.access_count += 1;
74    }
75
76    /// Get the age of this cache entry.
77    pub fn age(&self) -> Duration {
78        self.created_at.elapsed()
79    }
80}
81
82/// Configuration for the query cache.
83#[derive(Debug, Clone)]
84pub struct CacheConfig {
85    /// Maximum number of entries in the cache
86    pub max_entries: usize,
87    /// Default time-to-live for cache entries
88    pub default_ttl: Option<Duration>,
89    /// Whether to enable LRU eviction
90    pub enable_lru: bool,
91    /// Whether to enable statistics tracking
92    pub enable_stats: bool,
93}
94
95impl Default for CacheConfig {
96    fn default() -> Self {
97        Self {
98            max_entries: 1000,
99            default_ttl: Some(Duration::from_secs(300)), // 5 minutes
100            enable_lru: true,
101            enable_stats: true,
102        }
103    }
104}
105
106impl CacheConfig {
107    /// Create a configuration for a small cache.
108    pub fn small() -> Self {
109        Self {
110            max_entries: 100,
111            default_ttl: Some(Duration::from_secs(60)),
112            enable_lru: true,
113            enable_stats: true,
114        }
115    }
116
117    /// Create a configuration for a large cache.
118    pub fn large() -> Self {
119        Self {
120            max_entries: 10000,
121            default_ttl: Some(Duration::from_secs(600)),
122            enable_lru: true,
123            enable_stats: true,
124        }
125    }
126
127    /// Create a configuration with no TTL (cache until evicted).
128    pub fn no_ttl() -> Self {
129        Self {
130            max_entries: 1000,
131            default_ttl: None,
132            enable_lru: true,
133            enable_stats: true,
134        }
135    }
136}
137
138/// Statistics for query cache performance.
139#[derive(Debug, Clone, Default)]
140pub struct QueryCacheStats {
141    /// Total number of cache hits
142    pub hits: u64,
143    /// Total number of cache misses
144    pub misses: u64,
145    /// Total number of evictions
146    pub evictions: u64,
147    /// Total number of expirations
148    pub expirations: u64,
149    /// Total number of invalidations
150    pub invalidations: u64,
151}
152
153impl QueryCacheStats {
154    /// Calculate the hit rate (0.0 to 1.0).
155    pub fn hit_rate(&self) -> f64 {
156        let total = self.hits + self.misses;
157        if total == 0 {
158            0.0
159        } else {
160            self.hits as f64 / total as f64
161        }
162    }
163
164    /// Calculate the miss rate (0.0 to 1.0).
165    pub fn miss_rate(&self) -> f64 {
166        1.0 - self.hit_rate()
167    }
168
169    /// Get total number of accesses.
170    pub fn total_accesses(&self) -> u64 {
171        self.hits + self.misses
172    }
173}
174
175/// A generic query result cache with TTL and LRU eviction.
176pub struct QueryCache<T> {
177    /// The cache storage
178    cache: HashMap<CacheKey, CachedResult<T>>,
179    /// LRU queue for eviction
180    lru_queue: VecDeque<CacheKey>,
181    /// Cache configuration
182    config: CacheConfig,
183    /// Cache statistics
184    stats: QueryCacheStats,
185}
186
187impl<T: Clone> QueryCache<T> {
188    /// Create a new query cache with default configuration.
189    pub fn new() -> Self {
190        Self::with_config(CacheConfig::default())
191    }
192
193    /// Create a new query cache with custom configuration.
194    pub fn with_config(config: CacheConfig) -> Self {
195        Self {
196            cache: HashMap::new(),
197            lru_queue: VecDeque::new(),
198            config,
199            stats: QueryCacheStats::default(),
200        }
201    }
202
203    /// Get a value from the cache.
204    pub fn get(&mut self, key: &CacheKey) -> Option<T> {
205        // Check if entry exists and not expired
206        let is_expired = self
207            .cache
208            .get(key)
209            .map(|entry| entry.is_expired())
210            .unwrap_or(false);
211
212        if is_expired {
213            self.cache.remove(key);
214            if self.config.enable_stats {
215                self.stats.expirations += 1;
216                self.stats.misses += 1;
217            }
218            return None;
219        }
220
221        // Get mutable entry and update
222        if let Some(entry) = self.cache.get_mut(key) {
223            // Update access statistics
224            entry.update_access();
225            if self.config.enable_stats {
226                self.stats.hits += 1;
227            }
228
229            let value = entry.value.clone();
230
231            // Update LRU queue if enabled
232            if self.config.enable_lru {
233                self.update_lru(key);
234            }
235
236            Some(value)
237        } else {
238            if self.config.enable_stats {
239                self.stats.misses += 1;
240            }
241            None
242        }
243    }
244
245    /// Insert a value into the cache.
246    pub fn insert(&mut self, key: CacheKey, value: T) {
247        self.insert_with_ttl(key, value, self.config.default_ttl);
248    }
249
250    /// Insert a value with a custom TTL.
251    pub fn insert_with_ttl(&mut self, key: CacheKey, value: T, ttl: Option<Duration>) {
252        // Check if we need to evict
253        if self.cache.len() >= self.config.max_entries {
254            self.evict_one();
255        }
256
257        // Insert the new entry
258        let entry = CachedResult::new(value, ttl);
259        self.cache.insert(key.clone(), entry);
260
261        // Update LRU queue
262        if self.config.enable_lru {
263            self.lru_queue.push_back(key);
264        }
265    }
266
267    /// Invalidate a specific cache entry.
268    pub fn invalidate(&mut self, key: &CacheKey) -> bool {
269        if self.cache.remove(key).is_some() {
270            if self.config.enable_stats {
271                self.stats.invalidations += 1;
272            }
273            // Remove from LRU queue
274            if self.config.enable_lru {
275                self.lru_queue.retain(|k| k != key);
276            }
277            true
278        } else {
279            false
280        }
281    }
282
283    /// Clear all cache entries.
284    pub fn clear(&mut self) {
285        self.cache.clear();
286        self.lru_queue.clear();
287    }
288
289    /// Remove expired entries.
290    pub fn cleanup_expired(&mut self) -> usize {
291        let mut removed = 0;
292        let expired_keys: Vec<CacheKey> = self
293            .cache
294            .iter()
295            .filter(|(_, v)| v.is_expired())
296            .map(|(k, _)| k.clone())
297            .collect();
298
299        for key in expired_keys {
300            self.cache.remove(&key);
301            self.lru_queue.retain(|k| k != &key);
302            removed += 1;
303        }
304
305        if self.config.enable_stats {
306            self.stats.expirations += removed as u64;
307        }
308
309        removed
310    }
311
312    /// Get cache statistics.
313    pub fn stats(&self) -> &QueryCacheStats {
314        &self.stats
315    }
316
317    /// Get the number of entries in the cache.
318    pub fn len(&self) -> usize {
319        self.cache.len()
320    }
321
322    /// Check if the cache is empty.
323    pub fn is_empty(&self) -> bool {
324        self.cache.is_empty()
325    }
326
327    /// Get the cache configuration.
328    pub fn config(&self) -> &CacheConfig {
329        &self.config
330    }
331
332    /// Update the LRU queue when an entry is accessed.
333    fn update_lru(&mut self, key: &CacheKey) {
334        // Remove the key from its current position
335        self.lru_queue.retain(|k| k != key);
336        // Add it to the back (most recently used)
337        self.lru_queue.push_back(key.clone());
338    }
339
340    /// Evict one entry using LRU strategy.
341    fn evict_one(&mut self) {
342        if let Some(key) = self.lru_queue.pop_front() {
343            self.cache.remove(&key);
344            if self.config.enable_stats {
345                self.stats.evictions += 1;
346            }
347        }
348    }
349}
350
351impl<T: Clone> Default for QueryCache<T> {
352    fn default() -> Self {
353        Self::new()
354    }
355}
356
357/// A specialized cache for symbol table queries.
358pub struct SymbolTableCache {
359    /// Cache for predicate queries
360    predicate_cache: QueryCache<Vec<PredicateInfo>>,
361    /// Cache for domain name queries
362    domain_cache: QueryCache<Vec<String>>,
363    /// Cache for scalar results
364    scalar_cache: QueryCache<usize>,
365}
366
367impl SymbolTableCache {
368    /// Create a new symbol table cache.
369    pub fn new() -> Self {
370        Self {
371            predicate_cache: QueryCache::new(),
372            domain_cache: QueryCache::new(),
373            scalar_cache: QueryCache::new(),
374        }
375    }
376
377    /// Create a new cache with custom configuration.
378    pub fn with_config(config: CacheConfig) -> Self {
379        Self {
380            predicate_cache: QueryCache::with_config(config.clone()),
381            domain_cache: QueryCache::with_config(config.clone()),
382            scalar_cache: QueryCache::with_config(config),
383        }
384    }
385
386    /// Get predicates by arity (cached).
387    pub fn get_predicates_by_arity(
388        &mut self,
389        table: &SymbolTable,
390        arity: usize,
391    ) -> Vec<PredicateInfo> {
392        let key = CacheKey::PredicatesByArity(arity);
393
394        if let Some(result) = self.predicate_cache.get(&key) {
395            return result;
396        }
397
398        // Cache miss - compute and cache
399        let result: Vec<PredicateInfo> = table
400            .predicates
401            .values()
402            .filter(|p| p.arg_domains.len() == arity)
403            .cloned()
404            .collect();
405
406        self.predicate_cache.insert(key, result.clone());
407        result
408    }
409
410    /// Get predicates using a domain (cached).
411    pub fn get_predicates_by_domain(
412        &mut self,
413        table: &SymbolTable,
414        domain: &str,
415    ) -> Vec<PredicateInfo> {
416        let key = CacheKey::PredicatesByDomain(domain.to_string());
417
418        if let Some(result) = self.predicate_cache.get(&key) {
419            return result;
420        }
421
422        // Cache miss - compute and cache
423        let result: Vec<PredicateInfo> = table
424            .predicates
425            .values()
426            .filter(|p| p.arg_domains.contains(&domain.to_string()))
427            .cloned()
428            .collect();
429
430        self.predicate_cache.insert(key, result.clone());
431        result
432    }
433
434    /// Get all domain names (cached).
435    pub fn get_domain_names(&mut self, table: &SymbolTable) -> Vec<String> {
436        let key = CacheKey::AllDomainNames;
437
438        if let Some(result) = self.domain_cache.get(&key) {
439            return result;
440        }
441
442        // Cache miss - compute and cache
443        let mut result: Vec<String> = table.domains.keys().cloned().collect();
444        result.sort();
445
446        self.domain_cache.insert(key, result.clone());
447        result
448    }
449
450    /// Get domain usage count (cached).
451    pub fn get_domain_usage_count(&mut self, table: &SymbolTable, domain: &str) -> usize {
452        let key = CacheKey::DomainUsageCount(domain.to_string());
453
454        if let Some(result) = self.scalar_cache.get(&key) {
455            return result;
456        }
457
458        // Cache miss - compute and cache
459        let mut count = 0;
460        for predicate in table.predicates.values() {
461            count += predicate
462                .arg_domains
463                .iter()
464                .filter(|d| d.as_str() == domain)
465                .count();
466        }
467
468        for var_domain in table.variables.values() {
469            if var_domain == domain {
470                count += 1;
471            }
472        }
473
474        self.scalar_cache.insert(key, count);
475        count
476    }
477
478    /// Invalidate all caches.
479    pub fn invalidate_all(&mut self) {
480        self.predicate_cache.clear();
481        self.domain_cache.clear();
482        self.scalar_cache.clear();
483    }
484
485    /// Invalidate caches related to a specific domain.
486    pub fn invalidate_domain(&mut self, domain: &str) {
487        self.predicate_cache
488            .invalidate(&CacheKey::PredicatesByDomain(domain.to_string()));
489        self.scalar_cache
490            .invalidate(&CacheKey::DomainUsageCount(domain.to_string()));
491        self.domain_cache.invalidate(&CacheKey::AllDomainNames);
492    }
493
494    /// Invalidate caches related to predicates.
495    pub fn invalidate_predicates(&mut self) {
496        self.predicate_cache.clear();
497    }
498
499    /// Get combined statistics from all caches.
500    pub fn combined_stats(&self) -> QueryCacheStats {
501        let pred_stats = self.predicate_cache.stats();
502        let domain_stats = self.domain_cache.stats();
503        let scalar_stats = self.scalar_cache.stats();
504
505        QueryCacheStats {
506            hits: pred_stats.hits + domain_stats.hits + scalar_stats.hits,
507            misses: pred_stats.misses + domain_stats.misses + scalar_stats.misses,
508            evictions: pred_stats.evictions + domain_stats.evictions + scalar_stats.evictions,
509            expirations: pred_stats.expirations
510                + domain_stats.expirations
511                + scalar_stats.expirations,
512            invalidations: pred_stats.invalidations
513                + domain_stats.invalidations
514                + scalar_stats.invalidations,
515        }
516    }
517
518    /// Cleanup expired entries in all caches.
519    pub fn cleanup_expired(&mut self) -> usize {
520        self.predicate_cache.cleanup_expired()
521            + self.domain_cache.cleanup_expired()
522            + self.scalar_cache.cleanup_expired()
523    }
524}
525
526impl Default for SymbolTableCache {
527    fn default() -> Self {
528        Self::new()
529    }
530}
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535    use crate::DomainInfo;
536
537    #[test]
538    fn test_cache_basic_operations() {
539        let mut cache: QueryCache<String> = QueryCache::new();
540        let key = CacheKey::Custom("test".to_string());
541
542        // Insert and retrieve
543        cache.insert(key.clone(), "value".to_string());
544        assert_eq!(cache.get(&key), Some("value".to_string()));
545
546        // Stats
547        assert_eq!(cache.stats().hits, 1);
548        assert_eq!(cache.stats().misses, 0);
549    }
550
551    #[test]
552    fn test_cache_miss() {
553        let mut cache: QueryCache<String> = QueryCache::new();
554        let key = CacheKey::Custom("nonexistent".to_string());
555
556        assert_eq!(cache.get(&key), None);
557        assert_eq!(cache.stats().misses, 1);
558    }
559
560    #[test]
561    fn test_cache_invalidation() {
562        let mut cache: QueryCache<String> = QueryCache::new();
563        let key = CacheKey::Custom("test".to_string());
564
565        cache.insert(key.clone(), "value".to_string());
566        assert!(cache.invalidate(&key));
567        assert_eq!(cache.get(&key), None);
568    }
569
570    #[test]
571    fn test_cache_expiration() {
572        let config = CacheConfig {
573            default_ttl: Some(Duration::from_millis(10)),
574            ..Default::default()
575        };
576        let mut cache: QueryCache<String> = QueryCache::with_config(config);
577        let key = CacheKey::Custom("test".to_string());
578
579        cache.insert(key.clone(), "value".to_string());
580        std::thread::sleep(Duration::from_millis(20));
581
582        // Should be expired
583        assert_eq!(cache.get(&key), None);
584        assert_eq!(cache.stats().expirations, 1);
585    }
586
587    #[test]
588    fn test_cache_eviction() {
589        let config = CacheConfig {
590            max_entries: 2,
591            enable_lru: true,
592            ..Default::default()
593        };
594        let mut cache: QueryCache<String> = QueryCache::with_config(config);
595
596        cache.insert(CacheKey::Custom("key1".to_string()), "value1".to_string());
597        cache.insert(CacheKey::Custom("key2".to_string()), "value2".to_string());
598        cache.insert(CacheKey::Custom("key3".to_string()), "value3".to_string());
599
600        // key1 should have been evicted
601        assert_eq!(cache.len(), 2);
602        assert_eq!(cache.stats().evictions, 1);
603    }
604
605    #[test]
606    fn test_symbol_table_cache() {
607        let mut table = SymbolTable::new();
608        table.add_domain(DomainInfo::new("Person", 100)).unwrap();
609        table
610            .add_predicate(PredicateInfo::new(
611                "knows",
612                vec!["Person".to_string(), "Person".to_string()],
613            ))
614            .unwrap();
615        table
616            .add_predicate(PredicateInfo::new("age", vec!["Person".to_string()]))
617            .unwrap();
618
619        let mut cache = SymbolTableCache::new();
620
621        // First call - cache miss
622        let predicates = cache.get_predicates_by_arity(&table, 2);
623        assert_eq!(predicates.len(), 1);
624        assert_eq!(cache.predicate_cache.stats().misses, 1);
625
626        // Second call - cache hit
627        let predicates = cache.get_predicates_by_arity(&table, 2);
628        assert_eq!(predicates.len(), 1);
629        assert_eq!(cache.predicate_cache.stats().hits, 1);
630    }
631
632    #[test]
633    fn test_cache_config_presets() {
634        let small = CacheConfig::small();
635        assert_eq!(small.max_entries, 100);
636
637        let large = CacheConfig::large();
638        assert_eq!(large.max_entries, 10000);
639
640        let no_ttl = CacheConfig::no_ttl();
641        assert!(no_ttl.default_ttl.is_none());
642    }
643
644    #[test]
645    fn test_cache_stats() {
646        let mut cache: QueryCache<String> = QueryCache::new();
647        let key1 = CacheKey::Custom("key1".to_string());
648        let key2 = CacheKey::Custom("key2".to_string());
649
650        cache.insert(key1.clone(), "value1".to_string());
651        cache.get(&key1); // hit
652        cache.get(&key2); // miss
653
654        let stats = cache.stats();
655        assert_eq!(stats.hit_rate(), 0.5);
656        assert_eq!(stats.miss_rate(), 0.5);
657        assert_eq!(stats.total_accesses(), 2);
658    }
659
660    #[test]
661    fn test_cleanup_expired() {
662        let config = CacheConfig {
663            default_ttl: Some(Duration::from_millis(10)),
664            ..Default::default()
665        };
666        let mut cache: QueryCache<String> = QueryCache::with_config(config);
667
668        cache.insert(CacheKey::Custom("key1".to_string()), "value1".to_string());
669        cache.insert(CacheKey::Custom("key2".to_string()), "value2".to_string());
670
671        std::thread::sleep(Duration::from_millis(20));
672
673        let removed = cache.cleanup_expired();
674        assert_eq!(removed, 2);
675        assert!(cache.is_empty());
676    }
677}