Skip to main content

tensorlogic_infer/
distributed.rs

1//! Distributed execution infrastructure for multi-device and multi-node computation.
2//!
3//! This module provides distributed training and inference capabilities:
4//! - `DistributedExecutor`: Multi-device execution coordination
5//! - `DataParallelism`: Data-parallel training across devices
6//! - `ModelParallelism`: Model-parallel execution with tensor sharding
7//! - `CommunicationBackend`: Abstract interface for device communication
8//! - `TlDistributedExecutor`: Trait for executors that support distributed execution
9//!
10//! # Parallelism Strategies
11//!
12//! ## Data Parallelism
13//! - Each device processes a different subset of the batch
14//! - Gradients are averaged across devices
15//! - Suitable for models that fit on a single device
16//!
17//! ## Model Parallelism
18//! - Model is split across multiple devices
19//! - Each device processes different parts of the model
20//! - Suitable for large models that don't fit on a single device
21//!
22//! ## Hybrid Parallelism
23//! - Combines data and model parallelism
24//! - Model is split across devices, each replica processes different data
25//!
26//! # Example
27//!
28//! ```
29//! use tensorlogic_infer::distributed::{DistributedConfig, ParallelismStrategy};
30//!
31//! let config = DistributedConfig {
32//!     parallelism: ParallelismStrategy::DataParallel,
33//!     num_devices: 4,
34//!     ..Default::default()
35//! };
36//! ```
37
38use crate::capabilities::DeviceType;
39use crate::error::ExecutorError;
40use crate::placement::Device;
41use crate::shape::TensorShape;
42use std::collections::HashMap;
43use std::sync::{Arc, RwLock};
44use tensorlogic_ir::EinsumGraph;
45
46/// Parallelism strategy for distributed execution.
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
48pub enum ParallelismStrategy {
49    /// Data parallelism: each device processes different data
50    #[default]
51    DataParallel,
52    /// Model parallelism: model is split across devices
53    ModelParallel,
54    /// Pipeline parallelism: model stages on different devices
55    PipelineParallel,
56    /// Hybrid: combination of data and model parallelism
57    Hybrid { data_parallel_groups: usize },
58}
59
60/// Configuration for distributed execution.
61#[derive(Debug, Clone)]
62pub struct DistributedConfig {
63    /// Parallelism strategy to use
64    pub parallelism: ParallelismStrategy,
65    /// Number of devices to use
66    pub num_devices: usize,
67    /// Communication backend (e.g., "nccl", "gloo", "mpi")
68    pub backend: String,
69    /// Master address for multi-node setups
70    pub master_addr: Option<String>,
71    /// Master port for multi-node setups
72    pub master_port: Option<u16>,
73    /// Rank of this process (0 to world_size-1)
74    pub rank: usize,
75    /// Total number of processes (world size)
76    pub world_size: usize,
77    /// Enable gradient compression
78    pub enable_gradient_compression: bool,
79    /// Enable mixed precision
80    pub enable_mixed_precision: bool,
81    /// Bucket size for gradient bucketing (bytes)
82    pub bucket_size: usize,
83    /// Enable asynchronous communication
84    pub enable_async_communication: bool,
85}
86
87impl Default for DistributedConfig {
88    fn default() -> Self {
89        DistributedConfig {
90            parallelism: ParallelismStrategy::default(),
91            num_devices: 1,
92            backend: "gloo".to_string(),
93            master_addr: None,
94            master_port: None,
95            rank: 0,
96            world_size: 1,
97            enable_gradient_compression: false,
98            enable_mixed_precision: false,
99            bucket_size: 25 * 1024 * 1024, // 25MB
100            enable_async_communication: true,
101        }
102    }
103}
104
105/// Tensor sharding specification for model parallelism.
106#[derive(Debug, Clone, PartialEq, Eq)]
107pub struct ShardingSpec {
108    /// Node ID that this sharding applies to
109    pub node_id: usize,
110    /// Dimension along which to shard
111    pub shard_dim: usize,
112    /// Number of shards
113    pub num_shards: usize,
114    /// Device assignment for each shard
115    pub shard_to_device: Vec<Device>,
116}
117
118impl ShardingSpec {
119    /// Create a new sharding specification.
120    pub fn new(node_id: usize, shard_dim: usize, devices: Vec<Device>) -> Self {
121        let num_shards = devices.len();
122        ShardingSpec {
123            node_id,
124            shard_dim,
125            num_shards,
126            shard_to_device: devices,
127        }
128    }
129
130    /// Get the device for a specific shard.
131    pub fn device_for_shard(&self, shard_id: usize) -> Option<&Device> {
132        self.shard_to_device.get(shard_id)
133    }
134
135    /// Check if a shard ID is valid.
136    pub fn is_valid_shard(&self, shard_id: usize) -> bool {
137        shard_id < self.num_shards
138    }
139}
140
141/// Placement plan for distributed execution.
142#[derive(Debug, Clone)]
143pub struct DistributedPlacementPlan {
144    /// Node to device mapping
145    pub node_placement: HashMap<usize, Device>,
146    /// Sharding specifications for model parallelism
147    pub sharding_specs: Vec<ShardingSpec>,
148    /// Communication dependencies (node -> nodes it depends on)
149    pub communication_deps: HashMap<usize, Vec<usize>>,
150}
151
152impl DistributedPlacementPlan {
153    /// Create a new empty placement plan.
154    pub fn new() -> Self {
155        DistributedPlacementPlan {
156            node_placement: HashMap::new(),
157            sharding_specs: Vec::new(),
158            communication_deps: HashMap::new(),
159        }
160    }
161
162    /// Add a node placement.
163    pub fn place_node(&mut self, node_id: usize, device: Device) {
164        self.node_placement.insert(node_id, device);
165    }
166
167    /// Add a sharding specification.
168    pub fn add_sharding(&mut self, spec: ShardingSpec) {
169        self.sharding_specs.push(spec);
170    }
171
172    /// Get the device for a node.
173    pub fn get_device(&self, node_id: usize) -> Option<&Device> {
174        self.node_placement.get(&node_id)
175    }
176
177    /// Get sharding spec for a node.
178    pub fn get_sharding(&self, node_id: usize) -> Option<&ShardingSpec> {
179        self.sharding_specs.iter().find(|s| s.node_id == node_id)
180    }
181
182    /// Check if a node is sharded.
183    pub fn is_sharded(&self, node_id: usize) -> bool {
184        self.get_sharding(node_id).is_some()
185    }
186}
187
188impl Default for DistributedPlacementPlan {
189    fn default() -> Self {
190        Self::new()
191    }
192}
193
194/// Communication operation for distributed execution.
195#[derive(Debug, Clone, PartialEq, Eq)]
196pub enum CommunicationOp {
197    /// All-reduce: reduce values across all devices
198    AllReduce {
199        /// Reduction operation (sum, mean, max, min)
200        reduction: ReductionOp,
201    },
202    /// Broadcast: send from one device to all others
203    Broadcast {
204        /// Source device rank
205        src_rank: usize,
206    },
207    /// Scatter: distribute data from one device to all
208    Scatter {
209        /// Source device rank
210        src_rank: usize,
211    },
212    /// Gather: collect data from all devices to one
213    Gather {
214        /// Destination device rank
215        dst_rank: usize,
216    },
217    /// All-gather: gather data from all devices to all
218    AllGather,
219    /// Reduce-scatter: reduce and scatter results
220    ReduceScatter {
221        /// Reduction operation
222        reduction: ReductionOp,
223    },
224    /// Peer-to-peer send
225    Send {
226        /// Destination device rank
227        dst_rank: usize,
228    },
229    /// Peer-to-peer receive
230    Recv {
231        /// Source device rank
232        src_rank: usize,
233    },
234}
235
236/// Reduction operation for communication.
237#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
238pub enum ReductionOp {
239    /// Sum reduction
240    Sum,
241    /// Mean reduction (sum / count)
242    Mean,
243    /// Maximum value
244    Max,
245    /// Minimum value
246    Min,
247    /// Product
248    Product,
249}
250
251/// Abstract communication backend for device-to-device communication.
252pub trait CommunicationBackend: Send + Sync {
253    /// Initialize the communication backend.
254    fn initialize(&mut self, config: &DistributedConfig) -> Result<(), ExecutorError>;
255
256    /// Finalize and clean up the backend.
257    fn finalize(&mut self) -> Result<(), ExecutorError>;
258
259    /// Get the rank of this process.
260    fn rank(&self) -> usize;
261
262    /// Get the world size (total number of processes).
263    fn world_size(&self) -> usize;
264
265    /// Perform an all-reduce operation.
266    fn all_reduce(&self, tensor_id: &str, reduction: ReductionOp) -> Result<(), ExecutorError>;
267
268    /// Broadcast from source rank to all ranks.
269    fn broadcast(&self, tensor_id: &str, src_rank: usize) -> Result<(), ExecutorError>;
270
271    /// Scatter data from source rank to all ranks.
272    fn scatter(&self, tensor_id: &str, src_rank: usize) -> Result<(), ExecutorError>;
273
274    /// Gather data from all ranks to destination rank.
275    fn gather(&self, tensor_id: &str, dst_rank: usize) -> Result<(), ExecutorError>;
276
277    /// All-gather operation.
278    fn all_gather(&self, tensor_id: &str) -> Result<(), ExecutorError>;
279
280    /// Reduce-scatter operation.
281    fn reduce_scatter(&self, tensor_id: &str, reduction: ReductionOp) -> Result<(), ExecutorError>;
282
283    /// Point-to-point send.
284    fn send(&self, tensor_id: &str, dst_rank: usize) -> Result<(), ExecutorError>;
285
286    /// Point-to-point receive.
287    fn recv(&self, tensor_id: &str, src_rank: usize) -> Result<(), ExecutorError>;
288
289    /// Synchronize all processes.
290    fn barrier(&self) -> Result<(), ExecutorError>;
291}
292
293/// Dummy communication backend for testing.
294pub struct DummyCommunicationBackend {
295    rank: usize,
296    world_size: usize,
297}
298
299impl DummyCommunicationBackend {
300    /// Create a new dummy backend.
301    pub fn new() -> Self {
302        DummyCommunicationBackend {
303            rank: 0,
304            world_size: 1,
305        }
306    }
307}
308
309impl Default for DummyCommunicationBackend {
310    fn default() -> Self {
311        Self::new()
312    }
313}
314
315impl CommunicationBackend for DummyCommunicationBackend {
316    fn initialize(&mut self, config: &DistributedConfig) -> Result<(), ExecutorError> {
317        self.rank = config.rank;
318        self.world_size = config.world_size;
319        Ok(())
320    }
321
322    fn finalize(&mut self) -> Result<(), ExecutorError> {
323        Ok(())
324    }
325
326    fn rank(&self) -> usize {
327        self.rank
328    }
329
330    fn world_size(&self) -> usize {
331        self.world_size
332    }
333
334    fn all_reduce(&self, _tensor_id: &str, _reduction: ReductionOp) -> Result<(), ExecutorError> {
335        Ok(())
336    }
337
338    fn broadcast(&self, _tensor_id: &str, _src_rank: usize) -> Result<(), ExecutorError> {
339        Ok(())
340    }
341
342    fn scatter(&self, _tensor_id: &str, _src_rank: usize) -> Result<(), ExecutorError> {
343        Ok(())
344    }
345
346    fn gather(&self, _tensor_id: &str, _dst_rank: usize) -> Result<(), ExecutorError> {
347        Ok(())
348    }
349
350    fn all_gather(&self, _tensor_id: &str) -> Result<(), ExecutorError> {
351        Ok(())
352    }
353
354    fn reduce_scatter(
355        &self,
356        _tensor_id: &str,
357        _reduction: ReductionOp,
358    ) -> Result<(), ExecutorError> {
359        Ok(())
360    }
361
362    fn send(&self, _tensor_id: &str, _dst_rank: usize) -> Result<(), ExecutorError> {
363        Ok(())
364    }
365
366    fn recv(&self, _tensor_id: &str, _src_rank: usize) -> Result<(), ExecutorError> {
367        Ok(())
368    }
369
370    fn barrier(&self) -> Result<(), ExecutorError> {
371        Ok(())
372    }
373}
374
375/// Data parallelism coordinator.
376pub struct DataParallelCoordinator {
377    config: DistributedConfig,
378    backend: Arc<RwLock<dyn CommunicationBackend>>,
379    devices: Vec<Device>,
380}
381
382impl DataParallelCoordinator {
383    /// Create a new data parallel coordinator.
384    pub fn new(config: DistributedConfig, backend: Arc<RwLock<dyn CommunicationBackend>>) -> Self {
385        let devices = (0..config.num_devices)
386            .map(|i| Device::new(DeviceType::CPU, i))
387            .collect();
388
389        DataParallelCoordinator {
390            config,
391            backend,
392            devices,
393        }
394    }
395
396    /// Distribute batch across devices.
397    pub fn distribute_batch(&self, batch_size: usize) -> Vec<(usize, usize)> {
398        let per_device = batch_size / self.config.num_devices;
399        let remainder = batch_size % self.config.num_devices;
400
401        let mut distribution = Vec::new();
402        let mut offset = 0;
403
404        for i in 0..self.config.num_devices {
405            let size = per_device + if i < remainder { 1 } else { 0 };
406            distribution.push((offset, size));
407            offset += size;
408        }
409
410        distribution
411    }
412
413    /// Synchronize gradients across devices.
414    pub fn synchronize_gradients(&self) -> Result<(), ExecutorError> {
415        let backend = self.backend.read().expect("lock should not be poisoned");
416
417        // All-reduce gradients with mean reduction
418        backend.all_reduce("gradients", ReductionOp::Mean)?;
419
420        Ok(())
421    }
422
423    /// Get the list of devices.
424    pub fn devices(&self) -> &[Device] {
425        &self.devices
426    }
427}
428
429/// Model parallelism coordinator.
430pub struct ModelParallelCoordinator {
431    config: DistributedConfig,
432    backend: Arc<RwLock<dyn CommunicationBackend>>,
433    placement_plan: DistributedPlacementPlan,
434}
435
436impl ModelParallelCoordinator {
437    /// Create a new model parallel coordinator.
438    pub fn new(config: DistributedConfig, backend: Arc<RwLock<dyn CommunicationBackend>>) -> Self {
439        ModelParallelCoordinator {
440            config,
441            backend,
442            placement_plan: DistributedPlacementPlan::new(),
443        }
444    }
445
446    /// Create a sharding plan for the graph.
447    pub fn create_sharding_plan(&mut self, graph: &EinsumGraph) -> Result<(), ExecutorError> {
448        let num_devices = self.config.num_devices;
449        let nodes_per_device = graph.nodes.len().div_ceil(num_devices);
450
451        // Simple sharding: distribute nodes across devices
452        for (node_id, _node) in graph.nodes.iter().enumerate() {
453            let device_idx = node_id / nodes_per_device;
454            let device = Device::new(DeviceType::CPU, device_idx);
455            self.placement_plan.place_node(node_id, device);
456        }
457
458        Ok(())
459    }
460
461    /// Get the placement plan.
462    pub fn placement_plan(&self) -> &DistributedPlacementPlan {
463        &self.placement_plan
464    }
465
466    /// Shard a tensor along a dimension.
467    pub fn shard_tensor(
468        &self,
469        _node_id: usize,
470        shape: &TensorShape,
471        shard_dim: usize,
472    ) -> Result<Vec<TensorShape>, ExecutorError> {
473        let num_shards = self.config.num_devices;
474
475        if shard_dim >= shape.rank() {
476            return Err(ExecutorError::InvalidInput(format!(
477                "Shard dimension {} exceeds tensor rank {}",
478                shard_dim,
479                shape.rank()
480            )));
481        }
482
483        let total_size = shape.dims[shard_dim].as_static().ok_or_else(|| {
484            ExecutorError::InvalidInput("Cannot shard dynamic dimension".to_string())
485        })?;
486
487        let per_shard = total_size / num_shards;
488        let remainder = total_size % num_shards;
489
490        let mut shard_shapes = Vec::new();
491        for i in 0..num_shards {
492            let shard_size = per_shard + if i < remainder { 1 } else { 0 };
493            let mut shard_shape = shape.clone();
494            shard_shape.dims[shard_dim] = crate::shape::DimSize::Static(shard_size);
495            shard_shapes.push(shard_shape);
496        }
497
498        Ok(shard_shapes)
499    }
500
501    /// Gather sharded tensors.
502    pub fn gather_shards(&self, _shard_dim: usize) -> Result<(), ExecutorError> {
503        let backend = self.backend.read().expect("lock should not be poisoned");
504        backend.all_gather("sharded_tensor")?;
505        Ok(())
506    }
507}
508
509/// Pipeline parallelism coordinator.
510pub struct PipelineParallelCoordinator {
511    config: DistributedConfig,
512    backend: Arc<RwLock<dyn CommunicationBackend>>,
513    num_stages: usize,
514    micro_batch_size: usize,
515}
516
517impl PipelineParallelCoordinator {
518    /// Create a new pipeline parallel coordinator.
519    pub fn new(
520        config: DistributedConfig,
521        backend: Arc<RwLock<dyn CommunicationBackend>>,
522        num_stages: usize,
523    ) -> Self {
524        PipelineParallelCoordinator {
525            config,
526            backend,
527            num_stages,
528            micro_batch_size: 1,
529        }
530    }
531
532    /// Set micro-batch size for pipeline parallelism.
533    pub fn set_micro_batch_size(&mut self, size: usize) {
534        self.micro_batch_size = size;
535    }
536
537    /// Get the stage assignment for this rank.
538    pub fn stage_for_rank(&self, rank: usize) -> usize {
539        rank % self.num_stages
540    }
541
542    /// Send activations to next stage.
543    pub fn send_activations(&self, stage: usize) -> Result<(), ExecutorError> {
544        if stage < self.num_stages - 1 {
545            let next_rank = stage + 1;
546            let backend = self.backend.read().expect("lock should not be poisoned");
547            backend.send("activations", next_rank)?;
548        }
549        Ok(())
550    }
551
552    /// Receive activations from previous stage.
553    pub fn recv_activations(&self, stage: usize) -> Result<(), ExecutorError> {
554        if stage > 0 {
555            let prev_rank = stage - 1;
556            let backend = self.backend.read().expect("lock should not be poisoned");
557            backend.recv("activations", prev_rank)?;
558        }
559        Ok(())
560    }
561
562    /// Send gradients to previous stage.
563    pub fn send_gradients(&self, stage: usize) -> Result<(), ExecutorError> {
564        if stage > 0 {
565            let prev_rank = stage - 1;
566            let backend = self.backend.read().expect("lock should not be poisoned");
567            backend.send("gradients", prev_rank)?;
568        }
569        Ok(())
570    }
571
572    /// Receive gradients from next stage.
573    pub fn recv_gradients(&self, stage: usize) -> Result<(), ExecutorError> {
574        if stage < self.num_stages - 1 {
575            let next_rank = stage + 1;
576            let backend = self.backend.read().expect("lock should not be poisoned");
577            backend.recv("gradients", next_rank)?;
578        }
579        Ok(())
580    }
581
582    /// Get the number of stages in the pipeline.
583    pub fn num_stages(&self) -> usize {
584        self.num_stages
585    }
586
587    /// Get the micro-batch size.
588    pub fn micro_batch_size(&self) -> usize {
589        self.micro_batch_size
590    }
591
592    /// Get the configuration.
593    pub fn config(&self) -> &DistributedConfig {
594        &self.config
595    }
596}
597
598/// Distributed executor that coordinates multi-device execution.
599pub struct DistributedExecutor {
600    config: DistributedConfig,
601    backend: Arc<RwLock<dyn CommunicationBackend>>,
602    data_parallel: Option<DataParallelCoordinator>,
603    model_parallel: Option<ModelParallelCoordinator>,
604    pipeline_parallel: Option<PipelineParallelCoordinator>,
605}
606
607impl DistributedExecutor {
608    /// Create a new distributed executor.
609    pub fn new(
610        config: DistributedConfig,
611        backend: Arc<RwLock<dyn CommunicationBackend>>,
612    ) -> Result<Self, ExecutorError> {
613        // Initialize backend
614        backend
615            .write()
616            .expect("lock should not be poisoned")
617            .initialize(&config)?;
618
619        let mut executor = DistributedExecutor {
620            config: config.clone(),
621            backend: backend.clone(),
622            data_parallel: None,
623            model_parallel: None,
624            pipeline_parallel: None,
625        };
626
627        // Setup coordinators based on strategy
628        executor.setup_coordinators()?;
629
630        Ok(executor)
631    }
632
633    /// Setup coordinators based on parallelism strategy.
634    fn setup_coordinators(&mut self) -> Result<(), ExecutorError> {
635        match self.config.parallelism {
636            ParallelismStrategy::DataParallel => {
637                self.data_parallel = Some(DataParallelCoordinator::new(
638                    self.config.clone(),
639                    self.backend.clone(),
640                ));
641            }
642            ParallelismStrategy::ModelParallel => {
643                self.model_parallel = Some(ModelParallelCoordinator::new(
644                    self.config.clone(),
645                    self.backend.clone(),
646                ));
647            }
648            ParallelismStrategy::PipelineParallel => {
649                let num_stages = self.config.num_devices;
650                self.pipeline_parallel = Some(PipelineParallelCoordinator::new(
651                    self.config.clone(),
652                    self.backend.clone(),
653                    num_stages,
654                ));
655            }
656            ParallelismStrategy::Hybrid {
657                data_parallel_groups: _,
658            } => {
659                self.data_parallel = Some(DataParallelCoordinator::new(
660                    self.config.clone(),
661                    self.backend.clone(),
662                ));
663                self.model_parallel = Some(ModelParallelCoordinator::new(
664                    self.config.clone(),
665                    self.backend.clone(),
666                ));
667            }
668        }
669        Ok(())
670    }
671
672    /// Get the parallelism strategy.
673    pub fn strategy(&self) -> ParallelismStrategy {
674        self.config.parallelism
675    }
676
677    /// Get the rank of this process.
678    pub fn rank(&self) -> usize {
679        self.backend
680            .read()
681            .expect("lock should not be poisoned")
682            .rank()
683    }
684
685    /// Get the world size.
686    pub fn world_size(&self) -> usize {
687        self.backend
688            .read()
689            .expect("lock should not be poisoned")
690            .world_size()
691    }
692
693    /// Synchronize all processes.
694    pub fn barrier(&self) -> Result<(), ExecutorError> {
695        self.backend
696            .read()
697            .expect("lock should not be poisoned")
698            .barrier()
699    }
700
701    /// Get data parallel coordinator.
702    pub fn data_parallel(&self) -> Option<&DataParallelCoordinator> {
703        self.data_parallel.as_ref()
704    }
705
706    /// Get model parallel coordinator.
707    pub fn model_parallel(&self) -> Option<&ModelParallelCoordinator> {
708        self.model_parallel.as_ref()
709    }
710
711    /// Get pipeline parallel coordinator.
712    pub fn pipeline_parallel(&self) -> Option<&PipelineParallelCoordinator> {
713        self.pipeline_parallel.as_ref()
714    }
715}
716
717impl Drop for DistributedExecutor {
718    fn drop(&mut self) {
719        let _ = self
720            .backend
721            .write()
722            .expect("lock should not be poisoned")
723            .finalize();
724    }
725}
726
727/// Trait for executors that support distributed execution.
728pub trait TlDistributedExecutor {
729    /// Get the distributed executor.
730    fn distributed_executor(&self) -> Option<&DistributedExecutor>;
731
732    /// Enable distributed execution.
733    fn enable_distributed(&mut self, config: DistributedConfig) -> Result<(), ExecutorError>;
734
735    /// Disable distributed execution.
736    fn disable_distributed(&mut self);
737
738    /// Check if distributed execution is enabled.
739    fn is_distributed(&self) -> bool;
740
741    /// Get the current rank.
742    fn rank(&self) -> usize {
743        self.distributed_executor().map(|d| d.rank()).unwrap_or(0)
744    }
745
746    /// Get the world size.
747    fn world_size(&self) -> usize {
748        self.distributed_executor()
749            .map(|d| d.world_size())
750            .unwrap_or(1)
751    }
752}
753
754/// Statistics for distributed execution.
755#[derive(Debug, Clone, Default)]
756pub struct DistributedStats {
757    /// Total number of communication operations
758    pub total_communications: usize,
759    /// Total bytes communicated
760    pub total_bytes_communicated: u64,
761    /// Number of gradient synchronizations
762    pub gradient_syncs: usize,
763    /// Average communication time
764    pub avg_communication_time_ms: f64,
765    /// Load imbalance metric (0.0 = perfect, 1.0 = worst)
766    pub load_imbalance: f64,
767}
768
769impl DistributedStats {
770    /// Get a summary of distributed execution statistics.
771    pub fn summary(&self) -> String {
772        format!(
773            "Distributed Stats: {} communications, {:.2} MB transferred, {} gradient syncs, {:.2}ms avg comm time, {:.2}% load imbalance",
774            self.total_communications,
775            self.total_bytes_communicated as f64 / 1_000_000.0,
776            self.gradient_syncs,
777            self.avg_communication_time_ms,
778            self.load_imbalance * 100.0
779        )
780    }
781}
782
783#[cfg(test)]
784mod tests {
785    use super::*;
786
787    #[test]
788    fn test_distributed_config_default() {
789        let config = DistributedConfig::default();
790        assert_eq!(config.parallelism, ParallelismStrategy::DataParallel);
791        assert_eq!(config.num_devices, 1);
792        assert_eq!(config.rank, 0);
793        assert_eq!(config.world_size, 1);
794    }
795
796    #[test]
797    fn test_sharding_spec() {
798        let devices = vec![
799            Device::new(DeviceType::CPU, 0),
800            Device::new(DeviceType::CPU, 1),
801            Device::new(DeviceType::CPU, 2),
802        ];
803        let spec = ShardingSpec::new(0, 1, devices);
804
805        assert_eq!(spec.num_shards, 3);
806        assert_eq!(spec.shard_dim, 1);
807        assert!(spec.is_valid_shard(0));
808        assert!(spec.is_valid_shard(2));
809        assert!(!spec.is_valid_shard(3));
810    }
811
812    #[test]
813    fn test_distributed_placement_plan() {
814        let mut plan = DistributedPlacementPlan::new();
815
816        plan.place_node(0, Device::new(DeviceType::CPU, 0));
817        plan.place_node(1, Device::new(DeviceType::CPU, 1));
818
819        assert!(plan.get_device(0).is_some());
820        assert!(plan.get_device(1).is_some());
821        assert!(plan.get_device(2).is_none());
822    }
823
824    #[test]
825    fn test_data_parallel_batch_distribution() {
826        let config = DistributedConfig {
827            num_devices: 4,
828            ..Default::default()
829        };
830        let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
831        let coordinator = DataParallelCoordinator::new(config, backend);
832
833        let distribution = coordinator.distribute_batch(10);
834        assert_eq!(distribution.len(), 4);
835
836        // Check total size
837        let total: usize = distribution.iter().map(|(_, size)| size).sum();
838        assert_eq!(total, 10);
839    }
840
841    #[test]
842    fn test_model_parallel_sharding() {
843        let config = DistributedConfig {
844            num_devices: 4,
845            parallelism: ParallelismStrategy::ModelParallel,
846            ..Default::default()
847        };
848        let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
849        let coordinator = ModelParallelCoordinator::new(config, backend);
850
851        let shape = TensorShape::static_shape(vec![8, 16]);
852        let shards = coordinator.shard_tensor(0, &shape, 0).expect("unwrap");
853
854        assert_eq!(shards.len(), 4);
855        // Each shard should have size 2 in dimension 0
856        assert_eq!(shards[0].dims[0].as_static().expect("unwrap"), 2);
857    }
858
859    #[test]
860    fn test_pipeline_parallel_stage_assignment() {
861        let config = DistributedConfig {
862            num_devices: 4,
863            parallelism: ParallelismStrategy::PipelineParallel,
864            ..Default::default()
865        };
866        let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
867        let coordinator = PipelineParallelCoordinator::new(config, backend, 4);
868
869        assert_eq!(coordinator.stage_for_rank(0), 0);
870        assert_eq!(coordinator.stage_for_rank(1), 1);
871        assert_eq!(coordinator.stage_for_rank(2), 2);
872        assert_eq!(coordinator.stage_for_rank(3), 3);
873    }
874
875    #[test]
876    fn test_distributed_executor_creation() {
877        let config = DistributedConfig::default();
878        let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
879
880        let executor = DistributedExecutor::new(config, backend);
881        assert!(executor.is_ok());
882
883        let executor = executor.expect("unwrap");
884        assert_eq!(executor.rank(), 0);
885        assert_eq!(executor.world_size(), 1);
886    }
887
888    #[test]
889    fn test_communication_ops() {
890        let op1 = CommunicationOp::AllReduce {
891            reduction: ReductionOp::Sum,
892        };
893        let op2 = CommunicationOp::Broadcast { src_rank: 0 };
894
895        assert_ne!(op1, op2);
896    }
897
898    #[test]
899    fn test_reduction_ops() {
900        let ops = [
901            ReductionOp::Sum,
902            ReductionOp::Mean,
903            ReductionOp::Max,
904            ReductionOp::Min,
905            ReductionOp::Product,
906        ];
907
908        assert_eq!(ops.len(), 5);
909    }
910
911    #[test]
912    fn test_dummy_backend() {
913        let mut backend = DummyCommunicationBackend::new();
914        let config = DistributedConfig::default();
915
916        assert!(backend.initialize(&config).is_ok());
917        assert_eq!(backend.rank(), 0);
918        assert_eq!(backend.world_size(), 1);
919        assert!(backend.all_reduce("test", ReductionOp::Sum).is_ok());
920        assert!(backend.barrier().is_ok());
921        assert!(backend.finalize().is_ok());
922    }
923
924    #[test]
925    fn test_distributed_stats() {
926        let stats = DistributedStats {
927            total_communications: 100,
928            total_bytes_communicated: 1_000_000,
929            gradient_syncs: 50,
930            avg_communication_time_ms: 10.5,
931            load_imbalance: 0.15,
932        };
933
934        let summary = stats.summary();
935        assert!(summary.contains("100 communications"));
936        assert!(summary.contains("50 gradient syncs"));
937    }
938
939    #[test]
940    fn test_hybrid_parallelism() {
941        let config = DistributedConfig {
942            parallelism: ParallelismStrategy::Hybrid {
943                data_parallel_groups: 2,
944            },
945            num_devices: 8,
946            ..Default::default()
947        };
948
949        let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
950        let executor = DistributedExecutor::new(config, backend).expect("unwrap");
951
952        assert!(executor.data_parallel().is_some());
953        assert!(executor.model_parallel().is_some());
954    }
955
956    #[test]
957    fn test_sharding_invalid_dimension() {
958        let config = DistributedConfig {
959            num_devices: 4,
960            ..Default::default()
961        };
962        let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
963        let coordinator = ModelParallelCoordinator::new(config, backend);
964
965        let shape = TensorShape::static_shape(vec![8, 16]);
966        let result = coordinator.shard_tensor(0, &shape, 5);
967
968        assert!(result.is_err());
969    }
970}