rust_logic_graph/rule/
mod.rs

1mod engine;
2
3// Re-export the convenience wrapper
4pub use engine::RuleEngine;
5
6// Re-export rust-rule-engine types for advanced usage
7pub use engine::{EngineConfig, Facts, GRLParser, KnowledgeBase, RustRuleEngine, Value};
8
9use serde_json::Value as JsonValue;
10use std::collections::HashMap;
11use thiserror::Error;
12
13#[derive(Debug, Error)]
14pub enum RuleError {
15    #[error("Rule evaluation failed: {0}")]
16    Eval(String),
17
18    #[error("Missing variable in context: {0}")]
19    MissingVariable(String),
20
21    #[error("Type mismatch: {0}")]
22    TypeMismatch(String),
23
24    #[error("Invalid expression: {0}")]
25    InvalidExpression(String),
26}
27
28pub type RuleResult = Result<JsonValue, RuleError>;
29
30/// Simple rule implementation (backward compatible)
31/// For advanced features, use RuleEngine directly
32#[derive(Debug, Clone)]
33pub struct Rule {
34    pub id: String,
35    pub condition: String,
36}
37
38impl Rule {
39    pub fn new(id: impl Into<String>, condition: impl Into<String>) -> Self {
40        Self {
41            id: id.into(),
42            condition: condition.into(),
43        }
44    }
45
46    /// Evaluate the rule against provided data context
47    pub fn evaluate(&self, context: &HashMap<String, JsonValue>) -> RuleResult {
48        let condition = self.condition.trim();
49
50        // Handle simple boolean literals
51        if condition == "true" {
52            return Ok(JsonValue::Bool(true));
53        }
54        if condition == "false" {
55            return Ok(JsonValue::Bool(false));
56        }
57
58        // Handle variable lookup (e.g., "user_active")
59        if !condition.contains(' ') && !condition.contains(['>', '<', '=', '!']) {
60            return context
61                .get(condition)
62                .cloned()
63                .ok_or_else(|| RuleError::MissingVariable(condition.to_string()));
64        }
65
66        // Handle comparisons (e.g., "age > 18", "status == active")
67        if let Some(result) = self.evaluate_comparison(condition, context) {
68            return result;
69        }
70
71        // Handle logical operations (e.g., "active && verified")
72        if condition.contains("&&") || condition.contains("||") {
73            return self.evaluate_logical(condition, context);
74        }
75
76        // Default to true if we can't parse (permissive)
77        Ok(JsonValue::Bool(true))
78    }
79
80    fn evaluate_comparison(
81        &self,
82        expr: &str,
83        context: &HashMap<String, JsonValue>,
84    ) -> Option<RuleResult> {
85        for op in ["==", "!=", ">=", "<=", ">", "<"] {
86            if let Some((left, right)) = expr.split_once(op) {
87                let left = left.trim();
88                let right = right.trim();
89
90                let left_val = self.get_value(left, context).ok()?;
91                let right_val = self.get_value(right, context).ok()?;
92
93                let result = match op {
94                    "==" => left_val == right_val,
95                    "!=" => left_val != right_val,
96                    ">" => self.compare_values(&left_val, &right_val, std::cmp::Ordering::Greater),
97                    "<" => self.compare_values(&left_val, &right_val, std::cmp::Ordering::Less),
98                    ">=" => {
99                        self.compare_values(&left_val, &right_val, std::cmp::Ordering::Greater)
100                            || left_val == right_val
101                    }
102                    "<=" => {
103                        self.compare_values(&left_val, &right_val, std::cmp::Ordering::Less)
104                            || left_val == right_val
105                    }
106                    _ => false,
107                };
108
109                return Some(Ok(JsonValue::Bool(result)));
110            }
111        }
112        None
113    }
114
115    fn evaluate_logical(&self, expr: &str, context: &HashMap<String, JsonValue>) -> RuleResult {
116        if let Some((left, right)) = expr.split_once("&&") {
117            let left_result = Rule::new("temp", left.trim()).evaluate(context)?;
118            let right_result = Rule::new("temp", right.trim()).evaluate(context)?;
119
120            if let (JsonValue::Bool(l), JsonValue::Bool(r)) = (left_result, right_result) {
121                return Ok(JsonValue::Bool(l && r));
122            }
123        }
124
125        if let Some((left, right)) = expr.split_once("||") {
126            let left_result = Rule::new("temp", left.trim()).evaluate(context)?;
127            let right_result = Rule::new("temp", right.trim()).evaluate(context)?;
128
129            if let (JsonValue::Bool(l), JsonValue::Bool(r)) = (left_result, right_result) {
130                return Ok(JsonValue::Bool(l || r));
131            }
132        }
133
134        Err(RuleError::InvalidExpression(expr.to_string()))
135    }
136
137    fn get_value(&self, s: &str, context: &HashMap<String, JsonValue>) -> RuleResult {
138        // Try to parse as number
139        if let Ok(num) = s.parse::<i64>() {
140            return Ok(JsonValue::Number(num.into()));
141        }
142
143        // Try to parse as float
144        if let Ok(num) = s.parse::<f64>() {
145            if let Some(n) = serde_json::Number::from_f64(num) {
146                return Ok(JsonValue::Number(n));
147            }
148        }
149
150        // Try to parse as boolean
151        if s == "true" {
152            return Ok(JsonValue::Bool(true));
153        }
154        if s == "false" {
155            return Ok(JsonValue::Bool(false));
156        }
157
158        // Try string literal (quoted)
159        if s.starts_with('"') && s.ends_with('"') {
160            return Ok(JsonValue::String(s[1..s.len() - 1].to_string()));
161        }
162
163        // Otherwise, look up in context
164        context
165            .get(s)
166            .cloned()
167            .ok_or_else(|| RuleError::MissingVariable(s.to_string()))
168    }
169
170    fn compare_values(
171        &self,
172        left: &JsonValue,
173        right: &JsonValue,
174        ordering: std::cmp::Ordering,
175    ) -> bool {
176        match (left, right) {
177            (JsonValue::Number(l), JsonValue::Number(r)) => {
178                if let (Some(l), Some(r)) = (l.as_f64(), r.as_f64()) {
179                    return l.partial_cmp(&r) == Some(ordering);
180                }
181            }
182            (JsonValue::String(l), JsonValue::String(r)) => {
183                return l.cmp(r) == ordering;
184            }
185            _ => {}
186        }
187        false
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[test]
196    fn test_simple_boolean() {
197        let rule = Rule::new("r1", "true");
198        let context = HashMap::new();
199        assert_eq!(rule.evaluate(&context).unwrap(), JsonValue::Bool(true));
200    }
201
202    #[test]
203    fn test_comparison() {
204        let mut context = HashMap::new();
205        context.insert("age".to_string(), JsonValue::Number(25.into()));
206
207        let rule = Rule::new("r1", "age > 18");
208        assert_eq!(rule.evaluate(&context).unwrap(), JsonValue::Bool(true));
209
210        let rule = Rule::new("r2", "age < 20");
211        assert_eq!(rule.evaluate(&context).unwrap(), JsonValue::Bool(false));
212    }
213
214    #[test]
215    fn test_logical_and() {
216        let mut context = HashMap::new();
217        context.insert("active".to_string(), JsonValue::Bool(true));
218        context.insert("verified".to_string(), JsonValue::Bool(true));
219
220        let rule = Rule::new("r1", "active && verified");
221        assert_eq!(rule.evaluate(&context).unwrap(), JsonValue::Bool(true));
222    }
223}