Skip to main content

trustformers_optim/
hierarchical_aggregation.rs

1use anyhow::Result;
2use std::collections::HashMap;
3use trustformers_core::parallel::CommunicationBackend;
4use trustformers_core::tensor::Tensor;
5
6/// Hierarchical aggregation strategies for distributed training
7///
8/// This module provides advanced hierarchical aggregation algorithms that optimize
9/// communication patterns for different network topologies and cluster configurations.
10/// It supports tree-based, ring-based, and butterfly aggregation patterns.
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum AggregationStrategy {
14    /// Binary tree aggregation (optimal for small clusters)
15    BinaryTree,
16    /// Ring-based aggregation (bandwidth-optimal)
17    Ring,
18    /// Butterfly aggregation (latency-optimal)
19    Butterfly,
20    /// Adaptive strategy that selects best algorithm based on cluster topology
21    Adaptive,
22}
23
24#[derive(Debug, Clone)]
25pub struct HierarchicalConfig {
26    /// Number of nodes in the cluster
27    pub num_nodes: usize,
28    /// Number of devices per node
29    pub devices_per_node: usize,
30    /// Node rank (0-based)
31    pub node_rank: usize,
32    /// Local rank within node
33    pub local_rank: usize,
34    /// Global rank across all nodes
35    pub global_rank: usize,
36    /// Aggregation strategy
37    pub strategy: AggregationStrategy,
38    /// Communication backend
39    pub comm_backend: CommunicationBackend,
40    /// Enable compression during aggregation
41    pub enable_compression: bool,
42    /// Compression threshold (only compress if savings > threshold)
43    pub compression_threshold: f32,
44    /// Enable fault tolerance
45    pub enable_fault_tolerance: bool,
46    /// Timeout for communication operations (ms)
47    pub comm_timeout_ms: u64,
48}
49
50impl Default for HierarchicalConfig {
51    fn default() -> Self {
52        Self {
53            num_nodes: 1,
54            devices_per_node: 1,
55            node_rank: 0,
56            local_rank: 0,
57            global_rank: 0,
58            strategy: AggregationStrategy::Adaptive,
59            comm_backend: CommunicationBackend::Mpi,
60            enable_compression: true,
61            compression_threshold: 0.1,
62            enable_fault_tolerance: true,
63            comm_timeout_ms: 30000,
64        }
65    }
66}
67
68impl HierarchicalConfig {
69    pub fn new(
70        num_nodes: usize,
71        devices_per_node: usize,
72        node_rank: usize,
73        local_rank: usize,
74    ) -> Self {
75        let global_rank = node_rank * devices_per_node + local_rank;
76        Self {
77            num_nodes,
78            devices_per_node,
79            node_rank,
80            local_rank,
81            global_rank,
82            ..Default::default()
83        }
84    }
85
86    pub fn world_size(&self) -> usize {
87        self.num_nodes * self.devices_per_node
88    }
89
90    pub fn is_master(&self) -> bool {
91        self.global_rank == 0
92    }
93
94    pub fn is_node_master(&self) -> bool {
95        self.local_rank == 0
96    }
97}
98
99/// Hierarchical aggregation coordinator
100pub struct HierarchicalAggregator {
101    config: HierarchicalConfig,
102    #[allow(dead_code)]
103    node_topology: NodeTopology,
104    communication_groups: CommunicationGroups,
105    aggregation_stats: AggregationStats,
106    #[allow(dead_code)]
107    fault_detector: Option<FaultDetector>,
108}
109
110/// Network topology representation
111#[derive(Debug, Clone)]
112pub struct NodeTopology {
113    /// Adjacency matrix for inter-node connectivity
114    pub node_adjacency: Vec<Vec<bool>>,
115    /// Bandwidth matrix between nodes (MB/s)
116    pub node_bandwidth: Vec<Vec<f32>>,
117    /// Latency matrix between nodes (ms)
118    pub node_latency: Vec<Vec<f32>>,
119    /// Intra-node connectivity (assumed full connectivity)
120    pub intra_node_bandwidth: f32,
121    /// Intra-node latency
122    pub intra_node_latency: f32,
123}
124
125/// Communication groups for hierarchical operations
126#[derive(Debug, Clone)]
127pub struct CommunicationGroups {
128    /// Ranks within the same node
129    pub node_local_group: Vec<usize>,
130    /// Node master ranks for cross-node communication
131    pub cross_node_group: Vec<usize>,
132    /// Binary tree structure for tree-based aggregation
133    pub tree_structure: TreeStructure,
134    /// Ring structure for ring-based aggregation
135    pub ring_structure: RingStructure,
136    /// Butterfly structure for butterfly aggregation
137    pub butterfly_structure: ButterflyStructure,
138}
139
140#[derive(Debug, Clone)]
141pub struct TreeStructure {
142    /// Parent rank in the tree (-1 if root)
143    pub parent: Option<usize>,
144    /// Children ranks in the tree
145    pub children: Vec<usize>,
146    /// Tree depth
147    pub depth: usize,
148    /// Tree height
149    pub height: usize,
150}
151
152#[derive(Debug, Clone)]
153pub struct RingStructure {
154    /// Next rank in the ring
155    pub next_rank: usize,
156    /// Previous rank in the ring
157    pub prev_rank: usize,
158    /// Ring size
159    pub ring_size: usize,
160}
161
162#[derive(Debug, Clone)]
163pub struct ButterflyStructure {
164    /// Butterfly connections for each stage
165    pub connections: Vec<Vec<usize>>,
166    /// Number of stages
167    pub num_stages: usize,
168}
169
170/// Aggregation operation statistics
171#[derive(Debug, Clone)]
172pub struct AggregationStats {
173    /// Total number of aggregation operations
174    pub total_operations: usize,
175    /// Average aggregation time (ms)
176    pub avg_aggregation_time: f32,
177    /// Total bytes transferred
178    pub total_bytes_transferred: usize,
179    /// Compression ratio achieved
180    pub compression_ratio: f32,
181    /// Number of failed operations
182    pub failed_operations: usize,
183    /// Strategy selection history
184    pub strategy_history: HashMap<AggregationStrategy, usize>,
185}
186
187/// Fault detection and recovery
188#[derive(Debug)]
189pub struct FaultDetector {
190    /// Failed nodes
191    pub failed_nodes: Vec<usize>,
192    /// Timeout threshold for detecting failures
193    pub timeout_threshold: u64,
194    /// Recovery strategy
195    pub recovery_strategy: RecoveryStrategy,
196}
197
198#[derive(Debug, Clone)]
199pub enum RecoveryStrategy {
200    /// Skip failed nodes and continue
201    Skip,
202    /// Retry with backup nodes
203    Retry,
204    /// Abort aggregation
205    Abort,
206}
207
208impl Default for AggregationStats {
209    fn default() -> Self {
210        Self {
211            total_operations: 0,
212            avg_aggregation_time: 0.0,
213            total_bytes_transferred: 0,
214            compression_ratio: 1.0,
215            failed_operations: 0,
216            strategy_history: HashMap::new(),
217        }
218    }
219}
220
221impl HierarchicalAggregator {
222    pub fn new(config: HierarchicalConfig) -> Result<Self> {
223        let node_topology = Self::detect_network_topology(&config)?;
224        let communication_groups = Self::build_communication_groups(&config, &node_topology)?;
225        let aggregation_stats = AggregationStats::default();
226
227        let fault_detector = if config.enable_fault_tolerance {
228            Some(FaultDetector {
229                failed_nodes: Vec::new(),
230                timeout_threshold: config.comm_timeout_ms,
231                recovery_strategy: RecoveryStrategy::Skip,
232            })
233        } else {
234            None
235        };
236
237        Ok(Self {
238            config,
239            node_topology,
240            communication_groups,
241            aggregation_stats,
242            fault_detector,
243        })
244    }
245
246    /// Detect network topology and measure bandwidth/latency
247    fn detect_network_topology(config: &HierarchicalConfig) -> Result<NodeTopology> {
248        let num_nodes = config.num_nodes;
249
250        // Initialize topology matrices
251        let mut node_adjacency = vec![vec![false; num_nodes]; num_nodes];
252        let mut node_bandwidth = vec![vec![0.0; num_nodes]; num_nodes];
253        let mut node_latency = vec![vec![0.0; num_nodes]; num_nodes];
254
255        // For this implementation, assume full connectivity with estimated values
256        // In practice, these would be measured through benchmarking
257        for i in 0..num_nodes {
258            for j in 0..num_nodes {
259                if i != j {
260                    node_adjacency[i][j] = true;
261                    // Estimate bandwidth based on network topology
262                    node_bandwidth[i][j] = if (i as i32 - j as i32).abs() == 1 {
263                        10000.0 // Adjacent nodes: 10 GB/s
264                    } else {
265                        1000.0 // Non-adjacent nodes: 1 GB/s
266                    };
267                    // Estimate latency
268                    node_latency[i][j] = if (i as i32 - j as i32).abs() == 1 {
269                        0.1 // Adjacent nodes: 0.1ms
270                    } else {
271                        1.0 // Non-adjacent nodes: 1ms
272                    };
273                } else {
274                    node_adjacency[i][j] = false;
275                    node_bandwidth[i][j] = f32::INFINITY;
276                    node_latency[i][j] = 0.0;
277                }
278            }
279        }
280
281        Ok(NodeTopology {
282            node_adjacency,
283            node_bandwidth,
284            node_latency,
285            intra_node_bandwidth: 80000.0, // 80 GB/s intra-node
286            intra_node_latency: 0.01,      // 0.01ms intra-node
287        })
288    }
289
290    /// Build communication groups for different aggregation strategies
291    fn build_communication_groups(
292        config: &HierarchicalConfig,
293        topology: &NodeTopology,
294    ) -> Result<CommunicationGroups> {
295        // Node-local group
296        let node_local_group: Vec<usize> = (0..config.devices_per_node)
297            .map(|i| config.node_rank * config.devices_per_node + i)
298            .collect();
299
300        // Cross-node group (node masters)
301        let cross_node_group: Vec<usize> =
302            (0..config.num_nodes).map(|i| i * config.devices_per_node).collect();
303
304        // Build tree structure
305        let tree_structure = Self::build_tree_structure(config, topology)?;
306
307        // Build ring structure
308        let ring_structure = Self::build_ring_structure(config)?;
309
310        // Build butterfly structure
311        let butterfly_structure = Self::build_butterfly_structure(config)?;
312
313        Ok(CommunicationGroups {
314            node_local_group,
315            cross_node_group,
316            tree_structure,
317            ring_structure,
318            butterfly_structure,
319        })
320    }
321
322    /// Build binary tree structure for tree-based aggregation
323    fn build_tree_structure(
324        config: &HierarchicalConfig,
325        _topology: &NodeTopology,
326    ) -> Result<TreeStructure> {
327        let world_size = config.world_size();
328        let rank = config.global_rank;
329
330        // Build binary tree
331        let parent = if rank == 0 { None } else { Some((rank - 1) / 2) };
332
333        let mut children = Vec::new();
334        let left_child = 2 * rank + 1;
335        let right_child = 2 * rank + 2;
336
337        if left_child < world_size {
338            children.push(left_child);
339        }
340        if right_child < world_size {
341            children.push(right_child);
342        }
343
344        // Calculate depth and height
345        let depth = (rank as f32).log2().floor() as usize;
346        let height = (world_size as f32).log2().ceil() as usize;
347
348        Ok(TreeStructure {
349            parent,
350            children,
351            depth,
352            height,
353        })
354    }
355
356    /// Build ring structure for ring-based aggregation
357    fn build_ring_structure(config: &HierarchicalConfig) -> Result<RingStructure> {
358        let world_size = config.world_size();
359        let rank = config.global_rank;
360
361        let next_rank = (rank + 1) % world_size;
362        let prev_rank = (rank + world_size - 1) % world_size;
363
364        Ok(RingStructure {
365            next_rank,
366            prev_rank,
367            ring_size: world_size,
368        })
369    }
370
371    /// Build butterfly structure for butterfly aggregation
372    fn build_butterfly_structure(config: &HierarchicalConfig) -> Result<ButterflyStructure> {
373        let world_size = config.world_size();
374        let rank = config.global_rank;
375        let num_stages = (world_size as f32).log2().ceil() as usize;
376
377        let mut connections = Vec::new();
378
379        for stage in 0..num_stages {
380            let mut stage_connections = Vec::new();
381            let distance = 1 << stage;
382
383            // XOR-based butterfly connections
384            let partner = rank ^ distance;
385            if partner < world_size {
386                stage_connections.push(partner);
387            }
388
389            connections.push(stage_connections);
390        }
391
392        Ok(ButterflyStructure {
393            connections,
394            num_stages,
395        })
396    }
397
398    /// Perform hierarchical all-reduce operation
399    pub fn hierarchical_all_reduce(
400        &mut self,
401        gradients: &mut HashMap<String, Tensor>,
402    ) -> Result<()> {
403        let start_time = std::time::Instant::now();
404
405        // Select optimal strategy based on configuration and topology
406        let strategy = self.select_optimal_strategy(gradients)?;
407
408        // Perform aggregation based on selected strategy
409        match strategy {
410            AggregationStrategy::BinaryTree => {
411                self.tree_based_all_reduce(gradients)?;
412            },
413            AggregationStrategy::Ring => {
414                self.ring_based_all_reduce(gradients)?;
415            },
416            AggregationStrategy::Butterfly => {
417                self.butterfly_based_all_reduce(gradients)?;
418            },
419            AggregationStrategy::Adaptive => {
420                // Adaptive strategy selects the best algorithm dynamically
421                let optimal_strategy = self.adaptive_strategy_selection(gradients)?;
422                match optimal_strategy {
423                    AggregationStrategy::BinaryTree => self.tree_based_all_reduce(gradients)?,
424                    AggregationStrategy::Ring => self.ring_based_all_reduce(gradients)?,
425                    AggregationStrategy::Butterfly => self.butterfly_based_all_reduce(gradients)?,
426                    AggregationStrategy::Adaptive => {
427                        return Err(anyhow::anyhow!(
428                            "Invalid adaptive strategy selection: recursive Adaptive strategy returned"
429                        ));
430                    },
431                }
432            },
433        }
434
435        // Update statistics
436        let elapsed = start_time.elapsed().as_millis() as f32;
437        self.update_aggregation_stats(strategy, elapsed, gradients)?;
438
439        Ok(())
440    }
441
442    /// Select optimal aggregation strategy
443    fn select_optimal_strategy(
444        &self,
445        gradients: &HashMap<String, Tensor>,
446    ) -> Result<AggregationStrategy> {
447        match self.config.strategy {
448            AggregationStrategy::Adaptive => self.adaptive_strategy_selection(gradients),
449            strategy => Ok(strategy),
450        }
451    }
452
453    /// Adaptive strategy selection based on cluster topology and data characteristics
454    fn adaptive_strategy_selection(
455        &self,
456        gradients: &HashMap<String, Tensor>,
457    ) -> Result<AggregationStrategy> {
458        let world_size = self.config.world_size();
459        let num_nodes = self.config.num_nodes;
460
461        // Calculate total data size
462        let total_data_size: usize = gradients.values().map(|tensor| tensor.memory_usage()).sum();
463
464        // Strategy selection heuristics
465        if world_size <= 8 {
466            // Small clusters: tree is optimal
467            Ok(AggregationStrategy::BinaryTree)
468        } else if total_data_size > 100 * 1024 * 1024 {
469            // Large data: ring is bandwidth-optimal
470            Ok(AggregationStrategy::Ring)
471        } else if num_nodes > 16 {
472            // Large clusters with small data: butterfly is latency-optimal
473            Ok(AggregationStrategy::Butterfly)
474        } else {
475            // Default to tree for medium-sized clusters
476            Ok(AggregationStrategy::BinaryTree)
477        }
478    }
479
480    /// Tree-based all-reduce (divide-and-conquer)
481    fn tree_based_all_reduce(&mut self, gradients: &mut HashMap<String, Tensor>) -> Result<()> {
482        let tree = self.communication_groups.tree_structure.clone();
483
484        // Phase 1: Reduce up the tree
485        self.tree_reduce_up(gradients, &tree)?;
486
487        // Phase 2: Broadcast down the tree
488        self.tree_broadcast_down(gradients, &tree)?;
489
490        Ok(())
491    }
492
493    /// Ring-based all-reduce (bandwidth-optimal)
494    fn ring_based_all_reduce(&mut self, gradients: &mut HashMap<String, Tensor>) -> Result<()> {
495        let ring = self.communication_groups.ring_structure.clone();
496
497        // Phase 1: Reduce-scatter
498        self.ring_reduce_scatter(gradients, &ring)?;
499
500        // Phase 2: All-gather
501        self.ring_all_gather(gradients, &ring)?;
502
503        Ok(())
504    }
505
506    /// Butterfly-based all-reduce (latency-optimal)
507    fn butterfly_based_all_reduce(
508        &mut self,
509        gradients: &mut HashMap<String, Tensor>,
510    ) -> Result<()> {
511        let butterfly = self.communication_groups.butterfly_structure.clone();
512
513        // Butterfly all-reduce in multiple stages
514        for stage in 0..butterfly.num_stages {
515            self.butterfly_stage_operation(gradients, &butterfly, stage)?;
516        }
517
518        Ok(())
519    }
520
521    /// Tree reduce-up phase
522    fn tree_reduce_up(
523        &mut self,
524        gradients: &mut HashMap<String, Tensor>,
525        tree: &TreeStructure,
526    ) -> Result<()> {
527        // Collect gradients from children
528        for &child_rank in &tree.children {
529            for (name, gradient) in gradients.iter_mut() {
530                // Simulate receiving gradient from child
531                let child_gradient = self.simulate_receive_gradient(child_rank, name)?;
532                *gradient = gradient.add(&child_gradient)?;
533            }
534        }
535
536        // Send reduced gradients to parent
537        if let Some(parent_rank) = tree.parent {
538            for (name, gradient) in gradients.iter() {
539                self.simulate_send_gradient(parent_rank, name, gradient)?;
540            }
541        }
542
543        Ok(())
544    }
545
546    /// Tree broadcast-down phase
547    fn tree_broadcast_down(
548        &mut self,
549        gradients: &mut HashMap<String, Tensor>,
550        tree: &TreeStructure,
551    ) -> Result<()> {
552        // Receive final gradients from parent
553        if let Some(parent_rank) = tree.parent {
554            for (name, gradient) in gradients.iter_mut() {
555                *gradient = self.simulate_receive_gradient(parent_rank, name)?;
556            }
557        }
558
559        // Broadcast to children
560        for &child_rank in &tree.children {
561            for (name, gradient) in gradients.iter() {
562                self.simulate_send_gradient(child_rank, name, gradient)?;
563            }
564        }
565
566        Ok(())
567    }
568
569    /// Ring reduce-scatter phase
570    fn ring_reduce_scatter(
571        &mut self,
572        gradients: &mut HashMap<String, Tensor>,
573        ring: &RingStructure,
574    ) -> Result<()> {
575        let num_chunks = ring.ring_size;
576        let rank = self.config.global_rank;
577
578        for step in 0..num_chunks - 1 {
579            let _send_chunk = (rank + ring.ring_size - step) % ring.ring_size;
580            let _recv_chunk = (rank + ring.ring_size - step - 1) % ring.ring_size;
581
582            // Send to next rank and receive from previous rank
583            for (name, gradient) in gradients.iter_mut() {
584                let chunk_gradient = self.simulate_receive_gradient(ring.prev_rank, name)?;
585                *gradient = gradient.add(&chunk_gradient)?;
586                self.simulate_send_gradient(ring.next_rank, name, gradient)?;
587            }
588        }
589
590        Ok(())
591    }
592
593    /// Ring all-gather phase
594    fn ring_all_gather(
595        &mut self,
596        gradients: &mut HashMap<String, Tensor>,
597        ring: &RingStructure,
598    ) -> Result<()> {
599        let num_chunks = ring.ring_size;
600
601        for _step in 0..num_chunks - 1 {
602            // Send to next rank and receive from previous rank
603            for (name, gradient) in gradients.iter_mut() {
604                let chunk_gradient = self.simulate_receive_gradient(ring.prev_rank, name)?;
605                *gradient = gradient.add(&chunk_gradient)?;
606                self.simulate_send_gradient(ring.next_rank, name, gradient)?;
607            }
608        }
609
610        Ok(())
611    }
612
613    /// Butterfly stage operation
614    fn butterfly_stage_operation(
615        &mut self,
616        gradients: &mut HashMap<String, Tensor>,
617        butterfly: &ButterflyStructure,
618        stage: usize,
619    ) -> Result<()> {
620        if stage < butterfly.connections.len() {
621            for &partner_rank in &butterfly.connections[stage] {
622                for (name, gradient) in gradients.iter_mut() {
623                    // Exchange gradients with partner
624                    let partner_gradient = self.simulate_receive_gradient(partner_rank, name)?;
625                    *gradient = gradient.add(&partner_gradient)?;
626                    self.simulate_send_gradient(partner_rank, name, gradient)?;
627                }
628            }
629        }
630
631        Ok(())
632    }
633
634    /// Simulate receiving gradient from another rank
635    fn simulate_receive_gradient(&self, _from_rank: usize, _name: &str) -> Result<Tensor> {
636        // In a real implementation, this would use MPI or other communication backend
637        // For this implementation, we'll create a dummy tensor
638        Ok(Tensor::zeros(&[1])?)
639    }
640
641    /// Simulate sending gradient to another rank
642    fn simulate_send_gradient(
643        &self,
644        _to_rank: usize,
645        _name: &str,
646        _gradient: &Tensor,
647    ) -> Result<()> {
648        // In a real implementation, this would use MPI or other communication backend
649        Ok(())
650    }
651
652    /// Update aggregation statistics
653    fn update_aggregation_stats(
654        &mut self,
655        strategy: AggregationStrategy,
656        elapsed_ms: f32,
657        gradients: &HashMap<String, Tensor>,
658    ) -> Result<()> {
659        let stats = &mut self.aggregation_stats;
660
661        stats.total_operations += 1;
662        stats.avg_aggregation_time =
663            (stats.avg_aggregation_time * (stats.total_operations - 1) as f32 + elapsed_ms)
664                / stats.total_operations as f32;
665
666        let bytes_transferred: usize = gradients.values().map(|tensor| tensor.memory_usage()).sum();
667        stats.total_bytes_transferred += bytes_transferred;
668
669        *stats.strategy_history.entry(strategy).or_insert(0) += 1;
670
671        Ok(())
672    }
673
674    /// Get current aggregation statistics
675    pub fn get_stats(&self) -> &AggregationStats {
676        &self.aggregation_stats
677    }
678
679    /// Reset aggregation statistics
680    pub fn reset_stats(&mut self) {
681        self.aggregation_stats = AggregationStats::default();
682    }
683
684    /// Get recommended strategy for current configuration
685    pub fn get_recommended_strategy(&self) -> AggregationStrategy {
686        let world_size = self.config.world_size();
687        let num_nodes = self.config.num_nodes;
688
689        if world_size <= 8 {
690            AggregationStrategy::BinaryTree
691        } else if num_nodes > 16 {
692            AggregationStrategy::Butterfly
693        } else {
694            AggregationStrategy::Ring
695        }
696    }
697}
698
699#[cfg(test)]
700mod tests {
701    use super::*;
702
703    #[test]
704    fn test_hierarchical_config() {
705        let config = HierarchicalConfig::new(4, 8, 2, 3);
706        assert_eq!(config.num_nodes, 4);
707        assert_eq!(config.devices_per_node, 8);
708        assert_eq!(config.node_rank, 2);
709        assert_eq!(config.local_rank, 3);
710        assert_eq!(config.global_rank, 19);
711        assert_eq!(config.world_size(), 32);
712        assert!(!config.is_master());
713        assert!(!config.is_node_master());
714    }
715
716    #[test]
717    fn test_tree_structure_building() {
718        let config = HierarchicalConfig::new(2, 4, 0, 0);
719        let topology = HierarchicalAggregator::detect_network_topology(&config).unwrap();
720        let tree = HierarchicalAggregator::build_tree_structure(&config, &topology).unwrap();
721
722        assert_eq!(tree.parent, None); // Root node
723        assert_eq!(tree.children, vec![1, 2]);
724        assert_eq!(tree.depth, 0);
725    }
726
727    #[test]
728    fn test_ring_structure_building() {
729        let config = HierarchicalConfig::new(2, 4, 0, 1);
730        let ring = HierarchicalAggregator::build_ring_structure(&config).unwrap();
731
732        assert_eq!(ring.next_rank, 2);
733        assert_eq!(ring.prev_rank, 0);
734        assert_eq!(ring.ring_size, 8);
735    }
736
737    #[test]
738    fn test_adaptive_strategy_selection() {
739        let config = HierarchicalConfig::new(4, 4, 0, 0);
740        let aggregator = HierarchicalAggregator::new(config).unwrap();
741
742        let mut gradients = HashMap::new();
743        // Create a large tensor that exceeds 100MB threshold: 8000x8000x4bytes = 256MB
744        gradients.insert("param1".to_string(), Tensor::zeros(&[8000, 8000]).unwrap());
745
746        let strategy = aggregator.adaptive_strategy_selection(&gradients).unwrap();
747        // Should select ring for large data
748        assert!(matches!(strategy, AggregationStrategy::Ring));
749    }
750
751    #[test]
752    fn test_aggregation_stats_update() {
753        let config = HierarchicalConfig::new(2, 2, 0, 0);
754        let mut aggregator = HierarchicalAggregator::new(config).unwrap();
755
756        let mut gradients = HashMap::new();
757        gradients.insert("param1".to_string(), Tensor::zeros(&[10, 10]).unwrap());
758
759        aggregator
760            .update_aggregation_stats(AggregationStrategy::BinaryTree, 100.0, &gradients)
761            .unwrap();
762
763        let stats = aggregator.get_stats();
764        assert_eq!(stats.total_operations, 1);
765        assert_eq!(stats.avg_aggregation_time, 100.0);
766        assert_eq!(
767            stats.strategy_history.get(&AggregationStrategy::BinaryTree),
768            Some(&1)
769        );
770    }
771
772    #[test]
773    fn test_recommended_strategy() {
774        let small_config = HierarchicalConfig::new(2, 2, 0, 0);
775        let small_aggregator = HierarchicalAggregator::new(small_config).unwrap();
776        assert!(matches!(
777            small_aggregator.get_recommended_strategy(),
778            AggregationStrategy::BinaryTree
779        ));
780
781        let large_config = HierarchicalConfig::new(20, 1, 0, 0);
782        let large_aggregator = HierarchicalAggregator::new(large_config).unwrap();
783        assert!(matches!(
784            large_aggregator.get_recommended_strategy(),
785            AggregationStrategy::Butterfly
786        ));
787    }
788
789    #[test]
790    fn test_butterfly_structure() {
791        let config = HierarchicalConfig::new(1, 8, 0, 0);
792        let butterfly = HierarchicalAggregator::build_butterfly_structure(&config).unwrap();
793
794        assert_eq!(butterfly.num_stages, 3); // log2(8) = 3
795        assert_eq!(butterfly.connections.len(), 3);
796    }
797
798    #[test]
799    fn test_network_topology_detection() {
800        let config = HierarchicalConfig::new(3, 2, 0, 0);
801        let topology = HierarchicalAggregator::detect_network_topology(&config).unwrap();
802
803        assert_eq!(topology.node_adjacency.len(), 3);
804        assert_eq!(topology.node_bandwidth.len(), 3);
805        assert_eq!(topology.node_latency.len(), 3);
806        assert!(topology.intra_node_bandwidth > 0.0);
807        assert!(topology.intra_node_latency > 0.0);
808    }
809}