1#![allow(clippy::type_complexity)]
2
3use crate::rete::alpha::AlphaNode;
4use std::sync::Arc;
5pub fn build_rete_ul_from_condition_group(
7 group: &crate::rete::auto_network::ConditionGroup,
8) -> ReteUlNode {
9 use crate::rete::auto_network::ConditionGroup;
10 match group {
11 ConditionGroup::Single(cond) => ReteUlNode::UlAlpha(AlphaNode {
12 field: cond.field.clone(),
13 operator: cond.operator.clone(),
14 value: cond.value.clone(),
15 }),
16 ConditionGroup::Compound {
17 left,
18 operator,
19 right,
20 } => match operator.as_str() {
21 "AND" => ReteUlNode::UlAnd(
22 Box::new(build_rete_ul_from_condition_group(left)),
23 Box::new(build_rete_ul_from_condition_group(right)),
24 ),
25 "OR" => ReteUlNode::UlOr(
26 Box::new(build_rete_ul_from_condition_group(left)),
27 Box::new(build_rete_ul_from_condition_group(right)),
28 ),
29 _ => ReteUlNode::UlAnd(
30 Box::new(build_rete_ul_from_condition_group(left)),
31 Box::new(build_rete_ul_from_condition_group(right)),
32 ),
33 },
34 ConditionGroup::Not(inner) => {
35 ReteUlNode::UlNot(Box::new(build_rete_ul_from_condition_group(inner)))
36 }
37 ConditionGroup::Exists(inner) => {
38 ReteUlNode::UlExists(Box::new(build_rete_ul_from_condition_group(inner)))
39 }
40 ConditionGroup::Forall(inner) => {
41 ReteUlNode::UlForall(Box::new(build_rete_ul_from_condition_group(inner)))
42 }
43 }
44}
45use std::collections::HashMap;
46
47fn evaluate_condition_string(condition: &str, facts: &HashMap<String, String>) -> bool {
49 let condition = condition.trim();
50 let operators = ["==", "!=", ">=", "<=", ">", "<"];
51
52 for op in &operators {
53 if let Some(pos) = condition.find(op) {
54 let field = condition[..pos].trim();
55 let value_str = condition[pos + op.len()..]
56 .trim()
57 .trim_matches('"')
58 .trim_matches('\'');
59
60 if let Some(field_value) = facts.get(field) {
61 return compare_string_values(field_value, op, value_str);
62 } else {
63 return false;
64 }
65 }
66 }
67 false
68}
69
70fn compare_string_values(field_value: &str, operator: &str, value_str: &str) -> bool {
72 if let (Ok(field_num), Ok(val_num)) = (field_value.parse::<f64>(), value_str.parse::<f64>()) {
74 match operator {
75 "==" => (field_num - val_num).abs() < f64::EPSILON,
76 "!=" => (field_num - val_num).abs() >= f64::EPSILON,
77 ">" => field_num > val_num,
78 "<" => field_num < val_num,
79 ">=" => field_num >= val_num,
80 "<=" => field_num <= val_num,
81 _ => false,
82 }
83 } else {
84 match operator {
86 "==" => field_value == value_str,
87 "!=" => field_value != value_str,
88 _ => false,
89 }
90 }
91}
92
93pub fn evaluate_rete_ul_node(node: &ReteUlNode, facts: &HashMap<String, String>) -> bool {
95 match node {
96 ReteUlNode::UlAlpha(alpha) => {
97 let val = if alpha.field.contains('.') {
98 let parts: Vec<&str> = alpha.field.split('.').collect();
99 if parts.len() == 2 {
100 let prefix = parts[0];
101 let suffix = parts[1];
102 facts
103 .get(&format!("{}.{}", prefix, suffix))
104 .or_else(|| facts.get(&format!("{}:{}", prefix, suffix)))
105 } else {
106 facts.get(&alpha.field)
107 }
108 } else {
109 facts.get(&alpha.field)
110 };
111 if let Some(val) = val {
112 match alpha.operator.as_str() {
113 "==" => val == &alpha.value,
114 "!=" => val != &alpha.value,
115 ">" => {
116 val.parse::<f64>().unwrap_or(0.0)
117 > alpha.value.parse::<f64>().unwrap_or(0.0)
118 }
119 "<" => {
120 val.parse::<f64>().unwrap_or(0.0)
121 < alpha.value.parse::<f64>().unwrap_or(0.0)
122 }
123 ">=" => {
124 val.parse::<f64>().unwrap_or(0.0)
125 >= alpha.value.parse::<f64>().unwrap_or(0.0)
126 }
127 "<=" => {
128 val.parse::<f64>().unwrap_or(0.0)
129 <= alpha.value.parse::<f64>().unwrap_or(0.0)
130 }
131 _ => false,
132 }
133 } else {
134 false
135 }
136 }
137 ReteUlNode::UlAnd(left, right) => {
138 evaluate_rete_ul_node(left, facts) && evaluate_rete_ul_node(right, facts)
139 }
140 ReteUlNode::UlOr(left, right) => {
141 evaluate_rete_ul_node(left, facts) || evaluate_rete_ul_node(right, facts)
142 }
143 ReteUlNode::UlNot(inner) => !evaluate_rete_ul_node(inner, facts),
144 ReteUlNode::UlExists(inner) => {
145 let target_field = match &**inner {
146 ReteUlNode::UlAlpha(alpha) => alpha.field.clone(),
147 _ => "".to_string(),
148 };
149 if target_field.contains('.') {
150 let parts: Vec<&str> = target_field.split('.').collect();
151 if parts.len() == 2 {
152 let prefix = parts[0];
153 let suffix = parts[1];
154 let filtered: Vec<_> = facts
155 .iter()
156 .filter(|(k, _)| k.starts_with(prefix) && k.ends_with(suffix))
157 .collect();
158 filtered.iter().any(|(_, value)| {
159 let mut sub_facts = HashMap::new();
160 sub_facts.insert(target_field.clone(), (*value).clone());
161 evaluate_rete_ul_node(inner, &sub_facts)
162 })
163 } else {
164 facts.iter().any(|(field, value)| {
165 let mut sub_facts = HashMap::new();
166 sub_facts.insert(field.clone(), value.clone());
167 evaluate_rete_ul_node(inner, &sub_facts)
168 })
169 }
170 } else {
171 facts.iter().any(|(field, value)| {
172 let mut sub_facts = HashMap::new();
173 sub_facts.insert(field.clone(), value.clone());
174 evaluate_rete_ul_node(inner, &sub_facts)
175 })
176 }
177 }
178 ReteUlNode::UlForall(inner) => {
179 let target_field = match &**inner {
180 ReteUlNode::UlAlpha(alpha) => alpha.field.clone(),
181 _ => "".to_string(),
182 };
183 if target_field.contains('.') {
184 let parts: Vec<&str> = target_field.split('.').collect();
185 if parts.len() == 2 {
186 let prefix = parts[0];
187 let suffix = parts[1];
188 let filtered: Vec<_> = facts
189 .iter()
190 .filter(|(k, _)| k.starts_with(prefix) && k.ends_with(suffix))
191 .collect();
192 if filtered.is_empty() {
193 return true; }
195 filtered.iter().all(|(_, value)| {
196 let mut sub_facts = HashMap::new();
197 sub_facts.insert(target_field.clone(), (*value).clone());
198 evaluate_rete_ul_node(inner, &sub_facts)
199 })
200 } else {
201 facts.iter().all(|(field, value)| {
202 let mut sub_facts = HashMap::new();
203 sub_facts.insert(field.clone(), value.clone());
204 evaluate_rete_ul_node(inner, &sub_facts)
205 })
206 }
207 } else {
208 facts.iter().all(|(field, value)| {
209 let mut sub_facts = HashMap::new();
210 sub_facts.insert(field.clone(), value.clone());
211 evaluate_rete_ul_node(inner, &sub_facts)
212 })
213 }
214 }
215 ReteUlNode::UlAccumulate {
216 source_pattern,
217 extract_field,
218 source_conditions,
219 function,
220 ..
221 } => {
222 let pattern_prefix = format!("{}.", source_pattern);
225 let mut matching_values = Vec::new();
226
227 let mut instances: std::collections::HashMap<
229 String,
230 std::collections::HashMap<String, String>,
231 > = std::collections::HashMap::new();
232
233 for (key, value) in facts {
234 if key.starts_with(&pattern_prefix) {
235 let parts: Vec<&str> = key
236 .strip_prefix(&pattern_prefix)
237 .unwrap()
238 .split('.')
239 .collect();
240
241 if parts.len() >= 2 {
242 let instance_id = parts[0];
243 let field_name = parts[1..].join(".");
244
245 instances
246 .entry(instance_id.to_string())
247 .or_default()
248 .insert(field_name, value.clone());
249 } else if parts.len() == 1 {
250 instances
251 .entry("default".to_string())
252 .or_default()
253 .insert(parts[0].to_string(), value.clone());
254 }
255 }
256 }
257
258 for (_instance_id, instance_facts) in instances {
260 let mut matches = true;
261
262 for condition_str in source_conditions {
263 if !evaluate_condition_string(condition_str, &instance_facts) {
264 matches = false;
265 break;
266 }
267 }
268
269 if matches {
270 if let Some(value_str) = instance_facts.get(extract_field) {
271 let fact_value = if let Ok(i) = value_str.parse::<i64>() {
273 super::facts::FactValue::Integer(i)
274 } else if let Ok(f) = value_str.parse::<f64>() {
275 super::facts::FactValue::Float(f)
276 } else if let Ok(b) = value_str.parse::<bool>() {
277 super::facts::FactValue::Boolean(b)
278 } else {
279 super::facts::FactValue::String(value_str.clone())
280 };
281 matching_values.push(fact_value);
282 }
283 }
284 }
285
286 let has_results = !matching_values.is_empty();
288
289 match function.as_str() {
290 "count" => has_results, "sum" | "average" | "min" | "max" => {
292 has_results
294 }
295 _ => true, }
297 }
298 ReteUlNode::UlMultiField {
299 field,
300 operation,
301 value,
302 operator,
303 compare_value,
304 } => {
305 let field_value = facts.get(field);
309
310 match operation.as_str() {
311 "empty" => {
312 field_value
314 .map(|v| v.is_empty() || v == "[]")
315 .unwrap_or(true)
316 }
317 "not_empty" => {
318 field_value
320 .map(|v| !v.is_empty() && v != "[]")
321 .unwrap_or(false)
322 }
323 "count" => {
324 if let Some(val) = field_value {
325 let count = if val.starts_with('[') && val.ends_with(']') {
328 let inner = &val[1..val.len() - 1];
329 if inner.trim().is_empty() {
330 0
331 } else {
332 inner.split(',').count()
333 }
334 } else {
335 0
336 };
337
338 if let (Some(op), Some(cmp_val)) = (operator, compare_value) {
340 let cmp_num = cmp_val.parse::<i64>().unwrap_or(0);
341 match op.as_str() {
342 ">" => (count as i64) > cmp_num,
343 "<" => (count as i64) < cmp_num,
344 ">=" => (count as i64) >= cmp_num,
345 "<=" => (count as i64) <= cmp_num,
346 "==" => (count as i64) == cmp_num,
347 "!=" => (count as i64) != cmp_num,
348 _ => false,
349 }
350 } else {
351 count > 0
352 }
353 } else {
354 false
355 }
356 }
357 "contains" => {
358 if let (Some(val), Some(search)) = (field_value, value) {
359 val.contains(search)
361 } else {
362 false
363 }
364 }
365 _ => {
366 false
368 }
369 }
370 }
371 ReteUlNode::UlTerminal(_) => true, }
373}
374
375#[derive(Debug, Clone)]
377pub enum ReteUlNode {
378 UlAlpha(AlphaNode),
379 UlAnd(Box<ReteUlNode>, Box<ReteUlNode>),
380 UlOr(Box<ReteUlNode>, Box<ReteUlNode>),
381 UlNot(Box<ReteUlNode>),
382 UlExists(Box<ReteUlNode>),
383 UlForall(Box<ReteUlNode>),
384 UlAccumulate {
385 result_var: String,
386 source_pattern: String,
387 extract_field: String,
388 source_conditions: Vec<String>,
389 function: String,
390 function_arg: String,
391 },
392 UlMultiField {
393 field: String,
394 operation: String, value: Option<String>, operator: Option<String>, compare_value: Option<String>, },
399 UlTerminal(String), }
401
402impl ReteUlNode {
403 pub fn evaluate_typed(&self, facts: &super::facts::TypedFacts) -> bool {
405 evaluate_rete_ul_node_typed(self, facts)
406 }
407}
408
409pub struct ReteUlRule {
411 pub name: String,
412 pub node: ReteUlNode,
413 pub priority: i32,
414 pub no_loop: bool,
415 pub action: Arc<dyn Fn(&mut std::collections::HashMap<String, String>) + Send + Sync>,
416}
417
418pub fn fire_rete_ul_rules(
421 rules: &mut [(
422 String,
423 ReteUlNode,
424 Box<dyn FnMut(&mut std::collections::HashMap<String, String>)>,
425 )],
426 facts: &mut std::collections::HashMap<String, String>,
427) -> Vec<String> {
428 let mut fired_rules = Vec::new();
429 let mut changed = true;
430 while changed {
431 changed = false;
432 for (rule_name, node, action) in rules.iter_mut() {
433 let fired_flag = format!("{}_fired", rule_name);
434 if facts.get(&fired_flag) == Some(&"true".to_string()) {
435 continue;
436 }
437 if evaluate_rete_ul_node(node, facts) {
438 action(facts);
439 facts.insert(fired_flag.clone(), "true".to_string());
440 fired_rules.push(rule_name.clone());
441 changed = true;
442 }
443 }
444 }
445 fired_rules
446}
447
448pub fn fire_rete_ul_rules_with_agenda(
450 rules: &mut [ReteUlRule],
451 facts: &mut std::collections::HashMap<String, String>,
452) -> Vec<String> {
453 let mut fired_rules = Vec::new();
454 let mut fired_flags = std::collections::HashSet::new();
455 let max_iterations = 100; let mut iterations = 0;
457
458 loop {
459 iterations += 1;
460 if iterations > max_iterations {
461 eprintln!(
462 "Warning: RETE engine reached max iterations ({})",
463 max_iterations
464 );
465 break;
466 }
467
468 let mut agenda: Vec<usize> = rules
470 .iter()
471 .enumerate()
472 .filter(|(_, rule)| {
473 if fired_flags.contains(&rule.name) {
475 return false;
476 }
477 evaluate_rete_ul_node(&rule.node, facts)
479 })
480 .map(|(i, _)| i)
481 .collect();
482
483 if agenda.is_empty() {
485 break;
486 }
487
488 agenda.sort_by_key(|&i| -rules[i].priority);
490
491 for &i in &agenda {
493 let rule = &mut rules[i];
494
495 (rule.action)(facts);
497
498 fired_rules.push(rule.name.clone());
500 fired_flags.insert(rule.name.clone());
501
502 let fired_flag = format!("{}_fired", rule.name);
503 facts.insert(fired_flag, "true".to_string());
504 }
505
506 if rules.iter().all(|r| r.no_loop) {
508 break;
509 }
510 }
511
512 fired_rules
513}
514
515pub struct ReteUlEngine {
518 rules: Vec<ReteUlRule>,
519 facts: std::collections::HashMap<String, String>,
520}
521
522impl Default for ReteUlEngine {
523 fn default() -> Self {
524 Self::new()
525 }
526}
527
528impl ReteUlEngine {
529 pub fn new() -> Self {
531 Self {
532 rules: Vec::new(),
533 facts: std::collections::HashMap::new(),
534 }
535 }
536
537 pub fn add_rule_with_action<F>(
539 &mut self,
540 name: String,
541 node: ReteUlNode,
542 priority: i32,
543 no_loop: bool,
544 action: F,
545 ) where
546 F: Fn(&mut std::collections::HashMap<String, String>) + Send + Sync + 'static,
547 {
548 self.rules.push(ReteUlRule {
549 name,
550 node,
551 priority,
552 no_loop,
553 action: Arc::new(action),
554 });
555 }
556
557 pub fn add_rule_from_definition(
559 &mut self,
560 rule: &crate::rete::auto_network::Rule,
561 priority: i32,
562 no_loop: bool,
563 ) {
564 let node = build_rete_ul_from_condition_group(&rule.conditions);
565 let rule_name = rule.name.clone();
566
567 let action = Arc::new(
569 move |facts: &mut std::collections::HashMap<String, String>| {
570 facts.insert(format!("{}_executed", rule_name), "true".to_string());
571 },
572 );
573
574 self.rules.push(ReteUlRule {
575 name: rule.name.clone(),
576 node,
577 priority,
578 no_loop,
579 action,
580 });
581 }
582
583 pub fn set_fact(&mut self, key: String, value: String) {
585 self.facts.insert(key, value);
586 }
587
588 pub fn get_fact(&self, key: &str) -> Option<&String> {
590 self.facts.get(key)
591 }
592
593 pub fn remove_fact(&mut self, key: &str) -> Option<String> {
595 self.facts.remove(key)
596 }
597
598 pub fn get_all_facts(&self) -> &std::collections::HashMap<String, String> {
600 &self.facts
601 }
602
603 pub fn clear_facts(&mut self) {
605 self.facts.clear();
606 }
607
608 pub fn fire_all(&mut self) -> Vec<String> {
610 fire_rete_ul_rules_with_agenda(&mut self.rules, &mut self.facts)
611 }
612
613 pub fn matches(&self, rule_name: &str) -> bool {
615 self.rules
616 .iter()
617 .find(|r| r.name == rule_name)
618 .map(|r| evaluate_rete_ul_node(&r.node, &self.facts))
619 .unwrap_or(false)
620 }
621
622 pub fn get_matching_rules(&self) -> Vec<&str> {
624 self.rules
625 .iter()
626 .filter(|r| evaluate_rete_ul_node(&r.node, &self.facts))
627 .map(|r| r.name.as_str())
628 .collect()
629 }
630
631 pub fn reset_fired_flags(&mut self) {
633 let keys_to_remove: Vec<_> = self
634 .facts
635 .keys()
636 .filter(|k| k.ends_with("_fired") || k.ends_with("_executed"))
637 .cloned()
638 .collect();
639 for key in keys_to_remove {
640 self.facts.remove(&key);
641 }
642 }
643}
644
645use super::facts::{FactValue, TypedFacts};
650
651pub fn evaluate_rete_ul_node_typed(node: &ReteUlNode, facts: &TypedFacts) -> bool {
653 match node {
654 ReteUlNode::UlAlpha(alpha) => alpha.matches_typed(facts),
655 ReteUlNode::UlAnd(left, right) => {
656 evaluate_rete_ul_node_typed(left, facts) && evaluate_rete_ul_node_typed(right, facts)
657 }
658 ReteUlNode::UlOr(left, right) => {
659 evaluate_rete_ul_node_typed(left, facts) || evaluate_rete_ul_node_typed(right, facts)
660 }
661 ReteUlNode::UlNot(inner) => !evaluate_rete_ul_node_typed(inner, facts),
662 ReteUlNode::UlExists(inner) => {
663 let target_field = match &**inner {
664 ReteUlNode::UlAlpha(alpha) => alpha.field.clone(),
665 _ => "".to_string(),
666 };
667 if target_field.contains('.') {
668 let parts: Vec<&str> = target_field.split('.').collect();
669 if parts.len() == 2 {
670 let prefix = parts[0];
671 let suffix = parts[1];
672 let filtered: Vec<_> = facts
673 .get_all()
674 .iter()
675 .filter(|(k, _)| k.starts_with(prefix) && k.ends_with(suffix))
676 .collect();
677 filtered
678 .iter()
679 .any(|(_, _)| evaluate_rete_ul_node_typed(inner, facts))
680 } else {
681 evaluate_rete_ul_node_typed(inner, facts)
682 }
683 } else {
684 evaluate_rete_ul_node_typed(inner, facts)
685 }
686 }
687 ReteUlNode::UlForall(inner) => {
688 let target_field = match &**inner {
689 ReteUlNode::UlAlpha(alpha) => alpha.field.clone(),
690 _ => "".to_string(),
691 };
692 if target_field.contains('.') {
693 let parts: Vec<&str> = target_field.split('.').collect();
694 if parts.len() == 2 {
695 let prefix = parts[0];
696 let suffix = parts[1];
697 let filtered: Vec<_> = facts
698 .get_all()
699 .iter()
700 .filter(|(k, _)| k.starts_with(prefix) && k.ends_with(suffix))
701 .collect();
702 if filtered.is_empty() {
703 return true; }
705 filtered
706 .iter()
707 .all(|(_, _)| evaluate_rete_ul_node_typed(inner, facts))
708 } else {
709 if facts.get_all().is_empty() {
710 return true; }
712 evaluate_rete_ul_node_typed(inner, facts)
713 }
714 } else {
715 if facts.get_all().is_empty() {
716 return true; }
718 evaluate_rete_ul_node_typed(inner, facts)
719 }
720 }
721 ReteUlNode::UlAccumulate {
722 source_pattern,
723 extract_field,
724 source_conditions,
725 function,
726 ..
727 } => {
728 let pattern_prefix = format!("{}.", source_pattern);
731 let mut matching_values = Vec::new();
732
733 let mut instances: std::collections::HashMap<
735 String,
736 std::collections::HashMap<String, FactValue>,
737 > = std::collections::HashMap::new();
738
739 for (key, value) in facts.get_all() {
740 if key.starts_with(&pattern_prefix) {
741 let parts: Vec<&str> = key
742 .strip_prefix(&pattern_prefix)
743 .unwrap()
744 .split('.')
745 .collect();
746
747 if parts.len() >= 2 {
748 let instance_id = parts[0];
749 let field_name = parts[1..].join(".");
750
751 instances
752 .entry(instance_id.to_string())
753 .or_default()
754 .insert(field_name, value.clone());
755 } else if parts.len() == 1 {
756 instances
757 .entry("default".to_string())
758 .or_default()
759 .insert(parts[0].to_string(), value.clone());
760 }
761 }
762 }
763
764 for (_instance_id, instance_facts) in instances {
766 let mut matches = true;
767
768 for condition_str in source_conditions {
769 let string_facts: HashMap<String, String> = instance_facts
771 .iter()
772 .map(|(k, v)| (k.clone(), format!("{:?}", v)))
773 .collect();
774
775 if !evaluate_condition_string(condition_str, &string_facts) {
776 matches = false;
777 break;
778 }
779 }
780
781 if matches {
782 if let Some(value) = instance_facts.get(extract_field) {
783 matching_values.push(value.clone());
784 }
785 }
786 }
787
788 let has_results = !matching_values.is_empty();
790
791 match function.as_str() {
792 "count" => has_results,
793 "sum" | "average" | "min" | "max" => has_results,
794 _ => true,
795 }
796 }
797 ReteUlNode::UlMultiField {
798 field,
799 operation,
800 value,
801 operator,
802 compare_value,
803 } => {
804 use super::facts::FactValue;
806
807 let field_value = facts.get(field);
808
809 match operation.as_str() {
810 "empty" => {
811 if let Some(FactValue::Array(arr)) = field_value {
813 arr.is_empty()
814 } else {
815 true
817 }
818 }
819 "not_empty" => {
820 if let Some(FactValue::Array(arr)) = field_value {
822 !arr.is_empty()
823 } else {
824 false
825 }
826 }
827 "count" => {
828 if let Some(FactValue::Array(arr)) = field_value {
829 let count = arr.len() as i64;
830
831 if let (Some(op), Some(cmp_val)) = (operator, compare_value) {
833 let cmp_num = cmp_val.parse::<i64>().unwrap_or(0);
834 match op.as_str() {
835 ">" => count > cmp_num,
836 "<" => count < cmp_num,
837 ">=" => count >= cmp_num,
838 "<=" => count <= cmp_num,
839 "==" => count == cmp_num,
840 "!=" => count != cmp_num,
841 _ => false,
842 }
843 } else {
844 count > 0
845 }
846 } else {
847 false
848 }
849 }
850 "contains" => {
851 if let (Some(FactValue::Array(arr)), Some(search)) = (field_value, value) {
852 arr.iter().any(|item| match item {
855 FactValue::String(s) => s == search,
856 FactValue::Integer(i) => i.to_string() == *search,
857 FactValue::Float(f) => f.to_string() == *search,
858 FactValue::Boolean(b) => b.to_string() == *search,
859 _ => false,
860 })
861 } else {
862 false
863 }
864 }
865 "first" => {
866 if let Some(FactValue::Array(arr)) = field_value {
868 !arr.is_empty()
869 } else {
870 false
871 }
872 }
873 "last" => {
874 if let Some(FactValue::Array(arr)) = field_value {
876 !arr.is_empty()
877 } else {
878 false
879 }
880 }
881 "collect" => {
882 matches!(field_value, Some(FactValue::Array(_)))
884 }
885 _ => {
886 false
888 }
889 }
890 }
891 ReteUlNode::UlTerminal(_) => true,
892 }
893}
894
895pub struct TypedReteUlRule {
897 pub name: String,
898 pub node: ReteUlNode,
899 pub priority: i32,
900 pub no_loop: bool,
901 pub action: Arc<dyn Fn(&mut TypedFacts, &mut super::ActionResults) + Send + Sync>,
902}
903
904pub struct TypedReteUlEngine {
907 rules: Vec<TypedReteUlRule>,
908 facts: TypedFacts,
909}
910
911impl TypedReteUlEngine {
912 pub fn new() -> Self {
914 Self {
915 rules: Vec::new(),
916 facts: TypedFacts::new(),
917 }
918 }
919
920 pub fn add_rule_with_action<F>(
922 &mut self,
923 name: String,
924 node: ReteUlNode,
925 priority: i32,
926 no_loop: bool,
927 action: F,
928 ) where
929 F: Fn(&mut TypedFacts, &mut super::ActionResults) + Send + Sync + 'static,
930 {
931 self.rules.push(TypedReteUlRule {
932 name,
933 node,
934 priority,
935 no_loop,
936 action: Arc::new(action),
937 });
938 }
939
940 pub fn add_rule_from_definition(
942 &mut self,
943 rule: &crate::rete::auto_network::Rule,
944 priority: i32,
945 no_loop: bool,
946 ) {
947 let node = build_rete_ul_from_condition_group(&rule.conditions);
948 let rule_name = rule.name.clone();
949
950 let action = Arc::new(
951 move |facts: &mut TypedFacts, _results: &mut super::ActionResults| {
952 facts.set(format!("{}_executed", rule_name), true);
953 },
954 );
955
956 self.rules.push(TypedReteUlRule {
957 name: rule.name.clone(),
958 node,
959 priority,
960 no_loop,
961 action,
962 });
963 }
964
965 pub fn set_fact<K: Into<String>, V: Into<FactValue>>(&mut self, key: K, value: V) {
967 self.facts.set(key, value);
968 }
969
970 pub fn get_fact(&self, key: &str) -> Option<&FactValue> {
972 self.facts.get(key)
973 }
974
975 pub fn remove_fact(&mut self, key: &str) -> Option<FactValue> {
977 self.facts.remove(key)
978 }
979
980 pub fn get_all_facts(&self) -> &TypedFacts {
982 &self.facts
983 }
984
985 pub fn clear_facts(&mut self) {
987 self.facts.clear();
988 }
989
990 pub fn fire_all(&mut self) -> Vec<String> {
992 let mut fired_rules = Vec::new();
993 let mut agenda: Vec<usize>;
994 let mut changed = true;
995 let mut fired_flags = std::collections::HashSet::new();
996
997 while changed {
998 changed = false;
999
1000 agenda = self
1002 .rules
1003 .iter()
1004 .enumerate()
1005 .filter(|(_, rule)| {
1006 let fired_flag = format!("{}_fired", rule.name);
1007 let already_fired = fired_flags.contains(&rule.name)
1008 || self.facts.get(&fired_flag).and_then(|v| v.as_boolean()) == Some(true);
1009 !rule.no_loop || !already_fired
1010 })
1011 .filter(|(_, rule)| evaluate_rete_ul_node_typed(&rule.node, &self.facts))
1012 .map(|(i, _)| i)
1013 .collect();
1014
1015 agenda.sort_by_key(|&i| -self.rules[i].priority);
1017
1018 for &i in &agenda {
1019 let rule = &mut self.rules[i];
1020 let fired_flag = format!("{}_fired", rule.name);
1021 let already_fired = fired_flags.contains(&rule.name)
1022 || self.facts.get(&fired_flag).and_then(|v| v.as_boolean()) == Some(true);
1023
1024 if rule.no_loop && already_fired {
1025 continue;
1026 }
1027
1028 let mut action_results = super::ActionResults::new();
1029 (rule.action)(&mut self.facts, &mut action_results);
1030 fired_rules.push(rule.name.clone());
1034 fired_flags.insert(rule.name.clone());
1035 self.facts.set(fired_flag, true);
1036 changed = true;
1037 }
1038 }
1039
1040 fired_rules
1041 }
1042
1043 pub fn matches(&self, rule_name: &str) -> bool {
1045 self.rules
1046 .iter()
1047 .find(|r| r.name == rule_name)
1048 .map(|r| evaluate_rete_ul_node_typed(&r.node, &self.facts))
1049 .unwrap_or(false)
1050 }
1051
1052 pub fn get_matching_rules(&self) -> Vec<&str> {
1054 self.rules
1055 .iter()
1056 .filter(|r| evaluate_rete_ul_node_typed(&r.node, &self.facts))
1057 .map(|r| r.name.as_str())
1058 .collect()
1059 }
1060
1061 pub fn reset_fired_flags(&mut self) {
1063 let keys_to_remove: Vec<_> = self
1064 .facts
1065 .get_all()
1066 .keys()
1067 .filter(|k| k.ends_with("_fired") || k.ends_with("_executed"))
1068 .cloned()
1069 .collect();
1070 for key in keys_to_remove {
1071 self.facts.remove(&key);
1072 }
1073 }
1074}
1075
1076impl Default for TypedReteUlEngine {
1077 fn default() -> Self {
1078 Self::new()
1079 }
1080}