1use crate::engine::{facts::Facts, knowledge_base::KnowledgeBase, rule::Rule};
2use crate::errors::{Result, RuleEngineError};
3use crate::types::{ActionType, Value};
4use std::collections::HashMap;
5use std::sync::{Arc, Mutex, RwLock};
6use std::thread;
7use std::time::{Duration, Instant};
8
9#[derive(Debug, Clone)]
11pub struct ParallelConfig {
12 pub enabled: bool,
14 pub max_threads: usize,
16 pub min_rules_per_thread: usize,
18 pub dependency_analysis: bool,
20}
21
22impl Default for ParallelConfig {
23 fn default() -> Self {
24 Self {
25 enabled: true,
26 max_threads: num_cpus::get(),
27 min_rules_per_thread: 2,
28 dependency_analysis: true,
29 }
30 }
31}
32
33type CustomFunctionMap =
35 HashMap<String, Box<dyn Fn(&[Value], &Facts) -> Result<Value> + Send + Sync>>;
36
37#[derive(Debug, Clone)]
39pub struct RuleExecutionContext {
40 pub rule: Rule,
42 pub fired: bool,
44 pub error: Option<String>,
46 pub execution_time: Duration,
48}
49
50pub struct ParallelRuleEngine {
52 config: ParallelConfig,
53 custom_functions: Arc<RwLock<CustomFunctionMap>>,
54}
55
56impl ParallelRuleEngine {
57 pub fn new(config: ParallelConfig) -> Self {
59 Self {
60 config,
61 custom_functions: Arc::new(RwLock::new(HashMap::new())),
62 }
63 }
64
65 pub fn register_function<F>(&mut self, name: &str, func: F)
67 where
68 F: Fn(&[Value], &Facts) -> Result<Value> + Send + Sync + 'static,
69 {
70 let mut functions = self.custom_functions.write().unwrap();
71 functions.insert(name.to_string(), Box::new(func));
72 }
73
74 pub fn execute_parallel(
76 &self,
77 knowledge_base: &KnowledgeBase,
78 facts: &Facts,
79 debug_mode: bool,
80 ) -> Result<ParallelExecutionResult> {
81 let start_time = Instant::now();
82
83 if debug_mode {
84 println!(
85 "๐ Starting parallel rule execution with {} rules",
86 knowledge_base.get_rules().len()
87 );
88 }
89
90 let salience_groups = self.group_rules_by_salience(&knowledge_base.get_rules());
92
93 let mut total_fired = 0;
94 let mut total_evaluated = 0;
95 let mut execution_contexts = Vec::new();
96
97 let mut salience_levels: Vec<_> = salience_groups.keys().copied().collect();
99 salience_levels.sort_by(|a, b| b.cmp(a)); for salience in salience_levels {
102 let rules_at_level = &salience_groups[&salience];
103
104 if debug_mode {
105 println!(
106 "โก Processing {} rules at salience level {}",
107 rules_at_level.len(),
108 salience
109 );
110 }
111
112 let should_parallelize = self.should_parallelize(rules_at_level);
114
115 let contexts = if should_parallelize {
116 self.execute_rules_parallel(rules_at_level, facts, debug_mode)?
117 } else {
118 self.execute_rules_sequential(rules_at_level, facts, debug_mode)?
119 };
120
121 for context in &contexts {
123 total_evaluated += 1;
124 if context.fired {
125 total_fired += 1;
126 }
127 }
128
129 execution_contexts.extend(contexts);
130 }
131
132 Ok(ParallelExecutionResult {
133 total_rules_evaluated: total_evaluated,
134 total_rules_fired: total_fired,
135 execution_time: start_time.elapsed(),
136 parallel_speedup: self.calculate_speedup(&execution_contexts),
137 execution_contexts,
138 })
139 }
140
141 fn group_rules_by_salience(&self, rules: &[Rule]) -> HashMap<i32, Vec<Rule>> {
143 let mut groups = HashMap::new();
144 for rule in rules {
145 if rule.enabled {
146 groups
147 .entry(rule.salience)
148 .or_insert_with(Vec::new)
149 .push(rule.clone());
150 }
151 }
152 groups
153 }
154
155 fn should_parallelize(&self, rules: &[Rule]) -> bool {
157 self.config.enabled && rules.len() >= self.config.min_rules_per_thread && rules.len() >= 2
158 }
159
160 fn execute_rules_parallel(
162 &self,
163 rules: &[Rule],
164 facts: &Facts,
165 debug_mode: bool,
166 ) -> Result<Vec<RuleExecutionContext>> {
167 let results = Arc::new(Mutex::new(Vec::new()));
168 let facts_arc = Arc::new(facts.clone());
169 let functions_arc = Arc::clone(&self.custom_functions);
170
171 let chunk_size = rules.len().div_ceil(self.config.max_threads);
173 let chunks: Vec<_> = rules.chunks(chunk_size).collect();
174
175 let handles: Vec<_> = chunks
176 .into_iter()
177 .enumerate()
178 .map(|(thread_id, chunk)| {
179 let chunk = chunk.to_vec();
180 let results_clone = Arc::clone(&results);
181 let facts_clone = Arc::clone(&facts_arc);
182 let functions_clone = Arc::clone(&functions_arc);
183
184 thread::spawn(move || {
185 if debug_mode {
186 println!(" ๐งต Thread {} processing {} rules", thread_id, chunk.len());
187 }
188
189 let mut thread_results = Vec::new();
190 for rule in chunk {
191 let start = Instant::now();
192 let fired = Self::evaluate_rule_conditions(&rule, &facts_clone);
193
194 if fired {
195 if debug_mode {
196 println!(" ๐ฅ Rule '{}' fired", rule.name);
197 }
198
199 for action in &rule.actions {
201 if let Err(e) = Self::execute_action_parallel(
202 action,
203 &facts_clone,
204 &functions_clone,
205 ) {
206 if debug_mode {
207 println!(" โ Action failed: {}", e);
208 }
209 }
210 }
211 }
212
213 thread_results.push(RuleExecutionContext {
214 rule: rule.clone(),
215 fired,
216 error: None,
217 execution_time: start.elapsed(),
218 });
219 }
220
221 let mut results = results_clone.lock().unwrap();
222 results.extend(thread_results);
223 })
224 })
225 .collect();
226
227 for handle in handles {
229 handle
230 .join()
231 .map_err(|_| RuleEngineError::EvaluationError {
232 message: "Thread panicked during parallel execution".to_string(),
233 })?;
234 }
235
236 let results = results.lock().unwrap();
237 Ok(results.clone())
238 }
239
240 fn execute_rules_sequential(
242 &self,
243 rules: &[Rule],
244 facts: &Facts,
245 debug_mode: bool,
246 ) -> Result<Vec<RuleExecutionContext>> {
247 let mut contexts = Vec::new();
248 let functions_arc = Arc::clone(&self.custom_functions);
249
250 for rule in rules {
251 let start = Instant::now();
252 let fired = Self::evaluate_rule_conditions(rule, facts);
253
254 if fired && debug_mode {
255 println!(" ๐ฅ Rule '{}' fired", rule.name);
256 }
257
258 if fired {
259 for action in &rule.actions {
261 if let Err(e) = Self::execute_action_parallel(action, facts, &functions_arc) {
262 if debug_mode {
263 println!(" โ Action failed: {}", e);
264 }
265 }
266 }
267 }
268
269 contexts.push(RuleExecutionContext {
270 rule: rule.clone(),
271 fired,
272 error: None,
273 execution_time: start.elapsed(),
274 });
275 }
276
277 Ok(contexts)
278 }
279
280 fn evaluate_rule_conditions(rule: &Rule, _facts: &Facts) -> bool {
282 !rule.actions.is_empty()
285 }
286
287 fn execute_action_parallel(
289 action: &ActionType,
290 facts: &Facts,
291 functions: &Arc<RwLock<CustomFunctionMap>>,
292 ) -> Result<()> {
293 match action {
294 ActionType::Call { function, args } => {
295 let functions_guard = functions.read().unwrap();
296 if let Some(func) = functions_guard.get(function) {
297 let _result = func(args, facts)?;
298 }
299 Ok(())
300 }
301 ActionType::MethodCall { .. } => {
302 Ok(())
304 }
305 ActionType::Set { .. } => {
306 Ok(())
308 }
309 ActionType::Log { message } => {
310 println!(" ๐ {}", message);
311 Ok(())
312 }
313 ActionType::Update { .. } => {
314 Ok(())
316 }
317 ActionType::Custom { .. } => {
318 Ok(())
320 }
321 }
322 }
323
324 fn calculate_speedup(&self, contexts: &[RuleExecutionContext]) -> f64 {
326 if contexts.is_empty() {
327 return 1.0;
328 }
329
330 let total_time: Duration = contexts.iter().map(|c| c.execution_time).sum();
331 let max_time = contexts
332 .iter()
333 .map(|c| c.execution_time)
334 .max()
335 .unwrap_or(Duration::ZERO);
336
337 if max_time.as_nanos() > 0 {
338 total_time.as_nanos() as f64 / max_time.as_nanos() as f64
339 } else {
340 1.0
341 }
342 }
343}
344
345#[derive(Debug)]
347pub struct ParallelExecutionResult {
348 pub total_rules_evaluated: usize,
350 pub total_rules_fired: usize,
352 pub execution_time: Duration,
354 pub execution_contexts: Vec<RuleExecutionContext>,
356 pub parallel_speedup: f64,
358}
359
360impl ParallelExecutionResult {
361 pub fn get_stats(&self) -> String {
363 format!(
364 "๐ Parallel Execution Stats:\n Rules evaluated: {}\n Rules fired: {}\n Execution time: {:?}\n Parallel speedup: {:.2}x",
365 self.total_rules_evaluated,
366 self.total_rules_fired,
367 self.execution_time,
368 self.parallel_speedup
369 )
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376 use crate::engine::rule::{Condition, ConditionGroup};
377 use crate::types::{Operator, Value};
378
379 #[test]
380 fn test_parallel_config_default() {
381 let config = ParallelConfig::default();
382 assert!(config.enabled);
383 assert!(config.max_threads > 0);
384 assert_eq!(config.min_rules_per_thread, 2);
385 }
386
387 #[test]
388 fn test_parallel_engine_creation() {
389 let config = ParallelConfig::default();
390 let engine = ParallelRuleEngine::new(config);
391 assert!(engine.custom_functions.read().unwrap().is_empty());
392 }
393
394 #[test]
395 fn test_salience_grouping() {
396 let config = ParallelConfig::default();
397 let engine = ParallelRuleEngine::new(config);
398
399 let rules = vec![
400 Rule::new(
401 "Rule1".to_string(),
402 ConditionGroup::Single(Condition::new(
403 "test".to_string(),
404 Operator::Equal,
405 Value::Boolean(true),
406 )),
407 vec![],
408 )
409 .with_priority(10),
410 Rule::new(
411 "Rule2".to_string(),
412 ConditionGroup::Single(Condition::new(
413 "test".to_string(),
414 Operator::Equal,
415 Value::Boolean(true),
416 )),
417 vec![],
418 )
419 .with_priority(10),
420 Rule::new(
421 "Rule3".to_string(),
422 ConditionGroup::Single(Condition::new(
423 "test".to_string(),
424 Operator::Equal,
425 Value::Boolean(true),
426 )),
427 vec![],
428 )
429 .with_priority(5),
430 ];
431
432 let groups = engine.group_rules_by_salience(&rules);
433 assert_eq!(groups.len(), 2);
434 assert_eq!(groups[&10].len(), 2);
435 assert_eq!(groups[&5].len(), 1);
436 }
437}