rust_logic_graph/rule/
engine.rs1pub use rust_rule_engine::{
3 engine::{facts::Facts, knowledge_base::KnowledgeBase, EngineConfig, RustRuleEngine},
4 types::Value,
5 GRLParser, };
7
8use serde_json::Value as JsonValue;
9use std::collections::HashMap;
10use tracing::debug;
11
12use super::{RuleError, RuleResult};
13
14pub struct RuleEngine {
49 engine: RustRuleEngine,
50}
51
52impl RuleEngine {
53 pub fn new() -> Self {
55 let kb = KnowledgeBase::new("LogicGraph");
56 Self {
57 engine: RustRuleEngine::new(kb),
58 }
59 }
60
61 pub fn with_config(config: EngineConfig) -> Self {
63 let kb = KnowledgeBase::new("LogicGraph");
64 Self {
65 engine: RustRuleEngine::with_config(kb, config),
66 }
67 }
68
69 pub fn add_grl_rule(&mut self, grl_content: &str) -> Result<(), RuleError> {
89 let start = std::time::Instant::now();
90 debug!("⏱️ [GRL Parse] Starting GRLParser::parse_rules()...");
91
92 let parse_start = std::time::Instant::now();
93 let rules = GRLParser::parse_rules(grl_content)
94 .map_err(|e| RuleError::Eval(format!("Failed to parse GRL: {}", e)))?;
95 let parse_elapsed = parse_start.elapsed();
96
97 let rule_count = rules.len();
98 debug!(
99 " ✅ GRLParser::parse_rules() took {:.3}s for {} rules",
100 parse_elapsed.as_secs_f64(),
101 rule_count
102 );
103
104 debug!(
105 "⏱️ [GRL Add] Adding {} rules to knowledge_base...",
106 rule_count
107 );
108 let add_start = std::time::Instant::now();
109
110 for (idx, rule) in rules.into_iter().enumerate() {
111 let rule_start = std::time::Instant::now();
112 self.engine
113 .knowledge_base()
114 .add_rule(rule)
115 .map_err(|e| RuleError::Eval(format!("Failed to add rule: {}", e)))?;
116 let rule_elapsed = rule_start.elapsed();
117
118 if rule_elapsed.as_millis() > 10 {
119 debug!(
120 " Rule #{} took {:.3}ms",
121 idx + 1,
122 rule_elapsed.as_secs_f64() * 1000.0
123 );
124 }
125 }
126
127 let add_elapsed = add_start.elapsed();
128 debug!(
129 " ✅ add_rule() loop took {:.3}s",
130 add_elapsed.as_secs_f64()
131 );
132
133 let total_elapsed = start.elapsed();
134 debug!(
135 "✅ [GRL Total] Loaded {} GRL rules in {:.3}s",
136 rule_count,
137 total_elapsed.as_secs_f64()
138 );
139
140 Ok(())
141 }
142
143 pub fn evaluate(&mut self, context: &HashMap<String, JsonValue>) -> RuleResult {
163 let facts = Facts::new();
165
166 for (key, value) in context {
167 let rr_value = match value {
168 JsonValue::Bool(b) => Value::Boolean(*b),
169 JsonValue::Number(n) => {
170 if let Some(f) = n.as_f64() {
171 Value::Number(f)
172 } else {
173 continue;
174 }
175 }
176 JsonValue::String(s) => Value::String(s.clone()),
177 _ => {
178 debug!("Skipping unsupported value type for key: {}", key);
179 continue;
180 }
181 };
182
183 facts.set(&key, rr_value);
184 }
185
186 match self.engine.execute(&facts) {
188 Ok(_) => {
189 debug!("Rules executed successfully");
190
191 let convert_value = |val: &Value| -> Option<JsonValue> {
193 match val {
194 Value::Boolean(b) => Some(JsonValue::Bool(*b)),
195 Value::Number(n) => Some(JsonValue::from(*n)),
196 Value::String(s) => Some(JsonValue::String(s.clone())),
197 Value::Integer(i) => Some(JsonValue::from(*i)),
198 _ => None,
199 }
200 };
201
202 let all_facts = facts.get_all_facts();
205
206 let mut result = HashMap::new();
207 for (key, value) in all_facts {
208 if let Some(json_value) = convert_value(&value) {
209 result.insert(key, json_value);
210 }
211 }
212
213 Ok(JsonValue::Object(result.into_iter().collect()))
214 }
215 Err(e) => Err(RuleError::Eval(format!("Rule execution failed: {}", e))),
216 }
217 }
218
219 pub fn from_grl(grl_script: &str) -> Result<Self, RuleError> {
238 let mut engine = Self::new();
239 engine.add_grl_rule(grl_script)?;
240 Ok(engine)
241 }
242
243 pub fn inner(&self) -> &RustRuleEngine {
252 &self.engine
253 }
254
255 pub fn inner_mut(&mut self) -> &mut RustRuleEngine {
257 &mut self.engine
258 }
259}
260
261impl Default for RuleEngine {
262 fn default() -> Self {
263 Self::new()
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use serde_json::json;
271
272 #[test]
273 fn test_rule_engine_creation() {
274 let _engine = RuleEngine::new();
275 }
276
277 #[test]
278 fn test_engine_evaluation() {
279 let mut engine = RuleEngine::new();
280
281 let grl = r#"
282 rule "test_rule" {
283 salience 100
284 when
285 age >= 18
286 then
287 eligible = true;
288 }
289 "#;
290
291 engine.add_grl_rule(grl).unwrap();
292
293 let mut context = HashMap::new();
294 context.insert("age".to_string(), json!(20));
295
296 let result = engine.evaluate(&context).unwrap();
297 assert_eq!(result.get("eligible").unwrap().as_bool().unwrap(), true);
298 }
299
300 #[test]
301 fn test_from_grl() {
302 let grl = r#"
303 rule "test" {
304 salience 100
305 when
306 x > 0
307 then
308 result = true;
309 message = "x is positive";
310 }
311 "#;
312
313 let mut engine = RuleEngine::from_grl(grl).unwrap();
314
315 let mut context = HashMap::new();
316 context.insert("x".to_string(), json!(5));
317
318 let result = engine.evaluate(&context).unwrap();
319 assert_eq!(result.get("result").unwrap().as_bool().unwrap(), true);
320 assert_eq!(
321 result.get("message").unwrap().as_str().unwrap(),
322 "x is positive"
323 );
324 }
325
326 #[test]
327 fn test_multiple_rules_salience() {
328 let mut engine = RuleEngine::new();
329
330 let grl = r#"
331 rule "high_priority" {
332 salience 100
333 when
334 value > 100
335 then
336 priority = "high";
337 high_rule_fired = true;
338 }
339
340 rule "medium_priority" {
341 salience 50
342 when
343 value > 50 && value <= 100
344 then
345 priority = "medium";
346 medium_rule_fired = true;
347 }
348 "#;
349
350 engine.add_grl_rule(grl).unwrap();
351
352 let mut context = HashMap::new();
354 context.insert("value".to_string(), json!(150));
355
356 let result = engine.evaluate(&context).unwrap();
357 assert_eq!(result.get("priority").unwrap().as_str().unwrap(), "high");
359 assert_eq!(
360 result.get("high_rule_fired").unwrap().as_bool().unwrap(),
361 true
362 );
363
364 let mut context2 = HashMap::new();
366 context2.insert("value".to_string(), json!(75));
367
368 let result2 = engine.evaluate(&context2).unwrap();
369 assert_eq!(result2.get("priority").unwrap().as_str().unwrap(), "medium");
371 assert_eq!(
372 result2.get("medium_rule_fired").unwrap().as_bool().unwrap(),
373 true
374 );
375 }
376
377 #[test]
378 fn test_direct_engine_access() {
379 let engine = RuleEngine::new();
380
381 let inner = engine.inner();
383 let kb = inner.knowledge_base();
384
385 assert_eq!(kb.name(), "LogicGraph");
387 }
388}