rust_rule_engine/rete/
propagation.rs

1//! Incremental Propagation Engine (P3 Feature - Advanced)
2//!
3//! This module implements incremental updates similar to Drools:
4//! - Only propagate changed facts through the network
5//! - Track affected rules and activations
6//! - Efficient re-evaluation after updates
7
8use std::collections::{HashMap, HashSet};
9use std::sync::Arc;
10use super::working_memory::{WorkingMemory, FactHandle};
11use super::network::{ReteUlNode, TypedReteUlRule};
12use super::facts::{TypedFacts, FactValue};
13use super::agenda::{AdvancedAgenda, Activation};
14use super::template::TemplateRegistry;
15use super::globals::GlobalsRegistry;
16use super::deffacts::DeffactsRegistry;
17use crate::errors::{Result, RuleEngineError};
18
19/// Track which rules are affected by which fact types
20#[derive(Debug)]
21pub struct RuleDependencyGraph {
22    /// Map: fact_type -> set of rule indices that depend on it
23    fact_type_to_rules: HashMap<String, HashSet<usize>>,
24    /// Map: rule index -> set of fact types it depends on
25    rule_to_fact_types: HashMap<usize, HashSet<String>>,
26}
27
28impl RuleDependencyGraph {
29    /// Create new dependency graph
30    pub fn new() -> Self {
31        Self {
32            fact_type_to_rules: HashMap::new(),
33            rule_to_fact_types: HashMap::new(),
34        }
35    }
36
37    /// Add dependency: rule depends on fact type
38    pub fn add_dependency(&mut self, rule_idx: usize, fact_type: String) {
39        self.fact_type_to_rules
40            .entry(fact_type.clone())
41            .or_insert_with(HashSet::new)
42            .insert(rule_idx);
43
44        self.rule_to_fact_types
45            .entry(rule_idx)
46            .or_insert_with(HashSet::new)
47            .insert(fact_type);
48    }
49
50    /// Get rules affected by a fact type change
51    pub fn get_affected_rules(&self, fact_type: &str) -> HashSet<usize> {
52        self.fact_type_to_rules
53            .get(fact_type)
54            .cloned()
55            .unwrap_or_else(HashSet::new)
56    }
57
58    /// Get fact types that a rule depends on
59    pub fn get_rule_dependencies(&self, rule_idx: usize) -> HashSet<String> {
60        self.rule_to_fact_types
61            .get(&rule_idx)
62            .cloned()
63            .unwrap_or_else(HashSet::new)
64    }
65}
66
67impl Default for RuleDependencyGraph {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73/// Type alias for custom test functions in RETE engine
74/// Functions take a slice of FactValues and return a FactValue (typically Boolean)
75pub type ReteCustomFunction = Arc<dyn Fn(&[FactValue], &TypedFacts) -> Result<FactValue> + Send + Sync>;
76
77/// Incremental Propagation Engine
78/// Only re-evaluates rules affected by changed facts
79pub struct IncrementalEngine {
80    /// Working memory
81    working_memory: WorkingMemory,
82    /// Rules
83    rules: Vec<TypedReteUlRule>,
84    /// Dependency graph
85    dependencies: RuleDependencyGraph,
86    /// Advanced agenda
87    agenda: AdvancedAgenda,
88    /// Track which facts each rule last matched
89    rule_matched_facts: HashMap<usize, HashSet<FactHandle>>,
90    /// Template registry for type-safe facts
91    templates: TemplateRegistry,
92    /// Global variables registry
93    globals: GlobalsRegistry,
94    /// Deffacts registry for initial facts
95    deffacts: DeffactsRegistry,
96    /// Custom functions for Test CE support
97    custom_functions: HashMap<String, ReteCustomFunction>,
98}
99
100impl IncrementalEngine {
101    /// Create new incremental engine
102    pub fn new() -> Self {
103        Self {
104            working_memory: WorkingMemory::new(),
105            rules: Vec::new(),
106            dependencies: RuleDependencyGraph::new(),
107            agenda: AdvancedAgenda::new(),
108            rule_matched_facts: HashMap::new(),
109            custom_functions: HashMap::new(),
110            templates: TemplateRegistry::new(),
111            globals: GlobalsRegistry::new(),
112            deffacts: DeffactsRegistry::new(),
113        }
114    }
115
116    /// Add rule and register its dependencies
117    pub fn add_rule(&mut self, rule: TypedReteUlRule, depends_on: Vec<String>) {
118        let rule_idx = self.rules.len();
119
120        // Register dependencies
121        for fact_type in depends_on {
122            self.dependencies.add_dependency(rule_idx, fact_type);
123        }
124
125        self.rules.push(rule);
126    }
127
128    /// Insert fact into working memory
129    pub fn insert(&mut self, fact_type: String, data: TypedFacts) -> FactHandle {
130        let handle = self.working_memory.insert(fact_type.clone(), data);
131
132        // Trigger incremental propagation for this fact type
133        self.propagate_changes_for_type(&fact_type);
134
135        handle
136    }
137
138    /// Update fact in working memory
139    pub fn update(&mut self, handle: FactHandle, data: TypedFacts) -> Result<()> {
140        // Get fact type before update
141        let fact_type = self.working_memory
142            .get(&handle)
143            .map(|f| f.fact_type.clone())
144            .ok_or_else(|| RuleEngineError::FieldNotFound {
145                field: format!("FactHandle {} not found", handle),
146            })?;
147
148        self.working_memory.update(handle, data).map_err(|e| RuleEngineError::EvaluationError {
149            message: e,
150        })?;
151
152        // Trigger incremental propagation for this fact type
153        self.propagate_changes_for_type(&fact_type);
154
155        Ok(())
156    }
157
158    /// Retract fact from working memory
159    pub fn retract(&mut self, handle: FactHandle) -> Result<()> {
160        // Get fact type before retract
161        let fact_type = self.working_memory
162            .get(&handle)
163            .map(|f| f.fact_type.clone())
164            .ok_or_else(|| RuleEngineError::FieldNotFound {
165                field: format!("FactHandle {} not found", handle),
166            })?;
167
168        self.working_memory.retract(handle).map_err(|e| RuleEngineError::EvaluationError {
169            message: e,
170        })?;
171
172        // Trigger incremental propagation for this fact type
173        self.propagate_changes_for_type(&fact_type);
174
175        Ok(())
176    }
177
178    /// Propagate changes for a specific fact type (incremental!)
179    fn propagate_changes_for_type(&mut self, fact_type: &str) {
180        // Get affected rules
181        let affected_rules = self.dependencies.get_affected_rules(fact_type);
182
183        if affected_rules.is_empty() {
184            return; // No rules depend on this fact type
185        }
186
187        // Flatten working memory to TypedFacts for evaluation
188        let facts = self.working_memory.to_typed_facts();
189
190        // Re-evaluate only affected rules
191        for &rule_idx in &affected_rules {
192            let rule = &self.rules[rule_idx];
193
194            // Evaluate rule condition
195            let matches = super::network::evaluate_rete_ul_node_typed(&rule.node, &facts);
196
197            if matches {
198                // Create activation
199                let activation = Activation::new(rule.name.clone(), rule.priority)
200                    .with_no_loop(rule.no_loop);
201
202                self.agenda.add_activation(activation);
203            }
204        }
205    }
206
207    /// Propagate changes for all fact types (re-evaluate all rules)
208    fn propagate_changes(&mut self) {
209        // Flatten working memory to TypedFacts for evaluation
210        let facts = self.working_memory.to_typed_facts();
211
212        // Re-evaluate ALL rules with current working memory state
213        for (rule_idx, rule) in self.rules.iter().enumerate() {
214            // Skip if rule has no-loop and already fired
215            if rule.no_loop && self.agenda.has_fired(&rule.name) {
216                continue;
217            }
218            
219            // Evaluate rule condition
220            let matches = super::network::evaluate_rete_ul_node_typed(&rule.node, &facts);
221
222            if matches {
223                // Create activation
224                let activation = Activation::new(rule.name.clone(), rule.priority)
225                    .with_no_loop(rule.no_loop);
226
227                self.agenda.add_activation(activation);
228            }
229        }
230    }
231
232    /// Fire all pending activations
233    pub fn fire_all(&mut self) -> Vec<String> {
234        let mut fired_rules = Vec::new();
235        let max_iterations = 1000; // Prevent infinite loops
236        let mut iteration_count = 0;
237
238        while let Some(activation) = self.agenda.get_next_activation() {
239            iteration_count += 1;
240            if iteration_count > max_iterations {
241                eprintln!("WARNING: Maximum iterations ({}) reached in fire_all(). Possible infinite loop!", max_iterations);
242                break;
243            }
244            
245            // Find rule
246            if let Some((idx, rule)) = self.rules
247                .iter_mut()
248                .enumerate()
249                .find(|(_, r)| r.name == activation.rule_name)
250            {
251                // Execute action on a copy of all facts
252                let original_facts = self.working_memory.to_typed_facts();
253                let mut modified_facts = original_facts.clone();
254                (rule.action)(&mut modified_facts);
255
256                // Update working memory: merge changed fields back into each fact
257                // Get handles and update only the fields that changed
258                let handles: Vec<_> = self.working_memory.get_all_handles();
259                for handle in handles {
260                    if let Some(wm_fact) = self.working_memory.get(&handle) {
261                        // Start with original fact data
262                        let mut updated_data = wm_fact.data.clone();
263
264                        // Merge in any NEW fields from modified_facts
265                        // (fields that were set by the action)
266                        for (key, value) in modified_facts.get_all() {
267                            // Only update if this field wasn't in original OR has changed
268                            if !original_facts.get_all().contains_key(key) ||
269                               original_facts.get(key) != Some(value) {
270                                // Strip any prefixes to get clean field name
271                                let clean_key = if key.contains('.') {
272                                    key.split('.').last().unwrap_or(key)
273                                } else {
274                                    key
275                                };
276                                updated_data.set(clean_key, value.clone());
277                            }
278                        }
279
280                        let _ = self.working_memory.update(handle, updated_data);
281                    }
282                }
283
284                // Re-evaluate matches after working memory update
285                // This allows subsequent rules to see the updated values
286                self.propagate_changes();
287
288                // Track fired rule
289                fired_rules.push(activation.rule_name.clone());
290                self.agenda.mark_rule_fired(&activation);
291            }
292        }
293
294        fired_rules
295    }
296
297    /// Get working memory
298    pub fn working_memory(&self) -> &WorkingMemory {
299        &self.working_memory
300    }
301
302    /// Get mutable working memory
303    pub fn working_memory_mut(&mut self) -> &mut WorkingMemory {
304        &mut self.working_memory
305    }
306
307    /// Get agenda
308    pub fn agenda(&self) -> &AdvancedAgenda {
309        &self.agenda
310    }
311
312    /// Get mutable agenda
313    pub fn agenda_mut(&mut self) -> &mut AdvancedAgenda {
314        &mut self.agenda
315    }
316
317    /// Set conflict resolution strategy
318    ///
319    /// Controls how conflicting rules in the agenda are ordered.
320    /// Available strategies: Salience (default), LEX, MEA, Depth, Breadth, Simplicity, Complexity, Random
321    pub fn set_conflict_resolution_strategy(
322        &mut self,
323        strategy: super::agenda::ConflictResolutionStrategy,
324    ) {
325        self.agenda.set_strategy(strategy);
326    }
327
328    /// Get current conflict resolution strategy
329    pub fn conflict_resolution_strategy(&self) -> super::agenda::ConflictResolutionStrategy {
330        self.agenda.strategy()
331    }
332
333    /// Get statistics
334    pub fn stats(&self) -> IncrementalEngineStats {
335        IncrementalEngineStats {
336            rules: self.rules.len(),
337            working_memory: self.working_memory.stats(),
338            agenda: self.agenda.stats(),
339            dependencies: self.dependencies.fact_type_to_rules.len(),
340        }
341    }
342
343    /// Clear fired flags and reset agenda
344    pub fn reset(&mut self) {
345        self.agenda.reset_fired_flags();
346    }
347
348    /// Get template registry
349    pub fn templates(&self) -> &TemplateRegistry {
350        &self.templates
351    }
352
353    /// Get mutable template registry
354    pub fn templates_mut(&mut self) -> &mut TemplateRegistry {
355        &mut self.templates
356    }
357
358    /// Register a custom function for Test CE support
359    ///
360    /// # Example
361    /// ```
362    /// use rust_rule_engine::rete::{IncrementalEngine, FactValue};
363    ///
364    /// let mut engine = IncrementalEngine::new();
365    /// engine.register_function(
366    ///     "is_valid_email",
367    ///     |args, _facts| {
368    ///         if let Some(FactValue::String(email)) = args.first() {
369    ///             Ok(FactValue::Boolean(email.contains('@')))
370    ///         } else {
371    ///             Ok(FactValue::Boolean(false))
372    ///         }
373    ///     }
374    /// );
375    /// ```
376    pub fn register_function<F>(&mut self, name: &str, func: F)
377    where
378        F: Fn(&[FactValue], &TypedFacts) -> Result<FactValue> + Send + Sync + 'static,
379    {
380        self.custom_functions.insert(name.to_string(), Arc::new(func));
381    }
382
383    /// Get a custom function by name (for Test CE evaluation)
384    pub fn get_function(&self, name: &str) -> Option<&ReteCustomFunction> {
385        self.custom_functions.get(name)
386    }
387
388    /// Get global variables registry
389    pub fn globals(&self) -> &GlobalsRegistry {
390        &self.globals
391    }
392
393    /// Get mutable global variables registry
394    pub fn globals_mut(&mut self) -> &mut GlobalsRegistry {
395        &mut self.globals
396    }
397
398    /// Get deffacts registry
399    pub fn deffacts(&self) -> &DeffactsRegistry {
400        &self.deffacts
401    }
402
403    /// Get mutable deffacts registry
404    pub fn deffacts_mut(&mut self) -> &mut DeffactsRegistry {
405        &mut self.deffacts
406    }
407
408    /// Load all registered deffacts into working memory
409    /// Returns handles of all inserted facts
410    pub fn load_deffacts(&mut self) -> Vec<FactHandle> {
411        let mut handles = Vec::new();
412
413        // Get all facts from all registered deffacts
414        let all_facts = self.deffacts.get_all_facts();
415
416        for (_deffacts_name, fact_instance) in all_facts {
417            // Check if template exists for this fact type
418            let handle = if self.templates.get(&fact_instance.fact_type).is_some() {
419                // Use template validation
420                match self.insert_with_template(&fact_instance.fact_type, fact_instance.data) {
421                    Ok(h) => h,
422                    Err(_) => continue, // Skip invalid facts
423                }
424            } else {
425                // Insert without template validation
426                self.insert(fact_instance.fact_type, fact_instance.data)
427            };
428
429            handles.push(handle);
430        }
431
432        handles
433    }
434
435    /// Load a specific deffacts set by name
436    /// Returns handles of inserted facts or error if deffacts not found
437    pub fn load_deffacts_by_name(&mut self, name: &str) -> crate::errors::Result<Vec<FactHandle>> {
438        // Clone the facts to avoid borrow checker issues
439        let facts_to_insert = {
440            let deffacts = self.deffacts.get(name).ok_or_else(|| {
441                crate::errors::RuleEngineError::EvaluationError {
442                    message: format!("Deffacts '{}' not found", name),
443                }
444            })?;
445            deffacts.facts.clone()
446        };
447
448        let mut handles = Vec::new();
449
450        for fact_instance in facts_to_insert {
451            // Check if template exists for this fact type
452            let handle = if self.templates.get(&fact_instance.fact_type).is_some() {
453                // Use template validation
454                self.insert_with_template(&fact_instance.fact_type, fact_instance.data)?
455            } else {
456                // Insert without template validation
457                self.insert(fact_instance.fact_type, fact_instance.data)
458            };
459
460            handles.push(handle);
461        }
462
463        Ok(handles)
464    }
465
466    /// Reset engine and reload all deffacts (similar to CLIPS reset)
467    /// Clears working memory and agenda, then loads all deffacts
468    pub fn reset_with_deffacts(&mut self) -> Vec<FactHandle> {
469        // Clear working memory and agenda
470        self.working_memory = WorkingMemory::new();
471        self.agenda.clear();
472        self.rule_matched_facts.clear();
473
474        // Reload all deffacts
475        self.load_deffacts()
476    }
477
478    /// Insert a typed fact with template validation
479    pub fn insert_with_template(
480        &mut self,
481        template_name: &str,
482        data: TypedFacts,
483    ) -> crate::errors::Result<FactHandle> {
484        // Validate against template
485        self.templates.validate(template_name, &data)?;
486
487        // Insert into working memory
488        Ok(self.insert(template_name.to_string(), data))
489    }
490}
491
492impl Default for IncrementalEngine {
493    fn default() -> Self {
494        Self::new()
495    }
496}
497
498/// Engine statistics
499#[derive(Debug)]
500pub struct IncrementalEngineStats {
501    pub rules: usize,
502    pub working_memory: super::working_memory::WorkingMemoryStats,
503    pub agenda: super::agenda::AgendaStats,
504    pub dependencies: usize,
505}
506
507impl std::fmt::Display for IncrementalEngineStats {
508    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
509        write!(
510            f,
511            "Engine Stats: {} rules, {} fact types tracked\nWM: {}\nAgenda: {}",
512            self.rules,
513            self.dependencies,
514            self.working_memory,
515            self.agenda
516        )
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523    use crate::rete::network::ReteUlNode;
524    use crate::rete::alpha::AlphaNode;
525
526    #[test]
527    fn test_dependency_graph() {
528        let mut graph = RuleDependencyGraph::new();
529
530        graph.add_dependency(0, "Person".to_string());
531        graph.add_dependency(1, "Person".to_string());
532        graph.add_dependency(1, "Order".to_string());
533
534        let affected = graph.get_affected_rules("Person");
535        assert_eq!(affected.len(), 2);
536        assert!(affected.contains(&0));
537        assert!(affected.contains(&1));
538
539        let deps = graph.get_rule_dependencies(1);
540        assert_eq!(deps.len(), 2);
541        assert!(deps.contains("Person"));
542        assert!(deps.contains("Order"));
543    }
544
545    #[test]
546    fn test_incremental_propagation() {
547        let mut engine = IncrementalEngine::new();
548
549        // Add rule that depends on "Person" type
550        let node = ReteUlNode::UlAlpha(AlphaNode {
551            field: "Person.age".to_string(),
552            operator: ">".to_string(),
553            value: "18".to_string(),
554        });
555
556        let rule = TypedReteUlRule {
557            name: "IsAdult".to_string(),
558            node,
559            priority: 0,
560            no_loop: true,
561            action: std::sync::Arc::new(|_| {}),
562        };
563
564        engine.add_rule(rule, vec!["Person".to_string()]);
565
566        // Insert Person fact
567        let mut person = TypedFacts::new();
568        person.set("age", 25i64);
569        let handle = engine.insert("Person".to_string(), person);
570
571        // Check that rule was activated
572        let stats = engine.stats();
573        assert!(stats.agenda.total_activations > 0);
574
575        // Update person
576        let mut updated = TypedFacts::new();
577        updated.set("age", 15i64); // Now under 18
578        engine.update(handle, updated).unwrap();
579
580        // Rule should be re-evaluated (incrementally)
581    }
582}