tensorlogic_quantrs_hooks/
cache.rs1use crate::error::Result;
8use crate::factor::Factor;
9use std::collections::HashMap;
10use std::hash::Hash;
11use std::sync::{Arc, Mutex};
12
13#[derive(Clone, Debug, Eq, PartialEq, Hash)]
15enum CacheKey {
16 Product(String, String),
18 Marginalize(String, String),
20 Divide(String, String),
22 Reduce(String, String, usize),
24}
25
26pub struct FactorCache {
31 cache: Arc<Mutex<HashMap<CacheKey, Factor>>>,
33 max_size: usize,
35 hits: Arc<Mutex<usize>>,
37 misses: Arc<Mutex<usize>>,
38}
39
40impl Default for FactorCache {
41 fn default() -> Self {
42 Self::new(1000)
43 }
44}
45
46impl FactorCache {
47 pub fn new(max_size: usize) -> Self {
49 Self {
50 cache: Arc::new(Mutex::new(HashMap::new())),
51 max_size,
52 hits: Arc::new(Mutex::new(0)),
53 misses: Arc::new(Mutex::new(0)),
54 }
55 }
56
57 pub fn get_product(&self, f1_name: &str, f2_name: &str) -> Option<Factor> {
59 let key = CacheKey::Product(f1_name.to_string(), f2_name.to_string());
60 self.get(&key)
61 }
62
63 pub fn put_product(&self, f1_name: &str, f2_name: &str, result: Factor) {
65 let key = CacheKey::Product(f1_name.to_string(), f2_name.to_string());
66 self.put(key, result);
67 }
68
69 pub fn get_marginalize(&self, factor_name: &str, var: &str) -> Option<Factor> {
71 let key = CacheKey::Marginalize(factor_name.to_string(), var.to_string());
72 self.get(&key)
73 }
74
75 pub fn put_marginalize(&self, factor_name: &str, var: &str, result: Factor) {
77 let key = CacheKey::Marginalize(factor_name.to_string(), var.to_string());
78 self.put(key, result);
79 }
80
81 pub fn get_divide(&self, f1_name: &str, f2_name: &str) -> Option<Factor> {
83 let key = CacheKey::Divide(f1_name.to_string(), f2_name.to_string());
84 self.get(&key)
85 }
86
87 pub fn put_divide(&self, f1_name: &str, f2_name: &str, result: Factor) {
89 let key = CacheKey::Divide(f1_name.to_string(), f2_name.to_string());
90 self.put(key, result);
91 }
92
93 pub fn get_reduce(&self, factor_name: &str, var: &str, value: usize) -> Option<Factor> {
95 let key = CacheKey::Reduce(factor_name.to_string(), var.to_string(), value);
96 self.get(&key)
97 }
98
99 pub fn put_reduce(&self, factor_name: &str, var: &str, value: usize, result: Factor) {
101 let key = CacheKey::Reduce(factor_name.to_string(), var.to_string(), value);
102 self.put(key, result);
103 }
104
105 fn get(&self, key: &CacheKey) -> Option<Factor> {
107 let cache = self.cache.lock().unwrap();
108 if let Some(factor) = cache.get(key) {
109 *self.hits.lock().unwrap() += 1;
110 Some(factor.clone())
111 } else {
112 *self.misses.lock().unwrap() += 1;
113 None
114 }
115 }
116
117 fn put(&self, key: CacheKey, factor: Factor) {
119 let mut cache = self.cache.lock().unwrap();
120
121 if cache.len() >= self.max_size {
123 if let Some(first_key) = cache.keys().next().cloned() {
124 cache.remove(&first_key);
125 }
126 }
127
128 cache.insert(key, factor);
129 }
130
131 pub fn clear(&self) {
133 self.cache.lock().unwrap().clear();
134 *self.hits.lock().unwrap() = 0;
135 *self.misses.lock().unwrap() = 0;
136 }
137
138 pub fn stats(&self) -> CacheStats {
140 let hits = *self.hits.lock().unwrap();
141 let misses = *self.misses.lock().unwrap();
142 let size = self.cache.lock().unwrap().len();
143
144 CacheStats {
145 hits,
146 misses,
147 size,
148 hit_rate: if hits + misses > 0 {
149 hits as f64 / (hits + misses) as f64
150 } else {
151 0.0
152 },
153 }
154 }
155
156 pub fn size(&self) -> usize {
158 self.cache.lock().unwrap().len()
159 }
160}
161
162#[derive(Debug, Clone)]
164pub struct CacheStats {
165 pub hits: usize,
167 pub misses: usize,
169 pub size: usize,
171 pub hit_rate: f64,
173}
174
175pub struct CachedFactor {
179 pub factor: Factor,
181 cache: Arc<FactorCache>,
183}
184
185impl CachedFactor {
186 pub fn new(factor: Factor, cache: Arc<FactorCache>) -> Self {
188 Self { factor, cache }
189 }
190
191 pub fn product_cached(&self, other: &CachedFactor) -> Result<Factor> {
193 if let Some(cached) = self
195 .cache
196 .get_product(&self.factor.name, &other.factor.name)
197 {
198 return Ok(cached);
199 }
200
201 let result = self.factor.product(&other.factor)?;
203 self.cache
204 .put_product(&self.factor.name, &other.factor.name, result.clone());
205
206 Ok(result)
207 }
208
209 pub fn marginalize_out_cached(&self, var: &str) -> Result<Factor> {
211 if let Some(cached) = self.cache.get_marginalize(&self.factor.name, var) {
213 return Ok(cached);
214 }
215
216 let result = self.factor.marginalize_out(var)?;
218 self.cache
219 .put_marginalize(&self.factor.name, var, result.clone());
220
221 Ok(result)
222 }
223
224 pub fn divide_cached(&self, other: &CachedFactor) -> Result<Factor> {
226 if let Some(cached) = self.cache.get_divide(&self.factor.name, &other.factor.name) {
228 return Ok(cached);
229 }
230
231 let result = self.factor.divide(&other.factor)?;
233 self.cache
234 .put_divide(&self.factor.name, &other.factor.name, result.clone());
235
236 Ok(result)
237 }
238
239 pub fn reduce_cached(&self, var: &str, value: usize) -> Result<Factor> {
241 if let Some(cached) = self.cache.get_reduce(&self.factor.name, var, value) {
243 return Ok(cached);
244 }
245
246 let result = self.factor.reduce(var, value)?;
248 self.cache
249 .put_reduce(&self.factor.name, var, value, result.clone());
250
251 Ok(result)
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use scirs2_core::ndarray::Array;
259
260 fn create_test_factor(name: &str) -> Factor {
261 let values = vec![0.1, 0.2, 0.3, 0.4];
262 let array = Array::from_shape_vec(vec![2, 2], values)
263 .unwrap()
264 .into_dyn();
265 Factor::new(
266 name.to_string(),
267 vec!["X".to_string(), "Y".to_string()],
268 array,
269 )
270 .unwrap()
271 }
272
273 #[test]
274 fn test_cache_product() {
275 let cache = Arc::new(FactorCache::new(100));
276 let f1 = CachedFactor::new(create_test_factor("f1"), cache.clone());
277 let f2 = CachedFactor::new(create_test_factor("f2"), cache.clone());
278
279 let result1 = f1.product_cached(&f2).unwrap();
281 let stats1 = cache.stats();
282 assert_eq!(stats1.misses, 1);
283 assert_eq!(stats1.hits, 0);
284
285 let result2 = f1.product_cached(&f2).unwrap();
287 let stats2 = cache.stats();
288 assert_eq!(stats2.misses, 1);
289 assert_eq!(stats2.hits, 1);
290
291 assert_eq!(result1.name, result2.name);
293 }
294
295 #[test]
296 fn test_cache_marginalize() {
297 let cache = Arc::new(FactorCache::new(100));
298 let f = CachedFactor::new(create_test_factor("f"), cache.clone());
299
300 let _result1 = f.marginalize_out_cached("Y").unwrap();
302 let stats1 = cache.stats();
303 assert_eq!(stats1.misses, 1);
304
305 let _result2 = f.marginalize_out_cached("Y").unwrap();
307 let stats2 = cache.stats();
308 assert_eq!(stats2.hits, 1);
309 }
310
311 #[test]
312 fn test_cache_stats() {
313 let cache = FactorCache::new(100);
314 let stats = cache.stats();
315 assert_eq!(stats.hits, 0);
316 assert_eq!(stats.misses, 0);
317 assert_eq!(stats.hit_rate, 0.0);
318 }
319
320 #[test]
321 fn test_cache_clear() {
322 let cache = Arc::new(FactorCache::new(100));
323 let f = CachedFactor::new(create_test_factor("f"), cache.clone());
324
325 let _ = f.marginalize_out_cached("Y").unwrap();
327 assert_eq!(cache.size(), 1);
328
329 cache.clear();
331 assert_eq!(cache.size(), 0);
332 assert_eq!(cache.stats().hits, 0);
333 assert_eq!(cache.stats().misses, 0);
334 }
335
336 #[test]
337 fn test_cache_eviction() {
338 let cache = Arc::new(FactorCache::new(2));
339
340 cache.put_marginalize("f1", "X", create_test_factor("result1"));
342 cache.put_marginalize("f2", "Y", create_test_factor("result2"));
343 cache.put_marginalize("f3", "Z", create_test_factor("result3"));
344
345 assert!(cache.size() <= 2);
347 }
348}