rust_logic_graph/rule/
mod.rs1mod engine;
2
3pub use engine::{RuleEngine, GrlRule};
4
5use serde_json::Value;
6use thiserror::Error;
7use std::collections::HashMap;
8
9#[derive(Debug, Error)]
10pub enum RuleError {
11 #[error("Rule evaluation failed: {0}")]
12 Eval(String),
13
14 #[error("Missing variable in context: {0}")]
15 MissingVariable(String),
16
17 #[error("Type mismatch: {0}")]
18 TypeMismatch(String),
19
20 #[error("Invalid expression: {0}")]
21 InvalidExpression(String),
22}
23
24pub type RuleResult = Result<Value, RuleError>;
25
26#[derive(Debug, Clone)]
29pub struct Rule {
30 pub id: String,
31 pub condition: String,
32}
33
34impl Rule {
35 pub fn new(id: impl Into<String>, condition: impl Into<String>) -> Self {
36 Self {
37 id: id.into(),
38 condition: condition.into(),
39 }
40 }
41
42 pub fn evaluate(&self, context: &HashMap<String, Value>) -> RuleResult {
44 let condition = self.condition.trim();
45
46 if condition == "true" {
48 return Ok(Value::Bool(true));
49 }
50 if condition == "false" {
51 return Ok(Value::Bool(false));
52 }
53
54 if !condition.contains(' ') && !condition.contains(['>', '<', '=', '!']) {
56 return context
57 .get(condition)
58 .cloned()
59 .ok_or_else(|| RuleError::MissingVariable(condition.to_string()));
60 }
61
62 if let Some(result) = self.evaluate_comparison(condition, context) {
64 return result;
65 }
66
67 if condition.contains("&&") || condition.contains("||") {
69 return self.evaluate_logical(condition, context);
70 }
71
72 Ok(Value::Bool(true))
74 }
75
76 fn evaluate_comparison(&self, expr: &str, context: &HashMap<String, Value>) -> Option<RuleResult> {
77 for op in ["==", "!=", ">=", "<=", ">", "<"] {
78 if let Some((left, right)) = expr.split_once(op) {
79 let left = left.trim();
80 let right = right.trim();
81
82 let left_val = self.get_value(left, context).ok()?;
83 let right_val = self.get_value(right, context).ok()?;
84
85 let result = match op {
86 "==" => left_val == right_val,
87 "!=" => left_val != right_val,
88 ">" => self.compare_values(&left_val, &right_val, std::cmp::Ordering::Greater),
89 "<" => self.compare_values(&left_val, &right_val, std::cmp::Ordering::Less),
90 ">=" => {
91 self.compare_values(&left_val, &right_val, std::cmp::Ordering::Greater)
92 || left_val == right_val
93 }
94 "<=" => {
95 self.compare_values(&left_val, &right_val, std::cmp::Ordering::Less)
96 || left_val == right_val
97 }
98 _ => false,
99 };
100
101 return Some(Ok(Value::Bool(result)));
102 }
103 }
104 None
105 }
106
107 fn evaluate_logical(&self, expr: &str, context: &HashMap<String, Value>) -> RuleResult {
108 if let Some((left, right)) = expr.split_once("&&") {
109 let left_result = Rule::new("temp", left.trim()).evaluate(context)?;
110 let right_result = Rule::new("temp", right.trim()).evaluate(context)?;
111
112 if let (Value::Bool(l), Value::Bool(r)) = (left_result, right_result) {
113 return Ok(Value::Bool(l && r));
114 }
115 }
116
117 if let Some((left, right)) = expr.split_once("||") {
118 let left_result = Rule::new("temp", left.trim()).evaluate(context)?;
119 let right_result = Rule::new("temp", right.trim()).evaluate(context)?;
120
121 if let (Value::Bool(l), Value::Bool(r)) = (left_result, right_result) {
122 return Ok(Value::Bool(l || r));
123 }
124 }
125
126 Err(RuleError::InvalidExpression(expr.to_string()))
127 }
128
129 fn get_value(&self, s: &str, context: &HashMap<String, Value>) -> RuleResult {
130 if let Ok(num) = s.parse::<i64>() {
132 return Ok(Value::Number(num.into()));
133 }
134
135 if let Ok(num) = s.parse::<f64>() {
137 if let Some(n) = serde_json::Number::from_f64(num) {
138 return Ok(Value::Number(n));
139 }
140 }
141
142 if s == "true" {
144 return Ok(Value::Bool(true));
145 }
146 if s == "false" {
147 return Ok(Value::Bool(false));
148 }
149
150 if s.starts_with('"') && s.ends_with('"') {
152 return Ok(Value::String(s[1..s.len() - 1].to_string()));
153 }
154
155 context
157 .get(s)
158 .cloned()
159 .ok_or_else(|| RuleError::MissingVariable(s.to_string()))
160 }
161
162 fn compare_values(&self, left: &Value, right: &Value, ordering: std::cmp::Ordering) -> bool {
163 match (left, right) {
164 (Value::Number(l), Value::Number(r)) => {
165 if let (Some(l), Some(r)) = (l.as_f64(), r.as_f64()) {
166 return l.partial_cmp(&r) == Some(ordering);
167 }
168 }
169 (Value::String(l), Value::String(r)) => {
170 return l.cmp(r) == ordering;
171 }
172 _ => {}
173 }
174 false
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn test_simple_boolean() {
184 let rule = Rule::new("r1", "true");
185 let context = HashMap::new();
186 assert_eq!(rule.evaluate(&context).unwrap(), Value::Bool(true));
187 }
188
189 #[test]
190 fn test_comparison() {
191 let mut context = HashMap::new();
192 context.insert("age".to_string(), Value::Number(25.into()));
193
194 let rule = Rule::new("r1", "age > 18");
195 assert_eq!(rule.evaluate(&context).unwrap(), Value::Bool(true));
196
197 let rule = Rule::new("r2", "age < 20");
198 assert_eq!(rule.evaluate(&context).unwrap(), Value::Bool(false));
199 }
200
201 #[test]
202 fn test_logical_and() {
203 let mut context = HashMap::new();
204 context.insert("active".to_string(), Value::Bool(true));
205 context.insert("verified".to_string(), Value::Bool(true));
206
207 let rule = Rule::new("r1", "active && verified");
208 assert_eq!(rule.evaluate(&context).unwrap(), Value::Bool(true));
209 }
210}