rust_rule_engine/rete/
network.rs

1use crate::rete::alpha::AlphaNode;
2use std::sync::Arc;
3/// Chuyển ConditionGroup sang ReteUlNode
4pub fn build_rete_ul_from_condition_group(group: &crate::rete::auto_network::ConditionGroup) -> ReteUlNode {
5    use crate::rete::auto_network::ConditionGroup;
6    match group {
7        ConditionGroup::Single(cond) => {
8            ReteUlNode::UlAlpha(AlphaNode {
9                field: cond.field.clone(),
10                operator: cond.operator.clone(),
11                value: cond.value.clone(),
12            })
13        }
14        ConditionGroup::Compound { left, operator, right } => {
15            match operator.as_str() {
16                "AND" => ReteUlNode::UlAnd(
17                    Box::new(build_rete_ul_from_condition_group(left)),
18                    Box::new(build_rete_ul_from_condition_group(right)),
19                ),
20                "OR" => ReteUlNode::UlOr(
21                    Box::new(build_rete_ul_from_condition_group(left)),
22                    Box::new(build_rete_ul_from_condition_group(right)),
23                ),
24                _ => ReteUlNode::UlAnd(
25                    Box::new(build_rete_ul_from_condition_group(left)),
26                    Box::new(build_rete_ul_from_condition_group(right)),
27                ),
28            }
29        }
30        ConditionGroup::Not(inner) => {
31            ReteUlNode::UlNot(Box::new(build_rete_ul_from_condition_group(inner)))
32        }
33        ConditionGroup::Exists(inner) => {
34            ReteUlNode::UlExists(Box::new(build_rete_ul_from_condition_group(inner)))
35        }
36        ConditionGroup::Forall(inner) => {
37            ReteUlNode::UlForall(Box::new(build_rete_ul_from_condition_group(inner)))
38        }
39    }
40}
41use std::collections::HashMap;
42
43/// Helper: Evaluate a condition string against facts (for accumulate)
44fn evaluate_condition_string(condition: &str, facts: &HashMap<String, String>) -> bool {
45    let condition = condition.trim();
46    let operators = ["==", "!=", ">=", "<=", ">", "<"];
47
48    for op in &operators {
49        if let Some(pos) = condition.find(op) {
50            let field = condition[..pos].trim();
51            let value_str = condition[pos + op.len()..]
52                .trim()
53                .trim_matches('"')
54                .trim_matches('\'');
55
56            if let Some(field_value) = facts.get(field) {
57                return compare_string_values(field_value, op, value_str);
58            } else {
59                return false;
60            }
61        }
62    }
63    false
64}
65
66/// Helper: Compare string values
67fn compare_string_values(field_value: &str, operator: &str, value_str: &str) -> bool {
68    // Try numeric comparison first
69    if let (Ok(field_num), Ok(val_num)) = (field_value.parse::<f64>(), value_str.parse::<f64>()) {
70        match operator {
71            "==" => (field_num - val_num).abs() < f64::EPSILON,
72            "!=" => (field_num - val_num).abs() >= f64::EPSILON,
73            ">" => field_num > val_num,
74            "<" => field_num < val_num,
75            ">=" => field_num >= val_num,
76            "<=" => field_num <= val_num,
77            _ => false,
78        }
79    } else {
80        // String comparison
81        match operator {
82            "==" => field_value == value_str,
83            "!=" => field_value != value_str,
84            _ => false,
85        }
86    }
87}
88
89/// Đánh giá mạng node RETE với facts
90pub fn evaluate_rete_ul_node(node: &ReteUlNode, facts: &HashMap<String, String>) -> bool {
91    match node {
92        ReteUlNode::UlAlpha(alpha) => {
93            let val = if alpha.field.contains('.') {
94                let parts: Vec<&str> = alpha.field.split('.').collect();
95                if parts.len() == 2 {
96                    let prefix = parts[0];
97                    let suffix = parts[1];
98                    facts.get(&format!("{}.{}", prefix, suffix)).or_else(|| facts.get(&format!("{}:{}", prefix, suffix)))
99                } else {
100                    facts.get(&alpha.field)
101                }
102            } else {
103                facts.get(&alpha.field)
104            };
105            if let Some(val) = val {
106                match alpha.operator.as_str() {
107                    "==" => val == &alpha.value,
108                    "!=" => val != &alpha.value,
109                    ">" => val.parse::<f64>().unwrap_or(0.0) > alpha.value.parse::<f64>().unwrap_or(0.0),
110                    "<" => val.parse::<f64>().unwrap_or(0.0) < alpha.value.parse::<f64>().unwrap_or(0.0),
111                    ">=" => val.parse::<f64>().unwrap_or(0.0) >= alpha.value.parse::<f64>().unwrap_or(0.0),
112                    "<=" => val.parse::<f64>().unwrap_or(0.0) <= alpha.value.parse::<f64>().unwrap_or(0.0),
113                    _ => false,
114                }
115            } else {
116                false
117            }
118        }
119        ReteUlNode::UlAnd(left, right) => {
120            evaluate_rete_ul_node(left, facts) && evaluate_rete_ul_node(right, facts)
121        }
122        ReteUlNode::UlOr(left, right) => {
123            evaluate_rete_ul_node(left, facts) || evaluate_rete_ul_node(right, facts)
124        }
125        ReteUlNode::UlNot(inner) => {
126            !evaluate_rete_ul_node(inner, facts)
127        }
128        ReteUlNode::UlExists(inner) => {
129            let target_field = match &**inner {
130                ReteUlNode::UlAlpha(alpha) => alpha.field.clone(),
131                _ => "".to_string(),
132            };
133            if target_field.contains('.') {
134                let parts: Vec<&str> = target_field.split('.').collect();
135                if parts.len() == 2 {
136                    let prefix = parts[0];
137                    let suffix = parts[1];
138                    let filtered: Vec<_> = facts.iter()
139                        .filter(|(k, _)| k.starts_with(prefix) && k.ends_with(suffix))
140                        .collect();
141                    filtered.iter().any(|(_, value)| {
142                        let mut sub_facts = HashMap::new();
143                        sub_facts.insert(target_field.clone(), (*value).clone());
144                        evaluate_rete_ul_node(inner, &sub_facts)
145                    })
146                } else {
147                    facts.iter().any(|(field, value)| {
148                        let mut sub_facts = HashMap::new();
149                        sub_facts.insert(field.clone(), value.clone());
150                        evaluate_rete_ul_node(inner, &sub_facts)
151                    })
152                }
153            } else {
154                facts.iter().any(|(field, value)| {
155                    let mut sub_facts = HashMap::new();
156                    sub_facts.insert(field.clone(), value.clone());
157                    evaluate_rete_ul_node(inner, &sub_facts)
158                })
159            }
160        }
161        ReteUlNode::UlForall(inner) => {
162            let target_field = match &**inner {
163                ReteUlNode::UlAlpha(alpha) => alpha.field.clone(),
164                _ => "".to_string(),
165            };
166            if target_field.contains('.') {
167                let parts: Vec<&str> = target_field.split('.').collect();
168                if parts.len() == 2 {
169                    let prefix = parts[0];
170                    let suffix = parts[1];
171                    let filtered: Vec<_> = facts.iter()
172                        .filter(|(k, _)| k.starts_with(prefix) && k.ends_with(suffix))
173                        .collect();
174                    if filtered.is_empty() {
175                        return true; // Vacuous truth: FORALL on empty set is TRUE
176                    }
177                    filtered.iter().all(|(_, value)| {
178                        let mut sub_facts = HashMap::new();
179                        sub_facts.insert(target_field.clone(), (*value).clone());
180                        evaluate_rete_ul_node(inner, &sub_facts)
181                    })
182                } else {
183                    facts.iter().all(|(field, value)| {
184                        let mut sub_facts = HashMap::new();
185                        sub_facts.insert(field.clone(), value.clone());
186                        evaluate_rete_ul_node(inner, &sub_facts)
187                    })
188                }
189            } else {
190                facts.iter().all(|(field, value)| {
191                    let mut sub_facts = HashMap::new();
192                    sub_facts.insert(field.clone(), value.clone());
193                    evaluate_rete_ul_node(inner, &sub_facts)
194                })
195            }
196        }
197        ReteUlNode::UlAccumulate {
198            source_pattern,
199            extract_field,
200            source_conditions,
201            function,
202            ..
203        } => {
204            // Evaluate accumulate: collect matching facts and run function
205            use super::accumulate::*;
206
207            let pattern_prefix = format!("{}.", source_pattern);
208            let mut matching_values = Vec::new();
209
210            // Group facts by instance
211            let mut instances: std::collections::HashMap<String, std::collections::HashMap<String, String>> =
212                std::collections::HashMap::new();
213
214            for (key, value) in facts {
215                if key.starts_with(&pattern_prefix) {
216                    let parts: Vec<&str> = key.strip_prefix(&pattern_prefix).unwrap().split('.').collect();
217
218                    if parts.len() >= 2 {
219                        let instance_id = parts[0];
220                        let field_name = parts[1..].join(".");
221
222                        instances
223                            .entry(instance_id.to_string())
224                            .or_insert_with(std::collections::HashMap::new)
225                            .insert(field_name, value.clone());
226                    } else if parts.len() == 1 {
227                        instances
228                            .entry("default".to_string())
229                            .or_insert_with(std::collections::HashMap::new)
230                            .insert(parts[0].to_string(), value.clone());
231                    }
232                }
233            }
234
235            // Filter instances by source conditions
236            for (_instance_id, instance_facts) in instances {
237                let mut matches = true;
238
239                for condition_str in source_conditions {
240                    if !evaluate_condition_string(condition_str, &instance_facts) {
241                        matches = false;
242                        break;
243                    }
244                }
245
246                if matches {
247                    if let Some(value_str) = instance_facts.get(extract_field) {
248                        // Convert string to FactValue
249                        let fact_value = if let Ok(i) = value_str.parse::<i64>() {
250                            super::facts::FactValue::Integer(i)
251                        } else if let Ok(f) = value_str.parse::<f64>() {
252                            super::facts::FactValue::Float(f)
253                        } else if let Ok(b) = value_str.parse::<bool>() {
254                            super::facts::FactValue::Boolean(b)
255                        } else {
256                            super::facts::FactValue::String(value_str.clone())
257                        };
258                        matching_values.push(fact_value);
259                    }
260                }
261            }
262
263            // Run accumulate function - result determines if condition passes
264            let has_results = !matching_values.is_empty();
265
266            match function.as_str() {
267                "count" => has_results, // Count passes if there are any matches
268                "sum" | "average" | "min" | "max" => {
269                    // These functions need at least one value
270                    has_results
271                }
272                _ => true, // Unknown function - allow to continue
273            }
274        }
275        ReteUlNode::UlTerminal(_) => true // Rule match
276    }
277}
278
279/// RETE-UL: Unified Logic Node
280#[derive(Debug, Clone)]
281pub enum ReteUlNode {
282    UlAlpha(AlphaNode),
283    UlAnd(Box<ReteUlNode>, Box<ReteUlNode>),
284    UlOr(Box<ReteUlNode>, Box<ReteUlNode>),
285    UlNot(Box<ReteUlNode>),
286    UlExists(Box<ReteUlNode>),
287    UlForall(Box<ReteUlNode>),
288    UlAccumulate {
289        result_var: String,
290        source_pattern: String,
291        extract_field: String,
292        source_conditions: Vec<String>,
293        function: String,
294        function_arg: String,
295    },
296    UlTerminal(String), // Rule name
297}
298
299impl ReteUlNode {
300    /// Evaluate with typed facts (convenience method)
301    pub fn evaluate_typed(&self, facts: &super::facts::TypedFacts) -> bool {
302        evaluate_rete_ul_node_typed(self, facts)
303    }
304}
305
306/// RETE-UL Rule Struct
307pub struct ReteUlRule {
308    pub name: String,
309    pub node: ReteUlNode,
310    pub priority: i32,
311    pub no_loop: bool,
312    pub action: Arc<dyn Fn(&mut std::collections::HashMap<String, String>) + Send + Sync>,
313}
314
315/// Drools-style RETE-UL rule firing loop
316/// Fires all matching rules, updates facts, repeats until no more rules can fire
317pub fn fire_rete_ul_rules(
318    rules: &mut [(String, ReteUlNode, Box<dyn FnMut(&mut std::collections::HashMap<String, String>)>)],
319    facts: &mut std::collections::HashMap<String, String>,
320) -> Vec<String> {
321    let mut fired_rules = Vec::new();
322    let mut changed = true;
323    while changed {
324        changed = false;
325        for (rule_name, node, action) in rules.iter_mut() {
326            let fired_flag = format!("{}_fired", rule_name);
327            if facts.get(&fired_flag) == Some(&"true".to_string()) {
328                continue;
329            }
330            if evaluate_rete_ul_node(node, facts) {
331                action(facts);
332                facts.insert(fired_flag.clone(), "true".to_string());
333                fired_rules.push(rule_name.clone());
334                changed = true;
335            }
336        }
337    }
338    fired_rules
339}
340
341/// Drools-style RETE-UL rule firing loop with agenda and control
342pub fn fire_rete_ul_rules_with_agenda(
343    rules: &mut [ReteUlRule],
344    facts: &mut std::collections::HashMap<String, String>,
345) -> Vec<String> {
346    let mut fired_rules = Vec::new();
347    let mut fired_flags = std::collections::HashSet::new();
348    let max_iterations = 100; // Prevent infinite loops
349    let mut iterations = 0;
350
351    loop {
352        iterations += 1;
353        if iterations > max_iterations {
354            eprintln!("Warning: RETE engine reached max iterations ({})", max_iterations);
355            break;
356        }
357
358        // Build agenda: rules that match and haven't been fired yet
359        let mut agenda: Vec<usize> = rules
360            .iter()
361            .enumerate()
362            .filter(|(_, rule)| {
363                // Check if rule already fired
364                if fired_flags.contains(&rule.name) {
365                    return false;
366                }
367                // Check if rule matches current facts
368                evaluate_rete_ul_node(&rule.node, facts)
369            })
370            .map(|(i, _)| i)
371            .collect();
372
373        // If no rules to fire, we're done
374        if agenda.is_empty() {
375            break;
376        }
377
378        // Sort agenda by priority (descending)
379        agenda.sort_by_key(|&i| -rules[i].priority);
380
381        // Fire all rules in agenda
382        for &i in &agenda {
383            let rule = &mut rules[i];
384
385            // Execute rule action
386            (rule.action)(facts);
387
388            // Mark as fired
389            fired_rules.push(rule.name.clone());
390            fired_flags.insert(rule.name.clone());
391
392            let fired_flag = format!("{}_fired", rule.name);
393            facts.insert(fired_flag, "true".to_string());
394        }
395
396        // If no_loop is enabled for all rules, stop after one iteration
397        if rules.iter().all(|r| r.no_loop) {
398            break;
399        }
400    }
401
402    fired_rules
403}
404
405/// RETE-UL Engine with cached nodes (Performance optimized!)
406/// This engine builds RETE nodes once and reuses them, avoiding expensive rebuilds
407pub struct ReteUlEngine {
408    rules: Vec<ReteUlRule>,
409    facts: std::collections::HashMap<String, String>,
410}
411
412impl ReteUlEngine {
413    /// Create new engine from Rule definitions (nodes are built and cached once)
414    pub fn new() -> Self {
415        Self {
416            rules: Vec::new(),
417            facts: std::collections::HashMap::new(),
418        }
419    }
420
421    /// Add a rule with custom action closure
422    pub fn add_rule_with_action<F>(
423        &mut self,
424        name: String,
425        node: ReteUlNode,
426        priority: i32,
427        no_loop: bool,
428        action: F,
429    ) where
430        F: Fn(&mut std::collections::HashMap<String, String>) + Send + Sync + 'static,
431    {
432        self.rules.push(ReteUlRule {
433            name,
434            node,
435            priority,
436            no_loop,
437            action: Arc::new(action),
438        });
439    }
440
441    /// Add a rule from Rule definition (auto-build node once and cache)
442    pub fn add_rule_from_definition(
443        &mut self,
444        rule: &crate::rete::auto_network::Rule,
445        priority: i32,
446        no_loop: bool,
447    ) {
448        let node = build_rete_ul_from_condition_group(&rule.conditions);
449        let rule_name = rule.name.clone();
450
451        // Default action: just mark as fired
452        let action = Arc::new(move |facts: &mut std::collections::HashMap<String, String>| {
453            facts.insert(format!("{}_executed", rule_name), "true".to_string());
454        });
455
456        self.rules.push(ReteUlRule {
457            name: rule.name.clone(),
458            node,
459            priority,
460            no_loop,
461            action,
462        });
463    }
464
465    /// Set a fact
466    pub fn set_fact(&mut self, key: String, value: String) {
467        self.facts.insert(key, value);
468    }
469
470    /// Get a fact
471    pub fn get_fact(&self, key: &str) -> Option<&String> {
472        self.facts.get(key)
473    }
474
475    /// Remove a fact
476    pub fn remove_fact(&mut self, key: &str) -> Option<String> {
477        self.facts.remove(key)
478    }
479
480    /// Get all facts
481    pub fn get_all_facts(&self) -> &std::collections::HashMap<String, String> {
482        &self.facts
483    }
484
485    /// Clear all facts
486    pub fn clear_facts(&mut self) {
487        self.facts.clear();
488    }
489
490    /// Fire all rules with agenda (using cached nodes - NO rebuild!)
491    pub fn fire_all(&mut self) -> Vec<String> {
492        fire_rete_ul_rules_with_agenda(&mut self.rules, &mut self.facts)
493    }
494
495    /// Check if a specific rule matches current facts (without firing)
496    pub fn matches(&self, rule_name: &str) -> bool {
497        self.rules
498            .iter()
499            .find(|r| r.name == rule_name)
500            .map(|r| evaluate_rete_ul_node(&r.node, &self.facts))
501            .unwrap_or(false)
502    }
503
504    /// Get all matching rules (without firing)
505    pub fn get_matching_rules(&self) -> Vec<&str> {
506        self.rules
507            .iter()
508            .filter(|r| evaluate_rete_ul_node(&r.node, &self.facts))
509            .map(|r| r.name.as_str())
510            .collect()
511    }
512
513    /// Reset fired flags (allow rules to fire again)
514    pub fn reset_fired_flags(&mut self) {
515        let keys_to_remove: Vec<_> = self.facts
516            .keys()
517            .filter(|k| k.ends_with("_fired") || k.ends_with("_executed"))
518            .cloned()
519            .collect();
520        for key in keys_to_remove {
521            self.facts.remove(&key);
522        }
523    }
524}
525
526// ============================================================================
527// Typed Facts Support (NEW!)
528// ============================================================================
529
530use super::facts::{FactValue, TypedFacts};
531
532/// Evaluate RETE-UL node with typed facts (NEW!)
533pub fn evaluate_rete_ul_node_typed(node: &ReteUlNode, facts: &TypedFacts) -> bool {
534    match node {
535        ReteUlNode::UlAlpha(alpha) => {
536            alpha.matches_typed(facts)
537        }
538        ReteUlNode::UlAnd(left, right) => {
539            evaluate_rete_ul_node_typed(left, facts) && evaluate_rete_ul_node_typed(right, facts)
540        }
541        ReteUlNode::UlOr(left, right) => {
542            evaluate_rete_ul_node_typed(left, facts) || evaluate_rete_ul_node_typed(right, facts)
543        }
544        ReteUlNode::UlNot(inner) => {
545            !evaluate_rete_ul_node_typed(inner, facts)
546        }
547        ReteUlNode::UlExists(inner) => {
548            let target_field = match &**inner {
549                ReteUlNode::UlAlpha(alpha) => alpha.field.clone(),
550                _ => "".to_string(),
551            };
552            if target_field.contains('.') {
553                let parts: Vec<&str> = target_field.split('.').collect();
554                if parts.len() == 2 {
555                    let prefix = parts[0];
556                    let suffix = parts[1];
557                    let filtered: Vec<_> = facts.get_all().iter()
558                        .filter(|(k, _)| k.starts_with(prefix) && k.ends_with(suffix))
559                        .collect();
560                    filtered.iter().any(|(_, _)| {
561                        evaluate_rete_ul_node_typed(inner, facts)
562                    })
563                } else {
564                    evaluate_rete_ul_node_typed(inner, facts)
565                }
566            } else {
567                evaluate_rete_ul_node_typed(inner, facts)
568            }
569        }
570        ReteUlNode::UlForall(inner) => {
571            let target_field = match &**inner {
572                ReteUlNode::UlAlpha(alpha) => alpha.field.clone(),
573                _ => "".to_string(),
574            };
575            if target_field.contains('.') {
576                let parts: Vec<&str> = target_field.split('.').collect();
577                if parts.len() == 2 {
578                    let prefix = parts[0];
579                    let suffix = parts[1];
580                    let filtered: Vec<_> = facts.get_all().iter()
581                        .filter(|(k, _)| k.starts_with(prefix) && k.ends_with(suffix))
582                        .collect();
583                    if filtered.is_empty() {
584                        return true; // Vacuous truth
585                    }
586                    filtered.iter().all(|(_, _)| {
587                        evaluate_rete_ul_node_typed(inner, facts)
588                    })
589                } else {
590                    if facts.get_all().is_empty() {
591                        return true; // Vacuous truth
592                    }
593                    evaluate_rete_ul_node_typed(inner, facts)
594                }
595            } else {
596                if facts.get_all().is_empty() {
597                    return true; // Vacuous truth
598                }
599                evaluate_rete_ul_node_typed(inner, facts)
600            }
601        }
602        ReteUlNode::UlAccumulate {
603            source_pattern,
604            extract_field,
605            source_conditions,
606            function,
607            ..
608        } => {
609            // Evaluate accumulate with typed facts
610            use super::accumulate::*;
611
612            let pattern_prefix = format!("{}.", source_pattern);
613            let mut matching_values = Vec::new();
614
615            // Group facts by instance
616            let mut instances: std::collections::HashMap<String, std::collections::HashMap<String, FactValue>> =
617                std::collections::HashMap::new();
618
619            for (key, value) in facts.get_all() {
620                if key.starts_with(&pattern_prefix) {
621                    let parts: Vec<&str> = key.strip_prefix(&pattern_prefix).unwrap().split('.').collect();
622
623                    if parts.len() >= 2 {
624                        let instance_id = parts[0];
625                        let field_name = parts[1..].join(".");
626
627                        instances
628                            .entry(instance_id.to_string())
629                            .or_insert_with(std::collections::HashMap::new)
630                            .insert(field_name, value.clone());
631                    } else if parts.len() == 1 {
632                        instances
633                            .entry("default".to_string())
634                            .or_insert_with(std::collections::HashMap::new)
635                            .insert(parts[0].to_string(), value.clone());
636                    }
637                }
638            }
639
640            // Filter instances by source conditions
641            for (_instance_id, instance_facts) in instances {
642                let mut matches = true;
643
644                for condition_str in source_conditions {
645                    // Convert FactValues to strings for condition evaluation
646                    let string_facts: HashMap<String, String> = instance_facts
647                        .iter()
648                        .map(|(k, v)| (k.clone(), format!("{:?}", v)))
649                        .collect();
650
651                    if !evaluate_condition_string(condition_str, &string_facts) {
652                        matches = false;
653                        break;
654                    }
655                }
656
657                if matches {
658                    if let Some(value) = instance_facts.get(extract_field) {
659                        matching_values.push(value.clone());
660                    }
661                }
662            }
663
664            // Run accumulate function - result determines if condition passes
665            let has_results = !matching_values.is_empty();
666
667            match function.as_str() {
668                "count" => has_results,
669                "sum" | "average" | "min" | "max" => has_results,
670                _ => true,
671            }
672        }
673        ReteUlNode::UlTerminal(_) => true
674    }
675}
676
677/// Typed RETE-UL Rule
678pub struct TypedReteUlRule {
679    pub name: String,
680    pub node: ReteUlNode,
681    pub priority: i32,
682    pub no_loop: bool,
683    pub action: Arc<dyn Fn(&mut TypedFacts) + Send + Sync>,
684}
685
686/// Typed RETE-UL Engine with cached nodes (Performance + Type Safety!)
687/// This is the recommended engine for new code
688pub struct TypedReteUlEngine {
689    rules: Vec<TypedReteUlRule>,
690    facts: TypedFacts,
691}
692
693impl TypedReteUlEngine {
694    /// Create new typed engine
695    pub fn new() -> Self {
696        Self {
697            rules: Vec::new(),
698            facts: TypedFacts::new(),
699        }
700    }
701
702    /// Add a rule with custom action
703    pub fn add_rule_with_action<F>(
704        &mut self,
705        name: String,
706        node: ReteUlNode,
707        priority: i32,
708        no_loop: bool,
709        action: F,
710    ) where
711        F: Fn(&mut TypedFacts) + Send + Sync + 'static,
712    {
713        self.rules.push(TypedReteUlRule {
714            name,
715            node,
716            priority,
717            no_loop,
718            action: Arc::new(action),
719        });
720    }
721
722    /// Add a rule from Rule definition
723    pub fn add_rule_from_definition(
724        &mut self,
725        rule: &crate::rete::auto_network::Rule,
726        priority: i32,
727        no_loop: bool,
728    ) {
729        let node = build_rete_ul_from_condition_group(&rule.conditions);
730        let rule_name = rule.name.clone();
731
732        let action = Arc::new(move |facts: &mut TypedFacts| {
733            facts.set(format!("{}_executed", rule_name), true);
734        });
735
736        self.rules.push(TypedReteUlRule {
737            name: rule.name.clone(),
738            node,
739            priority,
740            no_loop,
741            action,
742        });
743    }
744
745    /// Set a fact with typed value
746    pub fn set_fact<K: Into<String>, V: Into<FactValue>>(&mut self, key: K, value: V) {
747        self.facts.set(key, value);
748    }
749
750    /// Get a fact
751    pub fn get_fact(&self, key: &str) -> Option<&FactValue> {
752        self.facts.get(key)
753    }
754
755    /// Remove a fact
756    pub fn remove_fact(&mut self, key: &str) -> Option<FactValue> {
757        self.facts.remove(key)
758    }
759
760    /// Get all facts
761    pub fn get_all_facts(&self) -> &TypedFacts {
762        &self.facts
763    }
764
765    /// Clear all facts
766    pub fn clear_facts(&mut self) {
767        self.facts.clear();
768    }
769
770    /// Fire all rules with agenda (using cached nodes + typed evaluation!)
771    pub fn fire_all(&mut self) -> Vec<String> {
772        let mut fired_rules = Vec::new();
773        let mut agenda: Vec<usize>;
774        let mut changed = true;
775        let mut fired_flags = std::collections::HashSet::new();
776
777        while changed {
778            changed = false;
779
780            // Build agenda: rules that match and not fired
781            agenda = self.rules.iter().enumerate()
782                .filter(|(_, rule)| {
783                    let fired_flag = format!("{}_fired", rule.name);
784                    let already_fired = fired_flags.contains(&rule.name) ||
785                        self.facts.get(&fired_flag).and_then(|v| v.as_boolean()) == Some(true);
786                    !rule.no_loop || !already_fired
787                })
788                .filter(|(_, rule)| evaluate_rete_ul_node_typed(&rule.node, &self.facts))
789                .map(|(i, _)| i)
790                .collect();
791
792            // Sort by priority (descending)
793            agenda.sort_by_key(|&i| -self.rules[i].priority);
794
795            for &i in &agenda {
796                let rule = &mut self.rules[i];
797                let fired_flag = format!("{}_fired", rule.name);
798                let already_fired = fired_flags.contains(&rule.name) ||
799                    self.facts.get(&fired_flag).and_then(|v| v.as_boolean()) == Some(true);
800
801                if rule.no_loop && already_fired {
802                    continue;
803                }
804
805                (rule.action)(&mut self.facts);
806                fired_rules.push(rule.name.clone());
807                fired_flags.insert(rule.name.clone());
808                self.facts.set(fired_flag, true);
809                changed = true;
810            }
811        }
812
813        fired_rules
814    }
815
816    /// Check if a specific rule matches current facts
817    pub fn matches(&self, rule_name: &str) -> bool {
818        self.rules
819            .iter()
820            .find(|r| r.name == rule_name)
821            .map(|r| evaluate_rete_ul_node_typed(&r.node, &self.facts))
822            .unwrap_or(false)
823    }
824
825    /// Get all matching rules
826    pub fn get_matching_rules(&self) -> Vec<&str> {
827        self.rules
828            .iter()
829            .filter(|r| evaluate_rete_ul_node_typed(&r.node, &self.facts))
830            .map(|r| r.name.as_str())
831            .collect()
832    }
833
834    /// Reset fired flags
835    pub fn reset_fired_flags(&mut self) {
836        let keys_to_remove: Vec<_> = self.facts.get_all()
837            .keys()
838            .filter(|k| k.ends_with("_fired") || k.ends_with("_executed"))
839            .cloned()
840            .collect();
841        for key in keys_to_remove {
842            self.facts.remove(&key);
843        }
844    }
845}
846
847impl Default for TypedReteUlEngine {
848    fn default() -> Self {
849        Self::new()
850    }
851}
852