rust_logic_graph/rule/
mod.rs1mod engine;
2
3pub use engine::RuleEngine;
5
6pub use engine::{Facts, KnowledgeBase, GRLParser, EngineConfig, RustRuleEngine, Value};
8
9use serde_json::Value as JsonValue;
10use thiserror::Error;
11use std::collections::HashMap;
12
13#[derive(Debug, Error)]
14pub enum RuleError {
15 #[error("Rule evaluation failed: {0}")]
16 Eval(String),
17
18 #[error("Missing variable in context: {0}")]
19 MissingVariable(String),
20
21 #[error("Type mismatch: {0}")]
22 TypeMismatch(String),
23
24 #[error("Invalid expression: {0}")]
25 InvalidExpression(String),
26}
27
28pub type RuleResult = Result<JsonValue, RuleError>;
29
30#[derive(Debug, Clone)]
33pub struct Rule {
34 pub id: String,
35 pub condition: String,
36}
37
38impl Rule {
39 pub fn new(id: impl Into<String>, condition: impl Into<String>) -> Self {
40 Self {
41 id: id.into(),
42 condition: condition.into(),
43 }
44 }
45
46 pub fn evaluate(&self, context: &HashMap<String, JsonValue>) -> RuleResult {
48 let condition = self.condition.trim();
49
50 if condition == "true" {
52 return Ok(JsonValue::Bool(true));
53 }
54 if condition == "false" {
55 return Ok(JsonValue::Bool(false));
56 }
57
58 if !condition.contains(' ') && !condition.contains(['>', '<', '=', '!']) {
60 return context
61 .get(condition)
62 .cloned()
63 .ok_or_else(|| RuleError::MissingVariable(condition.to_string()));
64 }
65
66 if let Some(result) = self.evaluate_comparison(condition, context) {
68 return result;
69 }
70
71 if condition.contains("&&") || condition.contains("||") {
73 return self.evaluate_logical(condition, context);
74 }
75
76 Ok(JsonValue::Bool(true))
78 }
79
80 fn evaluate_comparison(&self, expr: &str, context: &HashMap<String, JsonValue>) -> Option<RuleResult> {
81 for op in ["==", "!=", ">=", "<=", ">", "<"] {
82 if let Some((left, right)) = expr.split_once(op) {
83 let left = left.trim();
84 let right = right.trim();
85
86 let left_val = self.get_value(left, context).ok()?;
87 let right_val = self.get_value(right, context).ok()?;
88
89 let result = match op {
90 "==" => left_val == right_val,
91 "!=" => left_val != right_val,
92 ">" => self.compare_values(&left_val, &right_val, std::cmp::Ordering::Greater),
93 "<" => self.compare_values(&left_val, &right_val, std::cmp::Ordering::Less),
94 ">=" => {
95 self.compare_values(&left_val, &right_val, std::cmp::Ordering::Greater)
96 || left_val == right_val
97 }
98 "<=" => {
99 self.compare_values(&left_val, &right_val, std::cmp::Ordering::Less)
100 || left_val == right_val
101 }
102 _ => false,
103 };
104
105 return Some(Ok(JsonValue::Bool(result)));
106 }
107 }
108 None
109 }
110
111 fn evaluate_logical(&self, expr: &str, context: &HashMap<String, JsonValue>) -> RuleResult {
112 if let Some((left, right)) = expr.split_once("&&") {
113 let left_result = Rule::new("temp", left.trim()).evaluate(context)?;
114 let right_result = Rule::new("temp", right.trim()).evaluate(context)?;
115
116 if let (JsonValue::Bool(l), JsonValue::Bool(r)) = (left_result, right_result) {
117 return Ok(JsonValue::Bool(l && r));
118 }
119 }
120
121 if let Some((left, right)) = expr.split_once("||") {
122 let left_result = Rule::new("temp", left.trim()).evaluate(context)?;
123 let right_result = Rule::new("temp", right.trim()).evaluate(context)?;
124
125 if let (JsonValue::Bool(l), JsonValue::Bool(r)) = (left_result, right_result) {
126 return Ok(JsonValue::Bool(l || r));
127 }
128 }
129
130 Err(RuleError::InvalidExpression(expr.to_string()))
131 }
132
133 fn get_value(&self, s: &str, context: &HashMap<String, JsonValue>) -> RuleResult {
134 if let Ok(num) = s.parse::<i64>() {
136 return Ok(JsonValue::Number(num.into()));
137 }
138
139 if let Ok(num) = s.parse::<f64>() {
141 if let Some(n) = serde_json::Number::from_f64(num) {
142 return Ok(JsonValue::Number(n));
143 }
144 }
145
146 if s == "true" {
148 return Ok(JsonValue::Bool(true));
149 }
150 if s == "false" {
151 return Ok(JsonValue::Bool(false));
152 }
153
154 if s.starts_with('"') && s.ends_with('"') {
156 return Ok(JsonValue::String(s[1..s.len() - 1].to_string()));
157 }
158
159 context
161 .get(s)
162 .cloned()
163 .ok_or_else(|| RuleError::MissingVariable(s.to_string()))
164 }
165
166 fn compare_values(&self, left: &JsonValue, right: &JsonValue, ordering: std::cmp::Ordering) -> bool {
167 match (left, right) {
168 (JsonValue::Number(l), JsonValue::Number(r)) => {
169 if let (Some(l), Some(r)) = (l.as_f64(), r.as_f64()) {
170 return l.partial_cmp(&r) == Some(ordering);
171 }
172 }
173 (JsonValue::String(l), JsonValue::String(r)) => {
174 return l.cmp(r) == ordering;
175 }
176 _ => {}
177 }
178 false
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 #[test]
187 fn test_simple_boolean() {
188 let rule = Rule::new("r1", "true");
189 let context = HashMap::new();
190 assert_eq!(rule.evaluate(&context).unwrap(), JsonValue::Bool(true));
191 }
192
193 #[test]
194 fn test_comparison() {
195 let mut context = HashMap::new();
196 context.insert("age".to_string(), JsonValue::Number(25.into()));
197
198 let rule = Rule::new("r1", "age > 18");
199 assert_eq!(rule.evaluate(&context).unwrap(), JsonValue::Bool(true));
200
201 let rule = Rule::new("r2", "age < 20");
202 assert_eq!(rule.evaluate(&context).unwrap(), JsonValue::Bool(false));
203 }
204
205 #[test]
206 fn test_logical_and() {
207 let mut context = HashMap::new();
208 context.insert("active".to_string(), JsonValue::Bool(true));
209 context.insert("verified".to_string(), JsonValue::Bool(true));
210
211 let rule = Rule::new("r1", "active && verified");
212 assert_eq!(rule.evaluate(&context).unwrap(), JsonValue::Bool(true));
213 }
214}