rust_logic_graph/rule/
engine.rs1pub use rust_rule_engine::{
3 engine::{
4 facts::Facts,
5 knowledge_base::KnowledgeBase,
6 EngineConfig,
7 RustRuleEngine,
8 },
9 parser::grl::GRLParser,
10 types::Value,
11};
12
13use serde_json::Value as JsonValue;
14use std::collections::HashMap;
15use tracing::debug;
16
17use super::{RuleError, RuleResult};
18
19pub struct RuleEngine {
54 engine: RustRuleEngine,
55}
56
57impl RuleEngine {
58 pub fn new() -> Self {
60 let kb = KnowledgeBase::new("LogicGraph");
61 Self {
62 engine: RustRuleEngine::new(kb),
63 }
64 }
65
66 pub fn with_config(config: EngineConfig) -> Self {
68 let kb = KnowledgeBase::new("LogicGraph");
69 Self {
70 engine: RustRuleEngine::with_config(kb, config),
71 }
72 }
73
74 pub fn add_grl_rule(&mut self, grl_content: &str) -> Result<(), RuleError> {
94 let rules = GRLParser::parse_rules(grl_content)
95 .map_err(|e| RuleError::Eval(format!("Failed to parse GRL: {}", e)))?;
96
97 let rule_count = rules.len();
98
99 for rule in rules {
100 self.engine
101 .knowledge_base()
102 .add_rule(rule)
103 .map_err(|e| RuleError::Eval(format!("Failed to add rule: {}", e)))?;
104 }
105
106 debug!("Loaded {} GRL rules", rule_count);
107
108 Ok(())
109 }
110
111 pub fn evaluate(&mut self, context: &HashMap<String, JsonValue>) -> RuleResult {
131 let facts = Facts::new();
133
134 for (key, value) in context {
135 let rr_value = match value {
136 JsonValue::Bool(b) => Value::Boolean(*b),
137 JsonValue::Number(n) => {
138 if let Some(f) = n.as_f64() {
139 Value::Number(f)
140 } else {
141 continue;
142 }
143 }
144 JsonValue::String(s) => Value::String(s.clone()),
145 _ => {
146 debug!("Skipping unsupported value type for key: {}", key);
147 continue;
148 }
149 };
150
151 facts.set(&key, rr_value);
152 }
153
154 match self.engine.execute(&facts) {
156 Ok(_) => {
157 debug!("Rules executed successfully");
158
159 let convert_value = |val: &Value| -> Option<JsonValue> {
161 match val {
162 Value::Boolean(b) => Some(JsonValue::Bool(*b)),
163 Value::Number(n) => Some(JsonValue::from(*n)),
164 Value::String(s) => Some(JsonValue::String(s.clone())),
165 Value::Integer(i) => Some(JsonValue::from(*i)),
166 _ => None,
167 }
168 };
169
170 let all_facts = facts.get_all_facts();
173
174 let mut result = HashMap::new();
175 for (key, value) in all_facts {
176 if let Some(json_value) = convert_value(&value) {
177 result.insert(key, json_value);
178 }
179 }
180
181 Ok(JsonValue::Object(result.into_iter().collect()))
182 }
183 Err(e) => Err(RuleError::Eval(format!("Rule execution failed: {}", e))),
184 }
185 }
186
187 pub fn from_grl(grl_script: &str) -> Result<Self, RuleError> {
206 let mut engine = Self::new();
207 engine.add_grl_rule(grl_script)?;
208 Ok(engine)
209 }
210
211 pub fn inner(&self) -> &RustRuleEngine {
220 &self.engine
221 }
222
223 pub fn inner_mut(&mut self) -> &mut RustRuleEngine {
225 &mut self.engine
226 }
227}
228
229impl Default for RuleEngine {
230 fn default() -> Self {
231 Self::new()
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use serde_json::json;
239
240 #[test]
241 fn test_rule_engine_creation() {
242 let _engine = RuleEngine::new();
243 }
244
245 #[test]
246 fn test_engine_evaluation() {
247 let mut engine = RuleEngine::new();
248
249 let grl = r#"
250 rule "test_rule" {
251 salience 100
252 when
253 age >= 18
254 then
255 eligible = true;
256 }
257 "#;
258
259 engine.add_grl_rule(grl).unwrap();
260
261 let mut context = HashMap::new();
262 context.insert("age".to_string(), json!(20));
263
264 let result = engine.evaluate(&context).unwrap();
265 assert_eq!(result.get("eligible").unwrap().as_bool().unwrap(), true);
266 }
267
268 #[test]
269 fn test_from_grl() {
270 let grl = r#"
271 rule "test" {
272 salience 100
273 when
274 x > 0
275 then
276 result = true;
277 message = "x is positive";
278 }
279 "#;
280
281 let mut engine = RuleEngine::from_grl(grl).unwrap();
282
283 let mut context = HashMap::new();
284 context.insert("x".to_string(), json!(5));
285
286 let result = engine.evaluate(&context).unwrap();
287 assert_eq!(result.get("result").unwrap().as_bool().unwrap(), true);
288 assert_eq!(result.get("message").unwrap().as_str().unwrap(), "x is positive");
289 }
290
291 #[test]
292 fn test_multiple_rules_salience() {
293 let mut engine = RuleEngine::new();
294
295 let grl = r#"
296 rule "high_priority" {
297 salience 100
298 when
299 value > 100
300 then
301 priority = "high";
302 high_rule_fired = true;
303 }
304
305 rule "medium_priority" {
306 salience 50
307 when
308 value > 50 && value <= 100
309 then
310 priority = "medium";
311 medium_rule_fired = true;
312 }
313 "#;
314
315 engine.add_grl_rule(grl).unwrap();
316
317 let mut context = HashMap::new();
319 context.insert("value".to_string(), json!(150));
320
321 let result = engine.evaluate(&context).unwrap();
322 assert_eq!(result.get("priority").unwrap().as_str().unwrap(), "high");
324 assert_eq!(result.get("high_rule_fired").unwrap().as_bool().unwrap(), true);
325
326 let mut context2 = HashMap::new();
328 context2.insert("value".to_string(), json!(75));
329
330 let result2 = engine.evaluate(&context2).unwrap();
331 assert_eq!(result2.get("priority").unwrap().as_str().unwrap(), "medium");
333 assert_eq!(result2.get("medium_rule_fired").unwrap().as_bool().unwrap(), true);
334 }
335
336 #[test]
337 fn test_direct_engine_access() {
338 let engine = RuleEngine::new();
339
340 let inner = engine.inner();
342 let kb = inner.knowledge_base();
343
344 assert_eq!(kb.name(), "LogicGraph");
346 }
347}