1use anyhow::Result;
2use std::collections::HashMap;
3use trustformers_core::parallel::CommunicationBackend;
4use trustformers_core::tensor::Tensor;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum AggregationStrategy {
14 BinaryTree,
16 Ring,
18 Butterfly,
20 Adaptive,
22}
23
24#[derive(Debug, Clone)]
25pub struct HierarchicalConfig {
26 pub num_nodes: usize,
28 pub devices_per_node: usize,
30 pub node_rank: usize,
32 pub local_rank: usize,
34 pub global_rank: usize,
36 pub strategy: AggregationStrategy,
38 pub comm_backend: CommunicationBackend,
40 pub enable_compression: bool,
42 pub compression_threshold: f32,
44 pub enable_fault_tolerance: bool,
46 pub comm_timeout_ms: u64,
48}
49
50impl Default for HierarchicalConfig {
51 fn default() -> Self {
52 Self {
53 num_nodes: 1,
54 devices_per_node: 1,
55 node_rank: 0,
56 local_rank: 0,
57 global_rank: 0,
58 strategy: AggregationStrategy::Adaptive,
59 comm_backend: CommunicationBackend::Mpi,
60 enable_compression: true,
61 compression_threshold: 0.1,
62 enable_fault_tolerance: true,
63 comm_timeout_ms: 30000,
64 }
65 }
66}
67
68impl HierarchicalConfig {
69 pub fn new(
70 num_nodes: usize,
71 devices_per_node: usize,
72 node_rank: usize,
73 local_rank: usize,
74 ) -> Self {
75 let global_rank = node_rank * devices_per_node + local_rank;
76 Self {
77 num_nodes,
78 devices_per_node,
79 node_rank,
80 local_rank,
81 global_rank,
82 ..Default::default()
83 }
84 }
85
86 pub fn world_size(&self) -> usize {
87 self.num_nodes * self.devices_per_node
88 }
89
90 pub fn is_master(&self) -> bool {
91 self.global_rank == 0
92 }
93
94 pub fn is_node_master(&self) -> bool {
95 self.local_rank == 0
96 }
97}
98
99pub struct HierarchicalAggregator {
101 config: HierarchicalConfig,
102 #[allow(dead_code)]
103 node_topology: NodeTopology,
104 communication_groups: CommunicationGroups,
105 aggregation_stats: AggregationStats,
106 #[allow(dead_code)]
107 fault_detector: Option<FaultDetector>,
108}
109
110#[derive(Debug, Clone)]
112pub struct NodeTopology {
113 pub node_adjacency: Vec<Vec<bool>>,
115 pub node_bandwidth: Vec<Vec<f32>>,
117 pub node_latency: Vec<Vec<f32>>,
119 pub intra_node_bandwidth: f32,
121 pub intra_node_latency: f32,
123}
124
125#[derive(Debug, Clone)]
127pub struct CommunicationGroups {
128 pub node_local_group: Vec<usize>,
130 pub cross_node_group: Vec<usize>,
132 pub tree_structure: TreeStructure,
134 pub ring_structure: RingStructure,
136 pub butterfly_structure: ButterflyStructure,
138}
139
140#[derive(Debug, Clone)]
141pub struct TreeStructure {
142 pub parent: Option<usize>,
144 pub children: Vec<usize>,
146 pub depth: usize,
148 pub height: usize,
150}
151
152#[derive(Debug, Clone)]
153pub struct RingStructure {
154 pub next_rank: usize,
156 pub prev_rank: usize,
158 pub ring_size: usize,
160}
161
162#[derive(Debug, Clone)]
163pub struct ButterflyStructure {
164 pub connections: Vec<Vec<usize>>,
166 pub num_stages: usize,
168}
169
170#[derive(Debug, Clone)]
172pub struct AggregationStats {
173 pub total_operations: usize,
175 pub avg_aggregation_time: f32,
177 pub total_bytes_transferred: usize,
179 pub compression_ratio: f32,
181 pub failed_operations: usize,
183 pub strategy_history: HashMap<AggregationStrategy, usize>,
185}
186
187#[derive(Debug)]
189pub struct FaultDetector {
190 pub failed_nodes: Vec<usize>,
192 pub timeout_threshold: u64,
194 pub recovery_strategy: RecoveryStrategy,
196}
197
198#[derive(Debug, Clone)]
199pub enum RecoveryStrategy {
200 Skip,
202 Retry,
204 Abort,
206}
207
208impl Default for AggregationStats {
209 fn default() -> Self {
210 Self {
211 total_operations: 0,
212 avg_aggregation_time: 0.0,
213 total_bytes_transferred: 0,
214 compression_ratio: 1.0,
215 failed_operations: 0,
216 strategy_history: HashMap::new(),
217 }
218 }
219}
220
221impl HierarchicalAggregator {
222 pub fn new(config: HierarchicalConfig) -> Result<Self> {
223 let node_topology = Self::detect_network_topology(&config)?;
224 let communication_groups = Self::build_communication_groups(&config, &node_topology)?;
225 let aggregation_stats = AggregationStats::default();
226
227 let fault_detector = if config.enable_fault_tolerance {
228 Some(FaultDetector {
229 failed_nodes: Vec::new(),
230 timeout_threshold: config.comm_timeout_ms,
231 recovery_strategy: RecoveryStrategy::Skip,
232 })
233 } else {
234 None
235 };
236
237 Ok(Self {
238 config,
239 node_topology,
240 communication_groups,
241 aggregation_stats,
242 fault_detector,
243 })
244 }
245
246 fn detect_network_topology(config: &HierarchicalConfig) -> Result<NodeTopology> {
248 let num_nodes = config.num_nodes;
249
250 let mut node_adjacency = vec![vec![false; num_nodes]; num_nodes];
252 let mut node_bandwidth = vec![vec![0.0; num_nodes]; num_nodes];
253 let mut node_latency = vec![vec![0.0; num_nodes]; num_nodes];
254
255 for i in 0..num_nodes {
258 for j in 0..num_nodes {
259 if i != j {
260 node_adjacency[i][j] = true;
261 node_bandwidth[i][j] = if (i as i32 - j as i32).abs() == 1 {
263 10000.0 } else {
265 1000.0 };
267 node_latency[i][j] = if (i as i32 - j as i32).abs() == 1 {
269 0.1 } else {
271 1.0 };
273 } else {
274 node_adjacency[i][j] = false;
275 node_bandwidth[i][j] = f32::INFINITY;
276 node_latency[i][j] = 0.0;
277 }
278 }
279 }
280
281 Ok(NodeTopology {
282 node_adjacency,
283 node_bandwidth,
284 node_latency,
285 intra_node_bandwidth: 80000.0, intra_node_latency: 0.01, })
288 }
289
290 fn build_communication_groups(
292 config: &HierarchicalConfig,
293 topology: &NodeTopology,
294 ) -> Result<CommunicationGroups> {
295 let node_local_group: Vec<usize> = (0..config.devices_per_node)
297 .map(|i| config.node_rank * config.devices_per_node + i)
298 .collect();
299
300 let cross_node_group: Vec<usize> =
302 (0..config.num_nodes).map(|i| i * config.devices_per_node).collect();
303
304 let tree_structure = Self::build_tree_structure(config, topology)?;
306
307 let ring_structure = Self::build_ring_structure(config)?;
309
310 let butterfly_structure = Self::build_butterfly_structure(config)?;
312
313 Ok(CommunicationGroups {
314 node_local_group,
315 cross_node_group,
316 tree_structure,
317 ring_structure,
318 butterfly_structure,
319 })
320 }
321
322 fn build_tree_structure(
324 config: &HierarchicalConfig,
325 _topology: &NodeTopology,
326 ) -> Result<TreeStructure> {
327 let world_size = config.world_size();
328 let rank = config.global_rank;
329
330 let parent = if rank == 0 { None } else { Some((rank - 1) / 2) };
332
333 let mut children = Vec::new();
334 let left_child = 2 * rank + 1;
335 let right_child = 2 * rank + 2;
336
337 if left_child < world_size {
338 children.push(left_child);
339 }
340 if right_child < world_size {
341 children.push(right_child);
342 }
343
344 let depth = (rank as f32).log2().floor() as usize;
346 let height = (world_size as f32).log2().ceil() as usize;
347
348 Ok(TreeStructure {
349 parent,
350 children,
351 depth,
352 height,
353 })
354 }
355
356 fn build_ring_structure(config: &HierarchicalConfig) -> Result<RingStructure> {
358 let world_size = config.world_size();
359 let rank = config.global_rank;
360
361 let next_rank = (rank + 1) % world_size;
362 let prev_rank = (rank + world_size - 1) % world_size;
363
364 Ok(RingStructure {
365 next_rank,
366 prev_rank,
367 ring_size: world_size,
368 })
369 }
370
371 fn build_butterfly_structure(config: &HierarchicalConfig) -> Result<ButterflyStructure> {
373 let world_size = config.world_size();
374 let rank = config.global_rank;
375 let num_stages = (world_size as f32).log2().ceil() as usize;
376
377 let mut connections = Vec::new();
378
379 for stage in 0..num_stages {
380 let mut stage_connections = Vec::new();
381 let distance = 1 << stage;
382
383 let partner = rank ^ distance;
385 if partner < world_size {
386 stage_connections.push(partner);
387 }
388
389 connections.push(stage_connections);
390 }
391
392 Ok(ButterflyStructure {
393 connections,
394 num_stages,
395 })
396 }
397
398 pub fn hierarchical_all_reduce(
400 &mut self,
401 gradients: &mut HashMap<String, Tensor>,
402 ) -> Result<()> {
403 let start_time = std::time::Instant::now();
404
405 let strategy = self.select_optimal_strategy(gradients)?;
407
408 match strategy {
410 AggregationStrategy::BinaryTree => {
411 self.tree_based_all_reduce(gradients)?;
412 },
413 AggregationStrategy::Ring => {
414 self.ring_based_all_reduce(gradients)?;
415 },
416 AggregationStrategy::Butterfly => {
417 self.butterfly_based_all_reduce(gradients)?;
418 },
419 AggregationStrategy::Adaptive => {
420 let optimal_strategy = self.adaptive_strategy_selection(gradients)?;
422 match optimal_strategy {
423 AggregationStrategy::BinaryTree => self.tree_based_all_reduce(gradients)?,
424 AggregationStrategy::Ring => self.ring_based_all_reduce(gradients)?,
425 AggregationStrategy::Butterfly => self.butterfly_based_all_reduce(gradients)?,
426 AggregationStrategy::Adaptive => {
427 return Err(anyhow::anyhow!(
428 "Invalid adaptive strategy selection: recursive Adaptive strategy returned"
429 ));
430 },
431 }
432 },
433 }
434
435 let elapsed = start_time.elapsed().as_millis() as f32;
437 self.update_aggregation_stats(strategy, elapsed, gradients)?;
438
439 Ok(())
440 }
441
442 fn select_optimal_strategy(
444 &self,
445 gradients: &HashMap<String, Tensor>,
446 ) -> Result<AggregationStrategy> {
447 match self.config.strategy {
448 AggregationStrategy::Adaptive => self.adaptive_strategy_selection(gradients),
449 strategy => Ok(strategy),
450 }
451 }
452
453 fn adaptive_strategy_selection(
455 &self,
456 gradients: &HashMap<String, Tensor>,
457 ) -> Result<AggregationStrategy> {
458 let world_size = self.config.world_size();
459 let num_nodes = self.config.num_nodes;
460
461 let total_data_size: usize = gradients.values().map(|tensor| tensor.memory_usage()).sum();
463
464 if world_size <= 8 {
466 Ok(AggregationStrategy::BinaryTree)
468 } else if total_data_size > 100 * 1024 * 1024 {
469 Ok(AggregationStrategy::Ring)
471 } else if num_nodes > 16 {
472 Ok(AggregationStrategy::Butterfly)
474 } else {
475 Ok(AggregationStrategy::BinaryTree)
477 }
478 }
479
480 fn tree_based_all_reduce(&mut self, gradients: &mut HashMap<String, Tensor>) -> Result<()> {
482 let tree = self.communication_groups.tree_structure.clone();
483
484 self.tree_reduce_up(gradients, &tree)?;
486
487 self.tree_broadcast_down(gradients, &tree)?;
489
490 Ok(())
491 }
492
493 fn ring_based_all_reduce(&mut self, gradients: &mut HashMap<String, Tensor>) -> Result<()> {
495 let ring = self.communication_groups.ring_structure.clone();
496
497 self.ring_reduce_scatter(gradients, &ring)?;
499
500 self.ring_all_gather(gradients, &ring)?;
502
503 Ok(())
504 }
505
506 fn butterfly_based_all_reduce(
508 &mut self,
509 gradients: &mut HashMap<String, Tensor>,
510 ) -> Result<()> {
511 let butterfly = self.communication_groups.butterfly_structure.clone();
512
513 for stage in 0..butterfly.num_stages {
515 self.butterfly_stage_operation(gradients, &butterfly, stage)?;
516 }
517
518 Ok(())
519 }
520
521 fn tree_reduce_up(
523 &mut self,
524 gradients: &mut HashMap<String, Tensor>,
525 tree: &TreeStructure,
526 ) -> Result<()> {
527 for &child_rank in &tree.children {
529 for (name, gradient) in gradients.iter_mut() {
530 let child_gradient = self.simulate_receive_gradient(child_rank, name)?;
532 *gradient = gradient.add(&child_gradient)?;
533 }
534 }
535
536 if let Some(parent_rank) = tree.parent {
538 for (name, gradient) in gradients.iter() {
539 self.simulate_send_gradient(parent_rank, name, gradient)?;
540 }
541 }
542
543 Ok(())
544 }
545
546 fn tree_broadcast_down(
548 &mut self,
549 gradients: &mut HashMap<String, Tensor>,
550 tree: &TreeStructure,
551 ) -> Result<()> {
552 if let Some(parent_rank) = tree.parent {
554 for (name, gradient) in gradients.iter_mut() {
555 *gradient = self.simulate_receive_gradient(parent_rank, name)?;
556 }
557 }
558
559 for &child_rank in &tree.children {
561 for (name, gradient) in gradients.iter() {
562 self.simulate_send_gradient(child_rank, name, gradient)?;
563 }
564 }
565
566 Ok(())
567 }
568
569 fn ring_reduce_scatter(
571 &mut self,
572 gradients: &mut HashMap<String, Tensor>,
573 ring: &RingStructure,
574 ) -> Result<()> {
575 let num_chunks = ring.ring_size;
576 let rank = self.config.global_rank;
577
578 for step in 0..num_chunks - 1 {
579 let _send_chunk = (rank + ring.ring_size - step) % ring.ring_size;
580 let _recv_chunk = (rank + ring.ring_size - step - 1) % ring.ring_size;
581
582 for (name, gradient) in gradients.iter_mut() {
584 let chunk_gradient = self.simulate_receive_gradient(ring.prev_rank, name)?;
585 *gradient = gradient.add(&chunk_gradient)?;
586 self.simulate_send_gradient(ring.next_rank, name, gradient)?;
587 }
588 }
589
590 Ok(())
591 }
592
593 fn ring_all_gather(
595 &mut self,
596 gradients: &mut HashMap<String, Tensor>,
597 ring: &RingStructure,
598 ) -> Result<()> {
599 let num_chunks = ring.ring_size;
600
601 for _step in 0..num_chunks - 1 {
602 for (name, gradient) in gradients.iter_mut() {
604 let chunk_gradient = self.simulate_receive_gradient(ring.prev_rank, name)?;
605 *gradient = gradient.add(&chunk_gradient)?;
606 self.simulate_send_gradient(ring.next_rank, name, gradient)?;
607 }
608 }
609
610 Ok(())
611 }
612
613 fn butterfly_stage_operation(
615 &mut self,
616 gradients: &mut HashMap<String, Tensor>,
617 butterfly: &ButterflyStructure,
618 stage: usize,
619 ) -> Result<()> {
620 if stage < butterfly.connections.len() {
621 for &partner_rank in &butterfly.connections[stage] {
622 for (name, gradient) in gradients.iter_mut() {
623 let partner_gradient = self.simulate_receive_gradient(partner_rank, name)?;
625 *gradient = gradient.add(&partner_gradient)?;
626 self.simulate_send_gradient(partner_rank, name, gradient)?;
627 }
628 }
629 }
630
631 Ok(())
632 }
633
634 fn simulate_receive_gradient(&self, _from_rank: usize, _name: &str) -> Result<Tensor> {
636 Ok(Tensor::zeros(&[1])?)
639 }
640
641 fn simulate_send_gradient(
643 &self,
644 _to_rank: usize,
645 _name: &str,
646 _gradient: &Tensor,
647 ) -> Result<()> {
648 Ok(())
650 }
651
652 fn update_aggregation_stats(
654 &mut self,
655 strategy: AggregationStrategy,
656 elapsed_ms: f32,
657 gradients: &HashMap<String, Tensor>,
658 ) -> Result<()> {
659 let stats = &mut self.aggregation_stats;
660
661 stats.total_operations += 1;
662 stats.avg_aggregation_time =
663 (stats.avg_aggregation_time * (stats.total_operations - 1) as f32 + elapsed_ms)
664 / stats.total_operations as f32;
665
666 let bytes_transferred: usize = gradients.values().map(|tensor| tensor.memory_usage()).sum();
667 stats.total_bytes_transferred += bytes_transferred;
668
669 *stats.strategy_history.entry(strategy).or_insert(0) += 1;
670
671 Ok(())
672 }
673
674 pub fn get_stats(&self) -> &AggregationStats {
676 &self.aggregation_stats
677 }
678
679 pub fn reset_stats(&mut self) {
681 self.aggregation_stats = AggregationStats::default();
682 }
683
684 pub fn get_recommended_strategy(&self) -> AggregationStrategy {
686 let world_size = self.config.world_size();
687 let num_nodes = self.config.num_nodes;
688
689 if world_size <= 8 {
690 AggregationStrategy::BinaryTree
691 } else if num_nodes > 16 {
692 AggregationStrategy::Butterfly
693 } else {
694 AggregationStrategy::Ring
695 }
696 }
697}
698
699#[cfg(test)]
700mod tests {
701 use super::*;
702
703 #[test]
704 fn test_hierarchical_config() {
705 let config = HierarchicalConfig::new(4, 8, 2, 3);
706 assert_eq!(config.num_nodes, 4);
707 assert_eq!(config.devices_per_node, 8);
708 assert_eq!(config.node_rank, 2);
709 assert_eq!(config.local_rank, 3);
710 assert_eq!(config.global_rank, 19);
711 assert_eq!(config.world_size(), 32);
712 assert!(!config.is_master());
713 assert!(!config.is_node_master());
714 }
715
716 #[test]
717 fn test_tree_structure_building() {
718 let config = HierarchicalConfig::new(2, 4, 0, 0);
719 let topology = HierarchicalAggregator::detect_network_topology(&config).unwrap();
720 let tree = HierarchicalAggregator::build_tree_structure(&config, &topology).unwrap();
721
722 assert_eq!(tree.parent, None); assert_eq!(tree.children, vec![1, 2]);
724 assert_eq!(tree.depth, 0);
725 }
726
727 #[test]
728 fn test_ring_structure_building() {
729 let config = HierarchicalConfig::new(2, 4, 0, 1);
730 let ring = HierarchicalAggregator::build_ring_structure(&config).unwrap();
731
732 assert_eq!(ring.next_rank, 2);
733 assert_eq!(ring.prev_rank, 0);
734 assert_eq!(ring.ring_size, 8);
735 }
736
737 #[test]
738 fn test_adaptive_strategy_selection() {
739 let config = HierarchicalConfig::new(4, 4, 0, 0);
740 let aggregator = HierarchicalAggregator::new(config).unwrap();
741
742 let mut gradients = HashMap::new();
743 gradients.insert("param1".to_string(), Tensor::zeros(&[8000, 8000]).unwrap());
745
746 let strategy = aggregator.adaptive_strategy_selection(&gradients).unwrap();
747 assert!(matches!(strategy, AggregationStrategy::Ring));
749 }
750
751 #[test]
752 fn test_aggregation_stats_update() {
753 let config = HierarchicalConfig::new(2, 2, 0, 0);
754 let mut aggregator = HierarchicalAggregator::new(config).unwrap();
755
756 let mut gradients = HashMap::new();
757 gradients.insert("param1".to_string(), Tensor::zeros(&[10, 10]).unwrap());
758
759 aggregator
760 .update_aggregation_stats(AggregationStrategy::BinaryTree, 100.0, &gradients)
761 .unwrap();
762
763 let stats = aggregator.get_stats();
764 assert_eq!(stats.total_operations, 1);
765 assert_eq!(stats.avg_aggregation_time, 100.0);
766 assert_eq!(
767 stats.strategy_history.get(&AggregationStrategy::BinaryTree),
768 Some(&1)
769 );
770 }
771
772 #[test]
773 fn test_recommended_strategy() {
774 let small_config = HierarchicalConfig::new(2, 2, 0, 0);
775 let small_aggregator = HierarchicalAggregator::new(small_config).unwrap();
776 assert!(matches!(
777 small_aggregator.get_recommended_strategy(),
778 AggregationStrategy::BinaryTree
779 ));
780
781 let large_config = HierarchicalConfig::new(20, 1, 0, 0);
782 let large_aggregator = HierarchicalAggregator::new(large_config).unwrap();
783 assert!(matches!(
784 large_aggregator.get_recommended_strategy(),
785 AggregationStrategy::Butterfly
786 ));
787 }
788
789 #[test]
790 fn test_butterfly_structure() {
791 let config = HierarchicalConfig::new(1, 8, 0, 0);
792 let butterfly = HierarchicalAggregator::build_butterfly_structure(&config).unwrap();
793
794 assert_eq!(butterfly.num_stages, 3); assert_eq!(butterfly.connections.len(), 3);
796 }
797
798 #[test]
799 fn test_network_topology_detection() {
800 let config = HierarchicalConfig::new(3, 2, 0, 0);
801 let topology = HierarchicalAggregator::detect_network_topology(&config).unwrap();
802
803 assert_eq!(topology.node_adjacency.len(), 3);
804 assert_eq!(topology.node_bandwidth.len(), 3);
805 assert_eq!(topology.node_latency.len(), 3);
806 assert!(topology.intra_node_bandwidth > 0.0);
807 assert!(topology.intra_node_latency > 0.0);
808 }
809}