rust_rule_engine/rete/
memoization.rs

1//! Memoization support for RETE-UL evaluation
2//!
3//! This module provides caching mechanisms to avoid re-evaluating the same
4//! node with the same facts multiple times, significantly improving performance
5//! for complex rule networks.
6
7use std::collections::HashMap;
8use std::hash::{Hash, Hasher};
9use std::collections::hash_map::DefaultHasher;
10use super::network::ReteUlNode;
11use super::facts::TypedFacts;
12
13/// Compute hash for facts (for memoization key)
14fn compute_facts_hash(facts: &TypedFacts) -> u64 {
15    let mut hasher = DefaultHasher::new();
16    let mut sorted_facts: Vec<_> = facts.get_all().iter().collect();
17    sorted_facts.sort_by_key(|(k, _)| *k);
18
19    for (key, value) in sorted_facts {
20        key.hash(&mut hasher);
21        value.as_string().hash(&mut hasher);
22    }
23
24    hasher.finish()
25}
26
27/// Compute hash for a node (for memoization key)
28fn compute_node_hash(node: &ReteUlNode) -> u64 {
29    let mut hasher = DefaultHasher::new();
30    // Simple hash based on node type and structure
31    format!("{:?}", node).hash(&mut hasher);
32    hasher.finish()
33}
34
35/// Memoization cache for RETE-UL evaluation
36pub struct MemoizedEvaluator {
37    cache: HashMap<(u64, u64), bool>,
38    hits: usize,
39    misses: usize,
40}
41
42impl MemoizedEvaluator {
43    /// Create new memoized evaluator
44    pub fn new() -> Self {
45        Self {
46            cache: HashMap::new(),
47            hits: 0,
48            misses: 0,
49        }
50    }
51
52    /// Evaluate node with memoization
53    pub fn evaluate(
54        &mut self,
55        node: &ReteUlNode,
56        facts: &TypedFacts,
57        eval_fn: impl FnOnce(&ReteUlNode, &TypedFacts) -> bool,
58    ) -> bool {
59        let node_hash = compute_node_hash(node);
60        let facts_hash = compute_facts_hash(facts);
61        let key = (node_hash, facts_hash);
62
63        if let Some(&result) = self.cache.get(&key) {
64            self.hits += 1;
65            return result;
66        }
67
68        self.misses += 1;
69        let result = eval_fn(node, facts);
70        self.cache.insert(key, result);
71        result
72    }
73
74    /// Get cache statistics
75    pub fn stats(&self) -> MemoStats {
76        MemoStats {
77            cache_size: self.cache.len(),
78            hits: self.hits,
79            misses: self.misses,
80            hit_rate: if self.hits + self.misses > 0 {
81                self.hits as f64 / (self.hits + self.misses) as f64
82            } else {
83                0.0
84            },
85        }
86    }
87
88    /// Clear the cache
89    pub fn clear(&mut self) {
90        self.cache.clear();
91        self.hits = 0;
92        self.misses = 0;
93    }
94
95    /// Get cache size
96    pub fn cache_size(&self) -> usize {
97        self.cache.len()
98    }
99}
100
101impl Default for MemoizedEvaluator {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107/// Memoization statistics
108#[derive(Debug, Clone, Copy)]
109pub struct MemoStats {
110    pub cache_size: usize,
111    pub hits: usize,
112    pub misses: usize,
113    pub hit_rate: f64,
114}
115
116impl std::fmt::Display for MemoStats {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        write!(
119            f,
120            "Memo Stats: {} entries, {} hits, {} misses, {:.2}% hit rate",
121            self.cache_size,
122            self.hits,
123            self.misses,
124            self.hit_rate * 100.0
125        )
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use crate::rete::alpha::AlphaNode;
133    use crate::rete::network::ReteUlNode;
134    use crate::rete::facts::{TypedFacts, FactValue};
135
136    #[test]
137    fn test_memoization() {
138        let mut evaluator = MemoizedEvaluator::new();
139        let mut facts = TypedFacts::new();
140        facts.set("age", 25i64);
141
142        let node = ReteUlNode::UlAlpha(AlphaNode {
143            field: "age".to_string(),
144            operator: ">".to_string(),
145            value: "18".to_string(),
146        });
147
148        // First evaluation - cache miss
149        let mut eval_count = 0;
150        let result1 = evaluator.evaluate(&node, &facts, |n, f| {
151            eval_count += 1;
152            n.evaluate_typed(f)
153        });
154        assert!(result1);
155        assert_eq!(eval_count, 1);
156
157        // Second evaluation - cache hit
158        let result2 = evaluator.evaluate(&node, &facts, |n, f| {
159            eval_count += 1;
160            n.evaluate_typed(f)
161        });
162        assert!(result2);
163        assert_eq!(eval_count, 1); // Should not re-evaluate!
164
165        let stats = evaluator.stats();
166        assert_eq!(stats.hits, 1);
167        assert_eq!(stats.misses, 1);
168        assert_eq!(stats.hit_rate, 0.5);
169    }
170}