rust_rule_engine/backward/
conclusion_index.rs

1//! Conclusion Index for efficient rule lookup in backward chaining
2//!
3//! This module provides O(1) lookup for finding rules that can prove a goal,
4//! replacing the naive O(n) linear scan through all rules.
5//!
6//! The index maps from conclusion patterns (facts that rules derive) to
7//! the set of rule names that can derive them.
8
9use crate::engine::rule::Rule;
10use crate::types::ActionType;
11use std::collections::{HashMap, HashSet};
12
13/// Index for fast lookup of rules by their conclusions
14///
15/// This is similar to RETE's beta memory but specialized for backward chaining.
16/// Instead of matching conditions, we match rule conclusions (actions) to goals.
17#[derive(Debug, Clone)]
18pub struct ConclusionIndex {
19    /// Maps field patterns to rules that can derive them
20    /// Example: "User.IsVIP" -> ["DetermineVIP", "PromoteToVIP"]
21    field_to_rules: HashMap<String, HashSet<String>>,
22
23    /// Maps rule names to their conclusions (for updates/removals)
24    rule_to_conclusions: HashMap<String, HashSet<String>>,
25
26    /// Total number of indexed rules
27    rule_count: usize,
28}
29
30impl ConclusionIndex {
31    /// Create a new empty conclusion index
32    pub fn new() -> Self {
33        Self {
34            field_to_rules: HashMap::new(),
35            rule_to_conclusions: HashMap::new(),
36            rule_count: 0,
37        }
38    }
39
40    /// Build index from a collection of rules
41    pub fn from_rules(rules: &[Rule]) -> Self {
42        let mut index = Self::new();
43        for rule in rules {
44            index.add_rule(rule);
45        }
46        index
47    }
48
49    /// Add a rule to the index
50    pub fn add_rule(&mut self, rule: &Rule) {
51        if !rule.enabled {
52            return; // Don't index disabled rules
53        }
54
55        let conclusions = self.extract_conclusions(rule);
56
57        if conclusions.is_empty() {
58            return; // No indexable conclusions
59        }
60
61        // Add bidirectional mappings
62        for conclusion in &conclusions {
63            self.field_to_rules
64                .entry(conclusion.clone())
65                .or_insert_with(HashSet::new)
66                .insert(rule.name.clone());
67        }
68
69        self.rule_to_conclusions.insert(rule.name.clone(), conclusions);
70        self.rule_count += 1;
71    }
72
73    /// Remove a rule from the index
74    pub fn remove_rule(&mut self, rule_name: &str) {
75        if let Some(conclusions) = self.rule_to_conclusions.remove(rule_name) {
76            for conclusion in conclusions {
77                if let Some(rules) = self.field_to_rules.get_mut(&conclusion) {
78                    rules.remove(rule_name);
79                    if rules.is_empty() {
80                        self.field_to_rules.remove(&conclusion);
81                    }
82                }
83            }
84            self.rule_count -= 1;
85        }
86    }
87
88    /// Find candidate rules that could prove a goal
89    ///
90    /// This is the O(1) lookup that replaces O(n) iteration.
91    ///
92    /// # Arguments
93    /// * `goal_pattern` - The goal pattern to prove (e.g., "User.IsVIP == true")
94    ///
95    /// # Returns
96    /// Set of rule names that might be able to derive this goal
97    pub fn find_candidates(&self, goal_pattern: &str) -> HashSet<String> {
98        let mut candidates = HashSet::new();
99
100        // Extract field name from goal pattern
101        // Examples:
102        //   "User.IsVIP == true" -> "User.IsVIP"
103        //   "Order.AutoApproved" -> "Order.AutoApproved"
104        //   "Customer.Status == 'VIP'" -> "Customer.Status"
105        let field = self.extract_field_from_goal(goal_pattern);
106
107        // Direct field match
108        if let Some(rules) = self.field_to_rules.get(field) {
109            candidates.extend(rules.iter().cloned());
110        }
111
112        // Check parent objects for partial matches
113        // Example: "User.IsVIP" also matches rules that set "User.*"
114        if let Some(dot_pos) = field.rfind('.') {
115            let object = &field[..dot_pos];
116
117            // Find all rules that modify any field of this object
118            for (indexed_field, rules) in &self.field_to_rules {
119                if indexed_field.starts_with(object) {
120                    candidates.extend(rules.iter().cloned());
121                }
122            }
123        }
124
125        candidates
126    }
127
128    /// Extract field name from goal pattern
129    fn extract_field_from_goal<'a>(&self, goal_pattern: &'a str) -> &'a str {
130        // Handle comparison operators
131        for op in &["==", "!=", ">=", "<=", ">", "<", " contains ", " matches "] {
132            if let Some(pos) = goal_pattern.find(op) {
133                return goal_pattern[..pos].trim();
134            }
135        }
136
137        // No operator found, return whole pattern
138        goal_pattern.trim()
139    }
140
141    /// Extract all conclusions (facts derived) from a rule
142    fn extract_conclusions(&self, rule: &Rule) -> HashSet<String> {
143        let mut conclusions = HashSet::new();
144
145        for action in &rule.actions {
146            match action {
147                ActionType::Set { field, .. } => {
148                    conclusions.insert(field.clone());
149                }
150                ActionType::MethodCall { object, method, .. } => {
151                    // Method calls might modify object state
152                    conclusions.insert(format!("{}.{}", object, method));
153                    // Also index the object itself
154                    conclusions.insert(object.clone());
155                }
156                ActionType::Retract { object } => {
157                    conclusions.insert(object.clone());
158                }
159                ActionType::SetWorkflowData { key, .. } => {
160                    conclusions.insert(key.clone());
161                }
162                // Log, Custom, ScheduleRule don't directly derive facts
163                _ => {}
164            }
165        }
166
167        conclusions
168    }
169
170    /// Get statistics about the index
171    pub fn stats(&self) -> IndexStats {
172        IndexStats {
173            total_rules: self.rule_count,
174            indexed_fields: self.field_to_rules.len(),
175            avg_rules_per_field: if self.field_to_rules.is_empty() {
176                0.0
177            } else {
178                self.field_to_rules.values().map(|s| s.len()).sum::<usize>() as f64
179                    / self.field_to_rules.len() as f64
180            },
181        }
182    }
183
184    /// Clear the index
185    pub fn clear(&mut self) {
186        self.field_to_rules.clear();
187        self.rule_to_conclusions.clear();
188        self.rule_count = 0;
189    }
190
191    /// Check if index is empty
192    pub fn is_empty(&self) -> bool {
193        self.rule_count == 0
194    }
195}
196
197impl Default for ConclusionIndex {
198    fn default() -> Self {
199        Self::new()
200    }
201}
202
203/// Statistics about the conclusion index
204#[derive(Debug, Clone)]
205pub struct IndexStats {
206    /// Total number of indexed rules
207    pub total_rules: usize,
208    /// Number of unique fields indexed
209    pub indexed_fields: usize,
210    /// Average number of rules per field
211    pub avg_rules_per_field: f64,
212}
213
214impl std::fmt::Display for IndexStats {
215    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216        writeln!(f, "Conclusion Index Statistics:")?;
217        writeln!(f, "  Total Rules: {}", self.total_rules)?;
218        writeln!(f, "  Indexed Fields: {}", self.indexed_fields)?;
219        writeln!(f, "  Avg Rules/Field: {:.2}", self.avg_rules_per_field)?;
220        Ok(())
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use crate::engine::rule::{Condition, ConditionGroup, Rule};
228    use crate::types::{Operator, Value};
229
230    fn create_test_rule(name: &str, set_field: &str) -> Rule {
231        let conditions = ConditionGroup::Single(Condition::new(
232            "dummy".to_string(),
233            Operator::Equal,
234            Value::Boolean(true),
235        ));
236        let actions = vec![ActionType::Set {
237            field: set_field.to_string(),
238            value: Value::Boolean(true),
239        }];
240        Rule::new(name.to_string(), conditions, actions)
241    }
242
243    #[test]
244    fn test_index_creation() {
245        let index = ConclusionIndex::new();
246        assert!(index.is_empty());
247        assert_eq!(index.rule_count, 0);
248    }
249
250    #[test]
251    fn test_add_single_rule() {
252        let mut index = ConclusionIndex::new();
253        let rule = create_test_rule("TestRule", "User.IsVIP");
254
255        index.add_rule(&rule);
256
257        assert_eq!(index.rule_count, 1);
258        assert_eq!(index.field_to_rules.len(), 1);
259        assert!(index.field_to_rules.contains_key("User.IsVIP"));
260    }
261
262    #[test]
263    fn test_find_candidates_exact_match() {
264        let mut index = ConclusionIndex::new();
265        let rule = create_test_rule("DetermineVIP", "User.IsVIP");
266        index.add_rule(&rule);
267
268        let candidates = index.find_candidates("User.IsVIP == true");
269
270        assert_eq!(candidates.len(), 1);
271        assert!(candidates.contains("DetermineVIP"));
272    }
273
274    #[test]
275    fn test_find_candidates_multiple_rules() {
276        let mut index = ConclusionIndex::new();
277        index.add_rule(&create_test_rule("Rule1", "User.IsVIP"));
278        index.add_rule(&create_test_rule("Rule2", "User.IsVIP"));
279        index.add_rule(&create_test_rule("Rule3", "Order.Status"));
280
281        let candidates = index.find_candidates("User.IsVIP == true");
282
283        assert_eq!(candidates.len(), 2);
284        assert!(candidates.contains("Rule1"));
285        assert!(candidates.contains("Rule2"));
286        assert!(!candidates.contains("Rule3"));
287    }
288
289    #[test]
290    fn test_remove_rule() {
291        let mut index = ConclusionIndex::new();
292        let rule = create_test_rule("TestRule", "User.IsVIP");
293        index.add_rule(&rule);
294
295        assert_eq!(index.rule_count, 1);
296
297        index.remove_rule("TestRule");
298
299        assert_eq!(index.rule_count, 0);
300        assert!(index.is_empty());
301        assert!(index.field_to_rules.is_empty());
302    }
303
304    #[test]
305    fn test_extract_field_from_goal() {
306        let index = ConclusionIndex::new();
307
308        assert_eq!(index.extract_field_from_goal("User.IsVIP == true"), "User.IsVIP");
309        assert_eq!(index.extract_field_from_goal("Order.Amount > 100"), "Order.Amount");
310        assert_eq!(index.extract_field_from_goal("User.Name"), "User.Name");
311        assert_eq!(
312            index.extract_field_from_goal("Customer.Email contains '@'"),
313            "Customer.Email"
314        );
315    }
316
317    #[test]
318    fn test_disabled_rules_not_indexed() {
319        let mut index = ConclusionIndex::new();
320        let mut rule = create_test_rule("DisabledRule", "User.IsVIP");
321        rule.enabled = false;
322
323        index.add_rule(&rule);
324
325        assert_eq!(index.rule_count, 0);
326        assert!(index.is_empty());
327    }
328
329    #[test]
330    fn test_from_rules_bulk_creation() {
331        let rules = vec![
332            create_test_rule("Rule1", "User.IsVIP"),
333            create_test_rule("Rule2", "Order.Status"),
334            create_test_rule("Rule3", "Customer.Rating"),
335        ];
336
337        let index = ConclusionIndex::from_rules(&rules);
338
339        assert_eq!(index.rule_count, 3);
340        assert_eq!(index.field_to_rules.len(), 3);
341    }
342
343    #[test]
344    fn test_stats() {
345        let mut index = ConclusionIndex::new();
346        index.add_rule(&create_test_rule("Rule1", "User.IsVIP"));
347        index.add_rule(&create_test_rule("Rule2", "User.IsVIP"));
348        index.add_rule(&create_test_rule("Rule3", "Order.Status"));
349
350        let stats = index.stats();
351
352        assert_eq!(stats.total_rules, 3);
353        assert_eq!(stats.indexed_fields, 2);
354        assert!(stats.avg_rules_per_field > 0.0);
355    }
356}