1use crate::rete::alpha::AlphaNode;
2use std::sync::Arc;
3pub fn build_rete_ul_from_condition_group(group: &crate::rete::auto_network::ConditionGroup) -> ReteUlNode {
5 use crate::rete::auto_network::ConditionGroup;
6 match group {
7 ConditionGroup::Single(cond) => {
8 ReteUlNode::UlAlpha(AlphaNode {
9 field: cond.field.clone(),
10 operator: cond.operator.clone(),
11 value: cond.value.clone(),
12 })
13 }
14 ConditionGroup::Compound { left, operator, right } => {
15 match operator.as_str() {
16 "AND" => ReteUlNode::UlAnd(
17 Box::new(build_rete_ul_from_condition_group(left)),
18 Box::new(build_rete_ul_from_condition_group(right)),
19 ),
20 "OR" => ReteUlNode::UlOr(
21 Box::new(build_rete_ul_from_condition_group(left)),
22 Box::new(build_rete_ul_from_condition_group(right)),
23 ),
24 _ => ReteUlNode::UlAnd(
25 Box::new(build_rete_ul_from_condition_group(left)),
26 Box::new(build_rete_ul_from_condition_group(right)),
27 ),
28 }
29 }
30 ConditionGroup::Not(inner) => {
31 ReteUlNode::UlNot(Box::new(build_rete_ul_from_condition_group(inner)))
32 }
33 ConditionGroup::Exists(inner) => {
34 ReteUlNode::UlExists(Box::new(build_rete_ul_from_condition_group(inner)))
35 }
36 ConditionGroup::Forall(inner) => {
37 ReteUlNode::UlForall(Box::new(build_rete_ul_from_condition_group(inner)))
38 }
39 }
40}
41use std::collections::HashMap;
42
43fn evaluate_condition_string(condition: &str, facts: &HashMap<String, String>) -> bool {
45 let condition = condition.trim();
46 let operators = ["==", "!=", ">=", "<=", ">", "<"];
47
48 for op in &operators {
49 if let Some(pos) = condition.find(op) {
50 let field = condition[..pos].trim();
51 let value_str = condition[pos + op.len()..]
52 .trim()
53 .trim_matches('"')
54 .trim_matches('\'');
55
56 if let Some(field_value) = facts.get(field) {
57 return compare_string_values(field_value, op, value_str);
58 } else {
59 return false;
60 }
61 }
62 }
63 false
64}
65
66fn compare_string_values(field_value: &str, operator: &str, value_str: &str) -> bool {
68 if let (Ok(field_num), Ok(val_num)) = (field_value.parse::<f64>(), value_str.parse::<f64>()) {
70 match operator {
71 "==" => (field_num - val_num).abs() < f64::EPSILON,
72 "!=" => (field_num - val_num).abs() >= f64::EPSILON,
73 ">" => field_num > val_num,
74 "<" => field_num < val_num,
75 ">=" => field_num >= val_num,
76 "<=" => field_num <= val_num,
77 _ => false,
78 }
79 } else {
80 match operator {
82 "==" => field_value == value_str,
83 "!=" => field_value != value_str,
84 _ => false,
85 }
86 }
87}
88
89pub fn evaluate_rete_ul_node(node: &ReteUlNode, facts: &HashMap<String, String>) -> bool {
91 match node {
92 ReteUlNode::UlAlpha(alpha) => {
93 let val = if alpha.field.contains('.') {
94 let parts: Vec<&str> = alpha.field.split('.').collect();
95 if parts.len() == 2 {
96 let prefix = parts[0];
97 let suffix = parts[1];
98 facts.get(&format!("{}.{}", prefix, suffix)).or_else(|| facts.get(&format!("{}:{}", prefix, suffix)))
99 } else {
100 facts.get(&alpha.field)
101 }
102 } else {
103 facts.get(&alpha.field)
104 };
105 if let Some(val) = val {
106 match alpha.operator.as_str() {
107 "==" => val == &alpha.value,
108 "!=" => val != &alpha.value,
109 ">" => val.parse::<f64>().unwrap_or(0.0) > alpha.value.parse::<f64>().unwrap_or(0.0),
110 "<" => val.parse::<f64>().unwrap_or(0.0) < alpha.value.parse::<f64>().unwrap_or(0.0),
111 ">=" => val.parse::<f64>().unwrap_or(0.0) >= alpha.value.parse::<f64>().unwrap_or(0.0),
112 "<=" => val.parse::<f64>().unwrap_or(0.0) <= alpha.value.parse::<f64>().unwrap_or(0.0),
113 _ => false,
114 }
115 } else {
116 false
117 }
118 }
119 ReteUlNode::UlAnd(left, right) => {
120 evaluate_rete_ul_node(left, facts) && evaluate_rete_ul_node(right, facts)
121 }
122 ReteUlNode::UlOr(left, right) => {
123 evaluate_rete_ul_node(left, facts) || evaluate_rete_ul_node(right, facts)
124 }
125 ReteUlNode::UlNot(inner) => {
126 !evaluate_rete_ul_node(inner, facts)
127 }
128 ReteUlNode::UlExists(inner) => {
129 let target_field = match &**inner {
130 ReteUlNode::UlAlpha(alpha) => alpha.field.clone(),
131 _ => "".to_string(),
132 };
133 if target_field.contains('.') {
134 let parts: Vec<&str> = target_field.split('.').collect();
135 if parts.len() == 2 {
136 let prefix = parts[0];
137 let suffix = parts[1];
138 let filtered: Vec<_> = facts.iter()
139 .filter(|(k, _)| k.starts_with(prefix) && k.ends_with(suffix))
140 .collect();
141 filtered.iter().any(|(_, value)| {
142 let mut sub_facts = HashMap::new();
143 sub_facts.insert(target_field.clone(), (*value).clone());
144 evaluate_rete_ul_node(inner, &sub_facts)
145 })
146 } else {
147 facts.iter().any(|(field, value)| {
148 let mut sub_facts = HashMap::new();
149 sub_facts.insert(field.clone(), value.clone());
150 evaluate_rete_ul_node(inner, &sub_facts)
151 })
152 }
153 } else {
154 facts.iter().any(|(field, value)| {
155 let mut sub_facts = HashMap::new();
156 sub_facts.insert(field.clone(), value.clone());
157 evaluate_rete_ul_node(inner, &sub_facts)
158 })
159 }
160 }
161 ReteUlNode::UlForall(inner) => {
162 let target_field = match &**inner {
163 ReteUlNode::UlAlpha(alpha) => alpha.field.clone(),
164 _ => "".to_string(),
165 };
166 if target_field.contains('.') {
167 let parts: Vec<&str> = target_field.split('.').collect();
168 if parts.len() == 2 {
169 let prefix = parts[0];
170 let suffix = parts[1];
171 let filtered: Vec<_> = facts.iter()
172 .filter(|(k, _)| k.starts_with(prefix) && k.ends_with(suffix))
173 .collect();
174 if filtered.is_empty() {
175 return true; }
177 filtered.iter().all(|(_, value)| {
178 let mut sub_facts = HashMap::new();
179 sub_facts.insert(target_field.clone(), (*value).clone());
180 evaluate_rete_ul_node(inner, &sub_facts)
181 })
182 } else {
183 facts.iter().all(|(field, value)| {
184 let mut sub_facts = HashMap::new();
185 sub_facts.insert(field.clone(), value.clone());
186 evaluate_rete_ul_node(inner, &sub_facts)
187 })
188 }
189 } else {
190 facts.iter().all(|(field, value)| {
191 let mut sub_facts = HashMap::new();
192 sub_facts.insert(field.clone(), value.clone());
193 evaluate_rete_ul_node(inner, &sub_facts)
194 })
195 }
196 }
197 ReteUlNode::UlAccumulate {
198 source_pattern,
199 extract_field,
200 source_conditions,
201 function,
202 ..
203 } => {
204 use super::accumulate::*;
206
207 let pattern_prefix = format!("{}.", source_pattern);
208 let mut matching_values = Vec::new();
209
210 let mut instances: std::collections::HashMap<String, std::collections::HashMap<String, String>> =
212 std::collections::HashMap::new();
213
214 for (key, value) in facts {
215 if key.starts_with(&pattern_prefix) {
216 let parts: Vec<&str> = key.strip_prefix(&pattern_prefix).unwrap().split('.').collect();
217
218 if parts.len() >= 2 {
219 let instance_id = parts[0];
220 let field_name = parts[1..].join(".");
221
222 instances
223 .entry(instance_id.to_string())
224 .or_insert_with(std::collections::HashMap::new)
225 .insert(field_name, value.clone());
226 } else if parts.len() == 1 {
227 instances
228 .entry("default".to_string())
229 .or_insert_with(std::collections::HashMap::new)
230 .insert(parts[0].to_string(), value.clone());
231 }
232 }
233 }
234
235 for (_instance_id, instance_facts) in instances {
237 let mut matches = true;
238
239 for condition_str in source_conditions {
240 if !evaluate_condition_string(condition_str, &instance_facts) {
241 matches = false;
242 break;
243 }
244 }
245
246 if matches {
247 if let Some(value_str) = instance_facts.get(extract_field) {
248 let fact_value = if let Ok(i) = value_str.parse::<i64>() {
250 super::facts::FactValue::Integer(i)
251 } else if let Ok(f) = value_str.parse::<f64>() {
252 super::facts::FactValue::Float(f)
253 } else if let Ok(b) = value_str.parse::<bool>() {
254 super::facts::FactValue::Boolean(b)
255 } else {
256 super::facts::FactValue::String(value_str.clone())
257 };
258 matching_values.push(fact_value);
259 }
260 }
261 }
262
263 let has_results = !matching_values.is_empty();
265
266 match function.as_str() {
267 "count" => has_results, "sum" | "average" | "min" | "max" => {
269 has_results
271 }
272 _ => true, }
274 }
275 ReteUlNode::UlMultiField { field, operation, value, operator, compare_value } => {
276 let field_value = facts.get(field);
280
281 match operation.as_str() {
282 "empty" => {
283 field_value.map(|v| v.is_empty() || v == "[]").unwrap_or(true)
285 }
286 "not_empty" => {
287 field_value.map(|v| !v.is_empty() && v != "[]").unwrap_or(false)
289 }
290 "count" => {
291 if let Some(val) = field_value {
292 let count = if val.starts_with('[') && val.ends_with(']') {
295 let inner = &val[1..val.len()-1];
296 if inner.trim().is_empty() {
297 0
298 } else {
299 inner.split(',').count()
300 }
301 } else {
302 0
303 };
304
305 if let (Some(op), Some(cmp_val)) = (operator, compare_value) {
307 let cmp_num = cmp_val.parse::<i64>().unwrap_or(0);
308 match op.as_str() {
309 ">" => (count as i64) > cmp_num,
310 "<" => (count as i64) < cmp_num,
311 ">=" => (count as i64) >= cmp_num,
312 "<=" => (count as i64) <= cmp_num,
313 "==" => (count as i64) == cmp_num,
314 "!=" => (count as i64) != cmp_num,
315 _ => false,
316 }
317 } else {
318 count > 0
319 }
320 } else {
321 false
322 }
323 }
324 "contains" => {
325 if let (Some(val), Some(search)) = (field_value, value) {
326 val.contains(search)
328 } else {
329 false
330 }
331 }
332 _ => {
333 false
335 }
336 }
337 }
338 ReteUlNode::UlTerminal(_) => true }
340}
341
342#[derive(Debug, Clone)]
344pub enum ReteUlNode {
345 UlAlpha(AlphaNode),
346 UlAnd(Box<ReteUlNode>, Box<ReteUlNode>),
347 UlOr(Box<ReteUlNode>, Box<ReteUlNode>),
348 UlNot(Box<ReteUlNode>),
349 UlExists(Box<ReteUlNode>),
350 UlForall(Box<ReteUlNode>),
351 UlAccumulate {
352 result_var: String,
353 source_pattern: String,
354 extract_field: String,
355 source_conditions: Vec<String>,
356 function: String,
357 function_arg: String,
358 },
359 UlMultiField {
360 field: String,
361 operation: String, value: Option<String>, operator: Option<String>, compare_value: Option<String>, },
366 UlTerminal(String), }
368
369impl ReteUlNode {
370 pub fn evaluate_typed(&self, facts: &super::facts::TypedFacts) -> bool {
372 evaluate_rete_ul_node_typed(self, facts)
373 }
374}
375
376pub struct ReteUlRule {
378 pub name: String,
379 pub node: ReteUlNode,
380 pub priority: i32,
381 pub no_loop: bool,
382 pub action: Arc<dyn Fn(&mut std::collections::HashMap<String, String>) + Send + Sync>,
383}
384
385pub fn fire_rete_ul_rules(
388 rules: &mut [(String, ReteUlNode, Box<dyn FnMut(&mut std::collections::HashMap<String, String>)>)],
389 facts: &mut std::collections::HashMap<String, String>,
390) -> Vec<String> {
391 let mut fired_rules = Vec::new();
392 let mut changed = true;
393 while changed {
394 changed = false;
395 for (rule_name, node, action) in rules.iter_mut() {
396 let fired_flag = format!("{}_fired", rule_name);
397 if facts.get(&fired_flag) == Some(&"true".to_string()) {
398 continue;
399 }
400 if evaluate_rete_ul_node(node, facts) {
401 action(facts);
402 facts.insert(fired_flag.clone(), "true".to_string());
403 fired_rules.push(rule_name.clone());
404 changed = true;
405 }
406 }
407 }
408 fired_rules
409}
410
411pub fn fire_rete_ul_rules_with_agenda(
413 rules: &mut [ReteUlRule],
414 facts: &mut std::collections::HashMap<String, String>,
415) -> Vec<String> {
416 let mut fired_rules = Vec::new();
417 let mut fired_flags = std::collections::HashSet::new();
418 let max_iterations = 100; let mut iterations = 0;
420
421 loop {
422 iterations += 1;
423 if iterations > max_iterations {
424 eprintln!("Warning: RETE engine reached max iterations ({})", max_iterations);
425 break;
426 }
427
428 let mut agenda: Vec<usize> = rules
430 .iter()
431 .enumerate()
432 .filter(|(_, rule)| {
433 if fired_flags.contains(&rule.name) {
435 return false;
436 }
437 evaluate_rete_ul_node(&rule.node, facts)
439 })
440 .map(|(i, _)| i)
441 .collect();
442
443 if agenda.is_empty() {
445 break;
446 }
447
448 agenda.sort_by_key(|&i| -rules[i].priority);
450
451 for &i in &agenda {
453 let rule = &mut rules[i];
454
455 (rule.action)(facts);
457
458 fired_rules.push(rule.name.clone());
460 fired_flags.insert(rule.name.clone());
461
462 let fired_flag = format!("{}_fired", rule.name);
463 facts.insert(fired_flag, "true".to_string());
464 }
465
466 if rules.iter().all(|r| r.no_loop) {
468 break;
469 }
470 }
471
472 fired_rules
473}
474
475pub struct ReteUlEngine {
478 rules: Vec<ReteUlRule>,
479 facts: std::collections::HashMap<String, String>,
480}
481
482impl ReteUlEngine {
483 pub fn new() -> Self {
485 Self {
486 rules: Vec::new(),
487 facts: std::collections::HashMap::new(),
488 }
489 }
490
491 pub fn add_rule_with_action<F>(
493 &mut self,
494 name: String,
495 node: ReteUlNode,
496 priority: i32,
497 no_loop: bool,
498 action: F,
499 ) where
500 F: Fn(&mut std::collections::HashMap<String, String>) + Send + Sync + 'static,
501 {
502 self.rules.push(ReteUlRule {
503 name,
504 node,
505 priority,
506 no_loop,
507 action: Arc::new(action),
508 });
509 }
510
511 pub fn add_rule_from_definition(
513 &mut self,
514 rule: &crate::rete::auto_network::Rule,
515 priority: i32,
516 no_loop: bool,
517 ) {
518 let node = build_rete_ul_from_condition_group(&rule.conditions);
519 let rule_name = rule.name.clone();
520
521 let action = Arc::new(move |facts: &mut std::collections::HashMap<String, String>| {
523 facts.insert(format!("{}_executed", rule_name), "true".to_string());
524 });
525
526 self.rules.push(ReteUlRule {
527 name: rule.name.clone(),
528 node,
529 priority,
530 no_loop,
531 action,
532 });
533 }
534
535 pub fn set_fact(&mut self, key: String, value: String) {
537 self.facts.insert(key, value);
538 }
539
540 pub fn get_fact(&self, key: &str) -> Option<&String> {
542 self.facts.get(key)
543 }
544
545 pub fn remove_fact(&mut self, key: &str) -> Option<String> {
547 self.facts.remove(key)
548 }
549
550 pub fn get_all_facts(&self) -> &std::collections::HashMap<String, String> {
552 &self.facts
553 }
554
555 pub fn clear_facts(&mut self) {
557 self.facts.clear();
558 }
559
560 pub fn fire_all(&mut self) -> Vec<String> {
562 fire_rete_ul_rules_with_agenda(&mut self.rules, &mut self.facts)
563 }
564
565 pub fn matches(&self, rule_name: &str) -> bool {
567 self.rules
568 .iter()
569 .find(|r| r.name == rule_name)
570 .map(|r| evaluate_rete_ul_node(&r.node, &self.facts))
571 .unwrap_or(false)
572 }
573
574 pub fn get_matching_rules(&self) -> Vec<&str> {
576 self.rules
577 .iter()
578 .filter(|r| evaluate_rete_ul_node(&r.node, &self.facts))
579 .map(|r| r.name.as_str())
580 .collect()
581 }
582
583 pub fn reset_fired_flags(&mut self) {
585 let keys_to_remove: Vec<_> = self.facts
586 .keys()
587 .filter(|k| k.ends_with("_fired") || k.ends_with("_executed"))
588 .cloned()
589 .collect();
590 for key in keys_to_remove {
591 self.facts.remove(&key);
592 }
593 }
594}
595
596use super::facts::{FactValue, TypedFacts};
601
602pub fn evaluate_rete_ul_node_typed(node: &ReteUlNode, facts: &TypedFacts) -> bool {
604 match node {
605 ReteUlNode::UlAlpha(alpha) => {
606 alpha.matches_typed(facts)
607 }
608 ReteUlNode::UlAnd(left, right) => {
609 evaluate_rete_ul_node_typed(left, facts) && evaluate_rete_ul_node_typed(right, facts)
610 }
611 ReteUlNode::UlOr(left, right) => {
612 evaluate_rete_ul_node_typed(left, facts) || evaluate_rete_ul_node_typed(right, facts)
613 }
614 ReteUlNode::UlNot(inner) => {
615 !evaluate_rete_ul_node_typed(inner, facts)
616 }
617 ReteUlNode::UlExists(inner) => {
618 let target_field = match &**inner {
619 ReteUlNode::UlAlpha(alpha) => alpha.field.clone(),
620 _ => "".to_string(),
621 };
622 if target_field.contains('.') {
623 let parts: Vec<&str> = target_field.split('.').collect();
624 if parts.len() == 2 {
625 let prefix = parts[0];
626 let suffix = parts[1];
627 let filtered: Vec<_> = facts.get_all().iter()
628 .filter(|(k, _)| k.starts_with(prefix) && k.ends_with(suffix))
629 .collect();
630 filtered.iter().any(|(_, _)| {
631 evaluate_rete_ul_node_typed(inner, facts)
632 })
633 } else {
634 evaluate_rete_ul_node_typed(inner, facts)
635 }
636 } else {
637 evaluate_rete_ul_node_typed(inner, facts)
638 }
639 }
640 ReteUlNode::UlForall(inner) => {
641 let target_field = match &**inner {
642 ReteUlNode::UlAlpha(alpha) => alpha.field.clone(),
643 _ => "".to_string(),
644 };
645 if target_field.contains('.') {
646 let parts: Vec<&str> = target_field.split('.').collect();
647 if parts.len() == 2 {
648 let prefix = parts[0];
649 let suffix = parts[1];
650 let filtered: Vec<_> = facts.get_all().iter()
651 .filter(|(k, _)| k.starts_with(prefix) && k.ends_with(suffix))
652 .collect();
653 if filtered.is_empty() {
654 return true; }
656 filtered.iter().all(|(_, _)| {
657 evaluate_rete_ul_node_typed(inner, facts)
658 })
659 } else {
660 if facts.get_all().is_empty() {
661 return true; }
663 evaluate_rete_ul_node_typed(inner, facts)
664 }
665 } else {
666 if facts.get_all().is_empty() {
667 return true; }
669 evaluate_rete_ul_node_typed(inner, facts)
670 }
671 }
672 ReteUlNode::UlAccumulate {
673 source_pattern,
674 extract_field,
675 source_conditions,
676 function,
677 ..
678 } => {
679 use super::accumulate::*;
681
682 let pattern_prefix = format!("{}.", source_pattern);
683 let mut matching_values = Vec::new();
684
685 let mut instances: std::collections::HashMap<String, std::collections::HashMap<String, FactValue>> =
687 std::collections::HashMap::new();
688
689 for (key, value) in facts.get_all() {
690 if key.starts_with(&pattern_prefix) {
691 let parts: Vec<&str> = key.strip_prefix(&pattern_prefix).unwrap().split('.').collect();
692
693 if parts.len() >= 2 {
694 let instance_id = parts[0];
695 let field_name = parts[1..].join(".");
696
697 instances
698 .entry(instance_id.to_string())
699 .or_insert_with(std::collections::HashMap::new)
700 .insert(field_name, value.clone());
701 } else if parts.len() == 1 {
702 instances
703 .entry("default".to_string())
704 .or_insert_with(std::collections::HashMap::new)
705 .insert(parts[0].to_string(), value.clone());
706 }
707 }
708 }
709
710 for (_instance_id, instance_facts) in instances {
712 let mut matches = true;
713
714 for condition_str in source_conditions {
715 let string_facts: HashMap<String, String> = instance_facts
717 .iter()
718 .map(|(k, v)| (k.clone(), format!("{:?}", v)))
719 .collect();
720
721 if !evaluate_condition_string(condition_str, &string_facts) {
722 matches = false;
723 break;
724 }
725 }
726
727 if matches {
728 if let Some(value) = instance_facts.get(extract_field) {
729 matching_values.push(value.clone());
730 }
731 }
732 }
733
734 let has_results = !matching_values.is_empty();
736
737 match function.as_str() {
738 "count" => has_results,
739 "sum" | "average" | "min" | "max" => has_results,
740 _ => true,
741 }
742 }
743 ReteUlNode::UlMultiField { field, operation, value, operator, compare_value } => {
744 use super::facts::FactValue;
746
747 let field_value = facts.get(field);
748
749 match operation.as_str() {
750 "empty" => {
751 if let Some(FactValue::Array(arr)) = field_value {
753 arr.is_empty()
754 } else {
755 true
757 }
758 }
759 "not_empty" => {
760 if let Some(FactValue::Array(arr)) = field_value {
762 !arr.is_empty()
763 } else {
764 false
765 }
766 }
767 "count" => {
768 if let Some(FactValue::Array(arr)) = field_value {
769 let count = arr.len() as i64;
770
771 if let (Some(op), Some(cmp_val)) = (operator, compare_value) {
773 let cmp_num = cmp_val.parse::<i64>().unwrap_or(0);
774 match op.as_str() {
775 ">" => count > cmp_num,
776 "<" => count < cmp_num,
777 ">=" => count >= cmp_num,
778 "<=" => count <= cmp_num,
779 "==" => count == cmp_num,
780 "!=" => count != cmp_num,
781 _ => false,
782 }
783 } else {
784 count > 0
785 }
786 } else {
787 false
788 }
789 }
790 "contains" => {
791 if let (Some(FactValue::Array(arr)), Some(search)) = (field_value, value) {
792 arr.iter().any(|item| {
795 match item {
796 FactValue::String(s) => s == search,
797 FactValue::Integer(i) => i.to_string() == *search,
798 FactValue::Float(f) => f.to_string() == *search,
799 FactValue::Boolean(b) => b.to_string() == *search,
800 _ => false,
801 }
802 })
803 } else {
804 false
805 }
806 }
807 "first" => {
808 if let Some(FactValue::Array(arr)) = field_value {
810 !arr.is_empty()
811 } else {
812 false
813 }
814 }
815 "last" => {
816 if let Some(FactValue::Array(arr)) = field_value {
818 !arr.is_empty()
819 } else {
820 false
821 }
822 }
823 "collect" => {
824 matches!(field_value, Some(FactValue::Array(_)))
826 }
827 _ => {
828 false
830 }
831 }
832 }
833 ReteUlNode::UlTerminal(_) => true
834 }
835}
836
837pub struct TypedReteUlRule {
839 pub name: String,
840 pub node: ReteUlNode,
841 pub priority: i32,
842 pub no_loop: bool,
843 pub action: Arc<dyn Fn(&mut TypedFacts) + Send + Sync>,
844}
845
846pub struct TypedReteUlEngine {
849 rules: Vec<TypedReteUlRule>,
850 facts: TypedFacts,
851}
852
853impl TypedReteUlEngine {
854 pub fn new() -> Self {
856 Self {
857 rules: Vec::new(),
858 facts: TypedFacts::new(),
859 }
860 }
861
862 pub fn add_rule_with_action<F>(
864 &mut self,
865 name: String,
866 node: ReteUlNode,
867 priority: i32,
868 no_loop: bool,
869 action: F,
870 ) where
871 F: Fn(&mut TypedFacts) + Send + Sync + 'static,
872 {
873 self.rules.push(TypedReteUlRule {
874 name,
875 node,
876 priority,
877 no_loop,
878 action: Arc::new(action),
879 });
880 }
881
882 pub fn add_rule_from_definition(
884 &mut self,
885 rule: &crate::rete::auto_network::Rule,
886 priority: i32,
887 no_loop: bool,
888 ) {
889 let node = build_rete_ul_from_condition_group(&rule.conditions);
890 let rule_name = rule.name.clone();
891
892 let action = Arc::new(move |facts: &mut TypedFacts| {
893 facts.set(format!("{}_executed", rule_name), true);
894 });
895
896 self.rules.push(TypedReteUlRule {
897 name: rule.name.clone(),
898 node,
899 priority,
900 no_loop,
901 action,
902 });
903 }
904
905 pub fn set_fact<K: Into<String>, V: Into<FactValue>>(&mut self, key: K, value: V) {
907 self.facts.set(key, value);
908 }
909
910 pub fn get_fact(&self, key: &str) -> Option<&FactValue> {
912 self.facts.get(key)
913 }
914
915 pub fn remove_fact(&mut self, key: &str) -> Option<FactValue> {
917 self.facts.remove(key)
918 }
919
920 pub fn get_all_facts(&self) -> &TypedFacts {
922 &self.facts
923 }
924
925 pub fn clear_facts(&mut self) {
927 self.facts.clear();
928 }
929
930 pub fn fire_all(&mut self) -> Vec<String> {
932 let mut fired_rules = Vec::new();
933 let mut agenda: Vec<usize>;
934 let mut changed = true;
935 let mut fired_flags = std::collections::HashSet::new();
936
937 while changed {
938 changed = false;
939
940 agenda = self.rules.iter().enumerate()
942 .filter(|(_, rule)| {
943 let fired_flag = format!("{}_fired", rule.name);
944 let already_fired = fired_flags.contains(&rule.name) ||
945 self.facts.get(&fired_flag).and_then(|v| v.as_boolean()) == Some(true);
946 !rule.no_loop || !already_fired
947 })
948 .filter(|(_, rule)| evaluate_rete_ul_node_typed(&rule.node, &self.facts))
949 .map(|(i, _)| i)
950 .collect();
951
952 agenda.sort_by_key(|&i| -self.rules[i].priority);
954
955 for &i in &agenda {
956 let rule = &mut self.rules[i];
957 let fired_flag = format!("{}_fired", rule.name);
958 let already_fired = fired_flags.contains(&rule.name) ||
959 self.facts.get(&fired_flag).and_then(|v| v.as_boolean()) == Some(true);
960
961 if rule.no_loop && already_fired {
962 continue;
963 }
964
965 (rule.action)(&mut self.facts);
966 fired_rules.push(rule.name.clone());
967 fired_flags.insert(rule.name.clone());
968 self.facts.set(fired_flag, true);
969 changed = true;
970 }
971 }
972
973 fired_rules
974 }
975
976 pub fn matches(&self, rule_name: &str) -> bool {
978 self.rules
979 .iter()
980 .find(|r| r.name == rule_name)
981 .map(|r| evaluate_rete_ul_node_typed(&r.node, &self.facts))
982 .unwrap_or(false)
983 }
984
985 pub fn get_matching_rules(&self) -> Vec<&str> {
987 self.rules
988 .iter()
989 .filter(|r| evaluate_rete_ul_node_typed(&r.node, &self.facts))
990 .map(|r| r.name.as_str())
991 .collect()
992 }
993
994 pub fn reset_fired_flags(&mut self) {
996 let keys_to_remove: Vec<_> = self.facts.get_all()
997 .keys()
998 .filter(|k| k.ends_with("_fired") || k.ends_with("_executed"))
999 .cloned()
1000 .collect();
1001 for key in keys_to_remove {
1002 self.facts.remove(&key);
1003 }
1004 }
1005}
1006
1007impl Default for TypedReteUlEngine {
1008 fn default() -> Self {
1009 Self::new()
1010 }
1011}
1012