rust_logic_graph/rule/
mod.rs

1mod engine;
2
3// Re-export the convenience wrapper
4pub use engine::RuleEngine;
5
6// Re-export rust-rule-engine types for advanced usage
7pub use engine::{Facts, KnowledgeBase, GRLParser, EngineConfig, RustRuleEngine, Value};
8
9use serde_json::Value as JsonValue;
10use thiserror::Error;
11use std::collections::HashMap;
12
13#[derive(Debug, Error)]
14pub enum RuleError {
15    #[error("Rule evaluation failed: {0}")]
16    Eval(String),
17
18    #[error("Missing variable in context: {0}")]
19    MissingVariable(String),
20
21    #[error("Type mismatch: {0}")]
22    TypeMismatch(String),
23
24    #[error("Invalid expression: {0}")]
25    InvalidExpression(String),
26}
27
28pub type RuleResult = Result<JsonValue, RuleError>;
29
30/// Simple rule implementation (backward compatible)
31/// For advanced features, use RuleEngine directly
32#[derive(Debug, Clone)]
33pub struct Rule {
34    pub id: String,
35    pub condition: String,
36}
37
38impl Rule {
39    pub fn new(id: impl Into<String>, condition: impl Into<String>) -> Self {
40        Self {
41            id: id.into(),
42            condition: condition.into(),
43        }
44    }
45
46    /// Evaluate the rule against provided data context
47    pub fn evaluate(&self, context: &HashMap<String, JsonValue>) -> RuleResult {
48        let condition = self.condition.trim();
49
50        // Handle simple boolean literals
51        if condition == "true" {
52            return Ok(JsonValue::Bool(true));
53        }
54        if condition == "false" {
55            return Ok(JsonValue::Bool(false));
56        }
57
58        // Handle variable lookup (e.g., "user_active")
59        if !condition.contains(' ') && !condition.contains(['>', '<', '=', '!']) {
60            return context
61                .get(condition)
62                .cloned()
63                .ok_or_else(|| RuleError::MissingVariable(condition.to_string()));
64        }
65
66        // Handle comparisons (e.g., "age > 18", "status == active")
67        if let Some(result) = self.evaluate_comparison(condition, context) {
68            return result;
69        }
70
71        // Handle logical operations (e.g., "active && verified")
72        if condition.contains("&&") || condition.contains("||") {
73            return self.evaluate_logical(condition, context);
74        }
75
76        // Default to true if we can't parse (permissive)
77        Ok(JsonValue::Bool(true))
78    }
79
80    fn evaluate_comparison(&self, expr: &str, context: &HashMap<String, JsonValue>) -> Option<RuleResult> {
81        for op in ["==", "!=", ">=", "<=", ">", "<"] {
82            if let Some((left, right)) = expr.split_once(op) {
83                let left = left.trim();
84                let right = right.trim();
85
86                let left_val = self.get_value(left, context).ok()?;
87                let right_val = self.get_value(right, context).ok()?;
88
89                let result = match op {
90                    "==" => left_val == right_val,
91                    "!=" => left_val != right_val,
92                    ">" => self.compare_values(&left_val, &right_val, std::cmp::Ordering::Greater),
93                    "<" => self.compare_values(&left_val, &right_val, std::cmp::Ordering::Less),
94                    ">=" => {
95                        self.compare_values(&left_val, &right_val, std::cmp::Ordering::Greater)
96                            || left_val == right_val
97                    }
98                    "<=" => {
99                        self.compare_values(&left_val, &right_val, std::cmp::Ordering::Less)
100                            || left_val == right_val
101                    }
102                    _ => false,
103                };
104
105                return Some(Ok(JsonValue::Bool(result)));
106            }
107        }
108        None
109    }
110
111    fn evaluate_logical(&self, expr: &str, context: &HashMap<String, JsonValue>) -> RuleResult {
112        if let Some((left, right)) = expr.split_once("&&") {
113            let left_result = Rule::new("temp", left.trim()).evaluate(context)?;
114            let right_result = Rule::new("temp", right.trim()).evaluate(context)?;
115
116            if let (JsonValue::Bool(l), JsonValue::Bool(r)) = (left_result, right_result) {
117                return Ok(JsonValue::Bool(l && r));
118            }
119        }
120
121        if let Some((left, right)) = expr.split_once("||") {
122            let left_result = Rule::new("temp", left.trim()).evaluate(context)?;
123            let right_result = Rule::new("temp", right.trim()).evaluate(context)?;
124
125            if let (JsonValue::Bool(l), JsonValue::Bool(r)) = (left_result, right_result) {
126                return Ok(JsonValue::Bool(l || r));
127            }
128        }
129
130        Err(RuleError::InvalidExpression(expr.to_string()))
131    }
132
133    fn get_value(&self, s: &str, context: &HashMap<String, JsonValue>) -> RuleResult {
134        // Try to parse as number
135        if let Ok(num) = s.parse::<i64>() {
136            return Ok(JsonValue::Number(num.into()));
137        }
138
139        // Try to parse as float
140        if let Ok(num) = s.parse::<f64>() {
141            if let Some(n) = serde_json::Number::from_f64(num) {
142                return Ok(JsonValue::Number(n));
143            }
144        }
145
146        // Try to parse as boolean
147        if s == "true" {
148            return Ok(JsonValue::Bool(true));
149        }
150        if s == "false" {
151            return Ok(JsonValue::Bool(false));
152        }
153
154        // Try string literal (quoted)
155        if s.starts_with('"') && s.ends_with('"') {
156            return Ok(JsonValue::String(s[1..s.len() - 1].to_string()));
157        }
158
159        // Otherwise, look up in context
160        context
161            .get(s)
162            .cloned()
163            .ok_or_else(|| RuleError::MissingVariable(s.to_string()))
164    }
165
166    fn compare_values(&self, left: &JsonValue, right: &JsonValue, ordering: std::cmp::Ordering) -> bool {
167        match (left, right) {
168            (JsonValue::Number(l), JsonValue::Number(r)) => {
169                if let (Some(l), Some(r)) = (l.as_f64(), r.as_f64()) {
170                    return l.partial_cmp(&r) == Some(ordering);
171                }
172            }
173            (JsonValue::String(l), JsonValue::String(r)) => {
174                return l.cmp(r) == ordering;
175            }
176            _ => {}
177        }
178        false
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    #[test]
187    fn test_simple_boolean() {
188        let rule = Rule::new("r1", "true");
189        let context = HashMap::new();
190        assert_eq!(rule.evaluate(&context).unwrap(), JsonValue::Bool(true));
191    }
192
193    #[test]
194    fn test_comparison() {
195        let mut context = HashMap::new();
196        context.insert("age".to_string(), JsonValue::Number(25.into()));
197
198        let rule = Rule::new("r1", "age > 18");
199        assert_eq!(rule.evaluate(&context).unwrap(), JsonValue::Bool(true));
200
201        let rule = Rule::new("r2", "age < 20");
202        assert_eq!(rule.evaluate(&context).unwrap(), JsonValue::Bool(false));
203    }
204
205    #[test]
206    fn test_logical_and() {
207        let mut context = HashMap::new();
208        context.insert("active".to_string(), JsonValue::Bool(true));
209        context.insert("verified".to_string(), JsonValue::Bool(true));
210
211        let rule = Rule::new("r1", "active && verified");
212        assert_eq!(rule.evaluate(&context).unwrap(), JsonValue::Bool(true));
213    }
214}