1#![allow(clippy::type_complexity)]
2
3use crate::rete::alpha::AlphaNode;
4use std::sync::Arc;
5pub fn build_rete_ul_from_condition_group(
7 group: &crate::rete::auto_network::ConditionGroup,
8) -> ReteUlNode {
9 use crate::rete::auto_network::ConditionGroup;
10 match group {
11 ConditionGroup::Single(cond) => ReteUlNode::UlAlpha(AlphaNode {
12 field: cond.field.clone(),
13 operator: cond.operator.clone(),
14 value: cond.value.clone(),
15 }),
16 ConditionGroup::Compound {
17 left,
18 operator,
19 right,
20 } => match operator.as_str() {
21 "AND" => ReteUlNode::UlAnd(
22 Box::new(build_rete_ul_from_condition_group(left)),
23 Box::new(build_rete_ul_from_condition_group(right)),
24 ),
25 "OR" => ReteUlNode::UlOr(
26 Box::new(build_rete_ul_from_condition_group(left)),
27 Box::new(build_rete_ul_from_condition_group(right)),
28 ),
29 _ => ReteUlNode::UlAnd(
30 Box::new(build_rete_ul_from_condition_group(left)),
31 Box::new(build_rete_ul_from_condition_group(right)),
32 ),
33 },
34 ConditionGroup::Not(inner) => {
35 ReteUlNode::UlNot(Box::new(build_rete_ul_from_condition_group(inner)))
36 }
37 ConditionGroup::Exists(inner) => {
38 ReteUlNode::UlExists(Box::new(build_rete_ul_from_condition_group(inner)))
39 }
40 ConditionGroup::Forall(inner) => {
41 ReteUlNode::UlForall(Box::new(build_rete_ul_from_condition_group(inner)))
42 }
43 }
44}
45use std::collections::HashMap;
46
47fn evaluate_condition_string(condition: &str, facts: &HashMap<String, String>) -> bool {
49 let condition = condition.trim();
50 let operators = ["==", "!=", ">=", "<=", ">", "<"];
51
52 for op in &operators {
53 if let Some(pos) = condition.find(op) {
54 let field = condition[..pos].trim();
55 let value_str = condition[pos + op.len()..]
56 .trim()
57 .trim_matches('"')
58 .trim_matches('\'');
59
60 if let Some(field_value) = facts.get(field) {
61 return compare_string_values(field_value, op, value_str);
62 } else {
63 return false;
64 }
65 }
66 }
67 false
68}
69
70fn compare_string_values(field_value: &str, operator: &str, value_str: &str) -> bool {
72 if let (Ok(field_num), Ok(val_num)) = (field_value.parse::<f64>(), value_str.parse::<f64>()) {
74 match operator {
75 "==" => (field_num - val_num).abs() < f64::EPSILON,
76 "!=" => (field_num - val_num).abs() >= f64::EPSILON,
77 ">" => field_num > val_num,
78 "<" => field_num < val_num,
79 ">=" => field_num >= val_num,
80 "<=" => field_num <= val_num,
81 _ => false,
82 }
83 } else {
84 match operator {
86 "==" => field_value == value_str,
87 "!=" => field_value != value_str,
88 _ => false,
89 }
90 }
91}
92
93pub fn evaluate_rete_ul_node(node: &ReteUlNode, facts: &HashMap<String, String>) -> bool {
95 match node {
96 ReteUlNode::UlAlpha(alpha) => {
97 let val = if alpha.field.contains('.') {
98 let parts: Vec<&str> = alpha.field.split('.').collect();
99 if parts.len() == 2 {
100 let prefix = parts[0];
101 let suffix = parts[1];
102 facts
103 .get(&format!("{}.{}", prefix, suffix))
104 .or_else(|| facts.get(&format!("{}:{}", prefix, suffix)))
105 } else {
106 facts.get(&alpha.field)
107 }
108 } else {
109 facts.get(&alpha.field)
110 };
111 if let Some(val) = val {
112 match alpha.operator.as_str() {
113 "==" => val == &alpha.value,
114 "!=" => val != &alpha.value,
115 ">" => {
116 val.parse::<f64>().unwrap_or(0.0)
117 > alpha.value.parse::<f64>().unwrap_or(0.0)
118 }
119 "<" => {
120 val.parse::<f64>().unwrap_or(0.0)
121 < alpha.value.parse::<f64>().unwrap_or(0.0)
122 }
123 ">=" => {
124 val.parse::<f64>().unwrap_or(0.0)
125 >= alpha.value.parse::<f64>().unwrap_or(0.0)
126 }
127 "<=" => {
128 val.parse::<f64>().unwrap_or(0.0)
129 <= alpha.value.parse::<f64>().unwrap_or(0.0)
130 }
131 _ => false,
132 }
133 } else {
134 false
135 }
136 }
137 ReteUlNode::UlAnd(left, right) => {
138 evaluate_rete_ul_node(left, facts) && evaluate_rete_ul_node(right, facts)
139 }
140 ReteUlNode::UlOr(left, right) => {
141 evaluate_rete_ul_node(left, facts) || evaluate_rete_ul_node(right, facts)
142 }
143 ReteUlNode::UlNot(inner) => !evaluate_rete_ul_node(inner, facts),
144 ReteUlNode::UlExists(inner) => {
145 let target_field = match &**inner {
146 ReteUlNode::UlAlpha(alpha) => alpha.field.clone(),
147 _ => "".to_string(),
148 };
149 if target_field.contains('.') {
150 let parts: Vec<&str> = target_field.split('.').collect();
151 if parts.len() == 2 {
152 let prefix = parts[0];
153 let suffix = parts[1];
154 let filtered: Vec<_> = facts
155 .iter()
156 .filter(|(k, _)| k.starts_with(prefix) && k.ends_with(suffix))
157 .collect();
158 filtered.iter().any(|(_, value)| {
159 let mut sub_facts = HashMap::new();
160 sub_facts.insert(target_field.clone(), (*value).clone());
161 evaluate_rete_ul_node(inner, &sub_facts)
162 })
163 } else {
164 facts.iter().any(|(field, value)| {
165 let mut sub_facts = HashMap::new();
166 sub_facts.insert(field.clone(), value.clone());
167 evaluate_rete_ul_node(inner, &sub_facts)
168 })
169 }
170 } else {
171 facts.iter().any(|(field, value)| {
172 let mut sub_facts = HashMap::new();
173 sub_facts.insert(field.clone(), value.clone());
174 evaluate_rete_ul_node(inner, &sub_facts)
175 })
176 }
177 }
178 ReteUlNode::UlForall(inner) => {
179 let target_field = match &**inner {
180 ReteUlNode::UlAlpha(alpha) => alpha.field.clone(),
181 _ => "".to_string(),
182 };
183 if target_field.contains('.') {
184 let parts: Vec<&str> = target_field.split('.').collect();
185 if parts.len() == 2 {
186 let prefix = parts[0];
187 let suffix = parts[1];
188 let filtered: Vec<_> = facts
189 .iter()
190 .filter(|(k, _)| k.starts_with(prefix) && k.ends_with(suffix))
191 .collect();
192 if filtered.is_empty() {
193 return true; }
195 filtered.iter().all(|(_, value)| {
196 let mut sub_facts = HashMap::new();
197 sub_facts.insert(target_field.clone(), (*value).clone());
198 evaluate_rete_ul_node(inner, &sub_facts)
199 })
200 } else {
201 facts.iter().all(|(field, value)| {
202 let mut sub_facts = HashMap::new();
203 sub_facts.insert(field.clone(), value.clone());
204 evaluate_rete_ul_node(inner, &sub_facts)
205 })
206 }
207 } else {
208 facts.iter().all(|(field, value)| {
209 let mut sub_facts = HashMap::new();
210 sub_facts.insert(field.clone(), value.clone());
211 evaluate_rete_ul_node(inner, &sub_facts)
212 })
213 }
214 }
215 ReteUlNode::UlAccumulate {
216 source_pattern,
217 extract_field,
218 source_conditions,
219 function,
220 ..
221 } => {
222 let pattern_prefix = format!("{}.", source_pattern);
225 let mut matching_values = Vec::new();
226
227 let mut instances: std::collections::HashMap<
229 String,
230 std::collections::HashMap<String, String>,
231 > = std::collections::HashMap::new();
232
233 for (key, value) in facts {
234 if key.starts_with(&pattern_prefix) {
235 let parts: Vec<&str> = key
236 .strip_prefix(&pattern_prefix)
237 .unwrap()
238 .split('.')
239 .collect();
240
241 if parts.len() >= 2 {
242 let instance_id = parts[0];
243 let field_name = parts[1..].join(".");
244
245 instances
246 .entry(instance_id.to_string())
247 .or_default()
248 .insert(field_name, value.clone());
249 } else if parts.len() == 1 {
250 instances
251 .entry("default".to_string())
252 .or_default()
253 .insert(parts[0].to_string(), value.clone());
254 }
255 }
256 }
257
258 for (_instance_id, instance_facts) in instances {
260 let mut matches = true;
261
262 for condition_str in source_conditions {
263 if !evaluate_condition_string(condition_str, &instance_facts) {
264 matches = false;
265 break;
266 }
267 }
268
269 if matches {
270 if let Some(value_str) = instance_facts.get(extract_field) {
271 let fact_value = if let Ok(i) = value_str.parse::<i64>() {
273 super::facts::FactValue::Integer(i)
274 } else if let Ok(f) = value_str.parse::<f64>() {
275 super::facts::FactValue::Float(f)
276 } else if let Ok(b) = value_str.parse::<bool>() {
277 super::facts::FactValue::Boolean(b)
278 } else {
279 super::facts::FactValue::String(value_str.clone())
280 };
281 matching_values.push(fact_value);
282 }
283 }
284 }
285
286 let has_results = !matching_values.is_empty();
288
289 match function.as_str() {
290 "count" => has_results, "sum" | "average" | "min" | "max" => {
292 has_results
294 }
295 _ => true, }
297 }
298 ReteUlNode::UlMultiField {
299 field,
300 operation,
301 value,
302 operator,
303 compare_value,
304 } => {
305 let field_value = facts.get(field);
309
310 match operation.as_str() {
311 "empty" => {
312 field_value
314 .map(|v| v.is_empty() || v == "[]")
315 .unwrap_or(true)
316 }
317 "not_empty" => {
318 field_value
320 .map(|v| !v.is_empty() && v != "[]")
321 .unwrap_or(false)
322 }
323 "count" => {
324 if let Some(val) = field_value {
325 let count = if val.starts_with('[') && val.ends_with(']') {
328 let inner = &val[1..val.len() - 1];
329 if inner.trim().is_empty() {
330 0
331 } else {
332 inner.split(',').count()
333 }
334 } else {
335 0
336 };
337
338 if let (Some(op), Some(cmp_val)) = (operator, compare_value) {
340 let cmp_num = cmp_val.parse::<i64>().unwrap_or(0);
341 match op.as_str() {
342 ">" => (count as i64) > cmp_num,
343 "<" => (count as i64) < cmp_num,
344 ">=" => (count as i64) >= cmp_num,
345 "<=" => (count as i64) <= cmp_num,
346 "==" => (count as i64) == cmp_num,
347 "!=" => (count as i64) != cmp_num,
348 _ => false,
349 }
350 } else {
351 count > 0
352 }
353 } else {
354 false
355 }
356 }
357 "contains" => {
358 if let (Some(val), Some(search)) = (field_value, value) {
359 val.contains(search)
361 } else {
362 false
363 }
364 }
365 _ => {
366 false
368 }
369 }
370 }
371 #[cfg(feature = "streaming")]
372 ReteUlNode::UlStream { .. } => {
373 true
376 }
377 ReteUlNode::UlTerminal(_) => true, }
379}
380
381#[derive(Debug, Clone)]
383pub enum ReteUlNode {
384 UlAlpha(AlphaNode),
385 UlAnd(Box<ReteUlNode>, Box<ReteUlNode>),
386 UlOr(Box<ReteUlNode>, Box<ReteUlNode>),
387 UlNot(Box<ReteUlNode>),
388 UlExists(Box<ReteUlNode>),
389 UlForall(Box<ReteUlNode>),
390 UlAccumulate {
391 result_var: String,
392 source_pattern: String,
393 extract_field: String,
394 source_conditions: Vec<String>,
395 function: String,
396 function_arg: String,
397 },
398 UlMultiField {
399 field: String,
400 operation: String, value: Option<String>, operator: Option<String>, compare_value: Option<String>, },
405 #[cfg(feature = "streaming")]
406 UlStream {
407 var_name: String,
408 event_type: Option<String>,
409 stream_name: String,
410 window: Option<StreamWindowSpec>,
411 },
412 UlTerminal(String), }
414
415#[cfg(feature = "streaming")]
416#[derive(Debug, Clone, PartialEq)]
417pub struct StreamWindowSpec {
418 pub duration: std::time::Duration,
419 pub window_type: StreamWindowTypeRete,
420}
421
422#[cfg(feature = "streaming")]
423#[derive(Debug, Clone, PartialEq)]
424pub enum StreamWindowTypeRete {
425 Sliding,
426 Tumbling,
427 Session { timeout: std::time::Duration },
428}
429
430impl ReteUlNode {
431 pub fn evaluate_typed(&self, facts: &super::facts::TypedFacts) -> bool {
433 evaluate_rete_ul_node_typed(self, facts)
434 }
435}
436
437pub struct ReteUlRule {
439 pub name: String,
440 pub node: ReteUlNode,
441 pub priority: i32,
442 pub no_loop: bool,
443 pub action: Arc<dyn Fn(&mut std::collections::HashMap<String, String>) + Send + Sync>,
444}
445
446pub fn fire_rete_ul_rules(
449 rules: &mut [(
450 String,
451 ReteUlNode,
452 Box<dyn FnMut(&mut std::collections::HashMap<String, String>)>,
453 )],
454 facts: &mut std::collections::HashMap<String, String>,
455) -> Vec<String> {
456 let mut fired_rules = Vec::new();
457 let mut changed = true;
458 while changed {
459 changed = false;
460 for (rule_name, node, action) in rules.iter_mut() {
461 let fired_flag = format!("{}_fired", rule_name);
462 if facts.get(&fired_flag) == Some(&"true".to_string()) {
463 continue;
464 }
465 if evaluate_rete_ul_node(node, facts) {
466 action(facts);
467 facts.insert(fired_flag.clone(), "true".to_string());
468 fired_rules.push(rule_name.clone());
469 changed = true;
470 }
471 }
472 }
473 fired_rules
474}
475
476pub fn fire_rete_ul_rules_with_agenda(
478 rules: &mut [ReteUlRule],
479 facts: &mut std::collections::HashMap<String, String>,
480) -> Vec<String> {
481 let mut fired_rules = Vec::new();
482 let mut fired_flags = std::collections::HashSet::new();
483 let max_iterations = 100; let mut iterations = 0;
485
486 loop {
487 iterations += 1;
488 if iterations > max_iterations {
489 eprintln!(
490 "Warning: RETE engine reached max iterations ({})",
491 max_iterations
492 );
493 break;
494 }
495
496 let mut agenda: Vec<usize> = rules
498 .iter()
499 .enumerate()
500 .filter(|(_, rule)| {
501 if fired_flags.contains(&rule.name) {
503 return false;
504 }
505 evaluate_rete_ul_node(&rule.node, facts)
507 })
508 .map(|(i, _)| i)
509 .collect();
510
511 if agenda.is_empty() {
513 break;
514 }
515
516 agenda.sort_by_key(|&i| -rules[i].priority);
518
519 for &i in &agenda {
521 let rule = &mut rules[i];
522
523 (rule.action)(facts);
525
526 fired_rules.push(rule.name.clone());
528 fired_flags.insert(rule.name.clone());
529
530 let fired_flag = format!("{}_fired", rule.name);
531 facts.insert(fired_flag, "true".to_string());
532 }
533
534 if rules.iter().all(|r| r.no_loop) {
536 break;
537 }
538 }
539
540 fired_rules
541}
542
543pub struct ReteUlEngine {
546 rules: Vec<ReteUlRule>,
547 facts: std::collections::HashMap<String, String>,
548}
549
550impl Default for ReteUlEngine {
551 fn default() -> Self {
552 Self::new()
553 }
554}
555
556impl ReteUlEngine {
557 pub fn new() -> Self {
559 Self {
560 rules: Vec::new(),
561 facts: std::collections::HashMap::new(),
562 }
563 }
564
565 pub fn add_rule_with_action<F>(
567 &mut self,
568 name: String,
569 node: ReteUlNode,
570 priority: i32,
571 no_loop: bool,
572 action: F,
573 ) where
574 F: Fn(&mut std::collections::HashMap<String, String>) + Send + Sync + 'static,
575 {
576 self.rules.push(ReteUlRule {
577 name,
578 node,
579 priority,
580 no_loop,
581 action: Arc::new(action),
582 });
583 }
584
585 pub fn add_rule_from_definition(
587 &mut self,
588 rule: &crate::rete::auto_network::Rule,
589 priority: i32,
590 no_loop: bool,
591 ) {
592 let node = build_rete_ul_from_condition_group(&rule.conditions);
593 let rule_name = rule.name.clone();
594
595 let action = Arc::new(
597 move |facts: &mut std::collections::HashMap<String, String>| {
598 facts.insert(format!("{}_executed", rule_name), "true".to_string());
599 },
600 );
601
602 self.rules.push(ReteUlRule {
603 name: rule.name.clone(),
604 node,
605 priority,
606 no_loop,
607 action,
608 });
609 }
610
611 pub fn set_fact(&mut self, key: String, value: String) {
613 self.facts.insert(key, value);
614 }
615
616 pub fn get_fact(&self, key: &str) -> Option<&String> {
618 self.facts.get(key)
619 }
620
621 pub fn remove_fact(&mut self, key: &str) -> Option<String> {
623 self.facts.remove(key)
624 }
625
626 pub fn get_all_facts(&self) -> &std::collections::HashMap<String, String> {
628 &self.facts
629 }
630
631 pub fn clear_facts(&mut self) {
633 self.facts.clear();
634 }
635
636 pub fn fire_all(&mut self) -> Vec<String> {
638 fire_rete_ul_rules_with_agenda(&mut self.rules, &mut self.facts)
639 }
640
641 pub fn matches(&self, rule_name: &str) -> bool {
643 self.rules
644 .iter()
645 .find(|r| r.name == rule_name)
646 .map(|r| evaluate_rete_ul_node(&r.node, &self.facts))
647 .unwrap_or(false)
648 }
649
650 pub fn get_matching_rules(&self) -> Vec<&str> {
652 self.rules
653 .iter()
654 .filter(|r| evaluate_rete_ul_node(&r.node, &self.facts))
655 .map(|r| r.name.as_str())
656 .collect()
657 }
658
659 pub fn reset_fired_flags(&mut self) {
661 let keys_to_remove: Vec<_> = self
662 .facts
663 .keys()
664 .filter(|k| k.ends_with("_fired") || k.ends_with("_executed"))
665 .cloned()
666 .collect();
667 for key in keys_to_remove {
668 self.facts.remove(&key);
669 }
670 }
671}
672
673use super::facts::{FactValue, TypedFacts};
678
679pub fn evaluate_rete_ul_node_typed(node: &ReteUlNode, facts: &TypedFacts) -> bool {
681 match node {
682 ReteUlNode::UlAlpha(alpha) => alpha.matches_typed(facts),
683 ReteUlNode::UlAnd(left, right) => {
684 evaluate_rete_ul_node_typed(left, facts) && evaluate_rete_ul_node_typed(right, facts)
685 }
686 ReteUlNode::UlOr(left, right) => {
687 evaluate_rete_ul_node_typed(left, facts) || evaluate_rete_ul_node_typed(right, facts)
688 }
689 ReteUlNode::UlNot(inner) => !evaluate_rete_ul_node_typed(inner, facts),
690 ReteUlNode::UlExists(inner) => {
691 let target_field = match &**inner {
692 ReteUlNode::UlAlpha(alpha) => alpha.field.clone(),
693 _ => "".to_string(),
694 };
695 if target_field.contains('.') {
696 let parts: Vec<&str> = target_field.split('.').collect();
697 if parts.len() == 2 {
698 let prefix = parts[0];
699 let suffix = parts[1];
700 let filtered: Vec<_> = facts
701 .get_all()
702 .iter()
703 .filter(|(k, _)| k.starts_with(prefix) && k.ends_with(suffix))
704 .collect();
705 filtered
706 .iter()
707 .any(|(_, _)| evaluate_rete_ul_node_typed(inner, facts))
708 } else {
709 evaluate_rete_ul_node_typed(inner, facts)
710 }
711 } else {
712 evaluate_rete_ul_node_typed(inner, facts)
713 }
714 }
715 ReteUlNode::UlForall(inner) => {
716 let target_field = match &**inner {
717 ReteUlNode::UlAlpha(alpha) => alpha.field.clone(),
718 _ => "".to_string(),
719 };
720 if target_field.contains('.') {
721 let parts: Vec<&str> = target_field.split('.').collect();
722 if parts.len() == 2 {
723 let prefix = parts[0];
724 let suffix = parts[1];
725 let filtered: Vec<_> = facts
726 .get_all()
727 .iter()
728 .filter(|(k, _)| k.starts_with(prefix) && k.ends_with(suffix))
729 .collect();
730 if filtered.is_empty() {
731 return true; }
733 filtered
734 .iter()
735 .all(|(_, _)| evaluate_rete_ul_node_typed(inner, facts))
736 } else {
737 if facts.get_all().is_empty() {
738 return true; }
740 evaluate_rete_ul_node_typed(inner, facts)
741 }
742 } else {
743 if facts.get_all().is_empty() {
744 return true; }
746 evaluate_rete_ul_node_typed(inner, facts)
747 }
748 }
749 ReteUlNode::UlAccumulate {
750 source_pattern,
751 extract_field,
752 source_conditions,
753 function,
754 ..
755 } => {
756 let pattern_prefix = format!("{}.", source_pattern);
759 let mut matching_values = Vec::new();
760
761 let mut instances: std::collections::HashMap<
763 String,
764 std::collections::HashMap<String, FactValue>,
765 > = std::collections::HashMap::new();
766
767 for (key, value) in facts.get_all() {
768 if key.starts_with(&pattern_prefix) {
769 let parts: Vec<&str> = key
770 .strip_prefix(&pattern_prefix)
771 .unwrap()
772 .split('.')
773 .collect();
774
775 if parts.len() >= 2 {
776 let instance_id = parts[0];
777 let field_name = parts[1..].join(".");
778
779 instances
780 .entry(instance_id.to_string())
781 .or_default()
782 .insert(field_name, value.clone());
783 } else if parts.len() == 1 {
784 instances
785 .entry("default".to_string())
786 .or_default()
787 .insert(parts[0].to_string(), value.clone());
788 }
789 }
790 }
791
792 for (_instance_id, instance_facts) in instances {
794 let mut matches = true;
795
796 for condition_str in source_conditions {
797 let string_facts: HashMap<String, String> = instance_facts
799 .iter()
800 .map(|(k, v)| (k.clone(), format!("{:?}", v)))
801 .collect();
802
803 if !evaluate_condition_string(condition_str, &string_facts) {
804 matches = false;
805 break;
806 }
807 }
808
809 if matches {
810 if let Some(value) = instance_facts.get(extract_field) {
811 matching_values.push(value.clone());
812 }
813 }
814 }
815
816 let has_results = !matching_values.is_empty();
818
819 match function.as_str() {
820 "count" => has_results,
821 "sum" | "average" | "min" | "max" => has_results,
822 _ => true,
823 }
824 }
825 ReteUlNode::UlMultiField {
826 field,
827 operation,
828 value,
829 operator,
830 compare_value,
831 } => {
832 use super::facts::FactValue;
834
835 let field_value = facts.get(field);
836
837 match operation.as_str() {
838 "empty" => {
839 if let Some(FactValue::Array(arr)) = field_value {
841 arr.is_empty()
842 } else {
843 true
845 }
846 }
847 "not_empty" => {
848 if let Some(FactValue::Array(arr)) = field_value {
850 !arr.is_empty()
851 } else {
852 false
853 }
854 }
855 "count" => {
856 if let Some(FactValue::Array(arr)) = field_value {
857 let count = arr.len() as i64;
858
859 if let (Some(op), Some(cmp_val)) = (operator, compare_value) {
861 let cmp_num = cmp_val.parse::<i64>().unwrap_or(0);
862 match op.as_str() {
863 ">" => count > cmp_num,
864 "<" => count < cmp_num,
865 ">=" => count >= cmp_num,
866 "<=" => count <= cmp_num,
867 "==" => count == cmp_num,
868 "!=" => count != cmp_num,
869 _ => false,
870 }
871 } else {
872 count > 0
873 }
874 } else {
875 false
876 }
877 }
878 "contains" => {
879 if let (Some(FactValue::Array(arr)), Some(search)) = (field_value, value) {
880 arr.iter().any(|item| match item {
883 FactValue::String(s) => s == search,
884 FactValue::Integer(i) => i.to_string() == *search,
885 FactValue::Float(f) => f.to_string() == *search,
886 FactValue::Boolean(b) => b.to_string() == *search,
887 _ => false,
888 })
889 } else {
890 false
891 }
892 }
893 "first" => {
894 if let Some(FactValue::Array(arr)) = field_value {
896 !arr.is_empty()
897 } else {
898 false
899 }
900 }
901 "last" => {
902 if let Some(FactValue::Array(arr)) = field_value {
904 !arr.is_empty()
905 } else {
906 false
907 }
908 }
909 "collect" => {
910 matches!(field_value, Some(FactValue::Array(_)))
912 }
913 _ => {
914 false
916 }
917 }
918 }
919 #[cfg(feature = "streaming")]
920 ReteUlNode::UlStream { .. } => {
921 true
924 }
925 ReteUlNode::UlTerminal(_) => true,
926 }
927}
928
929pub struct TypedReteUlRule {
931 pub name: String,
932 pub node: ReteUlNode,
933 pub priority: i32,
934 pub no_loop: bool,
935 pub action: Arc<dyn Fn(&mut TypedFacts, &mut super::ActionResults) + Send + Sync>,
936}
937
938pub struct TypedReteUlEngine {
941 rules: Vec<TypedReteUlRule>,
942 facts: TypedFacts,
943}
944
945impl TypedReteUlEngine {
946 pub fn new() -> Self {
948 Self {
949 rules: Vec::new(),
950 facts: TypedFacts::new(),
951 }
952 }
953
954 pub fn add_rule_with_action<F>(
956 &mut self,
957 name: String,
958 node: ReteUlNode,
959 priority: i32,
960 no_loop: bool,
961 action: F,
962 ) where
963 F: Fn(&mut TypedFacts, &mut super::ActionResults) + Send + Sync + 'static,
964 {
965 self.rules.push(TypedReteUlRule {
966 name,
967 node,
968 priority,
969 no_loop,
970 action: Arc::new(action),
971 });
972 }
973
974 pub fn add_rule_from_definition(
976 &mut self,
977 rule: &crate::rete::auto_network::Rule,
978 priority: i32,
979 no_loop: bool,
980 ) {
981 let node = build_rete_ul_from_condition_group(&rule.conditions);
982 let rule_name = rule.name.clone();
983
984 let action = Arc::new(
985 move |facts: &mut TypedFacts, _results: &mut super::ActionResults| {
986 facts.set(format!("{}_executed", rule_name), true);
987 },
988 );
989
990 self.rules.push(TypedReteUlRule {
991 name: rule.name.clone(),
992 node,
993 priority,
994 no_loop,
995 action,
996 });
997 }
998
999 pub fn set_fact<K: Into<String>, V: Into<FactValue>>(&mut self, key: K, value: V) {
1001 self.facts.set(key, value);
1002 }
1003
1004 pub fn get_fact(&self, key: &str) -> Option<&FactValue> {
1006 self.facts.get(key)
1007 }
1008
1009 pub fn remove_fact(&mut self, key: &str) -> Option<FactValue> {
1011 self.facts.remove(key)
1012 }
1013
1014 pub fn get_all_facts(&self) -> &TypedFacts {
1016 &self.facts
1017 }
1018
1019 pub fn clear_facts(&mut self) {
1021 self.facts.clear();
1022 }
1023
1024 pub fn fire_all(&mut self) -> Vec<String> {
1026 let mut fired_rules = Vec::new();
1027 let mut agenda: Vec<usize>;
1028 let mut changed = true;
1029 let mut fired_flags = std::collections::HashSet::new();
1030
1031 while changed {
1032 changed = false;
1033
1034 agenda = self
1036 .rules
1037 .iter()
1038 .enumerate()
1039 .filter(|(_, rule)| {
1040 let fired_flag = format!("{}_fired", rule.name);
1041 let already_fired = fired_flags.contains(&rule.name)
1042 || self.facts.get(&fired_flag).and_then(|v| v.as_boolean()) == Some(true);
1043 !rule.no_loop || !already_fired
1044 })
1045 .filter(|(_, rule)| evaluate_rete_ul_node_typed(&rule.node, &self.facts))
1046 .map(|(i, _)| i)
1047 .collect();
1048
1049 agenda.sort_by_key(|&i| -self.rules[i].priority);
1051
1052 for &i in &agenda {
1053 let rule = &mut self.rules[i];
1054 let fired_flag = format!("{}_fired", rule.name);
1055 let already_fired = fired_flags.contains(&rule.name)
1056 || self.facts.get(&fired_flag).and_then(|v| v.as_boolean()) == Some(true);
1057
1058 if rule.no_loop && already_fired {
1059 continue;
1060 }
1061
1062 let mut action_results = super::ActionResults::new();
1063 (rule.action)(&mut self.facts, &mut action_results);
1064 fired_rules.push(rule.name.clone());
1068 fired_flags.insert(rule.name.clone());
1069 self.facts.set(fired_flag, true);
1070 changed = true;
1071 }
1072 }
1073
1074 fired_rules
1075 }
1076
1077 pub fn matches(&self, rule_name: &str) -> bool {
1079 self.rules
1080 .iter()
1081 .find(|r| r.name == rule_name)
1082 .map(|r| evaluate_rete_ul_node_typed(&r.node, &self.facts))
1083 .unwrap_or(false)
1084 }
1085
1086 pub fn get_matching_rules(&self) -> Vec<&str> {
1088 self.rules
1089 .iter()
1090 .filter(|r| evaluate_rete_ul_node_typed(&r.node, &self.facts))
1091 .map(|r| r.name.as_str())
1092 .collect()
1093 }
1094
1095 pub fn reset_fired_flags(&mut self) {
1097 let keys_to_remove: Vec<_> = self
1098 .facts
1099 .get_all()
1100 .keys()
1101 .filter(|k| k.ends_with("_fired") || k.ends_with("_executed"))
1102 .cloned()
1103 .collect();
1104 for key in keys_to_remove {
1105 self.facts.remove(&key);
1106 }
1107 }
1108}
1109
1110impl Default for TypedReteUlEngine {
1111 fn default() -> Self {
1112 Self::new()
1113 }
1114}