1use crate::capabilities::DeviceType;
39use crate::error::ExecutorError;
40use crate::placement::Device;
41use crate::shape::TensorShape;
42use std::collections::HashMap;
43use std::sync::{Arc, RwLock};
44use tensorlogic_ir::EinsumGraph;
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
48pub enum ParallelismStrategy {
49 #[default]
51 DataParallel,
52 ModelParallel,
54 PipelineParallel,
56 Hybrid { data_parallel_groups: usize },
58}
59
60#[derive(Debug, Clone)]
62pub struct DistributedConfig {
63 pub parallelism: ParallelismStrategy,
65 pub num_devices: usize,
67 pub backend: String,
69 pub master_addr: Option<String>,
71 pub master_port: Option<u16>,
73 pub rank: usize,
75 pub world_size: usize,
77 pub enable_gradient_compression: bool,
79 pub enable_mixed_precision: bool,
81 pub bucket_size: usize,
83 pub enable_async_communication: bool,
85}
86
87impl Default for DistributedConfig {
88 fn default() -> Self {
89 DistributedConfig {
90 parallelism: ParallelismStrategy::default(),
91 num_devices: 1,
92 backend: "gloo".to_string(),
93 master_addr: None,
94 master_port: None,
95 rank: 0,
96 world_size: 1,
97 enable_gradient_compression: false,
98 enable_mixed_precision: false,
99 bucket_size: 25 * 1024 * 1024, enable_async_communication: true,
101 }
102 }
103}
104
105#[derive(Debug, Clone, PartialEq, Eq)]
107pub struct ShardingSpec {
108 pub node_id: usize,
110 pub shard_dim: usize,
112 pub num_shards: usize,
114 pub shard_to_device: Vec<Device>,
116}
117
118impl ShardingSpec {
119 pub fn new(node_id: usize, shard_dim: usize, devices: Vec<Device>) -> Self {
121 let num_shards = devices.len();
122 ShardingSpec {
123 node_id,
124 shard_dim,
125 num_shards,
126 shard_to_device: devices,
127 }
128 }
129
130 pub fn device_for_shard(&self, shard_id: usize) -> Option<&Device> {
132 self.shard_to_device.get(shard_id)
133 }
134
135 pub fn is_valid_shard(&self, shard_id: usize) -> bool {
137 shard_id < self.num_shards
138 }
139}
140
141#[derive(Debug, Clone)]
143pub struct DistributedPlacementPlan {
144 pub node_placement: HashMap<usize, Device>,
146 pub sharding_specs: Vec<ShardingSpec>,
148 pub communication_deps: HashMap<usize, Vec<usize>>,
150}
151
152impl DistributedPlacementPlan {
153 pub fn new() -> Self {
155 DistributedPlacementPlan {
156 node_placement: HashMap::new(),
157 sharding_specs: Vec::new(),
158 communication_deps: HashMap::new(),
159 }
160 }
161
162 pub fn place_node(&mut self, node_id: usize, device: Device) {
164 self.node_placement.insert(node_id, device);
165 }
166
167 pub fn add_sharding(&mut self, spec: ShardingSpec) {
169 self.sharding_specs.push(spec);
170 }
171
172 pub fn get_device(&self, node_id: usize) -> Option<&Device> {
174 self.node_placement.get(&node_id)
175 }
176
177 pub fn get_sharding(&self, node_id: usize) -> Option<&ShardingSpec> {
179 self.sharding_specs.iter().find(|s| s.node_id == node_id)
180 }
181
182 pub fn is_sharded(&self, node_id: usize) -> bool {
184 self.get_sharding(node_id).is_some()
185 }
186}
187
188impl Default for DistributedPlacementPlan {
189 fn default() -> Self {
190 Self::new()
191 }
192}
193
194#[derive(Debug, Clone, PartialEq, Eq)]
196pub enum CommunicationOp {
197 AllReduce {
199 reduction: ReductionOp,
201 },
202 Broadcast {
204 src_rank: usize,
206 },
207 Scatter {
209 src_rank: usize,
211 },
212 Gather {
214 dst_rank: usize,
216 },
217 AllGather,
219 ReduceScatter {
221 reduction: ReductionOp,
223 },
224 Send {
226 dst_rank: usize,
228 },
229 Recv {
231 src_rank: usize,
233 },
234}
235
236#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
238pub enum ReductionOp {
239 Sum,
241 Mean,
243 Max,
245 Min,
247 Product,
249}
250
251pub trait CommunicationBackend: Send + Sync {
253 fn initialize(&mut self, config: &DistributedConfig) -> Result<(), ExecutorError>;
255
256 fn finalize(&mut self) -> Result<(), ExecutorError>;
258
259 fn rank(&self) -> usize;
261
262 fn world_size(&self) -> usize;
264
265 fn all_reduce(&self, tensor_id: &str, reduction: ReductionOp) -> Result<(), ExecutorError>;
267
268 fn broadcast(&self, tensor_id: &str, src_rank: usize) -> Result<(), ExecutorError>;
270
271 fn scatter(&self, tensor_id: &str, src_rank: usize) -> Result<(), ExecutorError>;
273
274 fn gather(&self, tensor_id: &str, dst_rank: usize) -> Result<(), ExecutorError>;
276
277 fn all_gather(&self, tensor_id: &str) -> Result<(), ExecutorError>;
279
280 fn reduce_scatter(&self, tensor_id: &str, reduction: ReductionOp) -> Result<(), ExecutorError>;
282
283 fn send(&self, tensor_id: &str, dst_rank: usize) -> Result<(), ExecutorError>;
285
286 fn recv(&self, tensor_id: &str, src_rank: usize) -> Result<(), ExecutorError>;
288
289 fn barrier(&self) -> Result<(), ExecutorError>;
291}
292
293pub struct DummyCommunicationBackend {
295 rank: usize,
296 world_size: usize,
297}
298
299impl DummyCommunicationBackend {
300 pub fn new() -> Self {
302 DummyCommunicationBackend {
303 rank: 0,
304 world_size: 1,
305 }
306 }
307}
308
309impl Default for DummyCommunicationBackend {
310 fn default() -> Self {
311 Self::new()
312 }
313}
314
315impl CommunicationBackend for DummyCommunicationBackend {
316 fn initialize(&mut self, config: &DistributedConfig) -> Result<(), ExecutorError> {
317 self.rank = config.rank;
318 self.world_size = config.world_size;
319 Ok(())
320 }
321
322 fn finalize(&mut self) -> Result<(), ExecutorError> {
323 Ok(())
324 }
325
326 fn rank(&self) -> usize {
327 self.rank
328 }
329
330 fn world_size(&self) -> usize {
331 self.world_size
332 }
333
334 fn all_reduce(&self, _tensor_id: &str, _reduction: ReductionOp) -> Result<(), ExecutorError> {
335 Ok(())
336 }
337
338 fn broadcast(&self, _tensor_id: &str, _src_rank: usize) -> Result<(), ExecutorError> {
339 Ok(())
340 }
341
342 fn scatter(&self, _tensor_id: &str, _src_rank: usize) -> Result<(), ExecutorError> {
343 Ok(())
344 }
345
346 fn gather(&self, _tensor_id: &str, _dst_rank: usize) -> Result<(), ExecutorError> {
347 Ok(())
348 }
349
350 fn all_gather(&self, _tensor_id: &str) -> Result<(), ExecutorError> {
351 Ok(())
352 }
353
354 fn reduce_scatter(
355 &self,
356 _tensor_id: &str,
357 _reduction: ReductionOp,
358 ) -> Result<(), ExecutorError> {
359 Ok(())
360 }
361
362 fn send(&self, _tensor_id: &str, _dst_rank: usize) -> Result<(), ExecutorError> {
363 Ok(())
364 }
365
366 fn recv(&self, _tensor_id: &str, _src_rank: usize) -> Result<(), ExecutorError> {
367 Ok(())
368 }
369
370 fn barrier(&self) -> Result<(), ExecutorError> {
371 Ok(())
372 }
373}
374
375pub struct DataParallelCoordinator {
377 config: DistributedConfig,
378 backend: Arc<RwLock<dyn CommunicationBackend>>,
379 devices: Vec<Device>,
380}
381
382impl DataParallelCoordinator {
383 pub fn new(config: DistributedConfig, backend: Arc<RwLock<dyn CommunicationBackend>>) -> Self {
385 let devices = (0..config.num_devices)
386 .map(|i| Device::new(DeviceType::CPU, i))
387 .collect();
388
389 DataParallelCoordinator {
390 config,
391 backend,
392 devices,
393 }
394 }
395
396 pub fn distribute_batch(&self, batch_size: usize) -> Vec<(usize, usize)> {
398 let per_device = batch_size / self.config.num_devices;
399 let remainder = batch_size % self.config.num_devices;
400
401 let mut distribution = Vec::new();
402 let mut offset = 0;
403
404 for i in 0..self.config.num_devices {
405 let size = per_device + if i < remainder { 1 } else { 0 };
406 distribution.push((offset, size));
407 offset += size;
408 }
409
410 distribution
411 }
412
413 pub fn synchronize_gradients(&self) -> Result<(), ExecutorError> {
415 let backend = self.backend.read().unwrap();
416
417 backend.all_reduce("gradients", ReductionOp::Mean)?;
419
420 Ok(())
421 }
422
423 pub fn devices(&self) -> &[Device] {
425 &self.devices
426 }
427}
428
429pub struct ModelParallelCoordinator {
431 config: DistributedConfig,
432 backend: Arc<RwLock<dyn CommunicationBackend>>,
433 placement_plan: DistributedPlacementPlan,
434}
435
436impl ModelParallelCoordinator {
437 pub fn new(config: DistributedConfig, backend: Arc<RwLock<dyn CommunicationBackend>>) -> Self {
439 ModelParallelCoordinator {
440 config,
441 backend,
442 placement_plan: DistributedPlacementPlan::new(),
443 }
444 }
445
446 pub fn create_sharding_plan(&mut self, graph: &EinsumGraph) -> Result<(), ExecutorError> {
448 let num_devices = self.config.num_devices;
449 let nodes_per_device = graph.nodes.len().div_ceil(num_devices);
450
451 for (node_id, _node) in graph.nodes.iter().enumerate() {
453 let device_idx = node_id / nodes_per_device;
454 let device = Device::new(DeviceType::CPU, device_idx);
455 self.placement_plan.place_node(node_id, device);
456 }
457
458 Ok(())
459 }
460
461 pub fn placement_plan(&self) -> &DistributedPlacementPlan {
463 &self.placement_plan
464 }
465
466 pub fn shard_tensor(
468 &self,
469 _node_id: usize,
470 shape: &TensorShape,
471 shard_dim: usize,
472 ) -> Result<Vec<TensorShape>, ExecutorError> {
473 let num_shards = self.config.num_devices;
474
475 if shard_dim >= shape.rank() {
476 return Err(ExecutorError::InvalidInput(format!(
477 "Shard dimension {} exceeds tensor rank {}",
478 shard_dim,
479 shape.rank()
480 )));
481 }
482
483 let total_size = shape.dims[shard_dim].as_static().ok_or_else(|| {
484 ExecutorError::InvalidInput("Cannot shard dynamic dimension".to_string())
485 })?;
486
487 let per_shard = total_size / num_shards;
488 let remainder = total_size % num_shards;
489
490 let mut shard_shapes = Vec::new();
491 for i in 0..num_shards {
492 let shard_size = per_shard + if i < remainder { 1 } else { 0 };
493 let mut shard_shape = shape.clone();
494 shard_shape.dims[shard_dim] = crate::shape::DimSize::Static(shard_size);
495 shard_shapes.push(shard_shape);
496 }
497
498 Ok(shard_shapes)
499 }
500
501 pub fn gather_shards(&self, _shard_dim: usize) -> Result<(), ExecutorError> {
503 let backend = self.backend.read().unwrap();
504 backend.all_gather("sharded_tensor")?;
505 Ok(())
506 }
507}
508
509pub struct PipelineParallelCoordinator {
511 config: DistributedConfig,
512 backend: Arc<RwLock<dyn CommunicationBackend>>,
513 num_stages: usize,
514 micro_batch_size: usize,
515}
516
517impl PipelineParallelCoordinator {
518 pub fn new(
520 config: DistributedConfig,
521 backend: Arc<RwLock<dyn CommunicationBackend>>,
522 num_stages: usize,
523 ) -> Self {
524 PipelineParallelCoordinator {
525 config,
526 backend,
527 num_stages,
528 micro_batch_size: 1,
529 }
530 }
531
532 pub fn set_micro_batch_size(&mut self, size: usize) {
534 self.micro_batch_size = size;
535 }
536
537 pub fn stage_for_rank(&self, rank: usize) -> usize {
539 rank % self.num_stages
540 }
541
542 pub fn send_activations(&self, stage: usize) -> Result<(), ExecutorError> {
544 if stage < self.num_stages - 1 {
545 let next_rank = stage + 1;
546 let backend = self.backend.read().unwrap();
547 backend.send("activations", next_rank)?;
548 }
549 Ok(())
550 }
551
552 pub fn recv_activations(&self, stage: usize) -> Result<(), ExecutorError> {
554 if stage > 0 {
555 let prev_rank = stage - 1;
556 let backend = self.backend.read().unwrap();
557 backend.recv("activations", prev_rank)?;
558 }
559 Ok(())
560 }
561
562 pub fn send_gradients(&self, stage: usize) -> Result<(), ExecutorError> {
564 if stage > 0 {
565 let prev_rank = stage - 1;
566 let backend = self.backend.read().unwrap();
567 backend.send("gradients", prev_rank)?;
568 }
569 Ok(())
570 }
571
572 pub fn recv_gradients(&self, stage: usize) -> Result<(), ExecutorError> {
574 if stage < self.num_stages - 1 {
575 let next_rank = stage + 1;
576 let backend = self.backend.read().unwrap();
577 backend.recv("gradients", next_rank)?;
578 }
579 Ok(())
580 }
581
582 pub fn num_stages(&self) -> usize {
584 self.num_stages
585 }
586
587 pub fn micro_batch_size(&self) -> usize {
589 self.micro_batch_size
590 }
591
592 pub fn config(&self) -> &DistributedConfig {
594 &self.config
595 }
596}
597
598pub struct DistributedExecutor {
600 config: DistributedConfig,
601 backend: Arc<RwLock<dyn CommunicationBackend>>,
602 data_parallel: Option<DataParallelCoordinator>,
603 model_parallel: Option<ModelParallelCoordinator>,
604 pipeline_parallel: Option<PipelineParallelCoordinator>,
605}
606
607impl DistributedExecutor {
608 pub fn new(
610 config: DistributedConfig,
611 backend: Arc<RwLock<dyn CommunicationBackend>>,
612 ) -> Result<Self, ExecutorError> {
613 backend.write().unwrap().initialize(&config)?;
615
616 let mut executor = DistributedExecutor {
617 config: config.clone(),
618 backend: backend.clone(),
619 data_parallel: None,
620 model_parallel: None,
621 pipeline_parallel: None,
622 };
623
624 executor.setup_coordinators()?;
626
627 Ok(executor)
628 }
629
630 fn setup_coordinators(&mut self) -> Result<(), ExecutorError> {
632 match self.config.parallelism {
633 ParallelismStrategy::DataParallel => {
634 self.data_parallel = Some(DataParallelCoordinator::new(
635 self.config.clone(),
636 self.backend.clone(),
637 ));
638 }
639 ParallelismStrategy::ModelParallel => {
640 self.model_parallel = Some(ModelParallelCoordinator::new(
641 self.config.clone(),
642 self.backend.clone(),
643 ));
644 }
645 ParallelismStrategy::PipelineParallel => {
646 let num_stages = self.config.num_devices;
647 self.pipeline_parallel = Some(PipelineParallelCoordinator::new(
648 self.config.clone(),
649 self.backend.clone(),
650 num_stages,
651 ));
652 }
653 ParallelismStrategy::Hybrid {
654 data_parallel_groups: _,
655 } => {
656 self.data_parallel = Some(DataParallelCoordinator::new(
657 self.config.clone(),
658 self.backend.clone(),
659 ));
660 self.model_parallel = Some(ModelParallelCoordinator::new(
661 self.config.clone(),
662 self.backend.clone(),
663 ));
664 }
665 }
666 Ok(())
667 }
668
669 pub fn strategy(&self) -> ParallelismStrategy {
671 self.config.parallelism
672 }
673
674 pub fn rank(&self) -> usize {
676 self.backend.read().unwrap().rank()
677 }
678
679 pub fn world_size(&self) -> usize {
681 self.backend.read().unwrap().world_size()
682 }
683
684 pub fn barrier(&self) -> Result<(), ExecutorError> {
686 self.backend.read().unwrap().barrier()
687 }
688
689 pub fn data_parallel(&self) -> Option<&DataParallelCoordinator> {
691 self.data_parallel.as_ref()
692 }
693
694 pub fn model_parallel(&self) -> Option<&ModelParallelCoordinator> {
696 self.model_parallel.as_ref()
697 }
698
699 pub fn pipeline_parallel(&self) -> Option<&PipelineParallelCoordinator> {
701 self.pipeline_parallel.as_ref()
702 }
703}
704
705impl Drop for DistributedExecutor {
706 fn drop(&mut self) {
707 let _ = self.backend.write().unwrap().finalize();
708 }
709}
710
711pub trait TlDistributedExecutor {
713 fn distributed_executor(&self) -> Option<&DistributedExecutor>;
715
716 fn enable_distributed(&mut self, config: DistributedConfig) -> Result<(), ExecutorError>;
718
719 fn disable_distributed(&mut self);
721
722 fn is_distributed(&self) -> bool;
724
725 fn rank(&self) -> usize {
727 self.distributed_executor().map(|d| d.rank()).unwrap_or(0)
728 }
729
730 fn world_size(&self) -> usize {
732 self.distributed_executor()
733 .map(|d| d.world_size())
734 .unwrap_or(1)
735 }
736}
737
738#[derive(Debug, Clone, Default)]
740pub struct DistributedStats {
741 pub total_communications: usize,
743 pub total_bytes_communicated: u64,
745 pub gradient_syncs: usize,
747 pub avg_communication_time_ms: f64,
749 pub load_imbalance: f64,
751}
752
753impl DistributedStats {
754 pub fn summary(&self) -> String {
756 format!(
757 "Distributed Stats: {} communications, {:.2} MB transferred, {} gradient syncs, {:.2}ms avg comm time, {:.2}% load imbalance",
758 self.total_communications,
759 self.total_bytes_communicated as f64 / 1_000_000.0,
760 self.gradient_syncs,
761 self.avg_communication_time_ms,
762 self.load_imbalance * 100.0
763 )
764 }
765}
766
767#[cfg(test)]
768mod tests {
769 use super::*;
770
771 #[test]
772 fn test_distributed_config_default() {
773 let config = DistributedConfig::default();
774 assert_eq!(config.parallelism, ParallelismStrategy::DataParallel);
775 assert_eq!(config.num_devices, 1);
776 assert_eq!(config.rank, 0);
777 assert_eq!(config.world_size, 1);
778 }
779
780 #[test]
781 fn test_sharding_spec() {
782 let devices = vec![
783 Device::new(DeviceType::CPU, 0),
784 Device::new(DeviceType::CPU, 1),
785 Device::new(DeviceType::CPU, 2),
786 ];
787 let spec = ShardingSpec::new(0, 1, devices);
788
789 assert_eq!(spec.num_shards, 3);
790 assert_eq!(spec.shard_dim, 1);
791 assert!(spec.is_valid_shard(0));
792 assert!(spec.is_valid_shard(2));
793 assert!(!spec.is_valid_shard(3));
794 }
795
796 #[test]
797 fn test_distributed_placement_plan() {
798 let mut plan = DistributedPlacementPlan::new();
799
800 plan.place_node(0, Device::new(DeviceType::CPU, 0));
801 plan.place_node(1, Device::new(DeviceType::CPU, 1));
802
803 assert!(plan.get_device(0).is_some());
804 assert!(plan.get_device(1).is_some());
805 assert!(plan.get_device(2).is_none());
806 }
807
808 #[test]
809 fn test_data_parallel_batch_distribution() {
810 let config = DistributedConfig {
811 num_devices: 4,
812 ..Default::default()
813 };
814 let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
815 let coordinator = DataParallelCoordinator::new(config, backend);
816
817 let distribution = coordinator.distribute_batch(10);
818 assert_eq!(distribution.len(), 4);
819
820 let total: usize = distribution.iter().map(|(_, size)| size).sum();
822 assert_eq!(total, 10);
823 }
824
825 #[test]
826 fn test_model_parallel_sharding() {
827 let config = DistributedConfig {
828 num_devices: 4,
829 parallelism: ParallelismStrategy::ModelParallel,
830 ..Default::default()
831 };
832 let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
833 let coordinator = ModelParallelCoordinator::new(config, backend);
834
835 let shape = TensorShape::static_shape(vec![8, 16]);
836 let shards = coordinator.shard_tensor(0, &shape, 0).unwrap();
837
838 assert_eq!(shards.len(), 4);
839 assert_eq!(shards[0].dims[0].as_static().unwrap(), 2);
841 }
842
843 #[test]
844 fn test_pipeline_parallel_stage_assignment() {
845 let config = DistributedConfig {
846 num_devices: 4,
847 parallelism: ParallelismStrategy::PipelineParallel,
848 ..Default::default()
849 };
850 let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
851 let coordinator = PipelineParallelCoordinator::new(config, backend, 4);
852
853 assert_eq!(coordinator.stage_for_rank(0), 0);
854 assert_eq!(coordinator.stage_for_rank(1), 1);
855 assert_eq!(coordinator.stage_for_rank(2), 2);
856 assert_eq!(coordinator.stage_for_rank(3), 3);
857 }
858
859 #[test]
860 fn test_distributed_executor_creation() {
861 let config = DistributedConfig::default();
862 let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
863
864 let executor = DistributedExecutor::new(config, backend);
865 assert!(executor.is_ok());
866
867 let executor = executor.unwrap();
868 assert_eq!(executor.rank(), 0);
869 assert_eq!(executor.world_size(), 1);
870 }
871
872 #[test]
873 fn test_communication_ops() {
874 let op1 = CommunicationOp::AllReduce {
875 reduction: ReductionOp::Sum,
876 };
877 let op2 = CommunicationOp::Broadcast { src_rank: 0 };
878
879 assert_ne!(op1, op2);
880 }
881
882 #[test]
883 fn test_reduction_ops() {
884 let ops = [
885 ReductionOp::Sum,
886 ReductionOp::Mean,
887 ReductionOp::Max,
888 ReductionOp::Min,
889 ReductionOp::Product,
890 ];
891
892 assert_eq!(ops.len(), 5);
893 }
894
895 #[test]
896 fn test_dummy_backend() {
897 let mut backend = DummyCommunicationBackend::new();
898 let config = DistributedConfig::default();
899
900 assert!(backend.initialize(&config).is_ok());
901 assert_eq!(backend.rank(), 0);
902 assert_eq!(backend.world_size(), 1);
903 assert!(backend.all_reduce("test", ReductionOp::Sum).is_ok());
904 assert!(backend.barrier().is_ok());
905 assert!(backend.finalize().is_ok());
906 }
907
908 #[test]
909 fn test_distributed_stats() {
910 let stats = DistributedStats {
911 total_communications: 100,
912 total_bytes_communicated: 1_000_000,
913 gradient_syncs: 50,
914 avg_communication_time_ms: 10.5,
915 load_imbalance: 0.15,
916 };
917
918 let summary = stats.summary();
919 assert!(summary.contains("100 communications"));
920 assert!(summary.contains("50 gradient syncs"));
921 }
922
923 #[test]
924 fn test_hybrid_parallelism() {
925 let config = DistributedConfig {
926 parallelism: ParallelismStrategy::Hybrid {
927 data_parallel_groups: 2,
928 },
929 num_devices: 8,
930 ..Default::default()
931 };
932
933 let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
934 let executor = DistributedExecutor::new(config, backend).unwrap();
935
936 assert!(executor.data_parallel().is_some());
937 assert!(executor.model_parallel().is_some());
938 }
939
940 #[test]
941 fn test_sharding_invalid_dimension() {
942 let config = DistributedConfig {
943 num_devices: 4,
944 ..Default::default()
945 };
946 let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
947 let coordinator = ModelParallelCoordinator::new(config, backend);
948
949 let shape = TensorShape::static_shape(vec![8, 16]);
950 let result = coordinator.shard_tensor(0, &shape, 5);
951
952 assert!(result.is_err());
953 }
954}