rust_logic_graph/parallel/
mod.rs1use anyhow::Result;
6use std::collections::{HashMap, HashSet};
7use tracing::{debug, info, warn};
8
9use crate::core::{Graph, GraphDef};
10use crate::node::Node;
11use crate::rule::Rule;
12
13#[derive(Debug, Clone)]
15pub struct ParallelConfig {
16 pub max_concurrent: usize,
18 pub verbose: bool,
20}
21
22impl Default for ParallelConfig {
23 fn default() -> Self {
24 Self {
25 max_concurrent: 10,
26 verbose: false,
27 }
28 }
29}
30
31#[derive(Debug, Clone)]
33pub struct ExecutionLayer {
34 pub layer_index: usize,
35 pub node_ids: Vec<String>,
36}
37
38pub struct ParallelExecutor {
40 nodes: HashMap<String, Box<dyn Node>>,
41 _config: ParallelConfig,
42}
43
44impl ParallelExecutor {
45 pub fn new(config: ParallelConfig) -> Self {
47 Self {
48 nodes: HashMap::new(),
49 _config: config,
50 }
51 }
52
53 pub fn register_node(&mut self, node: Box<dyn Node>) {
55 let id = node.id().to_string();
56 self.nodes.insert(id, node);
57 }
58
59 pub fn identify_layers(&self, def: &GraphDef) -> Result<Vec<ExecutionLayer>> {
62 info!("ParallelExecutor: Identifying execution layers");
63
64 let mut adj_list: HashMap<String, Vec<String>> = HashMap::new();
66 let mut in_degree: HashMap<String, usize> = HashMap::new();
67
68 for node_id in def.nodes.keys() {
70 in_degree.insert(node_id.clone(), 0);
71 adj_list.insert(node_id.clone(), Vec::new());
72 }
73
74 for edge in &def.edges {
76 adj_list
77 .entry(edge.from.clone())
78 .or_insert_with(Vec::new)
79 .push(edge.to.clone());
80
81 *in_degree.entry(edge.to.clone()).or_insert(0) += 1;
82 }
83
84 let mut layers = Vec::new();
86 let mut current_layer_index = 0;
87 let mut processed = HashSet::new();
88
89 loop {
90 let current_layer_nodes: Vec<String> = in_degree
92 .iter()
93 .filter(|(id, °ree)| degree == 0 && !processed.contains(*id))
94 .map(|(id, _)| id.clone())
95 .collect();
96
97 if current_layer_nodes.is_empty() {
98 break;
99 }
100
101 debug!(
102 "Layer {}: {} nodes can execute in parallel: {:?}",
103 current_layer_index,
104 current_layer_nodes.len(),
105 current_layer_nodes
106 );
107
108 layers.push(ExecutionLayer {
109 layer_index: current_layer_index,
110 node_ids: current_layer_nodes.clone(),
111 });
112
113 for node_id in ¤t_layer_nodes {
115 processed.insert(node_id.clone());
116
117 if let Some(neighbors) = adj_list.get(node_id) {
119 for neighbor in neighbors {
120 if let Some(degree) = in_degree.get_mut(neighbor) {
121 *degree = degree.saturating_sub(1);
122 }
123 }
124 }
125 }
126
127 current_layer_index += 1;
128 }
129
130 let unprocessed: Vec<_> = def
132 .nodes
133 .keys()
134 .filter(|id| !processed.contains(*id))
135 .collect();
136
137 if !unprocessed.is_empty() {
138 warn!(
139 "Some nodes could not be scheduled (possible cycle): {:?}",
140 unprocessed
141 );
142 }
143
144 info!(
145 "ParallelExecutor: Identified {} execution layers with total {} nodes",
146 layers.len(),
147 processed.len()
148 );
149
150 Ok(layers)
151 }
152
153 async fn execute_layer(
155 &self,
156 layer: &ExecutionLayer,
157 graph: &mut Graph,
158 ) -> Result<Vec<String>> {
159 info!(
160 "ParallelExecutor: Executing layer {} with {} nodes",
161 layer.layer_index,
162 layer.node_ids.len()
163 );
164
165 let mut successful_nodes = Vec::new();
166
167 for node_id in &layer.node_ids {
172 let should_execute = self.check_incoming_rules(node_id, graph);
174
175 if !should_execute {
176 info!("Skipping node '{}' due to failed rule", node_id);
177 continue;
178 }
179
180 if let Some(node) = self.nodes.get(node_id) {
181 info!("Executing node '{}'", node_id);
182
183 match node.run(&mut graph.context).await {
184 Ok(_) => {
185 info!("Node '{}' executed successfully", node_id);
186 successful_nodes.push(node_id.clone());
187 }
188 Err(e) => {
189 warn!("Node '{}' execution failed: {:?}", node_id, e);
190 }
191 }
192 } else {
193 warn!("Node '{}' not found in executor", node_id);
194 }
195 }
196
197 info!(
198 "Layer {} completed: {}/{} nodes successful",
199 layer.layer_index,
200 successful_nodes.len(),
201 layer.node_ids.len()
202 );
203
204 Ok(successful_nodes)
205 }
206
207 fn check_incoming_rules(&self, node_id: &str, graph: &Graph) -> bool {
209 let incoming_edges: Vec<_> = graph.def.edges.iter().filter(|e| e.to == node_id).collect();
210
211 for edge in &incoming_edges {
212 if let Some(rule_id) = &edge.rule {
213 let rule = Rule::new(rule_id, "true");
214
215 match rule.evaluate(&graph.context.data) {
216 Ok(result) => {
217 if let serde_json::Value::Bool(false) = result {
218 debug!(
219 "Rule '{}' for edge {} -> {} evaluated to false",
220 rule_id, edge.from, edge.to
221 );
222 return false;
223 }
224 }
225 Err(e) => {
226 warn!(
227 "Rule '{}' evaluation failed: {}. Assuming true.",
228 rule_id, e
229 );
230 }
231 }
232 }
233 }
234
235 true
236 }
237
238 pub async fn execute(&self, graph: &mut Graph) -> Result<()> {
240 info!("ParallelExecutor: Starting parallel graph execution");
241
242 let layers = self.identify_layers(&graph.def)?;
244
245 if layers.is_empty() {
246 warn!("No execution layers found");
247 return Ok(());
248 }
249
250 info!("ParallelExecutor: Executing {} layers", layers.len());
251
252 let mut total_executed = 0;
253
254 for layer in layers {
256 let successful_nodes = self.execute_layer(&layer, graph).await?;
257 total_executed += successful_nodes.len();
258 }
259
260 info!(
261 "ParallelExecutor: Completed parallel execution. Total nodes executed: {}",
262 total_executed
263 );
264
265 Ok(())
266 }
267
268 pub fn get_parallelism_stats(&self, def: &GraphDef) -> Result<ParallelismStats> {
270 let layers = self.identify_layers(def)?;
271
272 let total_nodes = def.nodes.len();
273 let max_parallel_nodes = layers
274 .iter()
275 .map(|layer| layer.node_ids.len())
276 .max()
277 .unwrap_or(0);
278
279 let sequential_time = total_nodes; let parallel_time = layers.len(); let speedup = if parallel_time > 0 {
283 sequential_time as f64 / parallel_time as f64
284 } else {
285 1.0
286 };
287
288 Ok(ParallelismStats {
289 total_nodes,
290 num_layers: layers.len(),
291 max_parallel_nodes,
292 avg_parallel_nodes: if !layers.is_empty() {
293 total_nodes as f64 / layers.len() as f64
294 } else {
295 0.0
296 },
297 theoretical_speedup: speedup,
298 layers,
299 })
300 }
301}
302
303impl Default for ParallelExecutor {
304 fn default() -> Self {
305 Self::new(ParallelConfig::default())
306 }
307}
308
309#[derive(Debug)]
311pub struct ParallelismStats {
312 pub total_nodes: usize,
313 pub num_layers: usize,
314 pub max_parallel_nodes: usize,
315 pub avg_parallel_nodes: f64,
316 pub theoretical_speedup: f64,
317 pub layers: Vec<ExecutionLayer>,
318}
319
320impl ParallelismStats {
321 pub fn print_summary(&self) {
322 println!("\n=== Parallelism Analysis ===");
323 println!("Total nodes: {}", self.total_nodes);
324 println!("Execution layers: {}", self.num_layers);
325 println!("Max parallel nodes: {}", self.max_parallel_nodes);
326 println!(
327 "Avg parallel nodes per layer: {:.2}",
328 self.avg_parallel_nodes
329 );
330 println!("Theoretical speedup: {:.2}x", self.theoretical_speedup);
331 println!("\nLayer breakdown:");
332 for layer in &self.layers {
333 println!(
334 " Layer {}: {} nodes - {:?}",
335 layer.layer_index,
336 layer.node_ids.len(),
337 layer.node_ids
338 );
339 }
340 println!("===========================\n");
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use crate::node::NodeType;
348 use std::collections::HashMap;
349
350 #[tokio::test]
351 async fn test_layer_identification() {
352 let mut nodes = HashMap::new();
360 nodes.insert("A".to_string(), NodeType::RuleNode);
361 nodes.insert("B".to_string(), NodeType::RuleNode);
362 nodes.insert("C".to_string(), NodeType::RuleNode);
363 nodes.insert("D".to_string(), NodeType::RuleNode);
364
365 let mut edges = Vec::new();
366 edges.push(crate::core::Edge {
367 from: "A".to_string(),
368 to: "B".to_string(),
369 rule: None,
370 });
371 edges.push(crate::core::Edge {
372 from: "A".to_string(),
373 to: "C".to_string(),
374 rule: None,
375 });
376 edges.push(crate::core::Edge {
377 from: "B".to_string(),
378 to: "D".to_string(),
379 rule: None,
380 });
381 edges.push(crate::core::Edge {
382 from: "C".to_string(),
383 to: "D".to_string(),
384 rule: None,
385 });
386
387 let def = GraphDef::from_node_types(nodes, edges);
388 let executor = ParallelExecutor::default();
389 let layers = executor.identify_layers(&def).unwrap();
390
391 assert_eq!(layers.len(), 3);
393 assert_eq!(layers[0].node_ids.len(), 1); assert_eq!(layers[1].node_ids.len(), 2); assert_eq!(layers[2].node_ids.len(), 1); }
397
398 #[tokio::test]
399 async fn test_parallelism_stats() {
400 let mut nodes = HashMap::new();
401 nodes.insert("A".to_string(), NodeType::RuleNode);
403 nodes.insert("B".to_string(), NodeType::RuleNode);
404 nodes.insert("C".to_string(), NodeType::RuleNode);
405 nodes.insert("D".to_string(), NodeType::RuleNode);
406
407 let mut edges = Vec::new();
408 edges.push(crate::core::Edge {
409 from: "A".to_string(),
410 to: "B".to_string(),
411 rule: None,
412 });
413 edges.push(crate::core::Edge {
414 from: "B".to_string(),
415 to: "C".to_string(),
416 rule: None,
417 });
418 edges.push(crate::core::Edge {
419 from: "C".to_string(),
420 to: "D".to_string(),
421 rule: None,
422 });
423
424 let def = GraphDef::from_node_types(nodes, edges);
425 let executor = ParallelExecutor::default();
426 let stats = executor.get_parallelism_stats(&def).unwrap();
427
428 assert_eq!(stats.total_nodes, 4);
429 assert_eq!(stats.num_layers, 4); assert_eq!(stats.max_parallel_nodes, 1); assert_eq!(stats.theoretical_speedup, 1.0); }
433
434 #[tokio::test]
435 async fn test_parallel_graph_stats() {
436 let mut nodes = HashMap::new();
437 nodes.insert("A".to_string(), NodeType::RuleNode);
439 nodes.insert("B".to_string(), NodeType::RuleNode);
440 nodes.insert("C".to_string(), NodeType::RuleNode);
441 nodes.insert("D".to_string(), NodeType::RuleNode);
442
443 let def = GraphDef::from_node_types(nodes, vec![]);
444 let executor = ParallelExecutor::default();
445 let stats = executor.get_parallelism_stats(&def).unwrap();
446
447 assert_eq!(stats.total_nodes, 4);
448 assert_eq!(stats.num_layers, 1); assert_eq!(stats.max_parallel_nodes, 4); assert_eq!(stats.theoretical_speedup, 4.0); }
452}