rust_logic_graph/parallel/
mod.rs1use std::collections::{HashMap, HashSet};
6use anyhow::Result;
7use tracing::{info, debug, warn};
8
9use crate::core::{Graph, GraphDef};
10use crate::node::Node;
11use crate::rule::Rule;
12
13#[derive(Debug, Clone)]
15pub struct ParallelConfig {
16 pub max_concurrent: usize,
18 pub verbose: bool,
20}
21
22impl Default for ParallelConfig {
23 fn default() -> Self {
24 Self {
25 max_concurrent: 10,
26 verbose: false,
27 }
28 }
29}
30
31#[derive(Debug, Clone)]
33pub struct ExecutionLayer {
34 pub layer_index: usize,
35 pub node_ids: Vec<String>,
36}
37
38pub struct ParallelExecutor {
40 nodes: HashMap<String, Box<dyn Node>>,
41 _config: ParallelConfig,
42}
43
44impl ParallelExecutor {
45 pub fn new(config: ParallelConfig) -> Self {
47 Self {
48 nodes: HashMap::new(),
49 _config: config,
50 }
51 }
52
53 pub fn register_node(&mut self, node: Box<dyn Node>) {
55 let id = node.id().to_string();
56 self.nodes.insert(id, node);
57 }
58
59 pub fn identify_layers(&self, def: &GraphDef) -> Result<Vec<ExecutionLayer>> {
62 info!("ParallelExecutor: Identifying execution layers");
63
64 let mut adj_list: HashMap<String, Vec<String>> = HashMap::new();
66 let mut in_degree: HashMap<String, usize> = HashMap::new();
67
68 for node_id in def.nodes.keys() {
70 in_degree.insert(node_id.clone(), 0);
71 adj_list.insert(node_id.clone(), Vec::new());
72 }
73
74 for edge in &def.edges {
76 adj_list
77 .entry(edge.from.clone())
78 .or_insert_with(Vec::new)
79 .push(edge.to.clone());
80
81 *in_degree.entry(edge.to.clone()).or_insert(0) += 1;
82 }
83
84 let mut layers = Vec::new();
86 let mut current_layer_index = 0;
87 let mut processed = HashSet::new();
88
89 loop {
90 let current_layer_nodes: Vec<String> = in_degree
92 .iter()
93 .filter(|(id, °ree)| degree == 0 && !processed.contains(*id))
94 .map(|(id, _)| id.clone())
95 .collect();
96
97 if current_layer_nodes.is_empty() {
98 break;
99 }
100
101 debug!(
102 "Layer {}: {} nodes can execute in parallel: {:?}",
103 current_layer_index,
104 current_layer_nodes.len(),
105 current_layer_nodes
106 );
107
108 layers.push(ExecutionLayer {
109 layer_index: current_layer_index,
110 node_ids: current_layer_nodes.clone(),
111 });
112
113 for node_id in ¤t_layer_nodes {
115 processed.insert(node_id.clone());
116
117 if let Some(neighbors) = adj_list.get(node_id) {
119 for neighbor in neighbors {
120 if let Some(degree) = in_degree.get_mut(neighbor) {
121 *degree = degree.saturating_sub(1);
122 }
123 }
124 }
125 }
126
127 current_layer_index += 1;
128 }
129
130 let unprocessed: Vec<_> = def
132 .nodes
133 .keys()
134 .filter(|id| !processed.contains(*id))
135 .collect();
136
137 if !unprocessed.is_empty() {
138 warn!(
139 "Some nodes could not be scheduled (possible cycle): {:?}",
140 unprocessed
141 );
142 }
143
144 info!(
145 "ParallelExecutor: Identified {} execution layers with total {} nodes",
146 layers.len(),
147 processed.len()
148 );
149
150 Ok(layers)
151 }
152
153 async fn execute_layer(
155 &self,
156 layer: &ExecutionLayer,
157 graph: &mut Graph,
158 ) -> Result<Vec<String>> {
159 info!(
160 "ParallelExecutor: Executing layer {} with {} nodes",
161 layer.layer_index,
162 layer.node_ids.len()
163 );
164
165 let mut successful_nodes = Vec::new();
166
167 for node_id in &layer.node_ids {
172 let should_execute = self.check_incoming_rules(node_id, graph);
174
175 if !should_execute {
176 info!("Skipping node '{}' due to failed rule", node_id);
177 continue;
178 }
179
180 if let Some(node) = self.nodes.get(node_id) {
181 info!("Executing node '{}'", node_id);
182
183 match node.run(&mut graph.context).await {
184 Ok(_) => {
185 info!("Node '{}' executed successfully", node_id);
186 successful_nodes.push(node_id.clone());
187 }
188 Err(e) => {
189 warn!("Node '{}' execution failed: {:?}", node_id, e);
190 }
191 }
192 } else {
193 warn!("Node '{}' not found in executor", node_id);
194 }
195 }
196
197 info!(
198 "Layer {} completed: {}/{} nodes successful",
199 layer.layer_index,
200 successful_nodes.len(),
201 layer.node_ids.len()
202 );
203
204 Ok(successful_nodes)
205 }
206
207 fn check_incoming_rules(&self, node_id: &str, graph: &Graph) -> bool {
209 let incoming_edges: Vec<_> = graph
210 .def
211 .edges
212 .iter()
213 .filter(|e| e.to == node_id)
214 .collect();
215
216 for edge in &incoming_edges {
217 if let Some(rule_id) = &edge.rule {
218 let rule = Rule::new(rule_id, "true");
219
220 match rule.evaluate(&graph.context.data) {
221 Ok(result) => {
222 if let serde_json::Value::Bool(false) = result {
223 debug!(
224 "Rule '{}' for edge {} -> {} evaluated to false",
225 rule_id, edge.from, edge.to
226 );
227 return false;
228 }
229 }
230 Err(e) => {
231 warn!(
232 "Rule '{}' evaluation failed: {}. Assuming true.",
233 rule_id, e
234 );
235 }
236 }
237 }
238 }
239
240 true
241 }
242
243 pub async fn execute(&self, graph: &mut Graph) -> Result<()> {
245 info!("ParallelExecutor: Starting parallel graph execution");
246
247 let layers = self.identify_layers(&graph.def)?;
249
250 if layers.is_empty() {
251 warn!("No execution layers found");
252 return Ok(());
253 }
254
255 info!(
256 "ParallelExecutor: Executing {} layers",
257 layers.len()
258 );
259
260 let mut total_executed = 0;
261
262 for layer in layers {
264 let successful_nodes = self.execute_layer(&layer, graph).await?;
265 total_executed += successful_nodes.len();
266 }
267
268 info!(
269 "ParallelExecutor: Completed parallel execution. Total nodes executed: {}",
270 total_executed
271 );
272
273 Ok(())
274 }
275
276 pub fn get_parallelism_stats(&self, def: &GraphDef) -> Result<ParallelismStats> {
278 let layers = self.identify_layers(def)?;
279
280 let total_nodes = def.nodes.len();
281 let max_parallel_nodes = layers
282 .iter()
283 .map(|layer| layer.node_ids.len())
284 .max()
285 .unwrap_or(0);
286
287 let sequential_time = total_nodes; let parallel_time = layers.len(); let speedup = if parallel_time > 0 {
291 sequential_time as f64 / parallel_time as f64
292 } else {
293 1.0
294 };
295
296 Ok(ParallelismStats {
297 total_nodes,
298 num_layers: layers.len(),
299 max_parallel_nodes,
300 avg_parallel_nodes: if !layers.is_empty() {
301 total_nodes as f64 / layers.len() as f64
302 } else {
303 0.0
304 },
305 theoretical_speedup: speedup,
306 layers,
307 })
308 }
309}
310
311impl Default for ParallelExecutor {
312 fn default() -> Self {
313 Self::new(ParallelConfig::default())
314 }
315}
316
317#[derive(Debug)]
319pub struct ParallelismStats {
320 pub total_nodes: usize,
321 pub num_layers: usize,
322 pub max_parallel_nodes: usize,
323 pub avg_parallel_nodes: f64,
324 pub theoretical_speedup: f64,
325 pub layers: Vec<ExecutionLayer>,
326}
327
328impl ParallelismStats {
329 pub fn print_summary(&self) {
330 println!("\n=== Parallelism Analysis ===");
331 println!("Total nodes: {}", self.total_nodes);
332 println!("Execution layers: {}", self.num_layers);
333 println!("Max parallel nodes: {}", self.max_parallel_nodes);
334 println!("Avg parallel nodes per layer: {:.2}", self.avg_parallel_nodes);
335 println!("Theoretical speedup: {:.2}x", self.theoretical_speedup);
336 println!("\nLayer breakdown:");
337 for layer in &self.layers {
338 println!(
339 " Layer {}: {} nodes - {:?}",
340 layer.layer_index,
341 layer.node_ids.len(),
342 layer.node_ids
343 );
344 }
345 println!("===========================\n");
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352 use crate::node::NodeType;
353 use std::collections::HashMap;
354
355 #[tokio::test]
356 async fn test_layer_identification() {
357 let mut def = GraphDef {
365 nodes: HashMap::new(),
366 edges: Vec::new(),
367 };
368
369 def.nodes.insert("A".to_string(), NodeType::RuleNode);
370 def.nodes.insert("B".to_string(), NodeType::RuleNode);
371 def.nodes.insert("C".to_string(), NodeType::RuleNode);
372 def.nodes.insert("D".to_string(), NodeType::RuleNode);
373
374 def.edges.push(crate::core::Edge {
375 from: "A".to_string(),
376 to: "B".to_string(),
377 rule: None,
378 });
379 def.edges.push(crate::core::Edge {
380 from: "A".to_string(),
381 to: "C".to_string(),
382 rule: None,
383 });
384 def.edges.push(crate::core::Edge {
385 from: "B".to_string(),
386 to: "D".to_string(),
387 rule: None,
388 });
389 def.edges.push(crate::core::Edge {
390 from: "C".to_string(),
391 to: "D".to_string(),
392 rule: None,
393 });
394
395 let executor = ParallelExecutor::default();
396 let layers = executor.identify_layers(&def).unwrap();
397
398 assert_eq!(layers.len(), 3);
400 assert_eq!(layers[0].node_ids.len(), 1); assert_eq!(layers[1].node_ids.len(), 2); assert_eq!(layers[2].node_ids.len(), 1); }
404
405 #[tokio::test]
406 async fn test_parallelism_stats() {
407 let mut def = GraphDef {
408 nodes: HashMap::new(),
409 edges: Vec::new(),
410 };
411
412 def.nodes.insert("A".to_string(), NodeType::RuleNode);
414 def.nodes.insert("B".to_string(), NodeType::RuleNode);
415 def.nodes.insert("C".to_string(), NodeType::RuleNode);
416 def.nodes.insert("D".to_string(), NodeType::RuleNode);
417
418 def.edges.push(crate::core::Edge {
419 from: "A".to_string(),
420 to: "B".to_string(),
421 rule: None,
422 });
423 def.edges.push(crate::core::Edge {
424 from: "B".to_string(),
425 to: "C".to_string(),
426 rule: None,
427 });
428 def.edges.push(crate::core::Edge {
429 from: "C".to_string(),
430 to: "D".to_string(),
431 rule: None,
432 });
433
434 let executor = ParallelExecutor::default();
435 let stats = executor.get_parallelism_stats(&def).unwrap();
436
437 assert_eq!(stats.total_nodes, 4);
438 assert_eq!(stats.num_layers, 4); assert_eq!(stats.max_parallel_nodes, 1); assert_eq!(stats.theoretical_speedup, 1.0); }
442
443 #[tokio::test]
444 async fn test_parallel_graph_stats() {
445 let mut def = GraphDef {
446 nodes: HashMap::new(),
447 edges: Vec::new(),
448 };
449
450 def.nodes.insert("A".to_string(), NodeType::RuleNode);
452 def.nodes.insert("B".to_string(), NodeType::RuleNode);
453 def.nodes.insert("C".to_string(), NodeType::RuleNode);
454 def.nodes.insert("D".to_string(), NodeType::RuleNode);
455
456 let executor = ParallelExecutor::default();
457 let stats = executor.get_parallelism_stats(&def).unwrap();
458
459 assert_eq!(stats.total_nodes, 4);
460 assert_eq!(stats.num_layers, 1); assert_eq!(stats.max_parallel_nodes, 4); assert_eq!(stats.theoretical_speedup, 4.0); }
464}