1use serde::{Deserialize, Serialize};
29use std::collections::{HashMap, HashSet, VecDeque};
30use thiserror::Error;
31
32#[derive(Error, Debug, Clone, PartialEq)]
34pub enum AutoParallelError {
35 #[error("Dependency cycle detected: {0}")]
36 DependencyCycle(String),
37
38 #[error("Invalid graph: {0}")]
39 InvalidGraph(String),
40
41 #[error("Cost model error: {0}")]
42 CostModelError(String),
43
44 #[error("Partitioning failed: {0}")]
45 PartitioningFailed(String),
46}
47
48pub type NodeId = String;
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
53pub enum ParallelizationStrategy {
54 Conservative,
56 Balanced,
58 Aggressive,
60 CostBased,
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
66pub enum CostModel {
67 Heuristic,
69 ProfileBased,
71 Analytical,
73 Hybrid,
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79pub enum DependencyType {
80 Data,
82 Control,
84 Memory,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct NodeInfo {
91 pub id: NodeId,
92 pub op_type: String,
93 pub estimated_cost: f64, pub memory_size: usize, pub dependencies: Vec<(NodeId, DependencyType)>,
96 pub can_parallelize: bool,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct ParallelStage {
102 pub stage_id: usize,
103 pub nodes: Vec<NodeId>,
104 pub estimated_time: f64,
105 pub memory_requirement: usize,
106 pub predecessors: Vec<usize>, }
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct WorkPartition {
112 pub worker_id: usize,
113 pub nodes: Vec<NodeId>,
114 pub estimated_load: f64,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct ParallelizationAnalysis {
120 pub num_stages: usize,
121 pub stages: Vec<ParallelStage>,
122 pub critical_path_length: f64,
123 pub total_work: f64,
124 pub parallelism_factor: f64, pub communication_overhead: f64,
126 pub recommended_workers: usize,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct ParallelExecutionPlan {
132 pub stages: Vec<ParallelStage>,
133 pub partitions: Vec<WorkPartition>,
134 pub estimated_speedup: f64,
135 pub load_balance_ratio: f64,
136}
137
138pub struct AutoParallelizer {
140 strategy: ParallelizationStrategy,
141 cost_model: CostModel,
142 max_workers: usize,
143 overhead_per_task: f64, communication_bandwidth: f64, profile_data: HashMap<String, f64>, }
147
148impl AutoParallelizer {
149 pub fn new() -> Self {
151 Self {
152 strategy: ParallelizationStrategy::Balanced,
153 cost_model: CostModel::Heuristic,
154 max_workers: num_cpus::get(),
155 overhead_per_task: 10.0, communication_bandwidth: 100.0, profile_data: HashMap::new(),
158 }
159 }
160
161 pub fn with_strategy(mut self, strategy: ParallelizationStrategy) -> Self {
163 self.strategy = strategy;
164 self
165 }
166
167 pub fn with_cost_model(mut self, model: CostModel) -> Self {
169 self.cost_model = model;
170 self
171 }
172
173 pub fn with_max_workers(mut self, workers: usize) -> Self {
175 self.max_workers = workers;
176 self
177 }
178
179 pub fn update_profile(&mut self, op_type: String, time_us: f64) {
181 let entry = self.profile_data.entry(op_type).or_insert(0.0);
182 *entry = 0.9 * *entry + 0.1 * time_us; }
184
185 pub fn analyze(
187 &self,
188 nodes: &[NodeInfo],
189 ) -> Result<ParallelizationAnalysis, AutoParallelError> {
190 let dep_graph = self.build_dependency_graph(nodes)?;
192
193 let stages = self.compute_stages(nodes, &dep_graph)?;
195
196 let critical_path_length = self.calculate_critical_path(&stages);
198
199 let total_work: f64 = nodes.iter().map(|n| n.estimated_cost).sum();
201
202 let communication_overhead = self.estimate_communication_overhead(&stages, nodes);
204
205 let parallelism_factor = if critical_path_length > 0.0 {
207 total_work / critical_path_length
208 } else {
209 1.0
210 };
211
212 let recommended_workers = self.recommend_worker_count(parallelism_factor);
214
215 Ok(ParallelizationAnalysis {
216 num_stages: stages.len(),
217 stages,
218 critical_path_length,
219 total_work,
220 parallelism_factor,
221 communication_overhead,
222 recommended_workers,
223 })
224 }
225
226 pub fn generate_plan(
228 &self,
229 nodes: &[NodeInfo],
230 ) -> Result<ParallelExecutionPlan, AutoParallelError> {
231 let analysis = self.analyze(nodes)?;
232
233 let partitions = self.partition_work(&analysis)?;
235
236 let sequential_time = analysis.total_work;
238 let parallel_time = analysis.critical_path_length + analysis.communication_overhead;
239 let estimated_speedup = if parallel_time > 0.0 {
240 sequential_time / parallel_time
241 } else {
242 1.0
243 };
244
245 let load_balance_ratio = self.calculate_load_balance(&partitions);
247
248 Ok(ParallelExecutionPlan {
249 stages: analysis.stages,
250 partitions,
251 estimated_speedup,
252 load_balance_ratio,
253 })
254 }
255
256 fn build_dependency_graph(
258 &self,
259 nodes: &[NodeInfo],
260 ) -> Result<HashMap<NodeId, HashSet<NodeId>>, AutoParallelError> {
261 let mut graph: HashMap<NodeId, HashSet<NodeId>> = HashMap::new();
262
263 for node in nodes {
265 graph.entry(node.id.clone()).or_insert_with(HashSet::new);
266 }
267
268 for node in nodes {
270 for (dep_id, _dep_type) in &node.dependencies {
271 if !graph.contains_key(dep_id) {
272 return Err(AutoParallelError::InvalidGraph(format!(
273 "Unknown dependency: {}",
274 dep_id
275 )));
276 }
277 graph
278 .entry(node.id.clone())
279 .or_insert_with(HashSet::new)
280 .insert(dep_id.clone());
281 }
282 }
283
284 self.check_cycles(&graph)?;
286
287 Ok(graph)
288 }
289
290 fn check_cycles(
292 &self,
293 graph: &HashMap<NodeId, HashSet<NodeId>>,
294 ) -> Result<(), AutoParallelError> {
295 let mut visited = HashSet::new();
296 let mut rec_stack = HashSet::new();
297
298 for node in graph.keys() {
299 if !visited.contains(node) {
300 if self.has_cycle_util(node, graph, &mut visited, &mut rec_stack)? {
301 return Err(AutoParallelError::DependencyCycle(format!(
302 "Cycle detected involving node: {}",
303 node
304 )));
305 }
306 }
307 }
308
309 Ok(())
310 }
311
312 fn has_cycle_util(
313 &self,
314 node: &NodeId,
315 graph: &HashMap<NodeId, HashSet<NodeId>>,
316 visited: &mut HashSet<NodeId>,
317 rec_stack: &mut HashSet<NodeId>,
318 ) -> Result<bool, AutoParallelError> {
319 visited.insert(node.clone());
320 rec_stack.insert(node.clone());
321
322 if let Some(neighbors) = graph.get(node) {
323 for neighbor in neighbors {
324 if !visited.contains(neighbor) {
325 if self.has_cycle_util(neighbor, graph, visited, rec_stack)? {
326 return Ok(true);
327 }
328 } else if rec_stack.contains(neighbor) {
329 return Ok(true);
330 }
331 }
332 }
333
334 rec_stack.remove(node);
335 Ok(false)
336 }
337
338 fn compute_stages(
340 &self,
341 nodes: &[NodeInfo],
342 dep_graph: &HashMap<NodeId, HashSet<NodeId>>,
343 ) -> Result<Vec<ParallelStage>, AutoParallelError> {
344 let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
345 let mut node_map: HashMap<NodeId, &NodeInfo> = HashMap::new();
346
347 for node in nodes {
349 node_map.insert(node.id.clone(), node);
350 let deps = dep_graph
351 .get(&node.id)
352 .expect("dep_graph built from same nodes");
353 in_degree.insert(node.id.clone(), deps.len());
354 }
355
356 let mut stages = Vec::new();
357 let mut current_level: VecDeque<NodeId> = VecDeque::new();
358
359 for (node_id, °ree) in &in_degree {
361 if degree == 0 {
362 current_level.push_back(node_id.clone());
363 }
364 }
365
366 let mut stage_id = 0;
367 while !current_level.is_empty() {
368 let mut stage_nodes = Vec::new();
369 let mut estimated_time: f64 = 0.0;
370 let mut memory_requirement = 0;
371
372 for _ in 0..current_level.len() {
374 if let Some(node_id) = current_level.pop_front() {
375 let node = node_map[&node_id];
376 stage_nodes.push(node_id.clone());
377 estimated_time = estimated_time.max(node.estimated_cost);
378 memory_requirement += node.memory_size;
379
380 for other_id in node_map.keys() {
382 if dep_graph[other_id].contains(&node_id) {
383 if let Some(degree) = in_degree.get_mut(other_id) {
384 *degree -= 1;
385 if *degree == 0 {
386 current_level.push_back(other_id.clone());
387 }
388 }
389 }
390 }
391 }
392 }
393
394 if !stage_nodes.is_empty() {
395 stages.push(ParallelStage {
396 stage_id,
397 nodes: stage_nodes,
398 estimated_time,
399 memory_requirement,
400 predecessors: if stage_id > 0 {
401 vec![stage_id - 1]
402 } else {
403 vec![]
404 },
405 });
406 stage_id += 1;
407 }
408 }
409
410 if stages.iter().map(|s| s.nodes.len()).sum::<usize>() != nodes.len() {
412 return Err(AutoParallelError::DependencyCycle(
413 "Not all nodes were processed - cycle detected".to_string(),
414 ));
415 }
416
417 Ok(stages)
418 }
419
420 fn calculate_critical_path(&self, stages: &[ParallelStage]) -> f64 {
422 stages.iter().map(|s| s.estimated_time).sum()
423 }
424
425 fn estimate_communication_overhead(
427 &self,
428 stages: &[ParallelStage],
429 _nodes: &[NodeInfo],
430 ) -> f64 {
431 let mut overhead = 0.0;
432
433 for stage in stages {
435 if stage.nodes.len() > 1 {
436 overhead += self.overhead_per_task * stage.nodes.len() as f64;
438
439 let transfer_time =
441 stage.memory_requirement as f64 / (self.communication_bandwidth * 1e9) * 1e6;
442 overhead += transfer_time;
443 }
444 }
445
446 overhead
447 }
448
449 fn recommend_worker_count(&self, parallelism_factor: f64) -> usize {
451 let ideal = parallelism_factor.ceil() as usize;
452
453 match self.strategy {
454 ParallelizationStrategy::Conservative => ideal.min(self.max_workers / 2).max(1),
455 ParallelizationStrategy::Balanced => ideal.min(self.max_workers),
456 ParallelizationStrategy::Aggressive => self.max_workers,
457 ParallelizationStrategy::CostBased => {
458 if parallelism_factor > 2.0 {
460 ideal.min(self.max_workers)
461 } else {
462 (ideal / 2).max(1)
463 }
464 }
465 }
466 }
467
468 fn partition_work(
470 &self,
471 analysis: &ParallelizationAnalysis,
472 ) -> Result<Vec<WorkPartition>, AutoParallelError> {
473 let num_workers = analysis.recommended_workers;
474 let mut partitions: Vec<WorkPartition> = (0..num_workers)
475 .map(|i| WorkPartition {
476 worker_id: i,
477 nodes: Vec::new(),
478 estimated_load: 0.0,
479 })
480 .collect();
481
482 for stage in &analysis.stages {
484 let mut stage_nodes: Vec<(NodeId, f64)> = stage
486 .nodes
487 .iter()
488 .map(|id| (id.clone(), 1.0)) .collect();
490 stage_nodes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
491
492 for (node_id, cost) in stage_nodes {
494 let min_partition = partitions
495 .iter_mut()
496 .min_by(|a, b| {
497 a.estimated_load
498 .partial_cmp(&b.estimated_load)
499 .unwrap_or(std::cmp::Ordering::Equal)
500 })
501 .ok_or_else(|| {
502 AutoParallelError::PartitioningFailed("No partitions available".to_string())
503 })?;
504
505 min_partition.nodes.push(node_id);
506 min_partition.estimated_load += cost;
507 }
508 }
509
510 Ok(partitions)
511 }
512
513 fn calculate_load_balance(&self, partitions: &[WorkPartition]) -> f64 {
515 if partitions.is_empty() {
516 return 1.0;
517 }
518
519 let loads: Vec<f64> = partitions.iter().map(|p| p.estimated_load).collect();
520 let max_load = loads.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
521 let avg_load = loads.iter().sum::<f64>() / loads.len() as f64;
522
523 if max_load > 0.0 {
524 avg_load / max_load
525 } else {
526 1.0
527 }
528 }
529}
530
531impl Default for AutoParallelizer {
532 fn default() -> Self {
533 Self::new()
534 }
535}
536
537#[cfg(test)]
538mod tests {
539 use super::*;
540
541 fn create_test_nodes() -> Vec<NodeInfo> {
542 vec![
543 NodeInfo {
544 id: "a".to_string(),
545 op_type: "input".to_string(),
546 estimated_cost: 10.0,
547 memory_size: 1000,
548 dependencies: vec![],
549 can_parallelize: true,
550 },
551 NodeInfo {
552 id: "b".to_string(),
553 op_type: "compute".to_string(),
554 estimated_cost: 20.0,
555 memory_size: 2000,
556 dependencies: vec![("a".to_string(), DependencyType::Data)],
557 can_parallelize: true,
558 },
559 NodeInfo {
560 id: "c".to_string(),
561 op_type: "compute".to_string(),
562 estimated_cost: 15.0,
563 memory_size: 1500,
564 dependencies: vec![("a".to_string(), DependencyType::Data)],
565 can_parallelize: true,
566 },
567 NodeInfo {
568 id: "d".to_string(),
569 op_type: "output".to_string(),
570 estimated_cost: 10.0,
571 memory_size: 1000,
572 dependencies: vec![
573 ("b".to_string(), DependencyType::Data),
574 ("c".to_string(), DependencyType::Data),
575 ],
576 can_parallelize: false,
577 },
578 ]
579 }
580
581 #[test]
582 fn test_auto_parallelizer_creation() {
583 let parallelizer = AutoParallelizer::new();
584 assert_eq!(parallelizer.strategy, ParallelizationStrategy::Balanced);
585 assert_eq!(parallelizer.cost_model, CostModel::Heuristic);
586 }
587
588 #[test]
589 fn test_builder_pattern() {
590 let parallelizer = AutoParallelizer::new()
591 .with_strategy(ParallelizationStrategy::Aggressive)
592 .with_cost_model(CostModel::ProfileBased)
593 .with_max_workers(8);
594
595 assert_eq!(parallelizer.strategy, ParallelizationStrategy::Aggressive);
596 assert_eq!(parallelizer.cost_model, CostModel::ProfileBased);
597 assert_eq!(parallelizer.max_workers, 8);
598 }
599
600 #[test]
601 fn test_dependency_graph_building() {
602 let parallelizer = AutoParallelizer::new();
603 let nodes = create_test_nodes();
604
605 let graph = parallelizer.build_dependency_graph(&nodes).expect("unwrap");
606
607 assert_eq!(graph.len(), 4);
608 assert!(graph["b"].contains("a"));
609 assert!(graph["c"].contains("a"));
610 assert!(graph["d"].contains("b"));
611 assert!(graph["d"].contains("c"));
612 }
613
614 #[test]
615 fn test_cycle_detection() {
616 let parallelizer = AutoParallelizer::new();
617
618 let nodes = vec![
620 NodeInfo {
621 id: "a".to_string(),
622 op_type: "compute".to_string(),
623 estimated_cost: 10.0,
624 memory_size: 1000,
625 dependencies: vec![("b".to_string(), DependencyType::Data)],
626 can_parallelize: true,
627 },
628 NodeInfo {
629 id: "b".to_string(),
630 op_type: "compute".to_string(),
631 estimated_cost: 10.0,
632 memory_size: 1000,
633 dependencies: vec![("a".to_string(), DependencyType::Data)],
634 can_parallelize: true,
635 },
636 ];
637
638 let result = parallelizer.build_dependency_graph(&nodes);
639 assert!(result.is_err());
640 }
641
642 #[test]
643 fn test_stage_computation() {
644 let parallelizer = AutoParallelizer::new();
645 let nodes = create_test_nodes();
646
647 let analysis = parallelizer.analyze(&nodes).expect("unwrap");
648
649 assert_eq!(analysis.num_stages, 3);
650 assert_eq!(analysis.stages[0].nodes, vec!["a"]);
651 assert_eq!(analysis.stages[1].nodes.len(), 2); assert!(analysis.stages[1].nodes.contains(&"b".to_string()));
653 assert!(analysis.stages[1].nodes.contains(&"c".to_string()));
654 assert_eq!(analysis.stages[2].nodes, vec!["d"]);
655 }
656
657 #[test]
658 fn test_critical_path_calculation() {
659 let parallelizer = AutoParallelizer::new();
660 let nodes = create_test_nodes();
661
662 let analysis = parallelizer.analyze(&nodes).expect("unwrap");
663
664 assert_eq!(analysis.critical_path_length, 40.0);
666 }
667
668 #[test]
669 fn test_parallelism_factor() {
670 let parallelizer = AutoParallelizer::new();
671 let nodes = create_test_nodes();
672
673 let analysis = parallelizer.analyze(&nodes).expect("unwrap");
674
675 assert!((analysis.parallelism_factor - 1.375).abs() < 0.01);
679 }
680
681 #[test]
682 fn test_execution_plan_generation() {
683 let parallelizer = AutoParallelizer::new();
684 let nodes = create_test_nodes();
685
686 let plan = parallelizer.generate_plan(&nodes).expect("unwrap");
687
688 assert_eq!(plan.stages.len(), 3);
689 assert!(!plan.partitions.is_empty());
690 assert!(plan.estimated_speedup > 0.0);
692 assert!(plan.load_balance_ratio > 0.0 && plan.load_balance_ratio <= 1.0);
693 }
694
695 #[test]
696 fn test_profile_update() {
697 let mut parallelizer = AutoParallelizer::new();
698
699 parallelizer.update_profile("compute".to_string(), 100.0);
700 parallelizer.update_profile("compute".to_string(), 200.0);
701
702 assert!(parallelizer.profile_data.contains_key("compute"));
703 let avg = parallelizer.profile_data["compute"];
704 assert!(avg >= 0.0);
707 }
708
709 #[test]
710 fn test_strategy_variations() {
711 let nodes = create_test_nodes();
712
713 let conservative = AutoParallelizer::new()
714 .with_strategy(ParallelizationStrategy::Conservative)
715 .analyze(&nodes)
716 .expect("unwrap");
717
718 let aggressive = AutoParallelizer::new()
719 .with_strategy(ParallelizationStrategy::Aggressive)
720 .analyze(&nodes)
721 .expect("unwrap");
722
723 assert!(aggressive.recommended_workers >= conservative.recommended_workers);
725 }
726
727 #[test]
728 fn test_sequential_graph() {
729 let parallelizer = AutoParallelizer::new();
730
731 let nodes = vec![
733 NodeInfo {
734 id: "a".to_string(),
735 op_type: "compute".to_string(),
736 estimated_cost: 10.0,
737 memory_size: 1000,
738 dependencies: vec![],
739 can_parallelize: true,
740 },
741 NodeInfo {
742 id: "b".to_string(),
743 op_type: "compute".to_string(),
744 estimated_cost: 10.0,
745 memory_size: 1000,
746 dependencies: vec![("a".to_string(), DependencyType::Data)],
747 can_parallelize: true,
748 },
749 ];
750
751 let analysis = parallelizer.analyze(&nodes).expect("unwrap");
752
753 assert_eq!(analysis.num_stages, 2);
754 assert_eq!(analysis.parallelism_factor, 1.0); }
756
757 #[test]
758 fn test_fully_parallel_graph() {
759 let parallelizer = AutoParallelizer::new();
760
761 let nodes = vec![
763 NodeInfo {
764 id: "a".to_string(),
765 op_type: "compute".to_string(),
766 estimated_cost: 10.0,
767 memory_size: 1000,
768 dependencies: vec![],
769 can_parallelize: true,
770 },
771 NodeInfo {
772 id: "b".to_string(),
773 op_type: "compute".to_string(),
774 estimated_cost: 10.0,
775 memory_size: 1000,
776 dependencies: vec![],
777 can_parallelize: true,
778 },
779 NodeInfo {
780 id: "c".to_string(),
781 op_type: "compute".to_string(),
782 estimated_cost: 10.0,
783 memory_size: 1000,
784 dependencies: vec![],
785 can_parallelize: true,
786 },
787 ];
788
789 let analysis = parallelizer.analyze(&nodes).expect("unwrap");
790
791 assert_eq!(analysis.num_stages, 1);
792 assert_eq!(analysis.parallelism_factor, 3.0); }
794
795 #[test]
796 fn test_load_balancing() {
797 let parallelizer = AutoParallelizer::new().with_max_workers(2);
798 let nodes = create_test_nodes();
799
800 let plan = parallelizer.generate_plan(&nodes).expect("unwrap");
801
802 assert!(plan.partitions.len() > 0);
804 assert!(plan.load_balance_ratio > 0.0 && plan.load_balance_ratio <= 1.0);
805 }
806
807 #[test]
808 fn test_invalid_graph() {
809 let parallelizer = AutoParallelizer::new();
810
811 let nodes = vec![NodeInfo {
813 id: "a".to_string(),
814 op_type: "compute".to_string(),
815 estimated_cost: 10.0,
816 memory_size: 1000,
817 dependencies: vec![("unknown".to_string(), DependencyType::Data)],
818 can_parallelize: true,
819 }];
820
821 let result = parallelizer.build_dependency_graph(&nodes);
822 assert!(result.is_err());
823 }
824}