rust_logic_graph/rule/
mod.rs1mod engine;
2
3pub use engine::RuleEngine;
5
6pub use engine::{EngineConfig, Facts, GRLParser, KnowledgeBase, RustRuleEngine, Value};
8
9use serde_json::Value as JsonValue;
10use std::collections::HashMap;
11use thiserror::Error;
12
13#[derive(Debug, Error)]
14pub enum RuleError {
15 #[error("Rule evaluation failed: {0}")]
16 Eval(String),
17
18 #[error("Missing variable in context: {0}")]
19 MissingVariable(String),
20
21 #[error("Type mismatch: {0}")]
22 TypeMismatch(String),
23
24 #[error("Invalid expression: {0}")]
25 InvalidExpression(String),
26}
27
28pub type RuleResult = Result<JsonValue, RuleError>;
29
30#[derive(Debug, Clone)]
33pub struct Rule {
34 pub id: String,
35 pub condition: String,
36}
37
38impl Rule {
39 pub fn new(id: impl Into<String>, condition: impl Into<String>) -> Self {
40 Self {
41 id: id.into(),
42 condition: condition.into(),
43 }
44 }
45
46 pub fn evaluate(&self, context: &HashMap<String, JsonValue>) -> RuleResult {
48 let condition = self.condition.trim();
49
50 if condition == "true" {
52 return Ok(JsonValue::Bool(true));
53 }
54 if condition == "false" {
55 return Ok(JsonValue::Bool(false));
56 }
57
58 if !condition.contains(' ') && !condition.contains(['>', '<', '=', '!']) {
60 return context
61 .get(condition)
62 .cloned()
63 .ok_or_else(|| RuleError::MissingVariable(condition.to_string()));
64 }
65
66 if let Some(result) = self.evaluate_comparison(condition, context) {
68 return result;
69 }
70
71 if condition.contains("&&") || condition.contains("||") {
73 return self.evaluate_logical(condition, context);
74 }
75
76 Ok(JsonValue::Bool(true))
78 }
79
80 fn evaluate_comparison(
81 &self,
82 expr: &str,
83 context: &HashMap<String, JsonValue>,
84 ) -> Option<RuleResult> {
85 for op in ["==", "!=", ">=", "<=", ">", "<"] {
86 if let Some((left, right)) = expr.split_once(op) {
87 let left = left.trim();
88 let right = right.trim();
89
90 let left_val = self.get_value(left, context).ok()?;
91 let right_val = self.get_value(right, context).ok()?;
92
93 let result = match op {
94 "==" => left_val == right_val,
95 "!=" => left_val != right_val,
96 ">" => self.compare_values(&left_val, &right_val, std::cmp::Ordering::Greater),
97 "<" => self.compare_values(&left_val, &right_val, std::cmp::Ordering::Less),
98 ">=" => {
99 self.compare_values(&left_val, &right_val, std::cmp::Ordering::Greater)
100 || left_val == right_val
101 }
102 "<=" => {
103 self.compare_values(&left_val, &right_val, std::cmp::Ordering::Less)
104 || left_val == right_val
105 }
106 _ => false,
107 };
108
109 return Some(Ok(JsonValue::Bool(result)));
110 }
111 }
112 None
113 }
114
115 fn evaluate_logical(&self, expr: &str, context: &HashMap<String, JsonValue>) -> RuleResult {
116 if let Some((left, right)) = expr.split_once("&&") {
117 let left_result = Rule::new("temp", left.trim()).evaluate(context)?;
118 let right_result = Rule::new("temp", right.trim()).evaluate(context)?;
119
120 if let (JsonValue::Bool(l), JsonValue::Bool(r)) = (left_result, right_result) {
121 return Ok(JsonValue::Bool(l && r));
122 }
123 }
124
125 if let Some((left, right)) = expr.split_once("||") {
126 let left_result = Rule::new("temp", left.trim()).evaluate(context)?;
127 let right_result = Rule::new("temp", right.trim()).evaluate(context)?;
128
129 if let (JsonValue::Bool(l), JsonValue::Bool(r)) = (left_result, right_result) {
130 return Ok(JsonValue::Bool(l || r));
131 }
132 }
133
134 Err(RuleError::InvalidExpression(expr.to_string()))
135 }
136
137 fn get_value(&self, s: &str, context: &HashMap<String, JsonValue>) -> RuleResult {
138 if let Ok(num) = s.parse::<i64>() {
140 return Ok(JsonValue::Number(num.into()));
141 }
142
143 if let Ok(num) = s.parse::<f64>() {
145 if let Some(n) = serde_json::Number::from_f64(num) {
146 return Ok(JsonValue::Number(n));
147 }
148 }
149
150 if s == "true" {
152 return Ok(JsonValue::Bool(true));
153 }
154 if s == "false" {
155 return Ok(JsonValue::Bool(false));
156 }
157
158 if s.starts_with('"') && s.ends_with('"') {
160 return Ok(JsonValue::String(s[1..s.len() - 1].to_string()));
161 }
162
163 context
165 .get(s)
166 .cloned()
167 .ok_or_else(|| RuleError::MissingVariable(s.to_string()))
168 }
169
170 fn compare_values(
171 &self,
172 left: &JsonValue,
173 right: &JsonValue,
174 ordering: std::cmp::Ordering,
175 ) -> bool {
176 match (left, right) {
177 (JsonValue::Number(l), JsonValue::Number(r)) => {
178 if let (Some(l), Some(r)) = (l.as_f64(), r.as_f64()) {
179 return l.partial_cmp(&r) == Some(ordering);
180 }
181 }
182 (JsonValue::String(l), JsonValue::String(r)) => {
183 return l.cmp(r) == ordering;
184 }
185 _ => {}
186 }
187 false
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn test_simple_boolean() {
197 let rule = Rule::new("r1", "true");
198 let context = HashMap::new();
199 assert_eq!(rule.evaluate(&context).unwrap(), JsonValue::Bool(true));
200 }
201
202 #[test]
203 fn test_comparison() {
204 let mut context = HashMap::new();
205 context.insert("age".to_string(), JsonValue::Number(25.into()));
206
207 let rule = Rule::new("r1", "age > 18");
208 assert_eq!(rule.evaluate(&context).unwrap(), JsonValue::Bool(true));
209
210 let rule = Rule::new("r2", "age < 20");
211 assert_eq!(rule.evaluate(&context).unwrap(), JsonValue::Bool(false));
212 }
213
214 #[test]
215 fn test_logical_and() {
216 let mut context = HashMap::new();
217 context.insert("active".to_string(), JsonValue::Bool(true));
218 context.insert("verified".to_string(), JsonValue::Bool(true));
219
220 let rule = Rule::new("r1", "active && verified");
221 assert_eq!(rule.evaluate(&context).unwrap(), JsonValue::Bool(true));
222 }
223}