Skip to main content

tensorlogic_quantrs_hooks/
cache.rs

1//! Caching and memoization for factor operations.
2//!
3//! This module provides caching mechanisms to avoid recomputing expensive factor operations
4//! like product, marginalization, and division. This can significantly improve performance
5//! when the same operations are performed repeatedly.
6
7use crate::error::Result;
8use crate::factor::Factor;
9use std::collections::HashMap;
10use std::hash::Hash;
11use std::sync::{Arc, Mutex};
12
13/// A key for caching factor operations.
14#[derive(Clone, Debug, Eq, PartialEq, Hash)]
15enum CacheKey {
16    /// Product of two factors
17    Product(String, String),
18    /// Marginalization of a factor over a variable
19    Marginalize(String, String),
20    /// Division of two factors
21    Divide(String, String),
22    /// Reduction of a factor given evidence
23    Reduce(String, String, usize),
24}
25
26/// A cache for factor operations.
27///
28/// This cache stores the results of expensive factor operations to avoid recomputation.
29/// It uses a simple LRU-like eviction policy based on size limits.
30pub struct FactorCache {
31    /// The cached factors
32    cache: Arc<Mutex<HashMap<CacheKey, Factor>>>,
33    /// Maximum number of cached entries
34    max_size: usize,
35    /// Statistics
36    hits: Arc<Mutex<usize>>,
37    misses: Arc<Mutex<usize>>,
38}
39
40impl Default for FactorCache {
41    fn default() -> Self {
42        Self::new(1000)
43    }
44}
45
46impl FactorCache {
47    /// Create a new factor cache with a maximum size.
48    pub fn new(max_size: usize) -> Self {
49        Self {
50            cache: Arc::new(Mutex::new(HashMap::new())),
51            max_size,
52            hits: Arc::new(Mutex::new(0)),
53            misses: Arc::new(Mutex::new(0)),
54        }
55    }
56
57    /// Get a cached product result.
58    pub fn get_product(&self, f1_name: &str, f2_name: &str) -> Option<Factor> {
59        let key = CacheKey::Product(f1_name.to_string(), f2_name.to_string());
60        self.get(&key)
61    }
62
63    /// Cache a product result.
64    pub fn put_product(&self, f1_name: &str, f2_name: &str, result: Factor) {
65        let key = CacheKey::Product(f1_name.to_string(), f2_name.to_string());
66        self.put(key, result);
67    }
68
69    /// Get a cached marginalization result.
70    pub fn get_marginalize(&self, factor_name: &str, var: &str) -> Option<Factor> {
71        let key = CacheKey::Marginalize(factor_name.to_string(), var.to_string());
72        self.get(&key)
73    }
74
75    /// Cache a marginalization result.
76    pub fn put_marginalize(&self, factor_name: &str, var: &str, result: Factor) {
77        let key = CacheKey::Marginalize(factor_name.to_string(), var.to_string());
78        self.put(key, result);
79    }
80
81    /// Get a cached division result.
82    pub fn get_divide(&self, f1_name: &str, f2_name: &str) -> Option<Factor> {
83        let key = CacheKey::Divide(f1_name.to_string(), f2_name.to_string());
84        self.get(&key)
85    }
86
87    /// Cache a division result.
88    pub fn put_divide(&self, f1_name: &str, f2_name: &str, result: Factor) {
89        let key = CacheKey::Divide(f1_name.to_string(), f2_name.to_string());
90        self.put(key, result);
91    }
92
93    /// Get a cached reduction result.
94    pub fn get_reduce(&self, factor_name: &str, var: &str, value: usize) -> Option<Factor> {
95        let key = CacheKey::Reduce(factor_name.to_string(), var.to_string(), value);
96        self.get(&key)
97    }
98
99    /// Cache a reduction result.
100    pub fn put_reduce(&self, factor_name: &str, var: &str, value: usize, result: Factor) {
101        let key = CacheKey::Reduce(factor_name.to_string(), var.to_string(), value);
102        self.put(key, result);
103    }
104
105    /// Get from cache.
106    fn get(&self, key: &CacheKey) -> Option<Factor> {
107        let cache = self.cache.lock().expect("lock should not be poisoned");
108        if let Some(factor) = cache.get(key) {
109            *self.hits.lock().expect("lock should not be poisoned") += 1;
110            Some(factor.clone())
111        } else {
112            *self.misses.lock().expect("lock should not be poisoned") += 1;
113            None
114        }
115    }
116
117    /// Put into cache.
118    fn put(&self, key: CacheKey, factor: Factor) {
119        let mut cache = self.cache.lock().expect("lock should not be poisoned");
120
121        // Simple eviction: if at max size, remove a random entry
122        if cache.len() >= self.max_size {
123            if let Some(first_key) = cache.keys().next().cloned() {
124                cache.remove(&first_key);
125            }
126        }
127
128        cache.insert(key, factor);
129    }
130
131    /// Clear the cache.
132    pub fn clear(&self) {
133        self.cache
134            .lock()
135            .expect("lock should not be poisoned")
136            .clear();
137        *self.hits.lock().expect("lock should not be poisoned") = 0;
138        *self.misses.lock().expect("lock should not be poisoned") = 0;
139    }
140
141    /// Get cache statistics.
142    pub fn stats(&self) -> CacheStats {
143        let hits = *self.hits.lock().expect("lock should not be poisoned");
144        let misses = *self.misses.lock().expect("lock should not be poisoned");
145        let size = self
146            .cache
147            .lock()
148            .expect("lock should not be poisoned")
149            .len();
150
151        CacheStats {
152            hits,
153            misses,
154            size,
155            hit_rate: if hits + misses > 0 {
156                hits as f64 / (hits + misses) as f64
157            } else {
158                0.0
159            },
160        }
161    }
162
163    /// Get current cache size.
164    pub fn size(&self) -> usize {
165        self.cache
166            .lock()
167            .expect("lock should not be poisoned")
168            .len()
169    }
170}
171
172/// Cache statistics.
173#[derive(Debug, Clone)]
174pub struct CacheStats {
175    /// Number of cache hits
176    pub hits: usize,
177    /// Number of cache misses
178    pub misses: usize,
179    /// Current cache size
180    pub size: usize,
181    /// Hit rate (hits / (hits + misses))
182    pub hit_rate: f64,
183}
184
185/// A cached factor that memoizes operations.
186///
187/// This wraps a factor and caches the results of operations.
188pub struct CachedFactor {
189    /// The underlying factor
190    pub factor: Factor,
191    /// The cache
192    cache: Arc<FactorCache>,
193}
194
195impl CachedFactor {
196    /// Create a new cached factor.
197    pub fn new(factor: Factor, cache: Arc<FactorCache>) -> Self {
198        Self { factor, cache }
199    }
200
201    /// Compute product with caching.
202    pub fn product_cached(&self, other: &CachedFactor) -> Result<Factor> {
203        // Try to get from cache
204        if let Some(cached) = self
205            .cache
206            .get_product(&self.factor.name, &other.factor.name)
207        {
208            return Ok(cached);
209        }
210
211        // Compute and cache
212        let result = self.factor.product(&other.factor)?;
213        self.cache
214            .put_product(&self.factor.name, &other.factor.name, result.clone());
215
216        Ok(result)
217    }
218
219    /// Compute marginalization with caching.
220    pub fn marginalize_out_cached(&self, var: &str) -> Result<Factor> {
221        // Try to get from cache
222        if let Some(cached) = self.cache.get_marginalize(&self.factor.name, var) {
223            return Ok(cached);
224        }
225
226        // Compute and cache
227        let result = self.factor.marginalize_out(var)?;
228        self.cache
229            .put_marginalize(&self.factor.name, var, result.clone());
230
231        Ok(result)
232    }
233
234    /// Compute division with caching.
235    pub fn divide_cached(&self, other: &CachedFactor) -> Result<Factor> {
236        // Try to get from cache
237        if let Some(cached) = self.cache.get_divide(&self.factor.name, &other.factor.name) {
238            return Ok(cached);
239        }
240
241        // Compute and cache
242        let result = self.factor.divide(&other.factor)?;
243        self.cache
244            .put_divide(&self.factor.name, &other.factor.name, result.clone());
245
246        Ok(result)
247    }
248
249    /// Compute reduction with caching.
250    pub fn reduce_cached(&self, var: &str, value: usize) -> Result<Factor> {
251        // Try to get from cache
252        if let Some(cached) = self.cache.get_reduce(&self.factor.name, var, value) {
253            return Ok(cached);
254        }
255
256        // Compute and cache
257        let result = self.factor.reduce(var, value)?;
258        self.cache
259            .put_reduce(&self.factor.name, var, value, result.clone());
260
261        Ok(result)
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use scirs2_core::ndarray::Array;
269
270    fn create_test_factor(name: &str) -> Factor {
271        let values = vec![0.1, 0.2, 0.3, 0.4];
272        let array = Array::from_shape_vec(vec![2, 2], values)
273            .expect("unwrap")
274            .into_dyn();
275        Factor::new(
276            name.to_string(),
277            vec!["X".to_string(), "Y".to_string()],
278            array,
279        )
280        .expect("unwrap")
281    }
282
283    #[test]
284    fn test_cache_product() {
285        let cache = Arc::new(FactorCache::new(100));
286        let f1 = CachedFactor::new(create_test_factor("f1"), cache.clone());
287        let f2 = CachedFactor::new(create_test_factor("f2"), cache.clone());
288
289        // First call - cache miss
290        let result1 = f1.product_cached(&f2).expect("unwrap");
291        let stats1 = cache.stats();
292        assert_eq!(stats1.misses, 1);
293        assert_eq!(stats1.hits, 0);
294
295        // Second call - cache hit
296        let result2 = f1.product_cached(&f2).expect("unwrap");
297        let stats2 = cache.stats();
298        assert_eq!(stats2.misses, 1);
299        assert_eq!(stats2.hits, 1);
300
301        // Results should be the same
302        assert_eq!(result1.name, result2.name);
303    }
304
305    #[test]
306    fn test_cache_marginalize() {
307        let cache = Arc::new(FactorCache::new(100));
308        let f = CachedFactor::new(create_test_factor("f"), cache.clone());
309
310        // First call - cache miss
311        let _result1 = f.marginalize_out_cached("Y").expect("unwrap");
312        let stats1 = cache.stats();
313        assert_eq!(stats1.misses, 1);
314
315        // Second call - cache hit
316        let _result2 = f.marginalize_out_cached("Y").expect("unwrap");
317        let stats2 = cache.stats();
318        assert_eq!(stats2.hits, 1);
319    }
320
321    #[test]
322    fn test_cache_stats() {
323        let cache = FactorCache::new(100);
324        let stats = cache.stats();
325        assert_eq!(stats.hits, 0);
326        assert_eq!(stats.misses, 0);
327        assert_eq!(stats.hit_rate, 0.0);
328    }
329
330    #[test]
331    fn test_cache_clear() {
332        let cache = Arc::new(FactorCache::new(100));
333        let f = CachedFactor::new(create_test_factor("f"), cache.clone());
334
335        // Populate cache
336        let _ = f.marginalize_out_cached("Y").expect("unwrap");
337        assert_eq!(cache.size(), 1);
338
339        // Clear cache
340        cache.clear();
341        assert_eq!(cache.size(), 0);
342        assert_eq!(cache.stats().hits, 0);
343        assert_eq!(cache.stats().misses, 0);
344    }
345
346    #[test]
347    fn test_cache_eviction() {
348        let cache = Arc::new(FactorCache::new(2));
349
350        // Add 3 items (should evict oldest)
351        cache.put_marginalize("f1", "X", create_test_factor("result1"));
352        cache.put_marginalize("f2", "Y", create_test_factor("result2"));
353        cache.put_marginalize("f3", "Z", create_test_factor("result3"));
354
355        // Size should be at most 2
356        assert!(cache.size() <= 2);
357    }
358}