Skip to main content

tensorlogic_infer/
distributed.rs

1//! Distributed execution infrastructure for multi-device and multi-node computation.
2//!
3//! This module provides distributed training and inference capabilities:
4//! - `DistributedExecutor`: Multi-device execution coordination
5//! - `DataParallelism`: Data-parallel training across devices
6//! - `ModelParallelism`: Model-parallel execution with tensor sharding
7//! - `CommunicationBackend`: Abstract interface for device communication
8//! - `TlDistributedExecutor`: Trait for executors that support distributed execution
9//!
10//! # Parallelism Strategies
11//!
12//! ## Data Parallelism
13//! - Each device processes a different subset of the batch
14//! - Gradients are averaged across devices
15//! - Suitable for models that fit on a single device
16//!
17//! ## Model Parallelism
18//! - Model is split across multiple devices
19//! - Each device processes different parts of the model
20//! - Suitable for large models that don't fit on a single device
21//!
22//! ## Hybrid Parallelism
23//! - Combines data and model parallelism
24//! - Model is split across devices, each replica processes different data
25//!
26//! # Example
27//!
28//! ```
29//! use tensorlogic_infer::distributed::{DistributedConfig, ParallelismStrategy};
30//!
31//! let config = DistributedConfig {
32//!     parallelism: ParallelismStrategy::DataParallel,
33//!     num_devices: 4,
34//!     ..Default::default()
35//! };
36//! ```
37
38use crate::capabilities::DeviceType;
39use crate::error::ExecutorError;
40use crate::placement::Device;
41use crate::shape::TensorShape;
42use std::collections::HashMap;
43use std::sync::{Arc, RwLock};
44use tensorlogic_ir::EinsumGraph;
45
46/// Parallelism strategy for distributed execution.
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
48pub enum ParallelismStrategy {
49    /// Data parallelism: each device processes different data
50    #[default]
51    DataParallel,
52    /// Model parallelism: model is split across devices
53    ModelParallel,
54    /// Pipeline parallelism: model stages on different devices
55    PipelineParallel,
56    /// Hybrid: combination of data and model parallelism
57    Hybrid { data_parallel_groups: usize },
58}
59
60/// Configuration for distributed execution.
61#[derive(Debug, Clone)]
62pub struct DistributedConfig {
63    /// Parallelism strategy to use
64    pub parallelism: ParallelismStrategy,
65    /// Number of devices to use
66    pub num_devices: usize,
67    /// Communication backend (e.g., "nccl", "gloo", "mpi")
68    pub backend: String,
69    /// Master address for multi-node setups
70    pub master_addr: Option<String>,
71    /// Master port for multi-node setups
72    pub master_port: Option<u16>,
73    /// Rank of this process (0 to world_size-1)
74    pub rank: usize,
75    /// Total number of processes (world size)
76    pub world_size: usize,
77    /// Enable gradient compression
78    pub enable_gradient_compression: bool,
79    /// Enable mixed precision
80    pub enable_mixed_precision: bool,
81    /// Bucket size for gradient bucketing (bytes)
82    pub bucket_size: usize,
83    /// Enable asynchronous communication
84    pub enable_async_communication: bool,
85}
86
87impl Default for DistributedConfig {
88    fn default() -> Self {
89        DistributedConfig {
90            parallelism: ParallelismStrategy::default(),
91            num_devices: 1,
92            backend: "gloo".to_string(),
93            master_addr: None,
94            master_port: None,
95            rank: 0,
96            world_size: 1,
97            enable_gradient_compression: false,
98            enable_mixed_precision: false,
99            bucket_size: 25 * 1024 * 1024, // 25MB
100            enable_async_communication: true,
101        }
102    }
103}
104
105/// Tensor sharding specification for model parallelism.
106#[derive(Debug, Clone, PartialEq, Eq)]
107pub struct ShardingSpec {
108    /// Node ID that this sharding applies to
109    pub node_id: usize,
110    /// Dimension along which to shard
111    pub shard_dim: usize,
112    /// Number of shards
113    pub num_shards: usize,
114    /// Device assignment for each shard
115    pub shard_to_device: Vec<Device>,
116}
117
118impl ShardingSpec {
119    /// Create a new sharding specification.
120    pub fn new(node_id: usize, shard_dim: usize, devices: Vec<Device>) -> Self {
121        let num_shards = devices.len();
122        ShardingSpec {
123            node_id,
124            shard_dim,
125            num_shards,
126            shard_to_device: devices,
127        }
128    }
129
130    /// Get the device for a specific shard.
131    pub fn device_for_shard(&self, shard_id: usize) -> Option<&Device> {
132        self.shard_to_device.get(shard_id)
133    }
134
135    /// Check if a shard ID is valid.
136    pub fn is_valid_shard(&self, shard_id: usize) -> bool {
137        shard_id < self.num_shards
138    }
139}
140
141/// Placement plan for distributed execution.
142#[derive(Debug, Clone)]
143pub struct DistributedPlacementPlan {
144    /// Node to device mapping
145    pub node_placement: HashMap<usize, Device>,
146    /// Sharding specifications for model parallelism
147    pub sharding_specs: Vec<ShardingSpec>,
148    /// Communication dependencies (node -> nodes it depends on)
149    pub communication_deps: HashMap<usize, Vec<usize>>,
150}
151
152impl DistributedPlacementPlan {
153    /// Create a new empty placement plan.
154    pub fn new() -> Self {
155        DistributedPlacementPlan {
156            node_placement: HashMap::new(),
157            sharding_specs: Vec::new(),
158            communication_deps: HashMap::new(),
159        }
160    }
161
162    /// Add a node placement.
163    pub fn place_node(&mut self, node_id: usize, device: Device) {
164        self.node_placement.insert(node_id, device);
165    }
166
167    /// Add a sharding specification.
168    pub fn add_sharding(&mut self, spec: ShardingSpec) {
169        self.sharding_specs.push(spec);
170    }
171
172    /// Get the device for a node.
173    pub fn get_device(&self, node_id: usize) -> Option<&Device> {
174        self.node_placement.get(&node_id)
175    }
176
177    /// Get sharding spec for a node.
178    pub fn get_sharding(&self, node_id: usize) -> Option<&ShardingSpec> {
179        self.sharding_specs.iter().find(|s| s.node_id == node_id)
180    }
181
182    /// Check if a node is sharded.
183    pub fn is_sharded(&self, node_id: usize) -> bool {
184        self.get_sharding(node_id).is_some()
185    }
186}
187
188impl Default for DistributedPlacementPlan {
189    fn default() -> Self {
190        Self::new()
191    }
192}
193
194/// Communication operation for distributed execution.
195#[derive(Debug, Clone, PartialEq, Eq)]
196pub enum CommunicationOp {
197    /// All-reduce: reduce values across all devices
198    AllReduce {
199        /// Reduction operation (sum, mean, max, min)
200        reduction: ReductionOp,
201    },
202    /// Broadcast: send from one device to all others
203    Broadcast {
204        /// Source device rank
205        src_rank: usize,
206    },
207    /// Scatter: distribute data from one device to all
208    Scatter {
209        /// Source device rank
210        src_rank: usize,
211    },
212    /// Gather: collect data from all devices to one
213    Gather {
214        /// Destination device rank
215        dst_rank: usize,
216    },
217    /// All-gather: gather data from all devices to all
218    AllGather,
219    /// Reduce-scatter: reduce and scatter results
220    ReduceScatter {
221        /// Reduction operation
222        reduction: ReductionOp,
223    },
224    /// Peer-to-peer send
225    Send {
226        /// Destination device rank
227        dst_rank: usize,
228    },
229    /// Peer-to-peer receive
230    Recv {
231        /// Source device rank
232        src_rank: usize,
233    },
234}
235
236/// Reduction operation for communication.
237#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
238pub enum ReductionOp {
239    /// Sum reduction
240    Sum,
241    /// Mean reduction (sum / count)
242    Mean,
243    /// Maximum value
244    Max,
245    /// Minimum value
246    Min,
247    /// Product
248    Product,
249}
250
251/// Abstract communication backend for device-to-device communication.
252pub trait CommunicationBackend: Send + Sync {
253    /// Initialize the communication backend.
254    fn initialize(&mut self, config: &DistributedConfig) -> Result<(), ExecutorError>;
255
256    /// Finalize and clean up the backend.
257    fn finalize(&mut self) -> Result<(), ExecutorError>;
258
259    /// Get the rank of this process.
260    fn rank(&self) -> usize;
261
262    /// Get the world size (total number of processes).
263    fn world_size(&self) -> usize;
264
265    /// Perform an all-reduce operation.
266    fn all_reduce(&self, tensor_id: &str, reduction: ReductionOp) -> Result<(), ExecutorError>;
267
268    /// Broadcast from source rank to all ranks.
269    fn broadcast(&self, tensor_id: &str, src_rank: usize) -> Result<(), ExecutorError>;
270
271    /// Scatter data from source rank to all ranks.
272    fn scatter(&self, tensor_id: &str, src_rank: usize) -> Result<(), ExecutorError>;
273
274    /// Gather data from all ranks to destination rank.
275    fn gather(&self, tensor_id: &str, dst_rank: usize) -> Result<(), ExecutorError>;
276
277    /// All-gather operation.
278    fn all_gather(&self, tensor_id: &str) -> Result<(), ExecutorError>;
279
280    /// Reduce-scatter operation.
281    fn reduce_scatter(&self, tensor_id: &str, reduction: ReductionOp) -> Result<(), ExecutorError>;
282
283    /// Point-to-point send.
284    fn send(&self, tensor_id: &str, dst_rank: usize) -> Result<(), ExecutorError>;
285
286    /// Point-to-point receive.
287    fn recv(&self, tensor_id: &str, src_rank: usize) -> Result<(), ExecutorError>;
288
289    /// Synchronize all processes.
290    fn barrier(&self) -> Result<(), ExecutorError>;
291}
292
293/// Dummy communication backend for testing.
294pub struct DummyCommunicationBackend {
295    rank: usize,
296    world_size: usize,
297}
298
299impl DummyCommunicationBackend {
300    /// Create a new dummy backend.
301    pub fn new() -> Self {
302        DummyCommunicationBackend {
303            rank: 0,
304            world_size: 1,
305        }
306    }
307}
308
309impl Default for DummyCommunicationBackend {
310    fn default() -> Self {
311        Self::new()
312    }
313}
314
315impl CommunicationBackend for DummyCommunicationBackend {
316    fn initialize(&mut self, config: &DistributedConfig) -> Result<(), ExecutorError> {
317        self.rank = config.rank;
318        self.world_size = config.world_size;
319        Ok(())
320    }
321
322    fn finalize(&mut self) -> Result<(), ExecutorError> {
323        Ok(())
324    }
325
326    fn rank(&self) -> usize {
327        self.rank
328    }
329
330    fn world_size(&self) -> usize {
331        self.world_size
332    }
333
334    fn all_reduce(&self, _tensor_id: &str, _reduction: ReductionOp) -> Result<(), ExecutorError> {
335        Ok(())
336    }
337
338    fn broadcast(&self, _tensor_id: &str, _src_rank: usize) -> Result<(), ExecutorError> {
339        Ok(())
340    }
341
342    fn scatter(&self, _tensor_id: &str, _src_rank: usize) -> Result<(), ExecutorError> {
343        Ok(())
344    }
345
346    fn gather(&self, _tensor_id: &str, _dst_rank: usize) -> Result<(), ExecutorError> {
347        Ok(())
348    }
349
350    fn all_gather(&self, _tensor_id: &str) -> Result<(), ExecutorError> {
351        Ok(())
352    }
353
354    fn reduce_scatter(
355        &self,
356        _tensor_id: &str,
357        _reduction: ReductionOp,
358    ) -> Result<(), ExecutorError> {
359        Ok(())
360    }
361
362    fn send(&self, _tensor_id: &str, _dst_rank: usize) -> Result<(), ExecutorError> {
363        Ok(())
364    }
365
366    fn recv(&self, _tensor_id: &str, _src_rank: usize) -> Result<(), ExecutorError> {
367        Ok(())
368    }
369
370    fn barrier(&self) -> Result<(), ExecutorError> {
371        Ok(())
372    }
373}
374
375/// Data parallelism coordinator.
376pub struct DataParallelCoordinator {
377    config: DistributedConfig,
378    backend: Arc<RwLock<dyn CommunicationBackend>>,
379    devices: Vec<Device>,
380}
381
382impl DataParallelCoordinator {
383    /// Create a new data parallel coordinator.
384    pub fn new(config: DistributedConfig, backend: Arc<RwLock<dyn CommunicationBackend>>) -> Self {
385        let devices = (0..config.num_devices)
386            .map(|i| Device::new(DeviceType::CPU, i))
387            .collect();
388
389        DataParallelCoordinator {
390            config,
391            backend,
392            devices,
393        }
394    }
395
396    /// Distribute batch across devices.
397    pub fn distribute_batch(&self, batch_size: usize) -> Vec<(usize, usize)> {
398        let per_device = batch_size / self.config.num_devices;
399        let remainder = batch_size % self.config.num_devices;
400
401        let mut distribution = Vec::new();
402        let mut offset = 0;
403
404        for i in 0..self.config.num_devices {
405            let size = per_device + if i < remainder { 1 } else { 0 };
406            distribution.push((offset, size));
407            offset += size;
408        }
409
410        distribution
411    }
412
413    /// Synchronize gradients across devices.
414    pub fn synchronize_gradients(&self) -> Result<(), ExecutorError> {
415        let backend = self.backend.read().unwrap();
416
417        // All-reduce gradients with mean reduction
418        backend.all_reduce("gradients", ReductionOp::Mean)?;
419
420        Ok(())
421    }
422
423    /// Get the list of devices.
424    pub fn devices(&self) -> &[Device] {
425        &self.devices
426    }
427}
428
429/// Model parallelism coordinator.
430pub struct ModelParallelCoordinator {
431    config: DistributedConfig,
432    backend: Arc<RwLock<dyn CommunicationBackend>>,
433    placement_plan: DistributedPlacementPlan,
434}
435
436impl ModelParallelCoordinator {
437    /// Create a new model parallel coordinator.
438    pub fn new(config: DistributedConfig, backend: Arc<RwLock<dyn CommunicationBackend>>) -> Self {
439        ModelParallelCoordinator {
440            config,
441            backend,
442            placement_plan: DistributedPlacementPlan::new(),
443        }
444    }
445
446    /// Create a sharding plan for the graph.
447    pub fn create_sharding_plan(&mut self, graph: &EinsumGraph) -> Result<(), ExecutorError> {
448        let num_devices = self.config.num_devices;
449        let nodes_per_device = graph.nodes.len().div_ceil(num_devices);
450
451        // Simple sharding: distribute nodes across devices
452        for (node_id, _node) in graph.nodes.iter().enumerate() {
453            let device_idx = node_id / nodes_per_device;
454            let device = Device::new(DeviceType::CPU, device_idx);
455            self.placement_plan.place_node(node_id, device);
456        }
457
458        Ok(())
459    }
460
461    /// Get the placement plan.
462    pub fn placement_plan(&self) -> &DistributedPlacementPlan {
463        &self.placement_plan
464    }
465
466    /// Shard a tensor along a dimension.
467    pub fn shard_tensor(
468        &self,
469        _node_id: usize,
470        shape: &TensorShape,
471        shard_dim: usize,
472    ) -> Result<Vec<TensorShape>, ExecutorError> {
473        let num_shards = self.config.num_devices;
474
475        if shard_dim >= shape.rank() {
476            return Err(ExecutorError::InvalidInput(format!(
477                "Shard dimension {} exceeds tensor rank {}",
478                shard_dim,
479                shape.rank()
480            )));
481        }
482
483        let total_size = shape.dims[shard_dim].as_static().ok_or_else(|| {
484            ExecutorError::InvalidInput("Cannot shard dynamic dimension".to_string())
485        })?;
486
487        let per_shard = total_size / num_shards;
488        let remainder = total_size % num_shards;
489
490        let mut shard_shapes = Vec::new();
491        for i in 0..num_shards {
492            let shard_size = per_shard + if i < remainder { 1 } else { 0 };
493            let mut shard_shape = shape.clone();
494            shard_shape.dims[shard_dim] = crate::shape::DimSize::Static(shard_size);
495            shard_shapes.push(shard_shape);
496        }
497
498        Ok(shard_shapes)
499    }
500
501    /// Gather sharded tensors.
502    pub fn gather_shards(&self, _shard_dim: usize) -> Result<(), ExecutorError> {
503        let backend = self.backend.read().unwrap();
504        backend.all_gather("sharded_tensor")?;
505        Ok(())
506    }
507}
508
509/// Pipeline parallelism coordinator.
510pub struct PipelineParallelCoordinator {
511    config: DistributedConfig,
512    backend: Arc<RwLock<dyn CommunicationBackend>>,
513    num_stages: usize,
514    micro_batch_size: usize,
515}
516
517impl PipelineParallelCoordinator {
518    /// Create a new pipeline parallel coordinator.
519    pub fn new(
520        config: DistributedConfig,
521        backend: Arc<RwLock<dyn CommunicationBackend>>,
522        num_stages: usize,
523    ) -> Self {
524        PipelineParallelCoordinator {
525            config,
526            backend,
527            num_stages,
528            micro_batch_size: 1,
529        }
530    }
531
532    /// Set micro-batch size for pipeline parallelism.
533    pub fn set_micro_batch_size(&mut self, size: usize) {
534        self.micro_batch_size = size;
535    }
536
537    /// Get the stage assignment for this rank.
538    pub fn stage_for_rank(&self, rank: usize) -> usize {
539        rank % self.num_stages
540    }
541
542    /// Send activations to next stage.
543    pub fn send_activations(&self, stage: usize) -> Result<(), ExecutorError> {
544        if stage < self.num_stages - 1 {
545            let next_rank = stage + 1;
546            let backend = self.backend.read().unwrap();
547            backend.send("activations", next_rank)?;
548        }
549        Ok(())
550    }
551
552    /// Receive activations from previous stage.
553    pub fn recv_activations(&self, stage: usize) -> Result<(), ExecutorError> {
554        if stage > 0 {
555            let prev_rank = stage - 1;
556            let backend = self.backend.read().unwrap();
557            backend.recv("activations", prev_rank)?;
558        }
559        Ok(())
560    }
561
562    /// Send gradients to previous stage.
563    pub fn send_gradients(&self, stage: usize) -> Result<(), ExecutorError> {
564        if stage > 0 {
565            let prev_rank = stage - 1;
566            let backend = self.backend.read().unwrap();
567            backend.send("gradients", prev_rank)?;
568        }
569        Ok(())
570    }
571
572    /// Receive gradients from next stage.
573    pub fn recv_gradients(&self, stage: usize) -> Result<(), ExecutorError> {
574        if stage < self.num_stages - 1 {
575            let next_rank = stage + 1;
576            let backend = self.backend.read().unwrap();
577            backend.recv("gradients", next_rank)?;
578        }
579        Ok(())
580    }
581
582    /// Get the number of stages in the pipeline.
583    pub fn num_stages(&self) -> usize {
584        self.num_stages
585    }
586
587    /// Get the micro-batch size.
588    pub fn micro_batch_size(&self) -> usize {
589        self.micro_batch_size
590    }
591
592    /// Get the configuration.
593    pub fn config(&self) -> &DistributedConfig {
594        &self.config
595    }
596}
597
598/// Distributed executor that coordinates multi-device execution.
599pub struct DistributedExecutor {
600    config: DistributedConfig,
601    backend: Arc<RwLock<dyn CommunicationBackend>>,
602    data_parallel: Option<DataParallelCoordinator>,
603    model_parallel: Option<ModelParallelCoordinator>,
604    pipeline_parallel: Option<PipelineParallelCoordinator>,
605}
606
607impl DistributedExecutor {
608    /// Create a new distributed executor.
609    pub fn new(
610        config: DistributedConfig,
611        backend: Arc<RwLock<dyn CommunicationBackend>>,
612    ) -> Result<Self, ExecutorError> {
613        // Initialize backend
614        backend.write().unwrap().initialize(&config)?;
615
616        let mut executor = DistributedExecutor {
617            config: config.clone(),
618            backend: backend.clone(),
619            data_parallel: None,
620            model_parallel: None,
621            pipeline_parallel: None,
622        };
623
624        // Setup coordinators based on strategy
625        executor.setup_coordinators()?;
626
627        Ok(executor)
628    }
629
630    /// Setup coordinators based on parallelism strategy.
631    fn setup_coordinators(&mut self) -> Result<(), ExecutorError> {
632        match self.config.parallelism {
633            ParallelismStrategy::DataParallel => {
634                self.data_parallel = Some(DataParallelCoordinator::new(
635                    self.config.clone(),
636                    self.backend.clone(),
637                ));
638            }
639            ParallelismStrategy::ModelParallel => {
640                self.model_parallel = Some(ModelParallelCoordinator::new(
641                    self.config.clone(),
642                    self.backend.clone(),
643                ));
644            }
645            ParallelismStrategy::PipelineParallel => {
646                let num_stages = self.config.num_devices;
647                self.pipeline_parallel = Some(PipelineParallelCoordinator::new(
648                    self.config.clone(),
649                    self.backend.clone(),
650                    num_stages,
651                ));
652            }
653            ParallelismStrategy::Hybrid {
654                data_parallel_groups: _,
655            } => {
656                self.data_parallel = Some(DataParallelCoordinator::new(
657                    self.config.clone(),
658                    self.backend.clone(),
659                ));
660                self.model_parallel = Some(ModelParallelCoordinator::new(
661                    self.config.clone(),
662                    self.backend.clone(),
663                ));
664            }
665        }
666        Ok(())
667    }
668
669    /// Get the parallelism strategy.
670    pub fn strategy(&self) -> ParallelismStrategy {
671        self.config.parallelism
672    }
673
674    /// Get the rank of this process.
675    pub fn rank(&self) -> usize {
676        self.backend.read().unwrap().rank()
677    }
678
679    /// Get the world size.
680    pub fn world_size(&self) -> usize {
681        self.backend.read().unwrap().world_size()
682    }
683
684    /// Synchronize all processes.
685    pub fn barrier(&self) -> Result<(), ExecutorError> {
686        self.backend.read().unwrap().barrier()
687    }
688
689    /// Get data parallel coordinator.
690    pub fn data_parallel(&self) -> Option<&DataParallelCoordinator> {
691        self.data_parallel.as_ref()
692    }
693
694    /// Get model parallel coordinator.
695    pub fn model_parallel(&self) -> Option<&ModelParallelCoordinator> {
696        self.model_parallel.as_ref()
697    }
698
699    /// Get pipeline parallel coordinator.
700    pub fn pipeline_parallel(&self) -> Option<&PipelineParallelCoordinator> {
701        self.pipeline_parallel.as_ref()
702    }
703}
704
705impl Drop for DistributedExecutor {
706    fn drop(&mut self) {
707        let _ = self.backend.write().unwrap().finalize();
708    }
709}
710
711/// Trait for executors that support distributed execution.
712pub trait TlDistributedExecutor {
713    /// Get the distributed executor.
714    fn distributed_executor(&self) -> Option<&DistributedExecutor>;
715
716    /// Enable distributed execution.
717    fn enable_distributed(&mut self, config: DistributedConfig) -> Result<(), ExecutorError>;
718
719    /// Disable distributed execution.
720    fn disable_distributed(&mut self);
721
722    /// Check if distributed execution is enabled.
723    fn is_distributed(&self) -> bool;
724
725    /// Get the current rank.
726    fn rank(&self) -> usize {
727        self.distributed_executor().map(|d| d.rank()).unwrap_or(0)
728    }
729
730    /// Get the world size.
731    fn world_size(&self) -> usize {
732        self.distributed_executor()
733            .map(|d| d.world_size())
734            .unwrap_or(1)
735    }
736}
737
738/// Statistics for distributed execution.
739#[derive(Debug, Clone, Default)]
740pub struct DistributedStats {
741    /// Total number of communication operations
742    pub total_communications: usize,
743    /// Total bytes communicated
744    pub total_bytes_communicated: u64,
745    /// Number of gradient synchronizations
746    pub gradient_syncs: usize,
747    /// Average communication time
748    pub avg_communication_time_ms: f64,
749    /// Load imbalance metric (0.0 = perfect, 1.0 = worst)
750    pub load_imbalance: f64,
751}
752
753impl DistributedStats {
754    /// Get a summary of distributed execution statistics.
755    pub fn summary(&self) -> String {
756        format!(
757            "Distributed Stats: {} communications, {:.2} MB transferred, {} gradient syncs, {:.2}ms avg comm time, {:.2}% load imbalance",
758            self.total_communications,
759            self.total_bytes_communicated as f64 / 1_000_000.0,
760            self.gradient_syncs,
761            self.avg_communication_time_ms,
762            self.load_imbalance * 100.0
763        )
764    }
765}
766
767#[cfg(test)]
768mod tests {
769    use super::*;
770
771    #[test]
772    fn test_distributed_config_default() {
773        let config = DistributedConfig::default();
774        assert_eq!(config.parallelism, ParallelismStrategy::DataParallel);
775        assert_eq!(config.num_devices, 1);
776        assert_eq!(config.rank, 0);
777        assert_eq!(config.world_size, 1);
778    }
779
780    #[test]
781    fn test_sharding_spec() {
782        let devices = vec![
783            Device::new(DeviceType::CPU, 0),
784            Device::new(DeviceType::CPU, 1),
785            Device::new(DeviceType::CPU, 2),
786        ];
787        let spec = ShardingSpec::new(0, 1, devices);
788
789        assert_eq!(spec.num_shards, 3);
790        assert_eq!(spec.shard_dim, 1);
791        assert!(spec.is_valid_shard(0));
792        assert!(spec.is_valid_shard(2));
793        assert!(!spec.is_valid_shard(3));
794    }
795
796    #[test]
797    fn test_distributed_placement_plan() {
798        let mut plan = DistributedPlacementPlan::new();
799
800        plan.place_node(0, Device::new(DeviceType::CPU, 0));
801        plan.place_node(1, Device::new(DeviceType::CPU, 1));
802
803        assert!(plan.get_device(0).is_some());
804        assert!(plan.get_device(1).is_some());
805        assert!(plan.get_device(2).is_none());
806    }
807
808    #[test]
809    fn test_data_parallel_batch_distribution() {
810        let config = DistributedConfig {
811            num_devices: 4,
812            ..Default::default()
813        };
814        let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
815        let coordinator = DataParallelCoordinator::new(config, backend);
816
817        let distribution = coordinator.distribute_batch(10);
818        assert_eq!(distribution.len(), 4);
819
820        // Check total size
821        let total: usize = distribution.iter().map(|(_, size)| size).sum();
822        assert_eq!(total, 10);
823    }
824
825    #[test]
826    fn test_model_parallel_sharding() {
827        let config = DistributedConfig {
828            num_devices: 4,
829            parallelism: ParallelismStrategy::ModelParallel,
830            ..Default::default()
831        };
832        let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
833        let coordinator = ModelParallelCoordinator::new(config, backend);
834
835        let shape = TensorShape::static_shape(vec![8, 16]);
836        let shards = coordinator.shard_tensor(0, &shape, 0).unwrap();
837
838        assert_eq!(shards.len(), 4);
839        // Each shard should have size 2 in dimension 0
840        assert_eq!(shards[0].dims[0].as_static().unwrap(), 2);
841    }
842
843    #[test]
844    fn test_pipeline_parallel_stage_assignment() {
845        let config = DistributedConfig {
846            num_devices: 4,
847            parallelism: ParallelismStrategy::PipelineParallel,
848            ..Default::default()
849        };
850        let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
851        let coordinator = PipelineParallelCoordinator::new(config, backend, 4);
852
853        assert_eq!(coordinator.stage_for_rank(0), 0);
854        assert_eq!(coordinator.stage_for_rank(1), 1);
855        assert_eq!(coordinator.stage_for_rank(2), 2);
856        assert_eq!(coordinator.stage_for_rank(3), 3);
857    }
858
859    #[test]
860    fn test_distributed_executor_creation() {
861        let config = DistributedConfig::default();
862        let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
863
864        let executor = DistributedExecutor::new(config, backend);
865        assert!(executor.is_ok());
866
867        let executor = executor.unwrap();
868        assert_eq!(executor.rank(), 0);
869        assert_eq!(executor.world_size(), 1);
870    }
871
872    #[test]
873    fn test_communication_ops() {
874        let op1 = CommunicationOp::AllReduce {
875            reduction: ReductionOp::Sum,
876        };
877        let op2 = CommunicationOp::Broadcast { src_rank: 0 };
878
879        assert_ne!(op1, op2);
880    }
881
882    #[test]
883    fn test_reduction_ops() {
884        let ops = [
885            ReductionOp::Sum,
886            ReductionOp::Mean,
887            ReductionOp::Max,
888            ReductionOp::Min,
889            ReductionOp::Product,
890        ];
891
892        assert_eq!(ops.len(), 5);
893    }
894
895    #[test]
896    fn test_dummy_backend() {
897        let mut backend = DummyCommunicationBackend::new();
898        let config = DistributedConfig::default();
899
900        assert!(backend.initialize(&config).is_ok());
901        assert_eq!(backend.rank(), 0);
902        assert_eq!(backend.world_size(), 1);
903        assert!(backend.all_reduce("test", ReductionOp::Sum).is_ok());
904        assert!(backend.barrier().is_ok());
905        assert!(backend.finalize().is_ok());
906    }
907
908    #[test]
909    fn test_distributed_stats() {
910        let stats = DistributedStats {
911            total_communications: 100,
912            total_bytes_communicated: 1_000_000,
913            gradient_syncs: 50,
914            avg_communication_time_ms: 10.5,
915            load_imbalance: 0.15,
916        };
917
918        let summary = stats.summary();
919        assert!(summary.contains("100 communications"));
920        assert!(summary.contains("50 gradient syncs"));
921    }
922
923    #[test]
924    fn test_hybrid_parallelism() {
925        let config = DistributedConfig {
926            parallelism: ParallelismStrategy::Hybrid {
927                data_parallel_groups: 2,
928            },
929            num_devices: 8,
930            ..Default::default()
931        };
932
933        let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
934        let executor = DistributedExecutor::new(config, backend).unwrap();
935
936        assert!(executor.data_parallel().is_some());
937        assert!(executor.model_parallel().is_some());
938    }
939
940    #[test]
941    fn test_sharding_invalid_dimension() {
942        let config = DistributedConfig {
943            num_devices: 4,
944            ..Default::default()
945        };
946        let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
947        let coordinator = ModelParallelCoordinator::new(config, backend);
948
949        let shape = TensorShape::static_shape(vec![8, 16]);
950        let result = coordinator.shard_tensor(0, &shape, 5);
951
952        assert!(result.is_err());
953    }
954}