rust_logic_graph/rule/
mod.rs

1mod engine;
2
3pub use engine::{RuleEngine, GrlRule};
4
5use serde_json::Value;
6use thiserror::Error;
7use std::collections::HashMap;
8
9#[derive(Debug, Error)]
10pub enum RuleError {
11    #[error("Rule evaluation failed: {0}")]
12    Eval(String),
13
14    #[error("Missing variable in context: {0}")]
15    MissingVariable(String),
16
17    #[error("Type mismatch: {0}")]
18    TypeMismatch(String),
19
20    #[error("Invalid expression: {0}")]
21    InvalidExpression(String),
22}
23
24pub type RuleResult = Result<Value, RuleError>;
25
26/// Simple rule implementation (backward compatible)
27/// For advanced features, use RuleEngine or GrlRule
28#[derive(Debug, Clone)]
29pub struct Rule {
30    pub id: String,
31    pub condition: String,
32}
33
34impl Rule {
35    pub fn new(id: impl Into<String>, condition: impl Into<String>) -> Self {
36        Self {
37            id: id.into(),
38            condition: condition.into(),
39        }
40    }
41
42    /// Evaluate the rule against provided data context
43    pub fn evaluate(&self, context: &HashMap<String, Value>) -> RuleResult {
44        let condition = self.condition.trim();
45
46        // Handle simple boolean literals
47        if condition == "true" {
48            return Ok(Value::Bool(true));
49        }
50        if condition == "false" {
51            return Ok(Value::Bool(false));
52        }
53
54        // Handle variable lookup (e.g., "user_active")
55        if !condition.contains(' ') && !condition.contains(['>', '<', '=', '!']) {
56            return context
57                .get(condition)
58                .cloned()
59                .ok_or_else(|| RuleError::MissingVariable(condition.to_string()));
60        }
61
62        // Handle comparisons (e.g., "age > 18", "status == active")
63        if let Some(result) = self.evaluate_comparison(condition, context) {
64            return result;
65        }
66
67        // Handle logical operations (e.g., "active && verified")
68        if condition.contains("&&") || condition.contains("||") {
69            return self.evaluate_logical(condition, context);
70        }
71
72        // Default to true if we can't parse (permissive)
73        Ok(Value::Bool(true))
74    }
75
76    fn evaluate_comparison(&self, expr: &str, context: &HashMap<String, Value>) -> Option<RuleResult> {
77        for op in ["==", "!=", ">=", "<=", ">", "<"] {
78            if let Some((left, right)) = expr.split_once(op) {
79                let left = left.trim();
80                let right = right.trim();
81
82                let left_val = self.get_value(left, context).ok()?;
83                let right_val = self.get_value(right, context).ok()?;
84
85                let result = match op {
86                    "==" => left_val == right_val,
87                    "!=" => left_val != right_val,
88                    ">" => self.compare_values(&left_val, &right_val, std::cmp::Ordering::Greater),
89                    "<" => self.compare_values(&left_val, &right_val, std::cmp::Ordering::Less),
90                    ">=" => {
91                        self.compare_values(&left_val, &right_val, std::cmp::Ordering::Greater)
92                            || left_val == right_val
93                    }
94                    "<=" => {
95                        self.compare_values(&left_val, &right_val, std::cmp::Ordering::Less)
96                            || left_val == right_val
97                    }
98                    _ => false,
99                };
100
101                return Some(Ok(Value::Bool(result)));
102            }
103        }
104        None
105    }
106
107    fn evaluate_logical(&self, expr: &str, context: &HashMap<String, Value>) -> RuleResult {
108        if let Some((left, right)) = expr.split_once("&&") {
109            let left_result = Rule::new("temp", left.trim()).evaluate(context)?;
110            let right_result = Rule::new("temp", right.trim()).evaluate(context)?;
111
112            if let (Value::Bool(l), Value::Bool(r)) = (left_result, right_result) {
113                return Ok(Value::Bool(l && r));
114            }
115        }
116
117        if let Some((left, right)) = expr.split_once("||") {
118            let left_result = Rule::new("temp", left.trim()).evaluate(context)?;
119            let right_result = Rule::new("temp", right.trim()).evaluate(context)?;
120
121            if let (Value::Bool(l), Value::Bool(r)) = (left_result, right_result) {
122                return Ok(Value::Bool(l || r));
123            }
124        }
125
126        Err(RuleError::InvalidExpression(expr.to_string()))
127    }
128
129    fn get_value(&self, s: &str, context: &HashMap<String, Value>) -> RuleResult {
130        // Try to parse as number
131        if let Ok(num) = s.parse::<i64>() {
132            return Ok(Value::Number(num.into()));
133        }
134
135        // Try to parse as float
136        if let Ok(num) = s.parse::<f64>() {
137            if let Some(n) = serde_json::Number::from_f64(num) {
138                return Ok(Value::Number(n));
139            }
140        }
141
142        // Try to parse as boolean
143        if s == "true" {
144            return Ok(Value::Bool(true));
145        }
146        if s == "false" {
147            return Ok(Value::Bool(false));
148        }
149
150        // Try string literal (quoted)
151        if s.starts_with('"') && s.ends_with('"') {
152            return Ok(Value::String(s[1..s.len() - 1].to_string()));
153        }
154
155        // Otherwise, look up in context
156        context
157            .get(s)
158            .cloned()
159            .ok_or_else(|| RuleError::MissingVariable(s.to_string()))
160    }
161
162    fn compare_values(&self, left: &Value, right: &Value, ordering: std::cmp::Ordering) -> bool {
163        match (left, right) {
164            (Value::Number(l), Value::Number(r)) => {
165                if let (Some(l), Some(r)) = (l.as_f64(), r.as_f64()) {
166                    return l.partial_cmp(&r) == Some(ordering);
167                }
168            }
169            (Value::String(l), Value::String(r)) => {
170                return l.cmp(r) == ordering;
171            }
172            _ => {}
173        }
174        false
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn test_simple_boolean() {
184        let rule = Rule::new("r1", "true");
185        let context = HashMap::new();
186        assert_eq!(rule.evaluate(&context).unwrap(), Value::Bool(true));
187    }
188
189    #[test]
190    fn test_comparison() {
191        let mut context = HashMap::new();
192        context.insert("age".to_string(), Value::Number(25.into()));
193
194        let rule = Rule::new("r1", "age > 18");
195        assert_eq!(rule.evaluate(&context).unwrap(), Value::Bool(true));
196
197        let rule = Rule::new("r2", "age < 20");
198        assert_eq!(rule.evaluate(&context).unwrap(), Value::Bool(false));
199    }
200
201    #[test]
202    fn test_logical_and() {
203        let mut context = HashMap::new();
204        context.insert("active".to_string(), Value::Bool(true));
205        context.insert("verified".to_string(), Value::Bool(true));
206
207        let rule = Rule::new("r1", "active && verified");
208        assert_eq!(rule.evaluate(&context).unwrap(), Value::Bool(true));
209    }
210}