1use super::agenda::{Activation, AdvancedAgenda};
9use super::deffacts::DeffactsRegistry;
10use super::facts::{FactValue, TypedFacts};
11use super::globals::GlobalsRegistry;
12use super::network::TypedReteUlRule;
13use super::template::TemplateRegistry;
14use super::tms::TruthMaintenanceSystem;
15use super::working_memory::{FactHandle, WorkingMemory};
16use crate::errors::{Result, RuleEngineError};
17use std::collections::{HashMap, HashSet};
18use std::sync::Arc;
19
20#[derive(Debug)]
22pub struct RuleDependencyGraph {
23 fact_type_to_rules: HashMap<String, HashSet<usize>>,
25 rule_to_fact_types: HashMap<usize, HashSet<String>>,
27}
28
29impl RuleDependencyGraph {
30 pub fn new() -> Self {
32 Self {
33 fact_type_to_rules: HashMap::new(),
34 rule_to_fact_types: HashMap::new(),
35 }
36 }
37
38 pub fn add_dependency(&mut self, rule_idx: usize, fact_type: String) {
40 self.fact_type_to_rules
41 .entry(fact_type.clone())
42 .or_default()
43 .insert(rule_idx);
44
45 self.rule_to_fact_types
46 .entry(rule_idx)
47 .or_default()
48 .insert(fact_type);
49 }
50
51 pub fn get_affected_rules(&self, fact_type: &str) -> HashSet<usize> {
53 self.fact_type_to_rules
54 .get(fact_type)
55 .cloned()
56 .unwrap_or_else(HashSet::new)
57 }
58
59 pub fn get_rule_dependencies(&self, rule_idx: usize) -> HashSet<String> {
61 self.rule_to_fact_types
62 .get(&rule_idx)
63 .cloned()
64 .unwrap_or_else(HashSet::new)
65 }
66}
67
68impl Default for RuleDependencyGraph {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74pub type ReteCustomFunction =
77 Arc<dyn Fn(&[FactValue], &TypedFacts) -> Result<FactValue> + Send + Sync>;
78
79pub struct IncrementalEngine {
82 working_memory: WorkingMemory,
84 rules: Vec<TypedReteUlRule>,
86 dependencies: RuleDependencyGraph,
88 agenda: AdvancedAgenda,
90 rule_matched_facts: HashMap<usize, HashSet<FactHandle>>,
92 templates: TemplateRegistry,
94 globals: GlobalsRegistry,
96 deffacts: DeffactsRegistry,
98 custom_functions: HashMap<String, ReteCustomFunction>,
100 tms: TruthMaintenanceSystem,
102}
103
104impl IncrementalEngine {
105 pub fn new() -> Self {
107 Self {
108 working_memory: WorkingMemory::new(),
109 rules: Vec::new(),
110 dependencies: RuleDependencyGraph::new(),
111 agenda: AdvancedAgenda::new(),
112 rule_matched_facts: HashMap::new(),
113 custom_functions: HashMap::new(),
114 templates: TemplateRegistry::new(),
115 globals: GlobalsRegistry::new(),
116 deffacts: DeffactsRegistry::new(),
117 tms: TruthMaintenanceSystem::new(),
118 }
119 }
120
121 pub fn add_rule(&mut self, rule: TypedReteUlRule, depends_on: Vec<String>) {
123 let rule_idx = self.rules.len();
124
125 for fact_type in depends_on {
127 self.dependencies.add_dependency(rule_idx, fact_type);
128 }
129
130 self.rules.push(rule);
131 }
132
133 pub fn insert(&mut self, fact_type: String, data: TypedFacts) -> FactHandle {
135 let handle = self.working_memory.insert(fact_type.clone(), data);
136
137 self.tms.add_explicit_justification(handle);
139
140 self.propagate_changes_for_type(&fact_type);
142
143 handle
144 }
145
146 pub fn update(&mut self, handle: FactHandle, data: TypedFacts) -> Result<()> {
148 let fact_type = self
150 .working_memory
151 .get(&handle)
152 .map(|f| f.fact_type.clone())
153 .ok_or_else(|| RuleEngineError::FieldNotFound {
154 field: format!("FactHandle {} not found", handle),
155 })?;
156
157 self.working_memory
158 .update(handle, data)
159 .map_err(|e| RuleEngineError::EvaluationError { message: e })?;
160
161 self.propagate_changes_for_type(&fact_type);
163
164 Ok(())
165 }
166
167 pub fn retract(&mut self, handle: FactHandle) -> Result<()> {
169 let fact_type = self
171 .working_memory
172 .get(&handle)
173 .map(|f| f.fact_type.clone())
174 .ok_or_else(|| RuleEngineError::FieldNotFound {
175 field: format!("FactHandle {} not found", handle),
176 })?;
177
178 self.working_memory
179 .retract(handle)
180 .map_err(|e| RuleEngineError::EvaluationError { message: e })?;
181
182 let cascaded_facts = self.tms.retract_with_cascade(handle);
184
185 for cascaded_handle in cascaded_facts {
187 if let Ok(fact_type) = self
188 .working_memory
189 .get(&cascaded_handle)
190 .map(|f| f.fact_type.clone())
191 .ok_or_else(|| RuleEngineError::FieldNotFound {
192 field: format!("FactHandle {} not found", cascaded_handle),
193 })
194 {
195 let _ = self.working_memory.retract(cascaded_handle);
196 self.propagate_changes_for_type(&fact_type);
198 }
199 }
200
201 self.propagate_changes_for_type(&fact_type);
203
204 Ok(())
205 }
206
207 pub fn insert_explicit(&mut self, fact_type: String, data: TypedFacts) -> FactHandle {
210 let handle = self.working_memory.insert(fact_type.clone(), data);
211
212 self.tms.add_explicit_justification(handle);
214
215 self.propagate_changes_for_type(&fact_type);
217
218 handle
219 }
220
221 pub fn insert_logical(
230 &mut self,
231 fact_type: String,
232 data: TypedFacts,
233 source_rule: String,
234 premise_handles: Vec<FactHandle>,
235 ) -> FactHandle {
236 let handle = self.working_memory.insert(fact_type.clone(), data);
237
238 self.tms
240 .add_logical_justification(handle, source_rule, premise_handles);
241
242 self.propagate_changes_for_type(&fact_type);
244
245 handle
246 }
247
248 pub fn resolve_premise_keys(&self, premise_keys: Vec<String>) -> Vec<FactHandle> {
253 let mut handles = Vec::new();
254
255 for key in premise_keys {
256 if let Some(eq_pos) = key.find('=') {
258 let left = &key[..eq_pos];
259 let value_part = &key[eq_pos + 1..];
260
261 if let Some(dot_pos) = left.find('.') {
262 let fact_type = &left[..dot_pos];
263 let field = &left[dot_pos + 1..];
264
265 let facts = self.working_memory.get_by_type(fact_type);
267 if value_part.is_empty() {
269 if let Some(fact) = facts.iter().rev().find(|f| !f.metadata.retracted) {
271 handles.push(fact.handle);
272 continue;
273 }
274 } else {
275 fn parse_literal(s: &str) -> super::facts::FactValue {
279 let s = s.trim();
280 if s == "true" {
281 return super::facts::FactValue::Boolean(true);
282 }
283 if s == "false" {
284 return super::facts::FactValue::Boolean(false);
285 }
286 if (s.starts_with('"') && s.ends_with('"'))
288 || (s.starts_with('\'') && s.ends_with('\''))
289 {
290 return super::facts::FactValue::String(
291 s[1..s.len() - 1].to_string(),
292 );
293 }
294 if let Ok(i) = s.parse::<i64>() {
296 return super::facts::FactValue::Integer(i);
297 }
298 if let Ok(f) = s.parse::<f64>() {
300 return super::facts::FactValue::Float(f);
301 }
302
303 super::facts::FactValue::String(s.to_string())
305 }
306
307 fn fact_value_equal(
308 a: &super::facts::FactValue,
309 b: &super::facts::FactValue,
310 ) -> bool {
311 use super::facts::FactValue;
312 match (a, b) {
313 (FactValue::Boolean(x), FactValue::Boolean(y)) => x == y,
314 (FactValue::Integer(x), FactValue::Integer(y)) => x == y,
315 (FactValue::Float(x), FactValue::Float(y)) => (x - y).abs() < 1e-9,
316 (FactValue::Integer(x), FactValue::Float(y)) => {
318 ((*x as f64) - *y).abs() < 1e-9
319 }
320 (FactValue::Float(x), FactValue::Integer(y)) => {
321 (*x - (*y as f64)).abs() < 1e-9
322 }
323 (FactValue::String(x), FactValue::String(y)) => x == y,
324 _ => a.as_string() == b.as_string(),
326 }
327 }
328
329 let expected = parse_literal(value_part);
330
331 if let Some(fact) = facts.iter().rev().find(|fact| {
333 if fact.metadata.retracted {
334 return false;
335 }
336 if let Some(fv) = fact.data.get(field) {
337 fact_value_equal(fv, &expected) || fv.as_string() == value_part
338 } else {
339 false
340 }
341 }) {
342 handles.push(fact.handle);
343 }
344 }
345 }
346 }
347 }
348
349 handles
350 }
351
352 pub fn tms(&self) -> &TruthMaintenanceSystem {
354 &self.tms
355 }
356
357 pub fn tms_mut(&mut self) -> &mut TruthMaintenanceSystem {
359 &mut self.tms
360 }
361
362 fn propagate_changes_for_type(&mut self, fact_type: &str) {
364 let affected_rules = self.dependencies.get_affected_rules(fact_type);
366
367 if affected_rules.is_empty() {
368 return; }
370
371 let facts_of_type = self.working_memory.get_by_type(fact_type);
373
374 for &rule_idx in &affected_rules {
376 let rule = &self.rules[rule_idx];
377
378 for fact in &facts_of_type {
380 let mut single_fact_data = TypedFacts::new();
382 for (key, value) in fact.data.get_all() {
383 single_fact_data.set(format!("{}.{}", fact_type, key), value.clone());
384 }
385 single_fact_data.set_fact_handle(fact_type.to_string(), fact.handle);
387
388 let matches =
390 super::network::evaluate_rete_ul_node_typed(&rule.node, &single_fact_data);
391
392 if matches {
393 let activation = Activation::new(rule.name.clone(), rule.priority)
395 .with_no_loop(rule.no_loop)
396 .with_matched_fact(fact.handle);
397
398 self.agenda.add_activation(activation);
399 }
400 }
401 }
402 }
403
404 fn propagate_changes(&mut self) {
406 let fact_types: Vec<String> = self
408 .working_memory
409 .get_all_facts()
410 .iter()
411 .map(|f| f.fact_type.clone())
412 .collect::<std::collections::HashSet<_>>()
413 .into_iter()
414 .collect();
415
416 for fact_type in fact_types {
418 let facts_of_type = self.working_memory.get_by_type(&fact_type);
419
420 for rule in self.rules.iter() {
421 if rule.no_loop && self.agenda.has_fired(&rule.name) {
423 continue;
424 }
425
426 for fact in &facts_of_type {
428 let mut single_fact_data = TypedFacts::new();
429 for (key, value) in fact.data.get_all() {
430 single_fact_data.set(format!("{}.{}", fact_type, key), value.clone());
431 }
432
433 let matches =
434 super::network::evaluate_rete_ul_node_typed(&rule.node, &single_fact_data);
435
436 if matches {
437 let activation = Activation::new(rule.name.clone(), rule.priority)
438 .with_no_loop(rule.no_loop)
439 .with_matched_fact(fact.handle);
440
441 self.agenda.add_activation(activation);
442 }
443 }
444 }
445 }
446 }
447
448 pub fn fire_all(&mut self) -> Vec<String> {
450 let mut fired_rules = Vec::new();
451 let max_iterations = 1000; let mut iteration_count = 0;
453
454 while let Some(activation) = self.agenda.get_next_activation() {
455 iteration_count += 1;
456 if iteration_count > max_iterations {
457 eprintln!("WARNING: Maximum iterations ({}) reached in fire_all(). Possible infinite loop!", max_iterations);
458 break;
459 }
460
461 if let Some((_idx, rule)) = self
463 .rules
464 .iter_mut()
465 .enumerate()
466 .find(|(_, r)| r.name == activation.rule_name)
467 {
468 if let Some(matched_handle) = activation.matched_fact_handle {
470 if self.working_memory.get(&matched_handle).is_none() {
471 continue;
473 }
474 }
475
476 let original_facts = self.working_memory.to_typed_facts();
478 let mut modified_facts = original_facts.clone();
479
480 if let Some(matched_handle) = activation.matched_fact_handle {
482 if let Some(fact) = self.working_memory.get(&matched_handle) {
484 modified_facts.set_fact_handle(fact.fact_type.clone(), matched_handle);
485 }
486 }
487
488 let mut action_results = super::ActionResults::new();
489 (rule.action)(&mut modified_facts, &mut action_results);
490
491 let mut updates_by_type: HashMap<String, Vec<(String, FactValue)>> = HashMap::new();
494
495 for (key, value) in modified_facts.get_all() {
496 if let Some(original_value) = original_facts.get(key) {
499 if original_value != value {
500 let parts: Vec<&str> = key.split('.').collect();
502 if parts.len() >= 2 {
503 let fact_type = parts[0].to_string();
504 let field = if parts.len() == 2 {
506 parts[1].to_string()
507 } else {
508 parts[parts.len() - 1].to_string()
509 };
510
511 updates_by_type
512 .entry(fact_type)
513 .or_default()
514 .push((field, value.clone()));
515 }
516 }
517 } else {
518 let parts: Vec<&str> = key.split('.').collect();
520 if parts.len() >= 2 {
521 let fact_type = parts[0].to_string();
522 let field = if parts.len() == 2 {
523 parts[1].to_string()
524 } else {
525 parts[parts.len() - 1].to_string()
526 };
527
528 updates_by_type
529 .entry(fact_type)
530 .or_default()
531 .push((field, value.clone()));
532 }
533 }
534 }
535
536 for (fact_type, field_updates) in updates_by_type {
538 let fact_handles: Vec<FactHandle> = self
540 .working_memory
541 .get_by_type(&fact_type)
542 .iter()
543 .map(|f| f.handle)
544 .collect();
545
546 for handle in fact_handles {
547 if let Some(fact) = self.working_memory.get(&handle) {
548 let mut updated_data = fact.data.clone();
549
550 for (field, value) in &field_updates {
552 updated_data.set(field, value.clone());
553 }
554
555 let _ = self.working_memory.update(handle, updated_data);
556 }
557 }
558 }
559
560 self.propagate_changes();
563
564 self.process_action_results(action_results);
566
567 fired_rules.push(activation.rule_name.clone());
569 self.agenda.mark_rule_fired(&activation);
570 }
571 }
572
573 fired_rules
574 }
575
576 fn process_action_results(&mut self, results: super::ActionResults) {
578 for result in results.results {
579 match result {
580 super::ActionResult::Retract(handle) => {
581 if let Err(e) = self.retract(handle) {
583 eprintln!("❌ Failed to retract fact {:?}: {}", handle, e);
584 }
585 }
586 super::ActionResult::RetractByType(fact_type) => {
587 let facts_of_type = self.working_memory.get_by_type(&fact_type);
589 if let Some(fact) = facts_of_type.first() {
590 let handle = fact.handle;
591 if let Err(e) = self.retract(handle) {
592 eprintln!("❌ Failed to retract fact {:?}: {}", handle, e);
593 }
594 }
595 }
596 super::ActionResult::Update(handle) => {
597 if let Some(fact) = self.working_memory.get(&handle) {
599 let fact_type = fact.fact_type.clone();
600 self.propagate_changes_for_type(&fact_type);
601 }
602 }
603 super::ActionResult::ActivateAgendaGroup(group) => {
604 self.agenda.set_focus(group);
606 }
607 super::ActionResult::InsertFact { fact_type, data } => {
608 self.insert_explicit(fact_type, data);
610 }
611 super::ActionResult::InsertLogicalFact {
612 fact_type,
613 data,
614 rule_name,
615 premises,
616 } => {
617 let _handle = self.insert_logical(fact_type, data, rule_name, premises);
619 }
620 super::ActionResult::CallFunction {
621 function_name,
622 args,
623 } => {
624 if let Some(func) = self.custom_functions.get(&function_name) {
626 let fact_values: Vec<FactValue> =
628 args.iter().map(|s| FactValue::String(s.clone())).collect();
629
630 let all_facts = self.working_memory.to_typed_facts();
632 match func(&fact_values, &all_facts) {
633 Ok(_) => println!("✅ Called function: {}", function_name),
634 Err(e) => eprintln!("❌ Function {} failed: {}", function_name, e),
635 }
636 } else {
637 println!("🔧 Function call queued: {}({:?})", function_name, args);
639 }
640 }
641 super::ActionResult::ScheduleRule {
642 rule_name,
643 delay_ms,
644 } => {
645 println!("⏰ Rule scheduled: {} after {}ms", rule_name, delay_ms);
647 }
649 super::ActionResult::None => {
650 }
652 }
653 }
654 }
655
656 pub fn working_memory(&self) -> &WorkingMemory {
658 &self.working_memory
659 }
660
661 pub fn working_memory_mut(&mut self) -> &mut WorkingMemory {
663 &mut self.working_memory
664 }
665
666 pub fn agenda(&self) -> &AdvancedAgenda {
668 &self.agenda
669 }
670
671 pub fn agenda_mut(&mut self) -> &mut AdvancedAgenda {
673 &mut self.agenda
674 }
675
676 pub fn set_conflict_resolution_strategy(
681 &mut self,
682 strategy: super::agenda::ConflictResolutionStrategy,
683 ) {
684 self.agenda.set_strategy(strategy);
685 }
686
687 pub fn conflict_resolution_strategy(&self) -> super::agenda::ConflictResolutionStrategy {
689 self.agenda.strategy()
690 }
691
692 pub fn stats(&self) -> IncrementalEngineStats {
694 IncrementalEngineStats {
695 rules: self.rules.len(),
696 working_memory: self.working_memory.stats(),
697 agenda: self.agenda.stats(),
698 dependencies: self.dependencies.fact_type_to_rules.len(),
699 }
700 }
701
702 pub fn reset(&mut self) {
704 self.agenda.reset_fired_flags();
705 }
706
707 pub fn templates(&self) -> &TemplateRegistry {
709 &self.templates
710 }
711
712 pub fn templates_mut(&mut self) -> &mut TemplateRegistry {
714 &mut self.templates
715 }
716
717 pub fn register_function<F>(&mut self, name: &str, func: F)
736 where
737 F: Fn(&[FactValue], &TypedFacts) -> Result<FactValue> + Send + Sync + 'static,
738 {
739 self.custom_functions
740 .insert(name.to_string(), Arc::new(func));
741 }
742
743 pub fn get_function(&self, name: &str) -> Option<&ReteCustomFunction> {
745 self.custom_functions.get(name)
746 }
747
748 pub fn globals(&self) -> &GlobalsRegistry {
750 &self.globals
751 }
752
753 pub fn globals_mut(&mut self) -> &mut GlobalsRegistry {
755 &mut self.globals
756 }
757
758 pub fn deffacts(&self) -> &DeffactsRegistry {
760 &self.deffacts
761 }
762
763 pub fn deffacts_mut(&mut self) -> &mut DeffactsRegistry {
765 &mut self.deffacts
766 }
767
768 pub fn load_deffacts(&mut self) -> Vec<FactHandle> {
771 let mut handles = Vec::new();
772
773 let all_facts = self.deffacts.get_all_facts();
775
776 for (_deffacts_name, fact_instance) in all_facts {
777 let handle = if self.templates.get(&fact_instance.fact_type).is_some() {
779 match self.insert_with_template(&fact_instance.fact_type, fact_instance.data) {
781 Ok(h) => h,
782 Err(_) => continue, }
784 } else {
785 self.insert(fact_instance.fact_type, fact_instance.data)
787 };
788
789 handles.push(handle);
790 }
791
792 handles
793 }
794
795 pub fn load_deffacts_by_name(&mut self, name: &str) -> crate::errors::Result<Vec<FactHandle>> {
798 let facts_to_insert = {
800 let deffacts = self.deffacts.get(name).ok_or_else(|| {
801 crate::errors::RuleEngineError::EvaluationError {
802 message: format!("Deffacts '{}' not found", name),
803 }
804 })?;
805 deffacts.facts.clone()
806 };
807
808 let mut handles = Vec::new();
809
810 for fact_instance in facts_to_insert {
811 let handle = if self.templates.get(&fact_instance.fact_type).is_some() {
813 self.insert_with_template(&fact_instance.fact_type, fact_instance.data)?
815 } else {
816 self.insert(fact_instance.fact_type, fact_instance.data)
818 };
819
820 handles.push(handle);
821 }
822
823 Ok(handles)
824 }
825
826 pub fn reset_with_deffacts(&mut self) -> Vec<FactHandle> {
829 self.working_memory = WorkingMemory::new();
831 self.agenda.clear();
832 self.rule_matched_facts.clear();
833
834 self.load_deffacts()
836 }
837
838 pub fn insert_with_template(
840 &mut self,
841 template_name: &str,
842 data: TypedFacts,
843 ) -> crate::errors::Result<FactHandle> {
844 self.templates.validate(template_name, &data)?;
846
847 Ok(self.insert(template_name.to_string(), data))
849 }
850}
851
852impl Default for IncrementalEngine {
853 fn default() -> Self {
854 Self::new()
855 }
856}
857
858#[derive(Debug)]
860pub struct IncrementalEngineStats {
861 pub rules: usize,
862 pub working_memory: super::working_memory::WorkingMemoryStats,
863 pub agenda: super::agenda::AgendaStats,
864 pub dependencies: usize,
865}
866
867impl std::fmt::Display for IncrementalEngineStats {
868 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
869 write!(
870 f,
871 "Engine Stats: {} rules, {} fact types tracked\nWM: {}\nAgenda: {}",
872 self.rules, self.dependencies, self.working_memory, self.agenda
873 )
874 }
875}
876
877#[cfg(test)]
878mod tests {
879 use super::*;
880 use crate::rete::alpha::AlphaNode;
881 use crate::rete::network::ReteUlNode;
882
883 #[test]
884 fn test_dependency_graph() {
885 let mut graph = RuleDependencyGraph::new();
886
887 graph.add_dependency(0, "Person".to_string());
888 graph.add_dependency(1, "Person".to_string());
889 graph.add_dependency(1, "Order".to_string());
890
891 let affected = graph.get_affected_rules("Person");
892 assert_eq!(affected.len(), 2);
893 assert!(affected.contains(&0));
894 assert!(affected.contains(&1));
895
896 let deps = graph.get_rule_dependencies(1);
897 assert_eq!(deps.len(), 2);
898 assert!(deps.contains("Person"));
899 assert!(deps.contains("Order"));
900 }
901
902 #[test]
903 fn test_incremental_propagation() {
904 let mut engine = IncrementalEngine::new();
905
906 let node = ReteUlNode::UlAlpha(AlphaNode {
908 field: "Person.age".to_string(),
909 operator: ">".to_string(),
910 value: "18".to_string(),
911 });
912
913 let rule = TypedReteUlRule {
914 name: "IsAdult".to_string(),
915 node,
916 priority: 0,
917 no_loop: true,
918 action: std::sync::Arc::new(|_, _| {}),
919 };
920
921 engine.add_rule(rule, vec!["Person".to_string()]);
922
923 let mut person = TypedFacts::new();
925 person.set("age", 25i64);
926 let handle = engine.insert("Person".to_string(), person);
927
928 let stats = engine.stats();
930 assert!(stats.agenda.total_activations > 0);
931
932 let mut updated = TypedFacts::new();
934 updated.set("age", 15i64); engine.update(handle, updated).unwrap();
936
937 }
939}