Skip to main content

torsh_core/
distributed.rs

1// Copyright (c) 2025 ToRSh Contributors
2//
3// Distributed Tensor Metadata Management
4//
5// This module provides data structures and abstractions for managing tensors
6// distributed across multiple devices, nodes, or clusters. It enables efficient
7// distributed training and inference at scale.
8//
9// # Key Features
10//
11// - **Tensor Sharding**: Automatic tensor partitioning across devices
12// - **Communication Patterns**: AllReduce, AllGather, ReduceScatter, etc.
13// - **Device Topology**: Hierarchical device organization (node, rack, cluster)
14// - **Synchronization**: Efficient barrier and broadcast operations
15// - **Fault Tolerance**: Checkpoint and recovery mechanisms
16//
17// # Design Principles
18//
19// 1. **Scalability**: Support thousands of devices
20// 2. **Flexibility**: Multiple sharding strategies
21// 3. **Performance**: Overlap computation and communication
22// 4. **Resilience**: Handle device failures gracefully
23//
24// # Examples
25//
26// ```rust
27// use torsh_core::distributed::{DistributedTensor, ShardingStrategy, DeviceGroup};
28//
29// // Create a distributed tensor across 4 GPUs
30// let devices = DeviceGroup::new(vec![0, 1, 2, 3]);
31// let tensor = DistributedTensor::new(shape, ShardingStrategy::DataParallel, devices);
32//
33// // Perform all-reduce operation
34// tensor.all_reduce(ReduceOp::Sum);
35// ```
36
37use core::fmt;
38
39/// Device identifier in a distributed system
40///
41/// Uniquely identifies a device in a cluster with node, rack, and device ID.
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
43pub struct DeviceId {
44    /// Node ID within the cluster
45    node_id: usize,
46    /// Rack ID (for datacenter topology)
47    rack_id: usize,
48    /// Local device ID on the node
49    local_device_id: usize,
50}
51
52impl DeviceId {
53    /// Create a new device ID
54    pub fn new(node_id: usize, rack_id: usize, local_device_id: usize) -> Self {
55        Self {
56            node_id,
57            rack_id,
58            local_device_id,
59        }
60    }
61
62    /// Create a simple device ID (single node)
63    pub fn simple(local_device_id: usize) -> Self {
64        Self::new(0, 0, local_device_id)
65    }
66
67    /// Get node ID
68    pub fn node_id(&self) -> usize {
69        self.node_id
70    }
71
72    /// Get rack ID
73    pub fn rack_id(&self) -> usize {
74        self.rack_id
75    }
76
77    /// Get local device ID
78    pub fn local_device_id(&self) -> usize {
79        self.local_device_id
80    }
81
82    /// Get global unique ID
83    pub fn global_id(&self) -> usize {
84        // Simple encoding: rack_id * 1000 + node_id * 100 + local_device_id
85        self.rack_id * 1000 + self.node_id * 100 + self.local_device_id
86    }
87}
88
89/// Group of devices for distributed operations
90///
91/// Represents a logical group of devices that participate in collective operations.
92#[derive(Debug, Clone)]
93pub struct DeviceGroup {
94    /// Devices in this group
95    devices: Vec<DeviceId>,
96    /// Group name for debugging
97    name: Option<String>,
98}
99
100impl DeviceGroup {
101    /// Create a new device group
102    pub fn new(device_ids: Vec<usize>) -> Self {
103        let devices = device_ids.iter().map(|&id| DeviceId::simple(id)).collect();
104        Self {
105            devices,
106            name: None,
107        }
108    }
109
110    /// Create a device group with explicit device IDs
111    pub fn from_devices(devices: Vec<DeviceId>) -> Self {
112        Self {
113            devices,
114            name: None,
115        }
116    }
117
118    /// Set group name
119    pub fn with_name(mut self, name: impl Into<String>) -> Self {
120        self.name = Some(name.into());
121        self
122    }
123
124    /// Get devices in the group
125    pub fn devices(&self) -> &[DeviceId] {
126        &self.devices
127    }
128
129    /// Get group size
130    pub fn size(&self) -> usize {
131        self.devices.len()
132    }
133
134    /// Check if device is in the group
135    pub fn contains(&self, device_id: &DeviceId) -> bool {
136        self.devices.contains(device_id)
137    }
138
139    /// Get device rank (position in group)
140    pub fn rank(&self, device_id: &DeviceId) -> Option<usize> {
141        self.devices.iter().position(|d| d == device_id)
142    }
143
144    /// Get group name
145    pub fn name(&self) -> Option<&str> {
146        self.name.as_deref()
147    }
148}
149
150/// Tensor sharding strategies
151///
152/// Different ways to partition a tensor across multiple devices.
153#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154pub enum ShardingStrategy {
155    /// Replicate the full tensor on each device (data parallelism)
156    Replicated,
157    /// Shard along the batch dimension
158    DataParallel,
159    /// Shard along the model dimension (tensor parallelism)
160    ModelParallel,
161    /// Shard along a specific dimension
162    DimSharded(usize),
163    /// Pipeline parallelism (different layers on different devices)
164    Pipeline,
165    /// Combination of strategies
166    Hybrid,
167}
168
169impl fmt::Display for ShardingStrategy {
170    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171        match self {
172            ShardingStrategy::Replicated => write!(f, "Replicated"),
173            ShardingStrategy::DataParallel => write!(f, "DataParallel"),
174            ShardingStrategy::ModelParallel => write!(f, "ModelParallel"),
175            ShardingStrategy::DimSharded(dim) => write!(f, "DimSharded({})", dim),
176            ShardingStrategy::Pipeline => write!(f, "Pipeline"),
177            ShardingStrategy::Hybrid => write!(f, "Hybrid"),
178        }
179    }
180}
181
182/// Shard descriptor
183///
184/// Describes a single shard of a distributed tensor.
185#[derive(Debug, Clone)]
186pub struct Shard {
187    /// Device where this shard is located
188    device_id: DeviceId,
189    /// Offset in the global tensor
190    offset: Vec<usize>,
191    /// Shape of this shard
192    shape: Vec<usize>,
193    /// Rank of this shard in the group
194    rank: usize,
195}
196
197impl Shard {
198    /// Create a new shard
199    pub fn new(device_id: DeviceId, offset: Vec<usize>, shape: Vec<usize>, rank: usize) -> Self {
200        Self {
201            device_id,
202            offset,
203            shape,
204            rank,
205        }
206    }
207
208    /// Get device ID
209    pub fn device_id(&self) -> DeviceId {
210        self.device_id
211    }
212
213    /// Get offset
214    pub fn offset(&self) -> &[usize] {
215        &self.offset
216    }
217
218    /// Get shape
219    pub fn shape(&self) -> &[usize] {
220        &self.shape
221    }
222
223    /// Get rank
224    pub fn rank(&self) -> usize {
225        self.rank
226    }
227
228    /// Calculate shard size (number of elements)
229    pub fn size(&self) -> usize {
230        self.shape.iter().product()
231    }
232}
233
234/// Distributed tensor metadata
235///
236/// Represents a tensor distributed across multiple devices.
237#[derive(Debug, Clone)]
238pub struct DistributedTensor {
239    /// Global shape of the tensor
240    global_shape: Vec<usize>,
241    /// Sharding strategy
242    strategy: ShardingStrategy,
243    /// Device group
244    device_group: DeviceGroup,
245    /// Shard descriptors
246    shards: Vec<Shard>,
247}
248
249impl DistributedTensor {
250    /// Create a new distributed tensor
251    pub fn new(
252        global_shape: Vec<usize>,
253        strategy: ShardingStrategy,
254        device_group: DeviceGroup,
255    ) -> Self {
256        let shards = Self::create_shards(&global_shape, strategy, &device_group);
257        Self {
258            global_shape,
259            strategy,
260            device_group,
261            shards,
262        }
263    }
264
265    /// Create shards based on strategy
266    fn create_shards(
267        global_shape: &[usize],
268        strategy: ShardingStrategy,
269        device_group: &DeviceGroup,
270    ) -> Vec<Shard> {
271        let num_devices = device_group.size();
272        let mut shards = Vec::new();
273
274        match strategy {
275            ShardingStrategy::Replicated => {
276                // Full tensor on each device
277                for (rank, &device_id) in device_group.devices().iter().enumerate() {
278                    shards.push(Shard::new(
279                        device_id,
280                        vec![0; global_shape.len()],
281                        global_shape.to_vec(),
282                        rank,
283                    ));
284                }
285            }
286            ShardingStrategy::DataParallel | ShardingStrategy::DimSharded(0) => {
287                // Shard along first dimension
288                if global_shape.is_empty() {
289                    return shards;
290                }
291                let dim0 = global_shape[0];
292                let chunk_size = (dim0 + num_devices - 1) / num_devices;
293
294                for (rank, &device_id) in device_group.devices().iter().enumerate() {
295                    let start = rank * chunk_size;
296                    let end = (start + chunk_size).min(dim0);
297                    if start >= dim0 {
298                        break;
299                    }
300
301                    let mut offset = vec![0; global_shape.len()];
302                    offset[0] = start;
303
304                    let mut shape = global_shape.to_vec();
305                    shape[0] = end - start;
306
307                    shards.push(Shard::new(device_id, offset, shape, rank));
308                }
309            }
310            ShardingStrategy::ModelParallel => {
311                // For now, same as data parallel
312                // In practice, this would shard model parameters
313                return Self::create_shards(
314                    global_shape,
315                    ShardingStrategy::DataParallel,
316                    device_group,
317                );
318            }
319            ShardingStrategy::DimSharded(dim) => {
320                // Shard along specified dimension
321                if dim >= global_shape.len() {
322                    return shards;
323                }
324                let dim_size = global_shape[dim];
325                let chunk_size = (dim_size + num_devices - 1) / num_devices;
326
327                for (rank, &device_id) in device_group.devices().iter().enumerate() {
328                    let start = rank * chunk_size;
329                    let end = (start + chunk_size).min(dim_size);
330                    if start >= dim_size {
331                        break;
332                    }
333
334                    let mut offset = vec![0; global_shape.len()];
335                    offset[dim] = start;
336
337                    let mut shape = global_shape.to_vec();
338                    shape[dim] = end - start;
339
340                    shards.push(Shard::new(device_id, offset, shape, rank));
341                }
342            }
343            _ => {
344                // Default to replicated
345                return Self::create_shards(
346                    global_shape,
347                    ShardingStrategy::Replicated,
348                    device_group,
349                );
350            }
351        }
352
353        shards
354    }
355
356    /// Get global shape
357    pub fn global_shape(&self) -> &[usize] {
358        &self.global_shape
359    }
360
361    /// Get sharding strategy
362    pub fn strategy(&self) -> ShardingStrategy {
363        self.strategy
364    }
365
366    /// Get device group
367    pub fn device_group(&self) -> &DeviceGroup {
368        &self.device_group
369    }
370
371    /// Get shards
372    pub fn shards(&self) -> &[Shard] {
373        &self.shards
374    }
375
376    /// Get shard for a specific device
377    pub fn shard_for_device(&self, device_id: &DeviceId) -> Option<&Shard> {
378        self.shards.iter().find(|s| &s.device_id == device_id)
379    }
380
381    /// Get total number of elements across all shards
382    pub fn total_elements(&self) -> usize {
383        match self.strategy {
384            ShardingStrategy::Replicated => {
385                // Only count once for replicated
386                self.global_shape.iter().product()
387            }
388            _ => {
389                // Sum all shard sizes
390                self.shards.iter().map(|s| s.size()).sum()
391            }
392        }
393    }
394}
395
396/// Collective communication operations
397///
398/// Common collective operations for distributed tensors.
399#[derive(Debug, Clone, Copy, PartialEq, Eq)]
400pub enum CollectiveOp {
401    /// All-reduce: reduce across all devices and broadcast result
402    AllReduce(ReduceOp),
403    /// All-gather: gather data from all devices and broadcast
404    AllGather,
405    /// Reduce-scatter: reduce and scatter results
406    ReduceScatter(ReduceOp),
407    /// Broadcast: send data from one device to all others
408    Broadcast { root: usize },
409    /// Scatter: distribute data from one device to all others
410    Scatter { root: usize },
411    /// Gather: collect data from all devices to one device
412    Gather { root: usize },
413    /// All-to-all: each device sends unique data to every other device
414    AllToAll,
415    /// Barrier: synchronization point for all devices
416    Barrier,
417}
418
419/// Reduction operations
420#[derive(Debug, Clone, Copy, PartialEq, Eq)]
421pub enum ReduceOp {
422    /// Sum reduction
423    Sum,
424    /// Product reduction
425    Product,
426    /// Minimum reduction
427    Min,
428    /// Maximum reduction
429    Max,
430    /// Average reduction
431    Average,
432}
433
434impl fmt::Display for ReduceOp {
435    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
436        match self {
437            ReduceOp::Sum => write!(f, "Sum"),
438            ReduceOp::Product => write!(f, "Product"),
439            ReduceOp::Min => write!(f, "Min"),
440            ReduceOp::Max => write!(f, "Max"),
441            ReduceOp::Average => write!(f, "Average"),
442        }
443    }
444}
445
446/// Communication backend
447///
448/// Abstraction for different communication libraries (MPI, NCCL, Gloo, etc.)
449#[derive(Debug, Clone, Copy, PartialEq, Eq)]
450pub enum CommBackend {
451    /// NCCL (NVIDIA Collective Communications Library) for GPU
452    NCCL,
453    /// Gloo for CPU and GPU
454    Gloo,
455    /// MPI (Message Passing Interface)
456    MPI,
457    /// Custom implementation
458    Custom,
459}
460
461/// Communication descriptor
462///
463/// Describes a communication operation to be executed.
464#[derive(Debug, Clone)]
465pub struct CommunicationDescriptor {
466    /// Collective operation
467    operation: CollectiveOp,
468    /// Device group
469    device_group: DeviceGroup,
470    /// Backend to use
471    backend: CommBackend,
472    /// Whether to use asynchronous communication
473    async_op: bool,
474}
475
476impl CommunicationDescriptor {
477    /// Create a new communication descriptor
478    pub fn new(operation: CollectiveOp, device_group: DeviceGroup, backend: CommBackend) -> Self {
479        Self {
480            operation,
481            device_group,
482            backend,
483            async_op: false,
484        }
485    }
486
487    /// Set asynchronous flag
488    pub fn with_async(mut self, async_op: bool) -> Self {
489        self.async_op = async_op;
490        self
491    }
492
493    /// Get operation
494    pub fn operation(&self) -> CollectiveOp {
495        self.operation
496    }
497
498    /// Get device group
499    pub fn device_group(&self) -> &DeviceGroup {
500        &self.device_group
501    }
502
503    /// Get backend
504    pub fn backend(&self) -> CommBackend {
505        self.backend
506    }
507
508    /// Check if async
509    pub fn is_async(&self) -> bool {
510        self.async_op
511    }
512}
513
514/// Checkpoint metadata for fault tolerance
515///
516/// Contains information about a saved checkpoint of distributed tensors.
517#[derive(Debug, Clone)]
518pub struct CheckpointMetadata {
519    /// Checkpoint ID
520    id: String,
521    /// Global step number
522    step: u64,
523    /// List of device IDs that contributed to checkpoint
524    devices: Vec<DeviceId>,
525    /// Timestamp (Unix epoch seconds)
526    timestamp: u64,
527    /// Additional metadata
528    metadata: Vec<(String, String)>,
529}
530
531impl CheckpointMetadata {
532    /// Create a new checkpoint metadata
533    pub fn new(id: impl Into<String>, step: u64, devices: Vec<DeviceId>) -> Self {
534        Self {
535            id: id.into(),
536            step,
537            devices,
538            timestamp: 0, // Would use system time in real implementation
539            metadata: Vec::new(),
540        }
541    }
542
543    /// Add metadata entry
544    pub fn add_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
545        self.metadata.push((key.into(), value.into()));
546    }
547
548    /// Get checkpoint ID
549    pub fn id(&self) -> &str {
550        &self.id
551    }
552
553    /// Get step number
554    pub fn step(&self) -> u64 {
555        self.step
556    }
557
558    /// Get devices
559    pub fn devices(&self) -> &[DeviceId] {
560        &self.devices
561    }
562
563    /// Get timestamp
564    pub fn timestamp(&self) -> u64 {
565        self.timestamp
566    }
567
568    /// Get metadata
569    pub fn metadata(&self) -> &[(String, String)] {
570        &self.metadata
571    }
572}
573
574/// Device topology for hierarchical communication
575///
576/// Represents the physical/logical topology of devices in a cluster.
577#[derive(Debug, Clone)]
578pub struct DeviceTopology {
579    /// All devices in the topology
580    devices: Vec<DeviceId>,
581    /// Number of nodes
582    num_nodes: usize,
583    /// Number of racks
584    num_racks: usize,
585    /// Devices per node
586    devices_per_node: usize,
587}
588
589impl DeviceTopology {
590    /// Create a new device topology
591    pub fn new(num_racks: usize, num_nodes: usize, devices_per_node: usize) -> Self {
592        let mut devices = Vec::new();
593        for rack_id in 0..num_racks {
594            for node_id in 0..num_nodes {
595                for device_id in 0..devices_per_node {
596                    devices.push(DeviceId::new(node_id, rack_id, device_id));
597                }
598            }
599        }
600
601        Self {
602            devices,
603            num_nodes,
604            num_racks,
605            devices_per_node,
606        }
607    }
608
609    /// Get all devices
610    pub fn devices(&self) -> &[DeviceId] {
611        &self.devices
612    }
613
614    /// Get devices on a specific node
615    pub fn node_devices(&self, node_id: usize) -> Vec<DeviceId> {
616        self.devices
617            .iter()
618            .filter(|d| d.node_id() == node_id)
619            .copied()
620            .collect()
621    }
622
623    /// Get devices in a specific rack
624    pub fn rack_devices(&self, rack_id: usize) -> Vec<DeviceId> {
625        self.devices
626            .iter()
627            .filter(|d| d.rack_id() == rack_id)
628            .copied()
629            .collect()
630    }
631
632    /// Get total number of devices
633    pub fn total_devices(&self) -> usize {
634        self.devices.len()
635    }
636
637    /// Get number of nodes
638    pub fn num_nodes(&self) -> usize {
639        self.num_nodes
640    }
641
642    /// Get number of racks
643    pub fn num_racks(&self) -> usize {
644        self.num_racks
645    }
646
647    /// Get devices per node
648    pub fn devices_per_node(&self) -> usize {
649        self.devices_per_node
650    }
651}
652
653#[cfg(test)]
654mod tests {
655    use super::*;
656
657    #[test]
658    fn test_device_id() {
659        let device = DeviceId::new(0, 1, 2);
660        assert_eq!(device.node_id(), 0);
661        assert_eq!(device.rack_id(), 1);
662        assert_eq!(device.local_device_id(), 2);
663        assert_eq!(device.global_id(), 1002); // 1*1000 + 0*100 + 2
664    }
665
666    #[test]
667    fn test_simple_device_id() {
668        let device = DeviceId::simple(5);
669        assert_eq!(device.local_device_id(), 5);
670        assert_eq!(device.node_id(), 0);
671        assert_eq!(device.rack_id(), 0);
672    }
673
674    #[test]
675    fn test_device_group() {
676        let group = DeviceGroup::new(vec![0, 1, 2, 3]);
677        assert_eq!(group.size(), 4);
678        assert!(group.contains(&DeviceId::simple(0)));
679        assert_eq!(group.rank(&DeviceId::simple(2)), Some(2));
680    }
681
682    #[test]
683    fn test_device_group_with_name() {
684        let group = DeviceGroup::new(vec![0, 1]).with_name("test_group");
685        assert_eq!(group.name(), Some("test_group"));
686    }
687
688    #[test]
689    fn test_sharding_strategy_display() {
690        assert_eq!(format!("{}", ShardingStrategy::Replicated), "Replicated");
691        assert_eq!(
692            format!("{}", ShardingStrategy::DataParallel),
693            "DataParallel"
694        );
695        assert_eq!(
696            format!("{}", ShardingStrategy::DimSharded(1)),
697            "DimSharded(1)"
698        );
699    }
700
701    #[test]
702    fn test_shard() {
703        let device = DeviceId::simple(0);
704        let shard = Shard::new(device, vec![0, 0], vec![10, 20], 0);
705        assert_eq!(shard.device_id(), device);
706        assert_eq!(shard.offset(), &[0, 0]);
707        assert_eq!(shard.shape(), &[10, 20]);
708        assert_eq!(shard.rank(), 0);
709        assert_eq!(shard.size(), 200);
710    }
711
712    #[test]
713    fn test_distributed_tensor_replicated() {
714        let group = DeviceGroup::new(vec![0, 1, 2, 3]);
715        let tensor = DistributedTensor::new(vec![100, 50], ShardingStrategy::Replicated, group);
716
717        assert_eq!(tensor.global_shape(), &[100, 50]);
718        assert_eq!(tensor.shards().len(), 4);
719        assert_eq!(tensor.strategy(), ShardingStrategy::Replicated);
720
721        // All shards should have the full shape
722        for shard in tensor.shards() {
723            assert_eq!(shard.shape(), &[100, 50]);
724        }
725    }
726
727    #[test]
728    fn test_distributed_tensor_data_parallel() {
729        let group = DeviceGroup::new(vec![0, 1, 2, 3]);
730        let tensor = DistributedTensor::new(vec![100, 50], ShardingStrategy::DataParallel, group);
731
732        assert_eq!(tensor.shards().len(), 4);
733
734        // Each shard should have 25 rows (100 / 4)
735        for shard in tensor.shards() {
736            assert_eq!(shard.shape()[0], 25);
737            assert_eq!(shard.shape()[1], 50);
738        }
739    }
740
741    #[test]
742    fn test_distributed_tensor_dim_sharded() {
743        let group = DeviceGroup::new(vec![0, 1]);
744        let tensor =
745            DistributedTensor::new(vec![10, 20, 30], ShardingStrategy::DimSharded(1), group);
746
747        assert_eq!(tensor.shards().len(), 2);
748
749        // Sharded along dimension 1 (20 -> 10 + 10)
750        assert_eq!(tensor.shards()[0].shape(), &[10, 10, 30]);
751        assert_eq!(tensor.shards()[1].shape(), &[10, 10, 30]);
752    }
753
754    #[test]
755    fn test_shard_for_device() {
756        let group = DeviceGroup::new(vec![0, 1]);
757        let tensor = DistributedTensor::new(vec![10, 20], ShardingStrategy::DataParallel, group);
758
759        let device = DeviceId::simple(0);
760        let shard = tensor.shard_for_device(&device);
761        assert!(shard.is_some());
762        assert_eq!(
763            shard.expect("shard_for_device should succeed").device_id(),
764            device
765        );
766    }
767
768    #[test]
769    fn test_collective_operations() {
770        let _all_reduce = CollectiveOp::AllReduce(ReduceOp::Sum);
771        let _all_gather = CollectiveOp::AllGather;
772        let _reduce_scatter = CollectiveOp::ReduceScatter(ReduceOp::Average);
773        let _broadcast = CollectiveOp::Broadcast { root: 0 };
774        let _scatter = CollectiveOp::Scatter { root: 0 };
775        let _gather = CollectiveOp::Gather { root: 0 };
776        let _all_to_all = CollectiveOp::AllToAll;
777        let _barrier = CollectiveOp::Barrier;
778    }
779
780    #[test]
781    fn test_reduce_op_display() {
782        assert_eq!(format!("{}", ReduceOp::Sum), "Sum");
783        assert_eq!(format!("{}", ReduceOp::Product), "Product");
784        assert_eq!(format!("{}", ReduceOp::Min), "Min");
785        assert_eq!(format!("{}", ReduceOp::Max), "Max");
786        assert_eq!(format!("{}", ReduceOp::Average), "Average");
787    }
788
789    #[test]
790    fn test_comm_backend() {
791        let _nccl = CommBackend::NCCL;
792        let _gloo = CommBackend::Gloo;
793        let _mpi = CommBackend::MPI;
794        let _custom = CommBackend::Custom;
795    }
796
797    #[test]
798    fn test_communication_descriptor() {
799        let group = DeviceGroup::new(vec![0, 1, 2, 3]);
800        let comm_desc = CommunicationDescriptor::new(
801            CollectiveOp::AllReduce(ReduceOp::Sum),
802            group.clone(),
803            CommBackend::NCCL,
804        )
805        .with_async(true);
806
807        assert_eq!(
808            comm_desc.operation(),
809            CollectiveOp::AllReduce(ReduceOp::Sum)
810        );
811        assert_eq!(comm_desc.backend(), CommBackend::NCCL);
812        assert!(comm_desc.is_async());
813    }
814
815    #[test]
816    fn test_checkpoint_metadata() {
817        let devices = vec![DeviceId::simple(0), DeviceId::simple(1)];
818        let mut checkpoint = CheckpointMetadata::new("ckpt_001", 1000, devices);
819        checkpoint.add_metadata("model", "resnet50");
820        checkpoint.add_metadata("optimizer", "adam");
821
822        assert_eq!(checkpoint.id(), "ckpt_001");
823        assert_eq!(checkpoint.step(), 1000);
824        assert_eq!(checkpoint.devices().len(), 2);
825        assert_eq!(checkpoint.metadata().len(), 2);
826    }
827
828    #[test]
829    fn test_device_topology() {
830        let topology = DeviceTopology::new(2, 3, 4); // 2 racks, 3 nodes, 4 devices per node
831        assert_eq!(topology.total_devices(), 24); // 2 * 3 * 4
832        assert_eq!(topology.num_racks(), 2);
833        assert_eq!(topology.num_nodes(), 3);
834        assert_eq!(topology.devices_per_node(), 4);
835
836        let node0_devices = topology.node_devices(0);
837        assert_eq!(node0_devices.len(), 8); // 4 devices * 2 racks
838
839        let rack0_devices = topology.rack_devices(0);
840        assert_eq!(rack0_devices.len(), 12); // 3 nodes * 4 devices
841    }
842
843    #[test]
844    fn test_total_elements() {
845        let group = DeviceGroup::new(vec![0, 1, 2, 3]);
846
847        // Replicated: count only once
848        let replicated =
849            DistributedTensor::new(vec![100, 50], ShardingStrategy::Replicated, group.clone());
850        assert_eq!(replicated.total_elements(), 5000); // 100 * 50
851
852        // Data parallel: sum of all shards
853        let sharded = DistributedTensor::new(vec![100, 50], ShardingStrategy::DataParallel, group);
854        assert_eq!(sharded.total_elements(), 5000); // Still 100 * 50 total
855    }
856
857    #[test]
858    fn test_from_devices() {
859        let devices = vec![DeviceId::new(0, 0, 1), DeviceId::new(0, 0, 2)];
860        let group = DeviceGroup::from_devices(devices);
861        assert_eq!(group.size(), 2);
862    }
863
864    #[test]
865    fn test_device_not_in_group() {
866        let group = DeviceGroup::new(vec![0, 1, 2]);
867        assert!(!group.contains(&DeviceId::simple(5)));
868        assert_eq!(group.rank(&DeviceId::simple(5)), None);
869    }
870}