rust_rule_engine/rete/
memoization.rs1use std::collections::HashMap;
8use std::hash::{Hash, Hasher};
9use std::collections::hash_map::DefaultHasher;
10use super::network::ReteUlNode;
11use super::facts::TypedFacts;
12
13fn compute_facts_hash(facts: &TypedFacts) -> u64 {
15 let mut hasher = DefaultHasher::new();
16 let mut sorted_facts: Vec<_> = facts.get_all().iter().collect();
17 sorted_facts.sort_by_key(|(k, _)| *k);
18
19 for (key, value) in sorted_facts {
20 key.hash(&mut hasher);
21 value.as_string().hash(&mut hasher);
22 }
23
24 hasher.finish()
25}
26
27fn compute_node_hash(node: &ReteUlNode) -> u64 {
29 let mut hasher = DefaultHasher::new();
30 format!("{:?}", node).hash(&mut hasher);
32 hasher.finish()
33}
34
35pub struct MemoizedEvaluator {
37 cache: HashMap<(u64, u64), bool>,
38 hits: usize,
39 misses: usize,
40}
41
42impl MemoizedEvaluator {
43 pub fn new() -> Self {
45 Self {
46 cache: HashMap::new(),
47 hits: 0,
48 misses: 0,
49 }
50 }
51
52 pub fn evaluate(
54 &mut self,
55 node: &ReteUlNode,
56 facts: &TypedFacts,
57 eval_fn: impl FnOnce(&ReteUlNode, &TypedFacts) -> bool,
58 ) -> bool {
59 let node_hash = compute_node_hash(node);
60 let facts_hash = compute_facts_hash(facts);
61 let key = (node_hash, facts_hash);
62
63 if let Some(&result) = self.cache.get(&key) {
64 self.hits += 1;
65 return result;
66 }
67
68 self.misses += 1;
69 let result = eval_fn(node, facts);
70 self.cache.insert(key, result);
71 result
72 }
73
74 pub fn stats(&self) -> MemoStats {
76 MemoStats {
77 cache_size: self.cache.len(),
78 hits: self.hits,
79 misses: self.misses,
80 hit_rate: if self.hits + self.misses > 0 {
81 self.hits as f64 / (self.hits + self.misses) as f64
82 } else {
83 0.0
84 },
85 }
86 }
87
88 pub fn clear(&mut self) {
90 self.cache.clear();
91 self.hits = 0;
92 self.misses = 0;
93 }
94
95 pub fn cache_size(&self) -> usize {
97 self.cache.len()
98 }
99}
100
101impl Default for MemoizedEvaluator {
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107#[derive(Debug, Clone, Copy)]
109pub struct MemoStats {
110 pub cache_size: usize,
111 pub hits: usize,
112 pub misses: usize,
113 pub hit_rate: f64,
114}
115
116impl std::fmt::Display for MemoStats {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 write!(
119 f,
120 "Memo Stats: {} entries, {} hits, {} misses, {:.2}% hit rate",
121 self.cache_size,
122 self.hits,
123 self.misses,
124 self.hit_rate * 100.0
125 )
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use crate::rete::alpha::AlphaNode;
133 use crate::rete::network::ReteUlNode;
134 use crate::rete::facts::{TypedFacts, FactValue};
135
136 #[test]
137 fn test_memoization() {
138 let mut evaluator = MemoizedEvaluator::new();
139 let mut facts = TypedFacts::new();
140 facts.set("age", 25i64);
141
142 let node = ReteUlNode::UlAlpha(AlphaNode {
143 field: "age".to_string(),
144 operator: ">".to_string(),
145 value: "18".to_string(),
146 });
147
148 let mut eval_count = 0;
150 let result1 = evaluator.evaluate(&node, &facts, |n, f| {
151 eval_count += 1;
152 n.evaluate_typed(f)
153 });
154 assert!(result1);
155 assert_eq!(eval_count, 1);
156
157 let result2 = evaluator.evaluate(&node, &facts, |n, f| {
159 eval_count += 1;
160 n.evaluate_typed(f)
161 });
162 assert!(result2);
163 assert_eq!(eval_count, 1); let stats = evaluator.stats();
166 assert_eq!(stats.hits, 1);
167 assert_eq!(stats.misses, 1);
168 assert_eq!(stats.hit_rate, 0.5);
169 }
170}