rust_logic_graph/rule/
engine.rs1pub use rust_rule_engine::{
3 engine::{
4 facts::Facts,
5 knowledge_base::KnowledgeBase,
6 EngineConfig,
7 RustRuleEngine,
8 },
9 parser::grl::GRLParser,
10 types::Value,
11};
12
13use serde_json::Value as JsonValue;
14use std::collections::HashMap;
15use tracing::debug;
16
17use super::{RuleError, RuleResult};
18
19pub struct RuleEngine {
54 engine: RustRuleEngine,
55}
56
57impl RuleEngine {
58 pub fn new() -> Self {
60 let kb = KnowledgeBase::new("LogicGraph");
61 Self {
62 engine: RustRuleEngine::new(kb),
63 }
64 }
65
66 pub fn with_config(config: EngineConfig) -> Self {
68 let kb = KnowledgeBase::new("LogicGraph");
69 Self {
70 engine: RustRuleEngine::with_config(kb, config),
71 }
72 }
73
74 pub fn add_grl_rule(&mut self, grl_content: &str) -> Result<(), RuleError> {
94 let start = std::time::Instant::now();
95 debug!("⏱️ [GRL Parse] Starting GRLParser::parse_rules()...");
96
97 let parse_start = std::time::Instant::now();
98 let rules = GRLParser::parse_rules(grl_content)
99 .map_err(|e| RuleError::Eval(format!("Failed to parse GRL: {}", e)))?;
100 let parse_elapsed = parse_start.elapsed();
101
102 let rule_count = rules.len();
103 debug!(" ✅ GRLParser::parse_rules() took {:.3}s for {} rules",
104 parse_elapsed.as_secs_f64(), rule_count);
105
106 debug!("⏱️ [GRL Add] Adding {} rules to knowledge_base...", rule_count);
107 let add_start = std::time::Instant::now();
108
109 for (idx, rule) in rules.into_iter().enumerate() {
110 let rule_start = std::time::Instant::now();
111 self.engine
112 .knowledge_base()
113 .add_rule(rule)
114 .map_err(|e| RuleError::Eval(format!("Failed to add rule: {}", e)))?;
115 let rule_elapsed = rule_start.elapsed();
116
117 if rule_elapsed.as_millis() > 10 {
118 debug!(" Rule #{} took {:.3}ms", idx + 1, rule_elapsed.as_secs_f64() * 1000.0);
119 }
120 }
121
122 let add_elapsed = add_start.elapsed();
123 debug!(" ✅ add_rule() loop took {:.3}s", add_elapsed.as_secs_f64());
124
125 let total_elapsed = start.elapsed();
126 debug!("✅ [GRL Total] Loaded {} GRL rules in {:.3}s", rule_count, total_elapsed.as_secs_f64());
127
128 Ok(())
129 }
130
131 pub fn evaluate(&mut self, context: &HashMap<String, JsonValue>) -> RuleResult {
151 let facts = Facts::new();
153
154 for (key, value) in context {
155 let rr_value = match value {
156 JsonValue::Bool(b) => Value::Boolean(*b),
157 JsonValue::Number(n) => {
158 if let Some(f) = n.as_f64() {
159 Value::Number(f)
160 } else {
161 continue;
162 }
163 }
164 JsonValue::String(s) => Value::String(s.clone()),
165 _ => {
166 debug!("Skipping unsupported value type for key: {}", key);
167 continue;
168 }
169 };
170
171 facts.set(&key, rr_value);
172 }
173
174 match self.engine.execute(&facts) {
176 Ok(_) => {
177 debug!("Rules executed successfully");
178
179 let convert_value = |val: &Value| -> Option<JsonValue> {
181 match val {
182 Value::Boolean(b) => Some(JsonValue::Bool(*b)),
183 Value::Number(n) => Some(JsonValue::from(*n)),
184 Value::String(s) => Some(JsonValue::String(s.clone())),
185 Value::Integer(i) => Some(JsonValue::from(*i)),
186 _ => None,
187 }
188 };
189
190 let all_facts = facts.get_all_facts();
193
194 let mut result = HashMap::new();
195 for (key, value) in all_facts {
196 if let Some(json_value) = convert_value(&value) {
197 result.insert(key, json_value);
198 }
199 }
200
201 Ok(JsonValue::Object(result.into_iter().collect()))
202 }
203 Err(e) => Err(RuleError::Eval(format!("Rule execution failed: {}", e))),
204 }
205 }
206
207 pub fn from_grl(grl_script: &str) -> Result<Self, RuleError> {
226 let mut engine = Self::new();
227 engine.add_grl_rule(grl_script)?;
228 Ok(engine)
229 }
230
231 pub fn inner(&self) -> &RustRuleEngine {
240 &self.engine
241 }
242
243 pub fn inner_mut(&mut self) -> &mut RustRuleEngine {
245 &mut self.engine
246 }
247}
248
249impl Default for RuleEngine {
250 fn default() -> Self {
251 Self::new()
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use serde_json::json;
259
260 #[test]
261 fn test_rule_engine_creation() {
262 let _engine = RuleEngine::new();
263 }
264
265 #[test]
266 fn test_engine_evaluation() {
267 let mut engine = RuleEngine::new();
268
269 let grl = r#"
270 rule "test_rule" {
271 salience 100
272 when
273 age >= 18
274 then
275 eligible = true;
276 }
277 "#;
278
279 engine.add_grl_rule(grl).unwrap();
280
281 let mut context = HashMap::new();
282 context.insert("age".to_string(), json!(20));
283
284 let result = engine.evaluate(&context).unwrap();
285 assert_eq!(result.get("eligible").unwrap().as_bool().unwrap(), true);
286 }
287
288 #[test]
289 fn test_from_grl() {
290 let grl = r#"
291 rule "test" {
292 salience 100
293 when
294 x > 0
295 then
296 result = true;
297 message = "x is positive";
298 }
299 "#;
300
301 let mut engine = RuleEngine::from_grl(grl).unwrap();
302
303 let mut context = HashMap::new();
304 context.insert("x".to_string(), json!(5));
305
306 let result = engine.evaluate(&context).unwrap();
307 assert_eq!(result.get("result").unwrap().as_bool().unwrap(), true);
308 assert_eq!(result.get("message").unwrap().as_str().unwrap(), "x is positive");
309 }
310
311 #[test]
312 fn test_multiple_rules_salience() {
313 let mut engine = RuleEngine::new();
314
315 let grl = r#"
316 rule "high_priority" {
317 salience 100
318 when
319 value > 100
320 then
321 priority = "high";
322 high_rule_fired = true;
323 }
324
325 rule "medium_priority" {
326 salience 50
327 when
328 value > 50 && value <= 100
329 then
330 priority = "medium";
331 medium_rule_fired = true;
332 }
333 "#;
334
335 engine.add_grl_rule(grl).unwrap();
336
337 let mut context = HashMap::new();
339 context.insert("value".to_string(), json!(150));
340
341 let result = engine.evaluate(&context).unwrap();
342 assert_eq!(result.get("priority").unwrap().as_str().unwrap(), "high");
344 assert_eq!(result.get("high_rule_fired").unwrap().as_bool().unwrap(), true);
345
346 let mut context2 = HashMap::new();
348 context2.insert("value".to_string(), json!(75));
349
350 let result2 = engine.evaluate(&context2).unwrap();
351 assert_eq!(result2.get("priority").unwrap().as_str().unwrap(), "medium");
353 assert_eq!(result2.get("medium_rule_fired").unwrap().as_bool().unwrap(), true);
354 }
355
356 #[test]
357 fn test_direct_engine_access() {
358 let engine = RuleEngine::new();
359
360 let inner = engine.inner();
362 let kb = inner.knowledge_base();
363
364 assert_eq!(kb.name(), "LogicGraph");
366 }
367}