rust_logic_graph/rule/
engine.rs

1// Re-export rust-rule-engine types for convenience
2pub use rust_rule_engine::{
3    engine::{
4        facts::Facts,
5        knowledge_base::KnowledgeBase,
6        EngineConfig,
7        RustRuleEngine,
8    },
9    parser::grl::GRLParser,
10    types::Value,
11};
12
13use serde_json::Value as JsonValue;
14use std::collections::HashMap;
15use tracing::debug;
16
17use super::{RuleError, RuleResult};
18
19/// Convenience wrapper around RustRuleEngine with JSON integration
20///
21/// This provides a simplified API for common use cases while maintaining
22/// full access to the underlying rust-rule-engine capabilities.
23///
24/// # Thread Safety
25/// RustRuleEngine is thread-safe (Send + Sync), making it suitable for
26/// use in multi-threaded web services like Axum.
27///
28/// # Example
29/// ```no_run
30/// use rust_logic_graph::RuleEngine;
31/// use std::collections::HashMap;
32/// use serde_json::json;
33///
34/// let mut engine = RuleEngine::new();
35///
36/// let grl = r#"
37///     rule "discount_rule" {
38///         salience 100
39///         when
40///             total > 100
41///         then
42///             discount = total * 0.1;
43///     }
44/// "#;
45///
46/// engine.add_grl_rule(grl).unwrap();
47///
48/// let mut context = HashMap::new();
49/// context.insert("total".to_string(), json!(150.0));
50///
51/// let result = engine.evaluate(&context).unwrap();
52/// ```
53pub struct RuleEngine {
54    engine: RustRuleEngine,
55}
56
57impl RuleEngine {
58    /// Create a new rule engine with default knowledge base
59    pub fn new() -> Self {
60        let kb = KnowledgeBase::new("LogicGraph");
61        Self {
62            engine: RustRuleEngine::new(kb),
63        }
64    }
65
66    /// Create a new rule engine with custom configuration
67    pub fn with_config(config: EngineConfig) -> Self {
68        let kb = KnowledgeBase::new("LogicGraph");
69        Self {
70            engine: RustRuleEngine::with_config(kb, config),
71        }
72    }
73
74    /// Add rules from GRL syntax
75    ///
76    /// # Example
77    /// ```no_run
78    /// use rust_logic_graph::RuleEngine;
79    ///
80    /// let mut engine = RuleEngine::new();
81    /// let grl = r#"
82    ///     rule "high_value_order" {
83    ///         salience 100
84    ///         when
85    ///             order_amount > 1000
86    ///         then
87    ///             priority = "high";
88    ///             requires_approval = true;
89    ///     }
90    /// "#;
91    /// engine.add_grl_rule(grl).unwrap();
92    /// ```
93    pub fn add_grl_rule(&mut self, grl_content: &str) -> Result<(), RuleError> {
94        let rules = GRLParser::parse_rules(grl_content)
95            .map_err(|e| RuleError::Eval(format!("Failed to parse GRL: {}", e)))?;
96
97        let rule_count = rules.len();
98
99        for rule in rules {
100            self.engine
101                .knowledge_base()
102                .add_rule(rule)
103                .map_err(|e| RuleError::Eval(format!("Failed to add rule: {}", e)))?;
104        }
105
106        debug!("Loaded {} GRL rules", rule_count);
107
108        Ok(())
109    }
110
111    /// Evaluate rules with JSON context (convenience method)
112    ///
113    /// For more control, use `inner()` or `inner_mut()` to access the
114    /// underlying RustRuleEngine directly.
115    ///
116    /// # Example
117    /// ```no_run
118    /// use rust_logic_graph::RuleEngine;
119    /// use std::collections::HashMap;
120    /// use serde_json::json;
121    ///
122    /// let mut engine = RuleEngine::new();
123    /// // ... add rules ...
124    ///
125    /// let mut context = HashMap::new();
126    /// context.insert("total".to_string(), json!(150.0));
127    ///
128    /// let result = engine.evaluate(&context).unwrap();
129    /// ```
130    pub fn evaluate(&mut self, context: &HashMap<String, JsonValue>) -> RuleResult {
131        // Convert JSON context to Facts
132        let facts = Facts::new();
133
134        for (key, value) in context {
135            let rr_value = match value {
136                JsonValue::Bool(b) => Value::Boolean(*b),
137                JsonValue::Number(n) => {
138                    if let Some(f) = n.as_f64() {
139                        Value::Number(f)
140                    } else {
141                        continue;
142                    }
143                }
144                JsonValue::String(s) => Value::String(s.clone()),
145                _ => {
146                    debug!("Skipping unsupported value type for key: {}", key);
147                    continue;
148                }
149            };
150
151            facts.set(&key, rr_value);
152        }
153
154        // Execute rules
155        match self.engine.execute(&facts) {
156            Ok(_) => {
157                debug!("Rules executed successfully");
158
159                // Helper to convert Value to JsonValue
160                let convert_value = |val: &Value| -> Option<JsonValue> {
161                    match val {
162                        Value::Boolean(b) => Some(JsonValue::Bool(*b)),
163                        Value::Number(n) => Some(JsonValue::from(*n)),
164                        Value::String(s) => Some(JsonValue::String(s.clone())),
165                        Value::Integer(i) => Some(JsonValue::from(*i)),
166                        _ => None,
167                    }
168                };
169
170                // Get ALL facts from the engine after rule execution
171                // This captures all values set by rules (including Expression Evaluation results)
172                let all_facts = facts.get_all_facts();
173                
174                let mut result = HashMap::new();
175                for (key, value) in all_facts {
176                    if let Some(json_value) = convert_value(&value) {
177                        result.insert(key, json_value);
178                    }
179                }
180
181                Ok(JsonValue::Object(result.into_iter().collect()))
182            }
183            Err(e) => Err(RuleError::Eval(format!("Rule execution failed: {}", e))),
184        }
185    }
186
187    /// Create a rule engine from GRL script
188    ///
189    /// # Example
190    /// ```no_run
191    /// use rust_logic_graph::RuleEngine;
192    ///
193    /// let grl = r#"
194    ///     rule "example" {
195    ///         salience 100
196    ///         when
197    ///             x > 0
198    ///         then
199    ///             y = x * 2;
200    ///     }
201    /// "#;
202    ///
203    /// let mut engine = RuleEngine::from_grl(grl).unwrap();
204    /// ```
205    pub fn from_grl(grl_script: &str) -> Result<Self, RuleError> {
206        let mut engine = Self::new();
207        engine.add_grl_rule(grl_script)?;
208        Ok(engine)
209    }
210
211    /// Get reference to underlying RustRuleEngine for advanced usage
212    ///
213    /// This provides full access to rust-rule-engine features:
214    /// - Custom functions
215    /// - Templates
216    /// - Globals
217    /// - Deffacts
218    /// - Fine-grained control
219    pub fn inner(&self) -> &RustRuleEngine {
220        &self.engine
221    }
222
223    /// Get mutable reference to underlying RustRuleEngine
224    pub fn inner_mut(&mut self) -> &mut RustRuleEngine {
225        &mut self.engine
226    }
227}
228
229impl Default for RuleEngine {
230    fn default() -> Self {
231        Self::new()
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use serde_json::json;
239
240    #[test]
241    fn test_rule_engine_creation() {
242        let _engine = RuleEngine::new();
243    }
244
245    #[test]
246    fn test_engine_evaluation() {
247        let mut engine = RuleEngine::new();
248
249        let grl = r#"
250        rule "test_rule" {
251            salience 100
252            when
253                age >= 18
254            then
255                eligible = true;
256        }
257        "#;
258
259        engine.add_grl_rule(grl).unwrap();
260
261        let mut context = HashMap::new();
262        context.insert("age".to_string(), json!(20));
263
264        let result = engine.evaluate(&context).unwrap();
265        assert_eq!(result.get("eligible").unwrap().as_bool().unwrap(), true);
266    }
267
268    #[test]
269    fn test_from_grl() {
270        let grl = r#"
271        rule "test" {
272            salience 100
273            when
274                x > 0
275            then
276                result = true;
277                message = "x is positive";
278        }
279        "#;
280
281        let mut engine = RuleEngine::from_grl(grl).unwrap();
282
283        let mut context = HashMap::new();
284        context.insert("x".to_string(), json!(5));
285
286        let result = engine.evaluate(&context).unwrap();
287        assert_eq!(result.get("result").unwrap().as_bool().unwrap(), true);
288        assert_eq!(result.get("message").unwrap().as_str().unwrap(), "x is positive");
289    }
290
291    #[test]
292    fn test_multiple_rules_salience() {
293        let mut engine = RuleEngine::new();
294
295        let grl = r#"
296        rule "high_priority" {
297            salience 100
298            when
299                value > 100
300            then
301                priority = "high";
302                high_rule_fired = true;
303        }
304
305        rule "medium_priority" {
306            salience 50
307            when
308                value > 50 && value <= 100
309            then
310                priority = "medium";
311                medium_rule_fired = true;
312        }
313        "#;
314
315        engine.add_grl_rule(grl).unwrap();
316
317        // Test with high value
318        let mut context = HashMap::new();
319        context.insert("value".to_string(), json!(150));
320
321        let result = engine.evaluate(&context).unwrap();
322        // Only high priority rule should fire for value > 100
323        assert_eq!(result.get("priority").unwrap().as_str().unwrap(), "high");
324        assert_eq!(result.get("high_rule_fired").unwrap().as_bool().unwrap(), true);
325
326        // Test with medium value
327        let mut context2 = HashMap::new();
328        context2.insert("value".to_string(), json!(75));
329
330        let result2 = engine.evaluate(&context2).unwrap();
331        // Only medium priority rule should fire for 50 < value <= 100
332        assert_eq!(result2.get("priority").unwrap().as_str().unwrap(), "medium");
333        assert_eq!(result2.get("medium_rule_fired").unwrap().as_bool().unwrap(), true);
334    }
335
336    #[test]
337    fn test_direct_engine_access() {
338        let engine = RuleEngine::new();
339
340        // Access underlying RustRuleEngine
341        let inner = engine.inner();
342        let kb = inner.knowledge_base();
343
344        // Can access knowledge base directly
345        assert_eq!(kb.name(), "LogicGraph");
346    }
347}