tensorlogic_quantrs_hooks/
cache.rs1use crate::error::Result;
8use crate::factor::Factor;
9use std::collections::HashMap;
10use std::hash::Hash;
11use std::sync::{Arc, Mutex};
12
13#[derive(Clone, Debug, Eq, PartialEq, Hash)]
15enum CacheKey {
16 Product(String, String),
18 Marginalize(String, String),
20 Divide(String, String),
22 Reduce(String, String, usize),
24}
25
26pub struct FactorCache {
31 cache: Arc<Mutex<HashMap<CacheKey, Factor>>>,
33 max_size: usize,
35 hits: Arc<Mutex<usize>>,
37 misses: Arc<Mutex<usize>>,
38}
39
40impl Default for FactorCache {
41 fn default() -> Self {
42 Self::new(1000)
43 }
44}
45
46impl FactorCache {
47 pub fn new(max_size: usize) -> Self {
49 Self {
50 cache: Arc::new(Mutex::new(HashMap::new())),
51 max_size,
52 hits: Arc::new(Mutex::new(0)),
53 misses: Arc::new(Mutex::new(0)),
54 }
55 }
56
57 pub fn get_product(&self, f1_name: &str, f2_name: &str) -> Option<Factor> {
59 let key = CacheKey::Product(f1_name.to_string(), f2_name.to_string());
60 self.get(&key)
61 }
62
63 pub fn put_product(&self, f1_name: &str, f2_name: &str, result: Factor) {
65 let key = CacheKey::Product(f1_name.to_string(), f2_name.to_string());
66 self.put(key, result);
67 }
68
69 pub fn get_marginalize(&self, factor_name: &str, var: &str) -> Option<Factor> {
71 let key = CacheKey::Marginalize(factor_name.to_string(), var.to_string());
72 self.get(&key)
73 }
74
75 pub fn put_marginalize(&self, factor_name: &str, var: &str, result: Factor) {
77 let key = CacheKey::Marginalize(factor_name.to_string(), var.to_string());
78 self.put(key, result);
79 }
80
81 pub fn get_divide(&self, f1_name: &str, f2_name: &str) -> Option<Factor> {
83 let key = CacheKey::Divide(f1_name.to_string(), f2_name.to_string());
84 self.get(&key)
85 }
86
87 pub fn put_divide(&self, f1_name: &str, f2_name: &str, result: Factor) {
89 let key = CacheKey::Divide(f1_name.to_string(), f2_name.to_string());
90 self.put(key, result);
91 }
92
93 pub fn get_reduce(&self, factor_name: &str, var: &str, value: usize) -> Option<Factor> {
95 let key = CacheKey::Reduce(factor_name.to_string(), var.to_string(), value);
96 self.get(&key)
97 }
98
99 pub fn put_reduce(&self, factor_name: &str, var: &str, value: usize, result: Factor) {
101 let key = CacheKey::Reduce(factor_name.to_string(), var.to_string(), value);
102 self.put(key, result);
103 }
104
105 fn get(&self, key: &CacheKey) -> Option<Factor> {
107 let cache = self.cache.lock().expect("lock should not be poisoned");
108 if let Some(factor) = cache.get(key) {
109 *self.hits.lock().expect("lock should not be poisoned") += 1;
110 Some(factor.clone())
111 } else {
112 *self.misses.lock().expect("lock should not be poisoned") += 1;
113 None
114 }
115 }
116
117 fn put(&self, key: CacheKey, factor: Factor) {
119 let mut cache = self.cache.lock().expect("lock should not be poisoned");
120
121 if cache.len() >= self.max_size {
123 if let Some(first_key) = cache.keys().next().cloned() {
124 cache.remove(&first_key);
125 }
126 }
127
128 cache.insert(key, factor);
129 }
130
131 pub fn clear(&self) {
133 self.cache
134 .lock()
135 .expect("lock should not be poisoned")
136 .clear();
137 *self.hits.lock().expect("lock should not be poisoned") = 0;
138 *self.misses.lock().expect("lock should not be poisoned") = 0;
139 }
140
141 pub fn stats(&self) -> CacheStats {
143 let hits = *self.hits.lock().expect("lock should not be poisoned");
144 let misses = *self.misses.lock().expect("lock should not be poisoned");
145 let size = self
146 .cache
147 .lock()
148 .expect("lock should not be poisoned")
149 .len();
150
151 CacheStats {
152 hits,
153 misses,
154 size,
155 hit_rate: if hits + misses > 0 {
156 hits as f64 / (hits + misses) as f64
157 } else {
158 0.0
159 },
160 }
161 }
162
163 pub fn size(&self) -> usize {
165 self.cache
166 .lock()
167 .expect("lock should not be poisoned")
168 .len()
169 }
170}
171
172#[derive(Debug, Clone)]
174pub struct CacheStats {
175 pub hits: usize,
177 pub misses: usize,
179 pub size: usize,
181 pub hit_rate: f64,
183}
184
185pub struct CachedFactor {
189 pub factor: Factor,
191 cache: Arc<FactorCache>,
193}
194
195impl CachedFactor {
196 pub fn new(factor: Factor, cache: Arc<FactorCache>) -> Self {
198 Self { factor, cache }
199 }
200
201 pub fn product_cached(&self, other: &CachedFactor) -> Result<Factor> {
203 if let Some(cached) = self
205 .cache
206 .get_product(&self.factor.name, &other.factor.name)
207 {
208 return Ok(cached);
209 }
210
211 let result = self.factor.product(&other.factor)?;
213 self.cache
214 .put_product(&self.factor.name, &other.factor.name, result.clone());
215
216 Ok(result)
217 }
218
219 pub fn marginalize_out_cached(&self, var: &str) -> Result<Factor> {
221 if let Some(cached) = self.cache.get_marginalize(&self.factor.name, var) {
223 return Ok(cached);
224 }
225
226 let result = self.factor.marginalize_out(var)?;
228 self.cache
229 .put_marginalize(&self.factor.name, var, result.clone());
230
231 Ok(result)
232 }
233
234 pub fn divide_cached(&self, other: &CachedFactor) -> Result<Factor> {
236 if let Some(cached) = self.cache.get_divide(&self.factor.name, &other.factor.name) {
238 return Ok(cached);
239 }
240
241 let result = self.factor.divide(&other.factor)?;
243 self.cache
244 .put_divide(&self.factor.name, &other.factor.name, result.clone());
245
246 Ok(result)
247 }
248
249 pub fn reduce_cached(&self, var: &str, value: usize) -> Result<Factor> {
251 if let Some(cached) = self.cache.get_reduce(&self.factor.name, var, value) {
253 return Ok(cached);
254 }
255
256 let result = self.factor.reduce(var, value)?;
258 self.cache
259 .put_reduce(&self.factor.name, var, value, result.clone());
260
261 Ok(result)
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use scirs2_core::ndarray::Array;
269
270 fn create_test_factor(name: &str) -> Factor {
271 let values = vec![0.1, 0.2, 0.3, 0.4];
272 let array = Array::from_shape_vec(vec![2, 2], values)
273 .expect("unwrap")
274 .into_dyn();
275 Factor::new(
276 name.to_string(),
277 vec!["X".to_string(), "Y".to_string()],
278 array,
279 )
280 .expect("unwrap")
281 }
282
283 #[test]
284 fn test_cache_product() {
285 let cache = Arc::new(FactorCache::new(100));
286 let f1 = CachedFactor::new(create_test_factor("f1"), cache.clone());
287 let f2 = CachedFactor::new(create_test_factor("f2"), cache.clone());
288
289 let result1 = f1.product_cached(&f2).expect("unwrap");
291 let stats1 = cache.stats();
292 assert_eq!(stats1.misses, 1);
293 assert_eq!(stats1.hits, 0);
294
295 let result2 = f1.product_cached(&f2).expect("unwrap");
297 let stats2 = cache.stats();
298 assert_eq!(stats2.misses, 1);
299 assert_eq!(stats2.hits, 1);
300
301 assert_eq!(result1.name, result2.name);
303 }
304
305 #[test]
306 fn test_cache_marginalize() {
307 let cache = Arc::new(FactorCache::new(100));
308 let f = CachedFactor::new(create_test_factor("f"), cache.clone());
309
310 let _result1 = f.marginalize_out_cached("Y").expect("unwrap");
312 let stats1 = cache.stats();
313 assert_eq!(stats1.misses, 1);
314
315 let _result2 = f.marginalize_out_cached("Y").expect("unwrap");
317 let stats2 = cache.stats();
318 assert_eq!(stats2.hits, 1);
319 }
320
321 #[test]
322 fn test_cache_stats() {
323 let cache = FactorCache::new(100);
324 let stats = cache.stats();
325 assert_eq!(stats.hits, 0);
326 assert_eq!(stats.misses, 0);
327 assert_eq!(stats.hit_rate, 0.0);
328 }
329
330 #[test]
331 fn test_cache_clear() {
332 let cache = Arc::new(FactorCache::new(100));
333 let f = CachedFactor::new(create_test_factor("f"), cache.clone());
334
335 let _ = f.marginalize_out_cached("Y").expect("unwrap");
337 assert_eq!(cache.size(), 1);
338
339 cache.clear();
341 assert_eq!(cache.size(), 0);
342 assert_eq!(cache.stats().hits, 0);
343 assert_eq!(cache.stats().misses, 0);
344 }
345
346 #[test]
347 fn test_cache_eviction() {
348 let cache = Arc::new(FactorCache::new(2));
349
350 cache.put_marginalize("f1", "X", create_test_factor("result1"));
352 cache.put_marginalize("f2", "Y", create_test_factor("result2"));
353 cache.put_marginalize("f3", "Z", create_test_factor("result3"));
354
355 assert!(cache.size() <= 2);
357 }
358}