Skip to main content

tensorlogic_quantrs_hooks/
cache.rs

1//! Caching and memoization for factor operations.
2//!
3//! This module provides caching mechanisms to avoid recomputing expensive factor operations
4//! like product, marginalization, and division. This can significantly improve performance
5//! when the same operations are performed repeatedly.
6
7use crate::error::Result;
8use crate::factor::Factor;
9use std::collections::HashMap;
10use std::hash::Hash;
11use std::sync::{Arc, Mutex};
12
13/// A key for caching factor operations.
14#[derive(Clone, Debug, Eq, PartialEq, Hash)]
15enum CacheKey {
16    /// Product of two factors
17    Product(String, String),
18    /// Marginalization of a factor over a variable
19    Marginalize(String, String),
20    /// Division of two factors
21    Divide(String, String),
22    /// Reduction of a factor given evidence
23    Reduce(String, String, usize),
24}
25
26/// A cache for factor operations.
27///
28/// This cache stores the results of expensive factor operations to avoid recomputation.
29/// It uses a simple LRU-like eviction policy based on size limits.
30pub struct FactorCache {
31    /// The cached factors
32    cache: Arc<Mutex<HashMap<CacheKey, Factor>>>,
33    /// Maximum number of cached entries
34    max_size: usize,
35    /// Statistics
36    hits: Arc<Mutex<usize>>,
37    misses: Arc<Mutex<usize>>,
38}
39
40impl Default for FactorCache {
41    fn default() -> Self {
42        Self::new(1000)
43    }
44}
45
46impl FactorCache {
47    /// Create a new factor cache with a maximum size.
48    pub fn new(max_size: usize) -> Self {
49        Self {
50            cache: Arc::new(Mutex::new(HashMap::new())),
51            max_size,
52            hits: Arc::new(Mutex::new(0)),
53            misses: Arc::new(Mutex::new(0)),
54        }
55    }
56
57    /// Get a cached product result.
58    pub fn get_product(&self, f1_name: &str, f2_name: &str) -> Option<Factor> {
59        let key = CacheKey::Product(f1_name.to_string(), f2_name.to_string());
60        self.get(&key)
61    }
62
63    /// Cache a product result.
64    pub fn put_product(&self, f1_name: &str, f2_name: &str, result: Factor) {
65        let key = CacheKey::Product(f1_name.to_string(), f2_name.to_string());
66        self.put(key, result);
67    }
68
69    /// Get a cached marginalization result.
70    pub fn get_marginalize(&self, factor_name: &str, var: &str) -> Option<Factor> {
71        let key = CacheKey::Marginalize(factor_name.to_string(), var.to_string());
72        self.get(&key)
73    }
74
75    /// Cache a marginalization result.
76    pub fn put_marginalize(&self, factor_name: &str, var: &str, result: Factor) {
77        let key = CacheKey::Marginalize(factor_name.to_string(), var.to_string());
78        self.put(key, result);
79    }
80
81    /// Get a cached division result.
82    pub fn get_divide(&self, f1_name: &str, f2_name: &str) -> Option<Factor> {
83        let key = CacheKey::Divide(f1_name.to_string(), f2_name.to_string());
84        self.get(&key)
85    }
86
87    /// Cache a division result.
88    pub fn put_divide(&self, f1_name: &str, f2_name: &str, result: Factor) {
89        let key = CacheKey::Divide(f1_name.to_string(), f2_name.to_string());
90        self.put(key, result);
91    }
92
93    /// Get a cached reduction result.
94    pub fn get_reduce(&self, factor_name: &str, var: &str, value: usize) -> Option<Factor> {
95        let key = CacheKey::Reduce(factor_name.to_string(), var.to_string(), value);
96        self.get(&key)
97    }
98
99    /// Cache a reduction result.
100    pub fn put_reduce(&self, factor_name: &str, var: &str, value: usize, result: Factor) {
101        let key = CacheKey::Reduce(factor_name.to_string(), var.to_string(), value);
102        self.put(key, result);
103    }
104
105    /// Get from cache.
106    fn get(&self, key: &CacheKey) -> Option<Factor> {
107        let cache = self.cache.lock().unwrap();
108        if let Some(factor) = cache.get(key) {
109            *self.hits.lock().unwrap() += 1;
110            Some(factor.clone())
111        } else {
112            *self.misses.lock().unwrap() += 1;
113            None
114        }
115    }
116
117    /// Put into cache.
118    fn put(&self, key: CacheKey, factor: Factor) {
119        let mut cache = self.cache.lock().unwrap();
120
121        // Simple eviction: if at max size, remove a random entry
122        if cache.len() >= self.max_size {
123            if let Some(first_key) = cache.keys().next().cloned() {
124                cache.remove(&first_key);
125            }
126        }
127
128        cache.insert(key, factor);
129    }
130
131    /// Clear the cache.
132    pub fn clear(&self) {
133        self.cache.lock().unwrap().clear();
134        *self.hits.lock().unwrap() = 0;
135        *self.misses.lock().unwrap() = 0;
136    }
137
138    /// Get cache statistics.
139    pub fn stats(&self) -> CacheStats {
140        let hits = *self.hits.lock().unwrap();
141        let misses = *self.misses.lock().unwrap();
142        let size = self.cache.lock().unwrap().len();
143
144        CacheStats {
145            hits,
146            misses,
147            size,
148            hit_rate: if hits + misses > 0 {
149                hits as f64 / (hits + misses) as f64
150            } else {
151                0.0
152            },
153        }
154    }
155
156    /// Get current cache size.
157    pub fn size(&self) -> usize {
158        self.cache.lock().unwrap().len()
159    }
160}
161
162/// Cache statistics.
163#[derive(Debug, Clone)]
164pub struct CacheStats {
165    /// Number of cache hits
166    pub hits: usize,
167    /// Number of cache misses
168    pub misses: usize,
169    /// Current cache size
170    pub size: usize,
171    /// Hit rate (hits / (hits + misses))
172    pub hit_rate: f64,
173}
174
175/// A cached factor that memoizes operations.
176///
177/// This wraps a factor and caches the results of operations.
178pub struct CachedFactor {
179    /// The underlying factor
180    pub factor: Factor,
181    /// The cache
182    cache: Arc<FactorCache>,
183}
184
185impl CachedFactor {
186    /// Create a new cached factor.
187    pub fn new(factor: Factor, cache: Arc<FactorCache>) -> Self {
188        Self { factor, cache }
189    }
190
191    /// Compute product with caching.
192    pub fn product_cached(&self, other: &CachedFactor) -> Result<Factor> {
193        // Try to get from cache
194        if let Some(cached) = self
195            .cache
196            .get_product(&self.factor.name, &other.factor.name)
197        {
198            return Ok(cached);
199        }
200
201        // Compute and cache
202        let result = self.factor.product(&other.factor)?;
203        self.cache
204            .put_product(&self.factor.name, &other.factor.name, result.clone());
205
206        Ok(result)
207    }
208
209    /// Compute marginalization with caching.
210    pub fn marginalize_out_cached(&self, var: &str) -> Result<Factor> {
211        // Try to get from cache
212        if let Some(cached) = self.cache.get_marginalize(&self.factor.name, var) {
213            return Ok(cached);
214        }
215
216        // Compute and cache
217        let result = self.factor.marginalize_out(var)?;
218        self.cache
219            .put_marginalize(&self.factor.name, var, result.clone());
220
221        Ok(result)
222    }
223
224    /// Compute division with caching.
225    pub fn divide_cached(&self, other: &CachedFactor) -> Result<Factor> {
226        // Try to get from cache
227        if let Some(cached) = self.cache.get_divide(&self.factor.name, &other.factor.name) {
228            return Ok(cached);
229        }
230
231        // Compute and cache
232        let result = self.factor.divide(&other.factor)?;
233        self.cache
234            .put_divide(&self.factor.name, &other.factor.name, result.clone());
235
236        Ok(result)
237    }
238
239    /// Compute reduction with caching.
240    pub fn reduce_cached(&self, var: &str, value: usize) -> Result<Factor> {
241        // Try to get from cache
242        if let Some(cached) = self.cache.get_reduce(&self.factor.name, var, value) {
243            return Ok(cached);
244        }
245
246        // Compute and cache
247        let result = self.factor.reduce(var, value)?;
248        self.cache
249            .put_reduce(&self.factor.name, var, value, result.clone());
250
251        Ok(result)
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use scirs2_core::ndarray::Array;
259
260    fn create_test_factor(name: &str) -> Factor {
261        let values = vec![0.1, 0.2, 0.3, 0.4];
262        let array = Array::from_shape_vec(vec![2, 2], values)
263            .unwrap()
264            .into_dyn();
265        Factor::new(
266            name.to_string(),
267            vec!["X".to_string(), "Y".to_string()],
268            array,
269        )
270        .unwrap()
271    }
272
273    #[test]
274    fn test_cache_product() {
275        let cache = Arc::new(FactorCache::new(100));
276        let f1 = CachedFactor::new(create_test_factor("f1"), cache.clone());
277        let f2 = CachedFactor::new(create_test_factor("f2"), cache.clone());
278
279        // First call - cache miss
280        let result1 = f1.product_cached(&f2).unwrap();
281        let stats1 = cache.stats();
282        assert_eq!(stats1.misses, 1);
283        assert_eq!(stats1.hits, 0);
284
285        // Second call - cache hit
286        let result2 = f1.product_cached(&f2).unwrap();
287        let stats2 = cache.stats();
288        assert_eq!(stats2.misses, 1);
289        assert_eq!(stats2.hits, 1);
290
291        // Results should be the same
292        assert_eq!(result1.name, result2.name);
293    }
294
295    #[test]
296    fn test_cache_marginalize() {
297        let cache = Arc::new(FactorCache::new(100));
298        let f = CachedFactor::new(create_test_factor("f"), cache.clone());
299
300        // First call - cache miss
301        let _result1 = f.marginalize_out_cached("Y").unwrap();
302        let stats1 = cache.stats();
303        assert_eq!(stats1.misses, 1);
304
305        // Second call - cache hit
306        let _result2 = f.marginalize_out_cached("Y").unwrap();
307        let stats2 = cache.stats();
308        assert_eq!(stats2.hits, 1);
309    }
310
311    #[test]
312    fn test_cache_stats() {
313        let cache = FactorCache::new(100);
314        let stats = cache.stats();
315        assert_eq!(stats.hits, 0);
316        assert_eq!(stats.misses, 0);
317        assert_eq!(stats.hit_rate, 0.0);
318    }
319
320    #[test]
321    fn test_cache_clear() {
322        let cache = Arc::new(FactorCache::new(100));
323        let f = CachedFactor::new(create_test_factor("f"), cache.clone());
324
325        // Populate cache
326        let _ = f.marginalize_out_cached("Y").unwrap();
327        assert_eq!(cache.size(), 1);
328
329        // Clear cache
330        cache.clear();
331        assert_eq!(cache.size(), 0);
332        assert_eq!(cache.stats().hits, 0);
333        assert_eq!(cache.stats().misses, 0);
334    }
335
336    #[test]
337    fn test_cache_eviction() {
338        let cache = Arc::new(FactorCache::new(2));
339
340        // Add 3 items (should evict oldest)
341        cache.put_marginalize("f1", "X", create_test_factor("result1"));
342        cache.put_marginalize("f2", "Y", create_test_factor("result2"));
343        cache.put_marginalize("f3", "Z", create_test_factor("result3"));
344
345        // Size should be at most 2
346        assert!(cache.size() <= 2);
347    }
348}