1use serde::{Deserialize, Serialize};
29use std::collections::{HashMap, HashSet, VecDeque};
30use thiserror::Error;
31
32#[derive(Error, Debug, Clone, PartialEq)]
34pub enum AutoParallelError {
35 #[error("Dependency cycle detected: {0}")]
36 DependencyCycle(String),
37
38 #[error("Invalid graph: {0}")]
39 InvalidGraph(String),
40
41 #[error("Cost model error: {0}")]
42 CostModelError(String),
43
44 #[error("Partitioning failed: {0}")]
45 PartitioningFailed(String),
46}
47
48pub type NodeId = String;
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
53pub enum ParallelizationStrategy {
54 Conservative,
56 Balanced,
58 Aggressive,
60 CostBased,
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
66pub enum CostModel {
67 Heuristic,
69 ProfileBased,
71 Analytical,
73 Hybrid,
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79pub enum DependencyType {
80 Data,
82 Control,
84 Memory,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct NodeInfo {
91 pub id: NodeId,
92 pub op_type: String,
93 pub estimated_cost: f64, pub memory_size: usize, pub dependencies: Vec<(NodeId, DependencyType)>,
96 pub can_parallelize: bool,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct ParallelStage {
102 pub stage_id: usize,
103 pub nodes: Vec<NodeId>,
104 pub estimated_time: f64,
105 pub memory_requirement: usize,
106 pub predecessors: Vec<usize>, }
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct WorkPartition {
112 pub worker_id: usize,
113 pub nodes: Vec<NodeId>,
114 pub estimated_load: f64,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct ParallelizationAnalysis {
120 pub num_stages: usize,
121 pub stages: Vec<ParallelStage>,
122 pub critical_path_length: f64,
123 pub total_work: f64,
124 pub parallelism_factor: f64, pub communication_overhead: f64,
126 pub recommended_workers: usize,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct ParallelExecutionPlan {
132 pub stages: Vec<ParallelStage>,
133 pub partitions: Vec<WorkPartition>,
134 pub estimated_speedup: f64,
135 pub load_balance_ratio: f64,
136}
137
138pub struct AutoParallelizer {
140 strategy: ParallelizationStrategy,
141 cost_model: CostModel,
142 max_workers: usize,
143 overhead_per_task: f64, communication_bandwidth: f64, profile_data: HashMap<String, f64>, }
147
148impl AutoParallelizer {
149 pub fn new() -> Self {
151 Self {
152 strategy: ParallelizationStrategy::Balanced,
153 cost_model: CostModel::Heuristic,
154 max_workers: num_cpus::get(),
155 overhead_per_task: 10.0, communication_bandwidth: 100.0, profile_data: HashMap::new(),
158 }
159 }
160
161 pub fn with_strategy(mut self, strategy: ParallelizationStrategy) -> Self {
163 self.strategy = strategy;
164 self
165 }
166
167 pub fn with_cost_model(mut self, model: CostModel) -> Self {
169 self.cost_model = model;
170 self
171 }
172
173 pub fn with_max_workers(mut self, workers: usize) -> Self {
175 self.max_workers = workers;
176 self
177 }
178
179 pub fn update_profile(&mut self, op_type: String, time_us: f64) {
181 let entry = self.profile_data.entry(op_type).or_insert(0.0);
182 *entry = 0.9 * *entry + 0.1 * time_us; }
184
185 pub fn analyze(
187 &self,
188 nodes: &[NodeInfo],
189 ) -> Result<ParallelizationAnalysis, AutoParallelError> {
190 let dep_graph = self.build_dependency_graph(nodes)?;
192
193 let stages = self.compute_stages(nodes, &dep_graph)?;
195
196 let critical_path_length = self.calculate_critical_path(&stages);
198
199 let total_work: f64 = nodes.iter().map(|n| n.estimated_cost).sum();
201
202 let communication_overhead = self.estimate_communication_overhead(&stages, nodes);
204
205 let parallelism_factor = if critical_path_length > 0.0 {
207 total_work / critical_path_length
208 } else {
209 1.0
210 };
211
212 let recommended_workers = self.recommend_worker_count(parallelism_factor);
214
215 Ok(ParallelizationAnalysis {
216 num_stages: stages.len(),
217 stages,
218 critical_path_length,
219 total_work,
220 parallelism_factor,
221 communication_overhead,
222 recommended_workers,
223 })
224 }
225
226 pub fn generate_plan(
228 &self,
229 nodes: &[NodeInfo],
230 ) -> Result<ParallelExecutionPlan, AutoParallelError> {
231 let analysis = self.analyze(nodes)?;
232
233 let partitions = self.partition_work(&analysis)?;
235
236 let sequential_time = analysis.total_work;
238 let parallel_time = analysis.critical_path_length + analysis.communication_overhead;
239 let estimated_speedup = if parallel_time > 0.0 {
240 sequential_time / parallel_time
241 } else {
242 1.0
243 };
244
245 let load_balance_ratio = self.calculate_load_balance(&partitions);
247
248 Ok(ParallelExecutionPlan {
249 stages: analysis.stages,
250 partitions,
251 estimated_speedup,
252 load_balance_ratio,
253 })
254 }
255
256 fn build_dependency_graph(
258 &self,
259 nodes: &[NodeInfo],
260 ) -> Result<HashMap<NodeId, HashSet<NodeId>>, AutoParallelError> {
261 let mut graph: HashMap<NodeId, HashSet<NodeId>> = HashMap::new();
262
263 for node in nodes {
265 graph.entry(node.id.clone()).or_insert_with(HashSet::new);
266 }
267
268 for node in nodes {
270 for (dep_id, _dep_type) in &node.dependencies {
271 if !graph.contains_key(dep_id) {
272 return Err(AutoParallelError::InvalidGraph(format!(
273 "Unknown dependency: {}",
274 dep_id
275 )));
276 }
277 graph
278 .entry(node.id.clone())
279 .or_insert_with(HashSet::new)
280 .insert(dep_id.clone());
281 }
282 }
283
284 self.check_cycles(&graph)?;
286
287 Ok(graph)
288 }
289
290 fn check_cycles(
292 &self,
293 graph: &HashMap<NodeId, HashSet<NodeId>>,
294 ) -> Result<(), AutoParallelError> {
295 let mut visited = HashSet::new();
296 let mut rec_stack = HashSet::new();
297
298 for node in graph.keys() {
299 if !visited.contains(node) {
300 if self.has_cycle_util(node, graph, &mut visited, &mut rec_stack)? {
301 return Err(AutoParallelError::DependencyCycle(format!(
302 "Cycle detected involving node: {}",
303 node
304 )));
305 }
306 }
307 }
308
309 Ok(())
310 }
311
312 fn has_cycle_util(
313 &self,
314 node: &NodeId,
315 graph: &HashMap<NodeId, HashSet<NodeId>>,
316 visited: &mut HashSet<NodeId>,
317 rec_stack: &mut HashSet<NodeId>,
318 ) -> Result<bool, AutoParallelError> {
319 visited.insert(node.clone());
320 rec_stack.insert(node.clone());
321
322 if let Some(neighbors) = graph.get(node) {
323 for neighbor in neighbors {
324 if !visited.contains(neighbor) {
325 if self.has_cycle_util(neighbor, graph, visited, rec_stack)? {
326 return Ok(true);
327 }
328 } else if rec_stack.contains(neighbor) {
329 return Ok(true);
330 }
331 }
332 }
333
334 rec_stack.remove(node);
335 Ok(false)
336 }
337
338 fn compute_stages(
340 &self,
341 nodes: &[NodeInfo],
342 dep_graph: &HashMap<NodeId, HashSet<NodeId>>,
343 ) -> Result<Vec<ParallelStage>, AutoParallelError> {
344 let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
345 let mut node_map: HashMap<NodeId, &NodeInfo> = HashMap::new();
346
347 for node in nodes {
349 node_map.insert(node.id.clone(), node);
350 let deps = dep_graph.get(&node.id).unwrap();
351 in_degree.insert(node.id.clone(), deps.len());
352 }
353
354 let mut stages = Vec::new();
355 let mut current_level: VecDeque<NodeId> = VecDeque::new();
356
357 for (node_id, °ree) in &in_degree {
359 if degree == 0 {
360 current_level.push_back(node_id.clone());
361 }
362 }
363
364 let mut stage_id = 0;
365 while !current_level.is_empty() {
366 let mut stage_nodes = Vec::new();
367 let mut estimated_time: f64 = 0.0;
368 let mut memory_requirement = 0;
369
370 for _ in 0..current_level.len() {
372 if let Some(node_id) = current_level.pop_front() {
373 let node = node_map[&node_id];
374 stage_nodes.push(node_id.clone());
375 estimated_time = estimated_time.max(node.estimated_cost);
376 memory_requirement += node.memory_size;
377
378 for other_id in node_map.keys() {
380 if dep_graph[other_id].contains(&node_id) {
381 if let Some(degree) = in_degree.get_mut(other_id) {
382 *degree -= 1;
383 if *degree == 0 {
384 current_level.push_back(other_id.clone());
385 }
386 }
387 }
388 }
389 }
390 }
391
392 if !stage_nodes.is_empty() {
393 stages.push(ParallelStage {
394 stage_id,
395 nodes: stage_nodes,
396 estimated_time,
397 memory_requirement,
398 predecessors: if stage_id > 0 {
399 vec![stage_id - 1]
400 } else {
401 vec![]
402 },
403 });
404 stage_id += 1;
405 }
406 }
407
408 if stages.iter().map(|s| s.nodes.len()).sum::<usize>() != nodes.len() {
410 return Err(AutoParallelError::DependencyCycle(
411 "Not all nodes were processed - cycle detected".to_string(),
412 ));
413 }
414
415 Ok(stages)
416 }
417
418 fn calculate_critical_path(&self, stages: &[ParallelStage]) -> f64 {
420 stages.iter().map(|s| s.estimated_time).sum()
421 }
422
423 fn estimate_communication_overhead(
425 &self,
426 stages: &[ParallelStage],
427 _nodes: &[NodeInfo],
428 ) -> f64 {
429 let mut overhead = 0.0;
430
431 for stage in stages {
433 if stage.nodes.len() > 1 {
434 overhead += self.overhead_per_task * stage.nodes.len() as f64;
436
437 let transfer_time =
439 stage.memory_requirement as f64 / (self.communication_bandwidth * 1e9) * 1e6;
440 overhead += transfer_time;
441 }
442 }
443
444 overhead
445 }
446
447 fn recommend_worker_count(&self, parallelism_factor: f64) -> usize {
449 let ideal = parallelism_factor.ceil() as usize;
450
451 match self.strategy {
452 ParallelizationStrategy::Conservative => ideal.min(self.max_workers / 2).max(1),
453 ParallelizationStrategy::Balanced => ideal.min(self.max_workers),
454 ParallelizationStrategy::Aggressive => self.max_workers,
455 ParallelizationStrategy::CostBased => {
456 if parallelism_factor > 2.0 {
458 ideal.min(self.max_workers)
459 } else {
460 (ideal / 2).max(1)
461 }
462 }
463 }
464 }
465
466 fn partition_work(
468 &self,
469 analysis: &ParallelizationAnalysis,
470 ) -> Result<Vec<WorkPartition>, AutoParallelError> {
471 let num_workers = analysis.recommended_workers;
472 let mut partitions: Vec<WorkPartition> = (0..num_workers)
473 .map(|i| WorkPartition {
474 worker_id: i,
475 nodes: Vec::new(),
476 estimated_load: 0.0,
477 })
478 .collect();
479
480 for stage in &analysis.stages {
482 let mut stage_nodes: Vec<(NodeId, f64)> = stage
484 .nodes
485 .iter()
486 .map(|id| (id.clone(), 1.0)) .collect();
488 stage_nodes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
489
490 for (node_id, cost) in stage_nodes {
492 let min_partition = partitions
493 .iter_mut()
494 .min_by(|a, b| a.estimated_load.partial_cmp(&b.estimated_load).unwrap())
495 .ok_or_else(|| {
496 AutoParallelError::PartitioningFailed("No partitions available".to_string())
497 })?;
498
499 min_partition.nodes.push(node_id);
500 min_partition.estimated_load += cost;
501 }
502 }
503
504 Ok(partitions)
505 }
506
507 fn calculate_load_balance(&self, partitions: &[WorkPartition]) -> f64 {
509 if partitions.is_empty() {
510 return 1.0;
511 }
512
513 let loads: Vec<f64> = partitions.iter().map(|p| p.estimated_load).collect();
514 let max_load = loads.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
515 let avg_load = loads.iter().sum::<f64>() / loads.len() as f64;
516
517 if max_load > 0.0 {
518 avg_load / max_load
519 } else {
520 1.0
521 }
522 }
523}
524
525impl Default for AutoParallelizer {
526 fn default() -> Self {
527 Self::new()
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534
535 fn create_test_nodes() -> Vec<NodeInfo> {
536 vec![
537 NodeInfo {
538 id: "a".to_string(),
539 op_type: "input".to_string(),
540 estimated_cost: 10.0,
541 memory_size: 1000,
542 dependencies: vec![],
543 can_parallelize: true,
544 },
545 NodeInfo {
546 id: "b".to_string(),
547 op_type: "compute".to_string(),
548 estimated_cost: 20.0,
549 memory_size: 2000,
550 dependencies: vec![("a".to_string(), DependencyType::Data)],
551 can_parallelize: true,
552 },
553 NodeInfo {
554 id: "c".to_string(),
555 op_type: "compute".to_string(),
556 estimated_cost: 15.0,
557 memory_size: 1500,
558 dependencies: vec![("a".to_string(), DependencyType::Data)],
559 can_parallelize: true,
560 },
561 NodeInfo {
562 id: "d".to_string(),
563 op_type: "output".to_string(),
564 estimated_cost: 10.0,
565 memory_size: 1000,
566 dependencies: vec![
567 ("b".to_string(), DependencyType::Data),
568 ("c".to_string(), DependencyType::Data),
569 ],
570 can_parallelize: false,
571 },
572 ]
573 }
574
575 #[test]
576 fn test_auto_parallelizer_creation() {
577 let parallelizer = AutoParallelizer::new();
578 assert_eq!(parallelizer.strategy, ParallelizationStrategy::Balanced);
579 assert_eq!(parallelizer.cost_model, CostModel::Heuristic);
580 }
581
582 #[test]
583 fn test_builder_pattern() {
584 let parallelizer = AutoParallelizer::new()
585 .with_strategy(ParallelizationStrategy::Aggressive)
586 .with_cost_model(CostModel::ProfileBased)
587 .with_max_workers(8);
588
589 assert_eq!(parallelizer.strategy, ParallelizationStrategy::Aggressive);
590 assert_eq!(parallelizer.cost_model, CostModel::ProfileBased);
591 assert_eq!(parallelizer.max_workers, 8);
592 }
593
594 #[test]
595 fn test_dependency_graph_building() {
596 let parallelizer = AutoParallelizer::new();
597 let nodes = create_test_nodes();
598
599 let graph = parallelizer.build_dependency_graph(&nodes).unwrap();
600
601 assert_eq!(graph.len(), 4);
602 assert!(graph["b"].contains("a"));
603 assert!(graph["c"].contains("a"));
604 assert!(graph["d"].contains("b"));
605 assert!(graph["d"].contains("c"));
606 }
607
608 #[test]
609 fn test_cycle_detection() {
610 let parallelizer = AutoParallelizer::new();
611
612 let nodes = vec![
614 NodeInfo {
615 id: "a".to_string(),
616 op_type: "compute".to_string(),
617 estimated_cost: 10.0,
618 memory_size: 1000,
619 dependencies: vec![("b".to_string(), DependencyType::Data)],
620 can_parallelize: true,
621 },
622 NodeInfo {
623 id: "b".to_string(),
624 op_type: "compute".to_string(),
625 estimated_cost: 10.0,
626 memory_size: 1000,
627 dependencies: vec![("a".to_string(), DependencyType::Data)],
628 can_parallelize: true,
629 },
630 ];
631
632 let result = parallelizer.build_dependency_graph(&nodes);
633 assert!(result.is_err());
634 }
635
636 #[test]
637 fn test_stage_computation() {
638 let parallelizer = AutoParallelizer::new();
639 let nodes = create_test_nodes();
640
641 let analysis = parallelizer.analyze(&nodes).unwrap();
642
643 assert_eq!(analysis.num_stages, 3);
644 assert_eq!(analysis.stages[0].nodes, vec!["a"]);
645 assert_eq!(analysis.stages[1].nodes.len(), 2); assert!(analysis.stages[1].nodes.contains(&"b".to_string()));
647 assert!(analysis.stages[1].nodes.contains(&"c".to_string()));
648 assert_eq!(analysis.stages[2].nodes, vec!["d"]);
649 }
650
651 #[test]
652 fn test_critical_path_calculation() {
653 let parallelizer = AutoParallelizer::new();
654 let nodes = create_test_nodes();
655
656 let analysis = parallelizer.analyze(&nodes).unwrap();
657
658 assert_eq!(analysis.critical_path_length, 40.0);
660 }
661
662 #[test]
663 fn test_parallelism_factor() {
664 let parallelizer = AutoParallelizer::new();
665 let nodes = create_test_nodes();
666
667 let analysis = parallelizer.analyze(&nodes).unwrap();
668
669 assert!((analysis.parallelism_factor - 1.375).abs() < 0.01);
673 }
674
675 #[test]
676 fn test_execution_plan_generation() {
677 let parallelizer = AutoParallelizer::new();
678 let nodes = create_test_nodes();
679
680 let plan = parallelizer.generate_plan(&nodes).unwrap();
681
682 assert_eq!(plan.stages.len(), 3);
683 assert!(!plan.partitions.is_empty());
684 assert!(plan.estimated_speedup > 0.0);
686 assert!(plan.load_balance_ratio > 0.0 && plan.load_balance_ratio <= 1.0);
687 }
688
689 #[test]
690 fn test_profile_update() {
691 let mut parallelizer = AutoParallelizer::new();
692
693 parallelizer.update_profile("compute".to_string(), 100.0);
694 parallelizer.update_profile("compute".to_string(), 200.0);
695
696 assert!(parallelizer.profile_data.contains_key("compute"));
697 let avg = parallelizer.profile_data["compute"];
698 assert!(avg >= 0.0);
701 }
702
703 #[test]
704 fn test_strategy_variations() {
705 let nodes = create_test_nodes();
706
707 let conservative = AutoParallelizer::new()
708 .with_strategy(ParallelizationStrategy::Conservative)
709 .analyze(&nodes)
710 .unwrap();
711
712 let aggressive = AutoParallelizer::new()
713 .with_strategy(ParallelizationStrategy::Aggressive)
714 .analyze(&nodes)
715 .unwrap();
716
717 assert!(aggressive.recommended_workers >= conservative.recommended_workers);
719 }
720
721 #[test]
722 fn test_sequential_graph() {
723 let parallelizer = AutoParallelizer::new();
724
725 let nodes = vec![
727 NodeInfo {
728 id: "a".to_string(),
729 op_type: "compute".to_string(),
730 estimated_cost: 10.0,
731 memory_size: 1000,
732 dependencies: vec![],
733 can_parallelize: true,
734 },
735 NodeInfo {
736 id: "b".to_string(),
737 op_type: "compute".to_string(),
738 estimated_cost: 10.0,
739 memory_size: 1000,
740 dependencies: vec![("a".to_string(), DependencyType::Data)],
741 can_parallelize: true,
742 },
743 ];
744
745 let analysis = parallelizer.analyze(&nodes).unwrap();
746
747 assert_eq!(analysis.num_stages, 2);
748 assert_eq!(analysis.parallelism_factor, 1.0); }
750
751 #[test]
752 fn test_fully_parallel_graph() {
753 let parallelizer = AutoParallelizer::new();
754
755 let nodes = vec![
757 NodeInfo {
758 id: "a".to_string(),
759 op_type: "compute".to_string(),
760 estimated_cost: 10.0,
761 memory_size: 1000,
762 dependencies: vec![],
763 can_parallelize: true,
764 },
765 NodeInfo {
766 id: "b".to_string(),
767 op_type: "compute".to_string(),
768 estimated_cost: 10.0,
769 memory_size: 1000,
770 dependencies: vec![],
771 can_parallelize: true,
772 },
773 NodeInfo {
774 id: "c".to_string(),
775 op_type: "compute".to_string(),
776 estimated_cost: 10.0,
777 memory_size: 1000,
778 dependencies: vec![],
779 can_parallelize: true,
780 },
781 ];
782
783 let analysis = parallelizer.analyze(&nodes).unwrap();
784
785 assert_eq!(analysis.num_stages, 1);
786 assert_eq!(analysis.parallelism_factor, 3.0); }
788
789 #[test]
790 fn test_load_balancing() {
791 let parallelizer = AutoParallelizer::new().with_max_workers(2);
792 let nodes = create_test_nodes();
793
794 let plan = parallelizer.generate_plan(&nodes).unwrap();
795
796 assert!(plan.partitions.len() > 0);
798 assert!(plan.load_balance_ratio > 0.0 && plan.load_balance_ratio <= 1.0);
799 }
800
801 #[test]
802 fn test_invalid_graph() {
803 let parallelizer = AutoParallelizer::new();
804
805 let nodes = vec![NodeInfo {
807 id: "a".to_string(),
808 op_type: "compute".to_string(),
809 estimated_cost: 10.0,
810 memory_size: 1000,
811 dependencies: vec![("unknown".to_string(), DependencyType::Data)],
812 can_parallelize: true,
813 }];
814
815 let result = parallelizer.build_dependency_graph(&nodes);
816 assert!(result.is_err());
817 }
818}