1use std::collections::{HashMap, HashSet};
9use std::sync::Arc;
10use super::working_memory::{WorkingMemory, FactHandle};
11use super::network::{ReteUlNode, TypedReteUlRule};
12use super::facts::{TypedFacts, FactValue};
13use super::agenda::{AdvancedAgenda, Activation};
14use super::template::TemplateRegistry;
15use super::globals::GlobalsRegistry;
16use super::deffacts::DeffactsRegistry;
17use crate::errors::{Result, RuleEngineError};
18
19#[derive(Debug)]
21pub struct RuleDependencyGraph {
22 fact_type_to_rules: HashMap<String, HashSet<usize>>,
24 rule_to_fact_types: HashMap<usize, HashSet<String>>,
26}
27
28impl RuleDependencyGraph {
29 pub fn new() -> Self {
31 Self {
32 fact_type_to_rules: HashMap::new(),
33 rule_to_fact_types: HashMap::new(),
34 }
35 }
36
37 pub fn add_dependency(&mut self, rule_idx: usize, fact_type: String) {
39 self.fact_type_to_rules
40 .entry(fact_type.clone())
41 .or_insert_with(HashSet::new)
42 .insert(rule_idx);
43
44 self.rule_to_fact_types
45 .entry(rule_idx)
46 .or_insert_with(HashSet::new)
47 .insert(fact_type);
48 }
49
50 pub fn get_affected_rules(&self, fact_type: &str) -> HashSet<usize> {
52 self.fact_type_to_rules
53 .get(fact_type)
54 .cloned()
55 .unwrap_or_else(HashSet::new)
56 }
57
58 pub fn get_rule_dependencies(&self, rule_idx: usize) -> HashSet<String> {
60 self.rule_to_fact_types
61 .get(&rule_idx)
62 .cloned()
63 .unwrap_or_else(HashSet::new)
64 }
65}
66
67impl Default for RuleDependencyGraph {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73pub type ReteCustomFunction = Arc<dyn Fn(&[FactValue], &TypedFacts) -> Result<FactValue> + Send + Sync>;
76
77pub struct IncrementalEngine {
80 working_memory: WorkingMemory,
82 rules: Vec<TypedReteUlRule>,
84 dependencies: RuleDependencyGraph,
86 agenda: AdvancedAgenda,
88 rule_matched_facts: HashMap<usize, HashSet<FactHandle>>,
90 templates: TemplateRegistry,
92 globals: GlobalsRegistry,
94 deffacts: DeffactsRegistry,
96 custom_functions: HashMap<String, ReteCustomFunction>,
98}
99
100impl IncrementalEngine {
101 pub fn new() -> Self {
103 Self {
104 working_memory: WorkingMemory::new(),
105 rules: Vec::new(),
106 dependencies: RuleDependencyGraph::new(),
107 agenda: AdvancedAgenda::new(),
108 rule_matched_facts: HashMap::new(),
109 custom_functions: HashMap::new(),
110 templates: TemplateRegistry::new(),
111 globals: GlobalsRegistry::new(),
112 deffacts: DeffactsRegistry::new(),
113 }
114 }
115
116 pub fn add_rule(&mut self, rule: TypedReteUlRule, depends_on: Vec<String>) {
118 let rule_idx = self.rules.len();
119
120 for fact_type in depends_on {
122 self.dependencies.add_dependency(rule_idx, fact_type);
123 }
124
125 self.rules.push(rule);
126 }
127
128 pub fn insert(&mut self, fact_type: String, data: TypedFacts) -> FactHandle {
130 let handle = self.working_memory.insert(fact_type.clone(), data);
131
132 self.propagate_changes_for_type(&fact_type);
134
135 handle
136 }
137
138 pub fn update(&mut self, handle: FactHandle, data: TypedFacts) -> Result<()> {
140 let fact_type = self.working_memory
142 .get(&handle)
143 .map(|f| f.fact_type.clone())
144 .ok_or_else(|| RuleEngineError::FieldNotFound {
145 field: format!("FactHandle {} not found", handle),
146 })?;
147
148 self.working_memory.update(handle, data).map_err(|e| RuleEngineError::EvaluationError {
149 message: e,
150 })?;
151
152 self.propagate_changes_for_type(&fact_type);
154
155 Ok(())
156 }
157
158 pub fn retract(&mut self, handle: FactHandle) -> Result<()> {
160 let fact_type = self.working_memory
162 .get(&handle)
163 .map(|f| f.fact_type.clone())
164 .ok_or_else(|| RuleEngineError::FieldNotFound {
165 field: format!("FactHandle {} not found", handle),
166 })?;
167
168 self.working_memory.retract(handle).map_err(|e| RuleEngineError::EvaluationError {
169 message: e,
170 })?;
171
172 self.propagate_changes_for_type(&fact_type);
174
175 Ok(())
176 }
177
178 fn propagate_changes_for_type(&mut self, fact_type: &str) {
180 let affected_rules = self.dependencies.get_affected_rules(fact_type);
182
183 if affected_rules.is_empty() {
184 return; }
186
187 let facts = self.working_memory.to_typed_facts();
189
190 for &rule_idx in &affected_rules {
192 let rule = &self.rules[rule_idx];
193
194 let matches = super::network::evaluate_rete_ul_node_typed(&rule.node, &facts);
196
197 if matches {
198 let activation = Activation::new(rule.name.clone(), rule.priority)
200 .with_no_loop(rule.no_loop);
201
202 self.agenda.add_activation(activation);
203 }
204 }
205 }
206
207 fn propagate_changes(&mut self) {
209 let facts = self.working_memory.to_typed_facts();
211
212 for (rule_idx, rule) in self.rules.iter().enumerate() {
214 if rule.no_loop && self.agenda.has_fired(&rule.name) {
216 continue;
217 }
218
219 let matches = super::network::evaluate_rete_ul_node_typed(&rule.node, &facts);
221
222 if matches {
223 let activation = Activation::new(rule.name.clone(), rule.priority)
225 .with_no_loop(rule.no_loop);
226
227 self.agenda.add_activation(activation);
228 }
229 }
230 }
231
232 pub fn fire_all(&mut self) -> Vec<String> {
234 let mut fired_rules = Vec::new();
235 let max_iterations = 1000; let mut iteration_count = 0;
237
238 while let Some(activation) = self.agenda.get_next_activation() {
239 iteration_count += 1;
240 if iteration_count > max_iterations {
241 eprintln!("WARNING: Maximum iterations ({}) reached in fire_all(). Possible infinite loop!", max_iterations);
242 break;
243 }
244
245 if let Some((idx, rule)) = self.rules
247 .iter_mut()
248 .enumerate()
249 .find(|(_, r)| r.name == activation.rule_name)
250 {
251 let original_facts = self.working_memory.to_typed_facts();
253 let mut modified_facts = original_facts.clone();
254 (rule.action)(&mut modified_facts);
255
256 let handles: Vec<_> = self.working_memory.get_all_handles();
259 for handle in handles {
260 if let Some(wm_fact) = self.working_memory.get(&handle) {
261 let mut updated_data = wm_fact.data.clone();
263
264 for (key, value) in modified_facts.get_all() {
267 if !original_facts.get_all().contains_key(key) ||
269 original_facts.get(key) != Some(value) {
270 let clean_key = if key.contains('.') {
272 key.split('.').last().unwrap_or(key)
273 } else {
274 key
275 };
276 updated_data.set(clean_key, value.clone());
277 }
278 }
279
280 let _ = self.working_memory.update(handle, updated_data);
281 }
282 }
283
284 self.propagate_changes();
287
288 fired_rules.push(activation.rule_name.clone());
290 self.agenda.mark_rule_fired(&activation);
291 }
292 }
293
294 fired_rules
295 }
296
297 pub fn working_memory(&self) -> &WorkingMemory {
299 &self.working_memory
300 }
301
302 pub fn working_memory_mut(&mut self) -> &mut WorkingMemory {
304 &mut self.working_memory
305 }
306
307 pub fn agenda(&self) -> &AdvancedAgenda {
309 &self.agenda
310 }
311
312 pub fn agenda_mut(&mut self) -> &mut AdvancedAgenda {
314 &mut self.agenda
315 }
316
317 pub fn set_conflict_resolution_strategy(
322 &mut self,
323 strategy: super::agenda::ConflictResolutionStrategy,
324 ) {
325 self.agenda.set_strategy(strategy);
326 }
327
328 pub fn conflict_resolution_strategy(&self) -> super::agenda::ConflictResolutionStrategy {
330 self.agenda.strategy()
331 }
332
333 pub fn stats(&self) -> IncrementalEngineStats {
335 IncrementalEngineStats {
336 rules: self.rules.len(),
337 working_memory: self.working_memory.stats(),
338 agenda: self.agenda.stats(),
339 dependencies: self.dependencies.fact_type_to_rules.len(),
340 }
341 }
342
343 pub fn reset(&mut self) {
345 self.agenda.reset_fired_flags();
346 }
347
348 pub fn templates(&self) -> &TemplateRegistry {
350 &self.templates
351 }
352
353 pub fn templates_mut(&mut self) -> &mut TemplateRegistry {
355 &mut self.templates
356 }
357
358 pub fn register_function<F>(&mut self, name: &str, func: F)
377 where
378 F: Fn(&[FactValue], &TypedFacts) -> Result<FactValue> + Send + Sync + 'static,
379 {
380 self.custom_functions.insert(name.to_string(), Arc::new(func));
381 }
382
383 pub fn get_function(&self, name: &str) -> Option<&ReteCustomFunction> {
385 self.custom_functions.get(name)
386 }
387
388 pub fn globals(&self) -> &GlobalsRegistry {
390 &self.globals
391 }
392
393 pub fn globals_mut(&mut self) -> &mut GlobalsRegistry {
395 &mut self.globals
396 }
397
398 pub fn deffacts(&self) -> &DeffactsRegistry {
400 &self.deffacts
401 }
402
403 pub fn deffacts_mut(&mut self) -> &mut DeffactsRegistry {
405 &mut self.deffacts
406 }
407
408 pub fn load_deffacts(&mut self) -> Vec<FactHandle> {
411 let mut handles = Vec::new();
412
413 let all_facts = self.deffacts.get_all_facts();
415
416 for (_deffacts_name, fact_instance) in all_facts {
417 let handle = if self.templates.get(&fact_instance.fact_type).is_some() {
419 match self.insert_with_template(&fact_instance.fact_type, fact_instance.data) {
421 Ok(h) => h,
422 Err(_) => continue, }
424 } else {
425 self.insert(fact_instance.fact_type, fact_instance.data)
427 };
428
429 handles.push(handle);
430 }
431
432 handles
433 }
434
435 pub fn load_deffacts_by_name(&mut self, name: &str) -> crate::errors::Result<Vec<FactHandle>> {
438 let facts_to_insert = {
440 let deffacts = self.deffacts.get(name).ok_or_else(|| {
441 crate::errors::RuleEngineError::EvaluationError {
442 message: format!("Deffacts '{}' not found", name),
443 }
444 })?;
445 deffacts.facts.clone()
446 };
447
448 let mut handles = Vec::new();
449
450 for fact_instance in facts_to_insert {
451 let handle = if self.templates.get(&fact_instance.fact_type).is_some() {
453 self.insert_with_template(&fact_instance.fact_type, fact_instance.data)?
455 } else {
456 self.insert(fact_instance.fact_type, fact_instance.data)
458 };
459
460 handles.push(handle);
461 }
462
463 Ok(handles)
464 }
465
466 pub fn reset_with_deffacts(&mut self) -> Vec<FactHandle> {
469 self.working_memory = WorkingMemory::new();
471 self.agenda.clear();
472 self.rule_matched_facts.clear();
473
474 self.load_deffacts()
476 }
477
478 pub fn insert_with_template(
480 &mut self,
481 template_name: &str,
482 data: TypedFacts,
483 ) -> crate::errors::Result<FactHandle> {
484 self.templates.validate(template_name, &data)?;
486
487 Ok(self.insert(template_name.to_string(), data))
489 }
490}
491
492impl Default for IncrementalEngine {
493 fn default() -> Self {
494 Self::new()
495 }
496}
497
498#[derive(Debug)]
500pub struct IncrementalEngineStats {
501 pub rules: usize,
502 pub working_memory: super::working_memory::WorkingMemoryStats,
503 pub agenda: super::agenda::AgendaStats,
504 pub dependencies: usize,
505}
506
507impl std::fmt::Display for IncrementalEngineStats {
508 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
509 write!(
510 f,
511 "Engine Stats: {} rules, {} fact types tracked\nWM: {}\nAgenda: {}",
512 self.rules,
513 self.dependencies,
514 self.working_memory,
515 self.agenda
516 )
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523 use crate::rete::network::ReteUlNode;
524 use crate::rete::alpha::AlphaNode;
525
526 #[test]
527 fn test_dependency_graph() {
528 let mut graph = RuleDependencyGraph::new();
529
530 graph.add_dependency(0, "Person".to_string());
531 graph.add_dependency(1, "Person".to_string());
532 graph.add_dependency(1, "Order".to_string());
533
534 let affected = graph.get_affected_rules("Person");
535 assert_eq!(affected.len(), 2);
536 assert!(affected.contains(&0));
537 assert!(affected.contains(&1));
538
539 let deps = graph.get_rule_dependencies(1);
540 assert_eq!(deps.len(), 2);
541 assert!(deps.contains("Person"));
542 assert!(deps.contains("Order"));
543 }
544
545 #[test]
546 fn test_incremental_propagation() {
547 let mut engine = IncrementalEngine::new();
548
549 let node = ReteUlNode::UlAlpha(AlphaNode {
551 field: "Person.age".to_string(),
552 operator: ">".to_string(),
553 value: "18".to_string(),
554 });
555
556 let rule = TypedReteUlRule {
557 name: "IsAdult".to_string(),
558 node,
559 priority: 0,
560 no_loop: true,
561 action: std::sync::Arc::new(|_| {}),
562 };
563
564 engine.add_rule(rule, vec!["Person".to_string()]);
565
566 let mut person = TypedFacts::new();
568 person.set("age", 25i64);
569 let handle = engine.insert("Person".to_string(), person);
570
571 let stats = engine.stats();
573 assert!(stats.agenda.total_activations > 0);
574
575 let mut updated = TypedFacts::new();
577 updated.set("age", 15i64); engine.update(handle, updated).unwrap();
579
580 }
582}