1use core::fmt;
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
43pub struct DeviceId {
44 node_id: usize,
46 rack_id: usize,
48 local_device_id: usize,
50}
51
52impl DeviceId {
53 pub fn new(node_id: usize, rack_id: usize, local_device_id: usize) -> Self {
55 Self {
56 node_id,
57 rack_id,
58 local_device_id,
59 }
60 }
61
62 pub fn simple(local_device_id: usize) -> Self {
64 Self::new(0, 0, local_device_id)
65 }
66
67 pub fn node_id(&self) -> usize {
69 self.node_id
70 }
71
72 pub fn rack_id(&self) -> usize {
74 self.rack_id
75 }
76
77 pub fn local_device_id(&self) -> usize {
79 self.local_device_id
80 }
81
82 pub fn global_id(&self) -> usize {
84 self.rack_id * 1000 + self.node_id * 100 + self.local_device_id
86 }
87}
88
89#[derive(Debug, Clone)]
93pub struct DeviceGroup {
94 devices: Vec<DeviceId>,
96 name: Option<String>,
98}
99
100impl DeviceGroup {
101 pub fn new(device_ids: Vec<usize>) -> Self {
103 let devices = device_ids.iter().map(|&id| DeviceId::simple(id)).collect();
104 Self {
105 devices,
106 name: None,
107 }
108 }
109
110 pub fn from_devices(devices: Vec<DeviceId>) -> Self {
112 Self {
113 devices,
114 name: None,
115 }
116 }
117
118 pub fn with_name(mut self, name: impl Into<String>) -> Self {
120 self.name = Some(name.into());
121 self
122 }
123
124 pub fn devices(&self) -> &[DeviceId] {
126 &self.devices
127 }
128
129 pub fn size(&self) -> usize {
131 self.devices.len()
132 }
133
134 pub fn contains(&self, device_id: &DeviceId) -> bool {
136 self.devices.contains(device_id)
137 }
138
139 pub fn rank(&self, device_id: &DeviceId) -> Option<usize> {
141 self.devices.iter().position(|d| d == device_id)
142 }
143
144 pub fn name(&self) -> Option<&str> {
146 self.name.as_deref()
147 }
148}
149
150#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154pub enum ShardingStrategy {
155 Replicated,
157 DataParallel,
159 ModelParallel,
161 DimSharded(usize),
163 Pipeline,
165 Hybrid,
167}
168
169impl fmt::Display for ShardingStrategy {
170 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171 match self {
172 ShardingStrategy::Replicated => write!(f, "Replicated"),
173 ShardingStrategy::DataParallel => write!(f, "DataParallel"),
174 ShardingStrategy::ModelParallel => write!(f, "ModelParallel"),
175 ShardingStrategy::DimSharded(dim) => write!(f, "DimSharded({})", dim),
176 ShardingStrategy::Pipeline => write!(f, "Pipeline"),
177 ShardingStrategy::Hybrid => write!(f, "Hybrid"),
178 }
179 }
180}
181
182#[derive(Debug, Clone)]
186pub struct Shard {
187 device_id: DeviceId,
189 offset: Vec<usize>,
191 shape: Vec<usize>,
193 rank: usize,
195}
196
197impl Shard {
198 pub fn new(device_id: DeviceId, offset: Vec<usize>, shape: Vec<usize>, rank: usize) -> Self {
200 Self {
201 device_id,
202 offset,
203 shape,
204 rank,
205 }
206 }
207
208 pub fn device_id(&self) -> DeviceId {
210 self.device_id
211 }
212
213 pub fn offset(&self) -> &[usize] {
215 &self.offset
216 }
217
218 pub fn shape(&self) -> &[usize] {
220 &self.shape
221 }
222
223 pub fn rank(&self) -> usize {
225 self.rank
226 }
227
228 pub fn size(&self) -> usize {
230 self.shape.iter().product()
231 }
232}
233
234#[derive(Debug, Clone)]
238pub struct DistributedTensor {
239 global_shape: Vec<usize>,
241 strategy: ShardingStrategy,
243 device_group: DeviceGroup,
245 shards: Vec<Shard>,
247}
248
249impl DistributedTensor {
250 pub fn new(
252 global_shape: Vec<usize>,
253 strategy: ShardingStrategy,
254 device_group: DeviceGroup,
255 ) -> Self {
256 let shards = Self::create_shards(&global_shape, strategy, &device_group);
257 Self {
258 global_shape,
259 strategy,
260 device_group,
261 shards,
262 }
263 }
264
265 fn create_shards(
267 global_shape: &[usize],
268 strategy: ShardingStrategy,
269 device_group: &DeviceGroup,
270 ) -> Vec<Shard> {
271 let num_devices = device_group.size();
272 let mut shards = Vec::new();
273
274 match strategy {
275 ShardingStrategy::Replicated => {
276 for (rank, &device_id) in device_group.devices().iter().enumerate() {
278 shards.push(Shard::new(
279 device_id,
280 vec![0; global_shape.len()],
281 global_shape.to_vec(),
282 rank,
283 ));
284 }
285 }
286 ShardingStrategy::DataParallel | ShardingStrategy::DimSharded(0) => {
287 if global_shape.is_empty() {
289 return shards;
290 }
291 let dim0 = global_shape[0];
292 let chunk_size = (dim0 + num_devices - 1) / num_devices;
293
294 for (rank, &device_id) in device_group.devices().iter().enumerate() {
295 let start = rank * chunk_size;
296 let end = (start + chunk_size).min(dim0);
297 if start >= dim0 {
298 break;
299 }
300
301 let mut offset = vec![0; global_shape.len()];
302 offset[0] = start;
303
304 let mut shape = global_shape.to_vec();
305 shape[0] = end - start;
306
307 shards.push(Shard::new(device_id, offset, shape, rank));
308 }
309 }
310 ShardingStrategy::ModelParallel => {
311 return Self::create_shards(
314 global_shape,
315 ShardingStrategy::DataParallel,
316 device_group,
317 );
318 }
319 ShardingStrategy::DimSharded(dim) => {
320 if dim >= global_shape.len() {
322 return shards;
323 }
324 let dim_size = global_shape[dim];
325 let chunk_size = (dim_size + num_devices - 1) / num_devices;
326
327 for (rank, &device_id) in device_group.devices().iter().enumerate() {
328 let start = rank * chunk_size;
329 let end = (start + chunk_size).min(dim_size);
330 if start >= dim_size {
331 break;
332 }
333
334 let mut offset = vec![0; global_shape.len()];
335 offset[dim] = start;
336
337 let mut shape = global_shape.to_vec();
338 shape[dim] = end - start;
339
340 shards.push(Shard::new(device_id, offset, shape, rank));
341 }
342 }
343 _ => {
344 return Self::create_shards(
346 global_shape,
347 ShardingStrategy::Replicated,
348 device_group,
349 );
350 }
351 }
352
353 shards
354 }
355
356 pub fn global_shape(&self) -> &[usize] {
358 &self.global_shape
359 }
360
361 pub fn strategy(&self) -> ShardingStrategy {
363 self.strategy
364 }
365
366 pub fn device_group(&self) -> &DeviceGroup {
368 &self.device_group
369 }
370
371 pub fn shards(&self) -> &[Shard] {
373 &self.shards
374 }
375
376 pub fn shard_for_device(&self, device_id: &DeviceId) -> Option<&Shard> {
378 self.shards.iter().find(|s| &s.device_id == device_id)
379 }
380
381 pub fn total_elements(&self) -> usize {
383 match self.strategy {
384 ShardingStrategy::Replicated => {
385 self.global_shape.iter().product()
387 }
388 _ => {
389 self.shards.iter().map(|s| s.size()).sum()
391 }
392 }
393 }
394}
395
396#[derive(Debug, Clone, Copy, PartialEq, Eq)]
400pub enum CollectiveOp {
401 AllReduce(ReduceOp),
403 AllGather,
405 ReduceScatter(ReduceOp),
407 Broadcast { root: usize },
409 Scatter { root: usize },
411 Gather { root: usize },
413 AllToAll,
415 Barrier,
417}
418
419#[derive(Debug, Clone, Copy, PartialEq, Eq)]
421pub enum ReduceOp {
422 Sum,
424 Product,
426 Min,
428 Max,
430 Average,
432}
433
434impl fmt::Display for ReduceOp {
435 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
436 match self {
437 ReduceOp::Sum => write!(f, "Sum"),
438 ReduceOp::Product => write!(f, "Product"),
439 ReduceOp::Min => write!(f, "Min"),
440 ReduceOp::Max => write!(f, "Max"),
441 ReduceOp::Average => write!(f, "Average"),
442 }
443 }
444}
445
446#[derive(Debug, Clone, Copy, PartialEq, Eq)]
450pub enum CommBackend {
451 NCCL,
453 Gloo,
455 MPI,
457 Custom,
459}
460
461#[derive(Debug, Clone)]
465pub struct CommunicationDescriptor {
466 operation: CollectiveOp,
468 device_group: DeviceGroup,
470 backend: CommBackend,
472 async_op: bool,
474}
475
476impl CommunicationDescriptor {
477 pub fn new(operation: CollectiveOp, device_group: DeviceGroup, backend: CommBackend) -> Self {
479 Self {
480 operation,
481 device_group,
482 backend,
483 async_op: false,
484 }
485 }
486
487 pub fn with_async(mut self, async_op: bool) -> Self {
489 self.async_op = async_op;
490 self
491 }
492
493 pub fn operation(&self) -> CollectiveOp {
495 self.operation
496 }
497
498 pub fn device_group(&self) -> &DeviceGroup {
500 &self.device_group
501 }
502
503 pub fn backend(&self) -> CommBackend {
505 self.backend
506 }
507
508 pub fn is_async(&self) -> bool {
510 self.async_op
511 }
512}
513
514#[derive(Debug, Clone)]
518pub struct CheckpointMetadata {
519 id: String,
521 step: u64,
523 devices: Vec<DeviceId>,
525 timestamp: u64,
527 metadata: Vec<(String, String)>,
529}
530
531impl CheckpointMetadata {
532 pub fn new(id: impl Into<String>, step: u64, devices: Vec<DeviceId>) -> Self {
534 Self {
535 id: id.into(),
536 step,
537 devices,
538 timestamp: 0, metadata: Vec::new(),
540 }
541 }
542
543 pub fn add_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
545 self.metadata.push((key.into(), value.into()));
546 }
547
548 pub fn id(&self) -> &str {
550 &self.id
551 }
552
553 pub fn step(&self) -> u64 {
555 self.step
556 }
557
558 pub fn devices(&self) -> &[DeviceId] {
560 &self.devices
561 }
562
563 pub fn timestamp(&self) -> u64 {
565 self.timestamp
566 }
567
568 pub fn metadata(&self) -> &[(String, String)] {
570 &self.metadata
571 }
572}
573
574#[derive(Debug, Clone)]
578pub struct DeviceTopology {
579 devices: Vec<DeviceId>,
581 num_nodes: usize,
583 num_racks: usize,
585 devices_per_node: usize,
587}
588
589impl DeviceTopology {
590 pub fn new(num_racks: usize, num_nodes: usize, devices_per_node: usize) -> Self {
592 let mut devices = Vec::new();
593 for rack_id in 0..num_racks {
594 for node_id in 0..num_nodes {
595 for device_id in 0..devices_per_node {
596 devices.push(DeviceId::new(node_id, rack_id, device_id));
597 }
598 }
599 }
600
601 Self {
602 devices,
603 num_nodes,
604 num_racks,
605 devices_per_node,
606 }
607 }
608
609 pub fn devices(&self) -> &[DeviceId] {
611 &self.devices
612 }
613
614 pub fn node_devices(&self, node_id: usize) -> Vec<DeviceId> {
616 self.devices
617 .iter()
618 .filter(|d| d.node_id() == node_id)
619 .copied()
620 .collect()
621 }
622
623 pub fn rack_devices(&self, rack_id: usize) -> Vec<DeviceId> {
625 self.devices
626 .iter()
627 .filter(|d| d.rack_id() == rack_id)
628 .copied()
629 .collect()
630 }
631
632 pub fn total_devices(&self) -> usize {
634 self.devices.len()
635 }
636
637 pub fn num_nodes(&self) -> usize {
639 self.num_nodes
640 }
641
642 pub fn num_racks(&self) -> usize {
644 self.num_racks
645 }
646
647 pub fn devices_per_node(&self) -> usize {
649 self.devices_per_node
650 }
651}
652
653#[cfg(test)]
654mod tests {
655 use super::*;
656
657 #[test]
658 fn test_device_id() {
659 let device = DeviceId::new(0, 1, 2);
660 assert_eq!(device.node_id(), 0);
661 assert_eq!(device.rack_id(), 1);
662 assert_eq!(device.local_device_id(), 2);
663 assert_eq!(device.global_id(), 1002); }
665
666 #[test]
667 fn test_simple_device_id() {
668 let device = DeviceId::simple(5);
669 assert_eq!(device.local_device_id(), 5);
670 assert_eq!(device.node_id(), 0);
671 assert_eq!(device.rack_id(), 0);
672 }
673
674 #[test]
675 fn test_device_group() {
676 let group = DeviceGroup::new(vec![0, 1, 2, 3]);
677 assert_eq!(group.size(), 4);
678 assert!(group.contains(&DeviceId::simple(0)));
679 assert_eq!(group.rank(&DeviceId::simple(2)), Some(2));
680 }
681
682 #[test]
683 fn test_device_group_with_name() {
684 let group = DeviceGroup::new(vec![0, 1]).with_name("test_group");
685 assert_eq!(group.name(), Some("test_group"));
686 }
687
688 #[test]
689 fn test_sharding_strategy_display() {
690 assert_eq!(format!("{}", ShardingStrategy::Replicated), "Replicated");
691 assert_eq!(
692 format!("{}", ShardingStrategy::DataParallel),
693 "DataParallel"
694 );
695 assert_eq!(
696 format!("{}", ShardingStrategy::DimSharded(1)),
697 "DimSharded(1)"
698 );
699 }
700
701 #[test]
702 fn test_shard() {
703 let device = DeviceId::simple(0);
704 let shard = Shard::new(device, vec![0, 0], vec![10, 20], 0);
705 assert_eq!(shard.device_id(), device);
706 assert_eq!(shard.offset(), &[0, 0]);
707 assert_eq!(shard.shape(), &[10, 20]);
708 assert_eq!(shard.rank(), 0);
709 assert_eq!(shard.size(), 200);
710 }
711
712 #[test]
713 fn test_distributed_tensor_replicated() {
714 let group = DeviceGroup::new(vec![0, 1, 2, 3]);
715 let tensor = DistributedTensor::new(vec![100, 50], ShardingStrategy::Replicated, group);
716
717 assert_eq!(tensor.global_shape(), &[100, 50]);
718 assert_eq!(tensor.shards().len(), 4);
719 assert_eq!(tensor.strategy(), ShardingStrategy::Replicated);
720
721 for shard in tensor.shards() {
723 assert_eq!(shard.shape(), &[100, 50]);
724 }
725 }
726
727 #[test]
728 fn test_distributed_tensor_data_parallel() {
729 let group = DeviceGroup::new(vec![0, 1, 2, 3]);
730 let tensor = DistributedTensor::new(vec![100, 50], ShardingStrategy::DataParallel, group);
731
732 assert_eq!(tensor.shards().len(), 4);
733
734 for shard in tensor.shards() {
736 assert_eq!(shard.shape()[0], 25);
737 assert_eq!(shard.shape()[1], 50);
738 }
739 }
740
741 #[test]
742 fn test_distributed_tensor_dim_sharded() {
743 let group = DeviceGroup::new(vec![0, 1]);
744 let tensor =
745 DistributedTensor::new(vec![10, 20, 30], ShardingStrategy::DimSharded(1), group);
746
747 assert_eq!(tensor.shards().len(), 2);
748
749 assert_eq!(tensor.shards()[0].shape(), &[10, 10, 30]);
751 assert_eq!(tensor.shards()[1].shape(), &[10, 10, 30]);
752 }
753
754 #[test]
755 fn test_shard_for_device() {
756 let group = DeviceGroup::new(vec![0, 1]);
757 let tensor = DistributedTensor::new(vec![10, 20], ShardingStrategy::DataParallel, group);
758
759 let device = DeviceId::simple(0);
760 let shard = tensor.shard_for_device(&device);
761 assert!(shard.is_some());
762 assert_eq!(
763 shard.expect("shard_for_device should succeed").device_id(),
764 device
765 );
766 }
767
768 #[test]
769 fn test_collective_operations() {
770 let _all_reduce = CollectiveOp::AllReduce(ReduceOp::Sum);
771 let _all_gather = CollectiveOp::AllGather;
772 let _reduce_scatter = CollectiveOp::ReduceScatter(ReduceOp::Average);
773 let _broadcast = CollectiveOp::Broadcast { root: 0 };
774 let _scatter = CollectiveOp::Scatter { root: 0 };
775 let _gather = CollectiveOp::Gather { root: 0 };
776 let _all_to_all = CollectiveOp::AllToAll;
777 let _barrier = CollectiveOp::Barrier;
778 }
779
780 #[test]
781 fn test_reduce_op_display() {
782 assert_eq!(format!("{}", ReduceOp::Sum), "Sum");
783 assert_eq!(format!("{}", ReduceOp::Product), "Product");
784 assert_eq!(format!("{}", ReduceOp::Min), "Min");
785 assert_eq!(format!("{}", ReduceOp::Max), "Max");
786 assert_eq!(format!("{}", ReduceOp::Average), "Average");
787 }
788
789 #[test]
790 fn test_comm_backend() {
791 let _nccl = CommBackend::NCCL;
792 let _gloo = CommBackend::Gloo;
793 let _mpi = CommBackend::MPI;
794 let _custom = CommBackend::Custom;
795 }
796
797 #[test]
798 fn test_communication_descriptor() {
799 let group = DeviceGroup::new(vec![0, 1, 2, 3]);
800 let comm_desc = CommunicationDescriptor::new(
801 CollectiveOp::AllReduce(ReduceOp::Sum),
802 group.clone(),
803 CommBackend::NCCL,
804 )
805 .with_async(true);
806
807 assert_eq!(
808 comm_desc.operation(),
809 CollectiveOp::AllReduce(ReduceOp::Sum)
810 );
811 assert_eq!(comm_desc.backend(), CommBackend::NCCL);
812 assert!(comm_desc.is_async());
813 }
814
815 #[test]
816 fn test_checkpoint_metadata() {
817 let devices = vec![DeviceId::simple(0), DeviceId::simple(1)];
818 let mut checkpoint = CheckpointMetadata::new("ckpt_001", 1000, devices);
819 checkpoint.add_metadata("model", "resnet50");
820 checkpoint.add_metadata("optimizer", "adam");
821
822 assert_eq!(checkpoint.id(), "ckpt_001");
823 assert_eq!(checkpoint.step(), 1000);
824 assert_eq!(checkpoint.devices().len(), 2);
825 assert_eq!(checkpoint.metadata().len(), 2);
826 }
827
828 #[test]
829 fn test_device_topology() {
830 let topology = DeviceTopology::new(2, 3, 4); assert_eq!(topology.total_devices(), 24); assert_eq!(topology.num_racks(), 2);
833 assert_eq!(topology.num_nodes(), 3);
834 assert_eq!(topology.devices_per_node(), 4);
835
836 let node0_devices = topology.node_devices(0);
837 assert_eq!(node0_devices.len(), 8); let rack0_devices = topology.rack_devices(0);
840 assert_eq!(rack0_devices.len(), 12); }
842
843 #[test]
844 fn test_total_elements() {
845 let group = DeviceGroup::new(vec![0, 1, 2, 3]);
846
847 let replicated =
849 DistributedTensor::new(vec![100, 50], ShardingStrategy::Replicated, group.clone());
850 assert_eq!(replicated.total_elements(), 5000); let sharded = DistributedTensor::new(vec![100, 50], ShardingStrategy::DataParallel, group);
854 assert_eq!(sharded.total_elements(), 5000); }
856
857 #[test]
858 fn test_from_devices() {
859 let devices = vec![DeviceId::new(0, 0, 1), DeviceId::new(0, 0, 2)];
860 let group = DeviceGroup::from_devices(devices);
861 assert_eq!(group.size(), 2);
862 }
863
864 #[test]
865 fn test_device_not_in_group() {
866 let group = DeviceGroup::new(vec![0, 1, 2]);
867 assert!(!group.contains(&DeviceId::simple(5)));
868 assert_eq!(group.rank(&DeviceId::simple(5)), None);
869 }
870}