1use crate::{FxGraph, Node, TorshResult};
4use petgraph::graph::NodeIndex;
5use serde::{Deserialize, Serialize};
6use std::collections::{HashMap, HashSet};
7use std::sync::{Arc, RwLock};
8use torsh_core::{device::DeviceType, error::TorshError};
9use torsh_tensor::Tensor;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct DistributedConfig {
14 pub world_size: usize,
16 pub rank: usize,
18 pub master_addr: String,
20 pub master_port: u16,
22 pub backend: CommunicationBackendType,
24 pub timeout: u64,
26}
27
28impl Default for DistributedConfig {
29 fn default() -> Self {
30 Self {
31 world_size: 1,
32 rank: 0,
33 master_addr: "localhost".to_string(),
34 master_port: 23456,
35 backend: CommunicationBackendType::Nccl,
36 timeout: 300,
37 }
38 }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
43pub enum CommunicationBackendType {
44 Nccl,
46 Gloo,
48 Mpi,
50 Tcp,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
56pub enum CollectiveOp {
57 AllReduce,
59 AllGather,
61 ReduceScatter,
63 Broadcast,
65 Send,
67 Recv,
69 Barrier,
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
75pub enum ReduceOp {
76 Sum,
77 Product,
78 Min,
79 Max,
80 Average,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct CommOp {
86 pub op_type: CollectiveOp,
87 pub reduce_op: Option<ReduceOp>,
88 pub src_rank: Option<usize>,
89 pub dst_rank: Option<usize>,
90 pub tag: u32,
91}
92
93#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
95pub enum DistributionStrategy {
96 DataParallel,
98 ModelParallel,
100 PipelineParallel,
102 HybridParallel,
104}
105
106#[derive(Debug, Clone)]
108pub struct DeviceMapping {
109 pub node_to_device: HashMap<NodeIndex, usize>,
111 pub rank_to_device_type: HashMap<usize, DeviceType>,
113 pub comm_groups: Vec<Vec<usize>>,
115}
116
117#[derive(Debug, Clone)]
119pub struct DistributedPartition {
120 pub nodes: HashSet<NodeIndex>,
122 pub external_inputs: HashMap<NodeIndex, usize>, pub external_outputs: HashMap<NodeIndex, Vec<usize>>, pub comm_ops: Vec<(NodeIndex, CommOp)>,
128 pub rank: usize,
130}
131
132#[derive(Debug, Clone)]
134pub struct DistributedExecutionPlan {
135 pub partitions: HashMap<usize, DistributedPartition>,
137 pub execution_order: Vec<Vec<NodeIndex>>, pub comm_schedule: HashMap<usize, Vec<CommOp>>, pub device_mapping: DeviceMapping,
143}
144
145pub struct DistributedPartitioner {
147 config: DistributedConfig,
148 strategy: DistributionStrategy,
149}
150
151impl DistributedPartitioner {
152 pub fn new(config: DistributedConfig, strategy: DistributionStrategy) -> Self {
154 Self { config, strategy }
155 }
156
157 pub fn partition(&self, graph: &FxGraph) -> TorshResult<DistributedExecutionPlan> {
159 match self.strategy {
160 DistributionStrategy::DataParallel => self.partition_data_parallel(graph),
161 DistributionStrategy::ModelParallel => self.partition_model_parallel(graph),
162 DistributionStrategy::PipelineParallel => self.partition_pipeline_parallel(graph),
163 DistributionStrategy::HybridParallel => self.partition_hybrid_parallel(graph),
164 }
165 }
166
167 fn partition_data_parallel(&self, graph: &FxGraph) -> TorshResult<DistributedExecutionPlan> {
169 let mut partitions = HashMap::new();
170 let mut device_mapping = DeviceMapping {
171 node_to_device: HashMap::new(),
172 rank_to_device_type: HashMap::new(),
173 comm_groups: vec![],
174 };
175
176 for rank in 0..self.config.world_size {
178 let mut partition = DistributedPartition {
179 nodes: graph.nodes().map(|(idx, _)| idx).collect(),
180 external_inputs: HashMap::new(),
181 external_outputs: HashMap::new(),
182 comm_ops: vec![],
183 rank,
184 };
185
186 for (node_idx, node) in graph.nodes() {
189 match node {
190 Node::Call(op_name, _)
191 if op_name.contains("backward") || op_name.contains("grad") =>
192 {
193 partition.comm_ops.push((
194 node_idx,
195 CommOp {
196 op_type: CollectiveOp::AllReduce,
197 reduce_op: Some(ReduceOp::Sum),
198 src_rank: None,
199 dst_rank: None,
200 tag: node_idx.index() as u32,
201 },
202 ));
203 }
204 _ => {}
205 }
206
207 device_mapping.node_to_device.insert(node_idx, rank);
208 }
209
210 device_mapping
211 .rank_to_device_type
212 .insert(rank, DeviceType::Cpu);
213 partitions.insert(rank, partition);
214 }
215
216 device_mapping
218 .comm_groups
219 .push((0..self.config.world_size).collect());
220
221 Ok(DistributedExecutionPlan {
222 partitions,
223 execution_order: self.compute_execution_order(graph)?,
224 comm_schedule: self.compute_comm_schedule(graph)?,
225 device_mapping,
226 })
227 }
228
229 fn partition_model_parallel(&self, graph: &FxGraph) -> TorshResult<DistributedExecutionPlan> {
231 let nodes: Vec<_> = graph.nodes().collect();
232 let nodes_per_rank = (nodes.len() + self.config.world_size - 1) / self.config.world_size;
233
234 let mut partitions = HashMap::new();
235 let mut device_mapping = DeviceMapping {
236 node_to_device: HashMap::new(),
237 rank_to_device_type: HashMap::new(),
238 comm_groups: vec![],
239 };
240
241 for rank in 0..self.config.world_size {
242 let start_idx = rank * nodes_per_rank;
243 let end_idx = ((rank + 1) * nodes_per_rank).min(nodes.len());
244
245 let mut partition = DistributedPartition {
246 nodes: HashSet::new(),
247 external_inputs: HashMap::new(),
248 external_outputs: HashMap::new(),
249 comm_ops: vec![],
250 rank,
251 };
252
253 for i in start_idx..end_idx {
255 let (node_idx, _) = nodes[i];
256 partition.nodes.insert(node_idx);
257 device_mapping.node_to_device.insert(node_idx, rank);
258 }
259
260 for &node_idx in &partition.nodes {
262 let predecessors: Vec<_> = graph
264 .graph
265 .neighbors_directed(node_idx, petgraph::Direction::Incoming)
266 .collect();
267
268 for pred_idx in predecessors {
269 if let Some(&src_rank) = device_mapping.node_to_device.get(&pred_idx) {
270 if src_rank != rank {
271 partition.external_inputs.insert(node_idx, src_rank);
272 partition.comm_ops.push((
273 node_idx,
274 CommOp {
275 op_type: CollectiveOp::Recv,
276 reduce_op: None,
277 src_rank: Some(src_rank),
278 dst_rank: Some(rank),
279 tag: node_idx.index() as u32,
280 },
281 ));
282 }
283 }
284 }
285
286 let successors: Vec<_> = graph
288 .graph
289 .neighbors_directed(node_idx, petgraph::Direction::Outgoing)
290 .collect();
291
292 let mut dst_ranks = vec![];
293 for succ_idx in successors {
294 if let Some(&dst_rank) = device_mapping.node_to_device.get(&succ_idx) {
295 if dst_rank != rank && !dst_ranks.contains(&dst_rank) {
296 dst_ranks.push(dst_rank);
297 }
298 }
299 }
300
301 if !dst_ranks.is_empty() {
302 partition
303 .external_outputs
304 .insert(node_idx, dst_ranks.clone());
305 for &dst_rank in &dst_ranks {
306 partition.comm_ops.push((
307 node_idx,
308 CommOp {
309 op_type: CollectiveOp::Send,
310 reduce_op: None,
311 src_rank: Some(rank),
312 dst_rank: Some(dst_rank),
313 tag: node_idx.index() as u32,
314 },
315 ));
316 }
317 }
318 }
319
320 device_mapping
321 .rank_to_device_type
322 .insert(rank, DeviceType::Cpu);
323 partitions.insert(rank, partition);
324 }
325
326 device_mapping
328 .comm_groups
329 .push((0..self.config.world_size).collect());
330
331 Ok(DistributedExecutionPlan {
332 partitions,
333 execution_order: self.compute_execution_order(graph)?,
334 comm_schedule: self.compute_comm_schedule(graph)?,
335 device_mapping,
336 })
337 }
338
339 fn partition_pipeline_parallel(
341 &self,
342 graph: &FxGraph,
343 ) -> TorshResult<DistributedExecutionPlan> {
344 let execution_order = self.compute_execution_order(graph)?;
346 let stages_per_rank =
347 (execution_order.len() + self.config.world_size - 1) / self.config.world_size;
348
349 let mut partitions = HashMap::new();
350 let mut device_mapping = DeviceMapping {
351 node_to_device: HashMap::new(),
352 rank_to_device_type: HashMap::new(),
353 comm_groups: vec![],
354 };
355
356 for rank in 0..self.config.world_size {
357 let start_stage = rank * stages_per_rank;
358 let end_stage = ((rank + 1) * stages_per_rank).min(execution_order.len());
359
360 let mut partition = DistributedPartition {
361 nodes: HashSet::new(),
362 external_inputs: HashMap::new(),
363 external_outputs: HashMap::new(),
364 comm_ops: vec![],
365 rank,
366 };
367
368 for stage_idx in start_stage..end_stage {
370 for &node_idx in &execution_order[stage_idx] {
371 partition.nodes.insert(node_idx);
372 device_mapping.node_to_device.insert(node_idx, rank);
373 }
374 }
375
376 if rank > 0 {
378 for &node_idx in &execution_order[start_stage] {
380 partition.external_inputs.insert(node_idx, rank - 1);
381 partition.comm_ops.push((
382 node_idx,
383 CommOp {
384 op_type: CollectiveOp::Recv,
385 reduce_op: None,
386 src_rank: Some(rank - 1),
387 dst_rank: Some(rank),
388 tag: (rank * 1000 + node_idx.index()) as u32,
389 },
390 ));
391 }
392 }
393
394 if rank < self.config.world_size - 1 && end_stage < execution_order.len() {
395 for &node_idx in &execution_order[end_stage - 1] {
397 partition.external_outputs.insert(node_idx, vec![rank + 1]);
398 partition.comm_ops.push((
399 node_idx,
400 CommOp {
401 op_type: CollectiveOp::Send,
402 reduce_op: None,
403 src_rank: Some(rank),
404 dst_rank: Some(rank + 1),
405 tag: ((rank + 1) * 1000 + node_idx.index()) as u32,
406 },
407 ));
408 }
409 }
410
411 device_mapping
412 .rank_to_device_type
413 .insert(rank, DeviceType::Cpu);
414 partitions.insert(rank, partition);
415 }
416
417 for rank in 0..self.config.world_size - 1 {
419 device_mapping.comm_groups.push(vec![rank, rank + 1]);
420 }
421
422 Ok(DistributedExecutionPlan {
423 partitions,
424 execution_order,
425 comm_schedule: self.compute_comm_schedule(graph)?,
426 device_mapping,
427 })
428 }
429
430 fn partition_hybrid_parallel(&self, graph: &FxGraph) -> TorshResult<DistributedExecutionPlan> {
432 if self.config.world_size <= 2 {
435 self.partition_data_parallel(graph)
436 } else {
437 let model_parallel_ranks = self.config.world_size / 2;
439 let mut base_plan = self.partition_model_parallel(graph)?;
440
441 let mut new_partitions = base_plan.partitions.clone();
443
444 for rank in model_parallel_ranks..self.config.world_size {
445 let base_rank = rank % model_parallel_ranks;
446 if let Some(base_partition) = base_plan.partitions.get(&base_rank) {
447 let mut new_partition = base_partition.clone();
448 new_partition.rank = rank;
449
450 for (node_idx, node) in graph.nodes() {
452 if new_partition.nodes.contains(&node_idx) {
453 if let Node::Call(op_name, _) = node {
454 if op_name.contains("backward") || op_name.contains("grad") {
455 new_partition.comm_ops.push((
456 node_idx,
457 CommOp {
458 op_type: CollectiveOp::AllReduce,
459 reduce_op: Some(ReduceOp::Sum),
460 src_rank: None,
461 dst_rank: None,
462 tag: (rank * 10000 + node_idx.index()) as u32,
463 },
464 ));
465 }
466 }
467 }
468 }
469
470 new_partitions.insert(rank, new_partition);
471 }
472 }
473
474 base_plan.partitions = new_partitions;
475 Ok(base_plan)
476 }
477 }
478
479 fn compute_execution_order(&self, graph: &FxGraph) -> TorshResult<Vec<Vec<NodeIndex>>> {
481 use petgraph::algo::toposort;
482
483 let topo_order = toposort(&graph.graph, None)
484 .map_err(|_| TorshError::InvalidArgument("Graph contains cycles".to_string()))?;
485
486 let mut stages = vec![];
488 let mut current_stage = vec![];
489 let mut processed = HashSet::new();
490
491 for node_idx in topo_order {
492 let predecessors: Vec<_> = graph
494 .graph
495 .neighbors_directed(node_idx, petgraph::Direction::Incoming)
496 .collect();
497
498 let can_execute = predecessors.iter().all(|&pred| processed.contains(&pred));
499
500 if can_execute || predecessors.is_empty() {
501 current_stage.push(node_idx);
502 processed.insert(node_idx);
503 } else {
504 if !current_stage.is_empty() {
506 stages.push(current_stage);
507 current_stage = vec![];
508 }
509 current_stage.push(node_idx);
510 processed.insert(node_idx);
511 }
512 }
513
514 if !current_stage.is_empty() {
515 stages.push(current_stage);
516 }
517
518 Ok(stages)
519 }
520
521 fn compute_comm_schedule(&self, _graph: &FxGraph) -> TorshResult<HashMap<usize, Vec<CommOp>>> {
523 let mut schedule = HashMap::new();
525
526 for rank in 0..self.config.world_size {
527 schedule.insert(rank, vec![]);
528 }
529
530 Ok(schedule)
531 }
532}
533
534pub struct ProcessGroup {
536 config: DistributedConfig,
537 backend: Box<dyn CommunicationBackend + Send + Sync>,
538}
539
540pub trait CommunicationBackend {
542 fn init(&mut self, config: &DistributedConfig) -> TorshResult<()>;
544
545 fn finalize(&mut self) -> TorshResult<()>;
547
548 fn all_reduce(&self, tensor: &mut Tensor, op: ReduceOp) -> TorshResult<()>;
550
551 fn all_gather(&self, input: &Tensor, outputs: &mut [Tensor]) -> TorshResult<()>;
553
554 fn broadcast(&self, tensor: &mut Tensor, root: usize) -> TorshResult<()>;
556
557 fn send(&self, tensor: &Tensor, dst: usize, tag: u32) -> TorshResult<()>;
559
560 fn recv(&self, tensor: &mut Tensor, src: usize, tag: u32) -> TorshResult<()>;
562
563 fn barrier(&self) -> TorshResult<()>;
565
566 fn rank(&self) -> usize;
568
569 fn world_size(&self) -> usize;
571}
572
573pub struct TcpBackend {
575 rank: usize,
576 world_size: usize,
577 initialized: bool,
578}
579
580impl TcpBackend {
581 pub fn new() -> Self {
582 Self {
583 rank: 0,
584 world_size: 1,
585 initialized: false,
586 }
587 }
588}
589
590impl CommunicationBackend for TcpBackend {
591 fn init(&mut self, config: &DistributedConfig) -> TorshResult<()> {
592 self.rank = config.rank;
593 self.world_size = config.world_size;
594 self.initialized = true;
595
596 Ok(())
599 }
600
601 fn finalize(&mut self) -> TorshResult<()> {
602 self.initialized = false;
603 Ok(())
604 }
605
606 fn all_reduce(&self, _tensor: &mut Tensor, _op: ReduceOp) -> TorshResult<()> {
607 if !self.initialized {
608 return Err(TorshError::InvalidArgument(
609 "Backend not initialized".to_string(),
610 ));
611 }
612
613 if self.world_size == 1 {
616 return Ok(());
617 }
618
619 Ok(())
621 }
622
623 fn all_gather(&self, _input: &Tensor, _outputs: &mut [Tensor]) -> TorshResult<()> {
624 if !self.initialized {
625 return Err(TorshError::InvalidArgument(
626 "Backend not initialized".to_string(),
627 ));
628 }
629
630 Ok(())
632 }
633
634 fn broadcast(&self, _tensor: &mut Tensor, _root: usize) -> TorshResult<()> {
635 if !self.initialized {
636 return Err(TorshError::InvalidArgument(
637 "Backend not initialized".to_string(),
638 ));
639 }
640
641 Ok(())
643 }
644
645 fn send(&self, _tensor: &Tensor, _dst: usize, _tag: u32) -> TorshResult<()> {
646 if !self.initialized {
647 return Err(TorshError::InvalidArgument(
648 "Backend not initialized".to_string(),
649 ));
650 }
651
652 Ok(())
654 }
655
656 fn recv(&self, _tensor: &mut Tensor, _src: usize, _tag: u32) -> TorshResult<()> {
657 if !self.initialized {
658 return Err(TorshError::InvalidArgument(
659 "Backend not initialized".to_string(),
660 ));
661 }
662
663 Ok(())
665 }
666
667 fn barrier(&self) -> TorshResult<()> {
668 if !self.initialized {
669 return Err(TorshError::InvalidArgument(
670 "Backend not initialized".to_string(),
671 ));
672 }
673
674 Ok(())
676 }
677
678 fn rank(&self) -> usize {
679 self.rank
680 }
681
682 fn world_size(&self) -> usize {
683 self.world_size
684 }
685}
686
687impl ProcessGroup {
688 pub fn new(config: DistributedConfig) -> TorshResult<Self> {
690 let backend: Box<dyn CommunicationBackend + Send + Sync> = match config.backend {
691 CommunicationBackendType::Tcp => Box::new(TcpBackend::new()),
692 _ => {
693 return Err(TorshError::InvalidArgument(format!(
694 "Backend {:?} not implemented",
695 config.backend
696 )));
697 }
698 };
699
700 Ok(Self { config, backend })
701 }
702
703 pub fn init(&mut self) -> TorshResult<()> {
705 self.backend.init(&self.config)
706 }
707
708 pub fn finalize(&mut self) -> TorshResult<()> {
710 self.backend.finalize()
711 }
712
713 pub fn rank(&self) -> usize {
715 self.backend.rank()
716 }
717
718 pub fn world_size(&self) -> usize {
720 self.backend.world_size()
721 }
722
723 pub fn execute_collective(&self, op: &CommOp, tensor: &mut Tensor) -> TorshResult<()> {
725 match op.op_type {
726 CollectiveOp::AllReduce => {
727 let reduce_op = op.reduce_op.unwrap_or(ReduceOp::Sum);
728 self.backend.all_reduce(tensor, reduce_op)
729 }
730 CollectiveOp::Broadcast => {
731 let root = op.src_rank.unwrap_or(0);
732 self.backend.broadcast(tensor, root)
733 }
734 CollectiveOp::Send => {
735 let dst = op.dst_rank.ok_or_else(|| {
736 TorshError::InvalidArgument("Send operation requires dst_rank".to_string())
737 })?;
738 self.backend.send(tensor, dst, op.tag)
739 }
740 CollectiveOp::Recv => {
741 let src = op.src_rank.ok_or_else(|| {
742 TorshError::InvalidArgument("Recv operation requires src_rank".to_string())
743 })?;
744 self.backend.recv(tensor, src, op.tag)
745 }
746 CollectiveOp::Barrier => self.backend.barrier(),
747 _ => Err(TorshError::InvalidArgument(format!(
748 "Collective operation {:?} not implemented",
749 op.op_type
750 ))),
751 }
752 }
753}
754
755pub struct DistributedExecutor {
757 config: DistributedConfig,
758 process_group: Arc<RwLock<ProcessGroup>>,
759 execution_plan: Option<DistributedExecutionPlan>,
760}
761
762impl DistributedExecutor {
763 pub fn new(config: DistributedConfig) -> TorshResult<Self> {
765 let process_group = ProcessGroup::new(config.clone())?;
766
767 Ok(Self {
768 config,
769 process_group: Arc::new(RwLock::new(process_group)),
770 execution_plan: None,
771 })
772 }
773
774 pub fn init(&mut self) -> TorshResult<()> {
776 let mut pg = self
777 .process_group
778 .write()
779 .map_err(|_| TorshError::InvalidArgument("Failed to acquire write lock".to_string()))?;
780 pg.init()
781 }
782
783 pub fn set_execution_plan(&mut self, plan: DistributedExecutionPlan) {
785 self.execution_plan = Some(plan);
786 }
787
788 pub fn execute(
790 &self,
791 graph: &FxGraph,
792 inputs: HashMap<String, Tensor>,
793 ) -> TorshResult<Vec<Tensor>> {
794 let plan = self
795 .execution_plan
796 .as_ref()
797 .ok_or_else(|| TorshError::InvalidArgument("No execution plan set".to_string()))?;
798
799 let partition = plan.partitions.get(&self.config.rank).ok_or_else(|| {
800 TorshError::InvalidArgument(format!("No partition for rank {}", self.config.rank))
801 })?;
802
803 self.execute_partition(graph, partition, inputs)
805 }
806
807 fn execute_partition(
809 &self,
810 graph: &FxGraph,
811 partition: &DistributedPartition,
812 inputs: HashMap<String, Tensor>,
813 ) -> TorshResult<Vec<Tensor>> {
814 let mut interpreter = crate::interpreter::GraphInterpreter::new(DeviceType::Cpu);
816
817 let local_graph = self.create_local_graph(graph, partition)?;
819
820 let mut local_inputs = inputs;
822
823 for (&node_idx, &_src_rank) in &partition.external_inputs {
825 for (comm_node_idx, comm_op) in &partition.comm_ops {
827 if *comm_node_idx == node_idx && comm_op.op_type == CollectiveOp::Recv {
828 let placeholder = torsh_tensor::creation::zeros(&[1]);
830 let node_index = node_idx.index();
832 local_inputs.insert(format!("external_{node_index}"), placeholder?);
833 break;
834 }
835 }
836 }
837
838 let outputs = interpreter.run(&local_graph, local_inputs)?;
840
841 for (&node_idx, _dst_ranks) in &partition.external_outputs {
843 for (comm_node_idx, comm_op) in &partition.comm_ops {
845 if *comm_node_idx == node_idx && comm_op.op_type == CollectiveOp::Send {
846 break;
848 }
849 }
850 }
851
852 for (_node_idx, comm_op) in &partition.comm_ops {
854 match comm_op.op_type {
855 CollectiveOp::AllReduce | CollectiveOp::Broadcast | CollectiveOp::Barrier => {
856 let pg = self.process_group.read().map_err(|_| {
858 TorshError::InvalidArgument("Failed to acquire read lock".to_string())
859 })?;
860
861 if comm_op.op_type == CollectiveOp::Barrier {
862 let mut temp_tensor = torsh_tensor::creation::zeros(&[1])?;
863 pg.execute_collective(comm_op, &mut temp_tensor)?;
864 }
865 }
867 _ => {
868 }
870 }
871 }
872
873 Ok(outputs)
874 }
875
876 fn create_local_graph(
878 &self,
879 graph: &FxGraph,
880 _partition: &DistributedPartition,
881 ) -> TorshResult<FxGraph> {
882 Ok(graph.clone())
886 }
887
888 pub fn finalize(&mut self) -> TorshResult<()> {
890 let mut pg = self
891 .process_group
892 .write()
893 .map_err(|_| TorshError::InvalidArgument("Failed to acquire write lock".to_string()))?;
894 pg.finalize()
895 }
896}
897
898pub fn init_distributed(config: DistributedConfig) -> TorshResult<DistributedExecutor> {
901 let mut executor = DistributedExecutor::new(config)?;
902 executor.init()?;
903 Ok(executor)
904}
905
906pub fn create_execution_plan(
908 graph: &FxGraph,
909 config: DistributedConfig,
910 strategy: DistributionStrategy,
911) -> TorshResult<DistributedExecutionPlan> {
912 let partitioner = DistributedPartitioner::new(config, strategy);
913 partitioner.partition(graph)
914}
915
916pub fn execute_distributed(
918 graph: &FxGraph,
919 inputs: HashMap<String, Tensor>,
920 config: DistributedConfig,
921 strategy: DistributionStrategy,
922) -> TorshResult<Vec<Tensor>> {
923 let mut executor = init_distributed(config.clone())?;
924 let plan = create_execution_plan(graph, config, strategy)?;
925 executor.set_execution_plan(plan);
926
927 let outputs = executor.execute(graph, inputs)?;
928 executor.finalize()?;
929
930 Ok(outputs)
931}
932
933#[cfg(test)]
934mod tests {
935 use super::*;
936 use crate::tracer::ModuleTracer;
937 use torsh_tensor::creation::ones;
938
939 #[test]
940 fn test_distributed_config() {
941 let config = DistributedConfig::default();
942 assert_eq!(config.world_size, 1);
943 assert_eq!(config.rank, 0);
944 assert_eq!(config.master_addr, "localhost");
945 }
946
947 #[test]
948 fn test_process_group_creation() {
949 let config = DistributedConfig::default();
950 let result = ProcessGroup::new(config);
951 match result {
953 Ok(_) => {
954 }
956 Err(_) => {
957 }
959 }
960 }
961
962 #[test]
963 fn test_distributed_partitioner_data_parallel() {
964 let config = DistributedConfig {
965 world_size: 2,
966 rank: 0,
967 ..Default::default()
968 };
969
970 let partitioner = DistributedPartitioner::new(config, DistributionStrategy::DataParallel);
971
972 let mut tracer = ModuleTracer::new();
973 tracer.add_input("x");
974 tracer.add_call("relu", vec!["x".to_string()]);
975 tracer.add_output("node_0");
976 let graph = tracer.finalize();
977
978 let result = partitioner.partition(&graph);
979 assert!(result.is_ok());
980
981 let plan = result.unwrap();
982 assert_eq!(plan.partitions.len(), 2);
983 }
984
985 #[test]
986 fn test_distributed_partitioner_model_parallel() {
987 let config = DistributedConfig {
988 world_size: 2,
989 rank: 0,
990 ..Default::default()
991 };
992
993 let partitioner = DistributedPartitioner::new(config, DistributionStrategy::ModelParallel);
994
995 let mut tracer = ModuleTracer::new();
996 tracer.add_input("x");
997 tracer.add_call("linear", vec!["x".to_string()]);
998 tracer.add_call("relu", vec!["node_0".to_string()]);
999 tracer.add_output("node_1");
1000 let graph = tracer.finalize();
1001
1002 let result = partitioner.partition(&graph);
1003 assert!(result.is_ok());
1004
1005 let plan = result.unwrap();
1006 assert_eq!(plan.partitions.len(), 2);
1007 }
1008
1009 #[test]
1010 fn test_distributed_executor_creation() {
1011 let config = DistributedConfig::default();
1012 let result = DistributedExecutor::new(config);
1013 match result {
1015 Ok(_) => {
1016 }
1018 Err(_) => {
1019 }
1021 }
1022 }
1023
1024 #[test]
1025 fn test_tcp_backend() {
1026 let mut backend = TcpBackend::new();
1027 let config = DistributedConfig::default();
1028
1029 assert!(backend.init(&config).is_ok());
1030 assert_eq!(backend.rank(), 0);
1031 assert_eq!(backend.world_size(), 1);
1032 assert!(backend.finalize().is_ok());
1033 }
1034
1035 #[test]
1036 fn test_comm_op_serialization() {
1037 let comm_op = CommOp {
1038 op_type: CollectiveOp::AllReduce,
1039 reduce_op: Some(ReduceOp::Sum),
1040 src_rank: None,
1041 dst_rank: None,
1042 tag: 42,
1043 };
1044
1045 let serialized = serde_json::to_string(&comm_op).unwrap();
1046 let deserialized: CommOp = serde_json::from_str(&serialized).unwrap();
1047
1048 assert_eq!(comm_op.tag, deserialized.tag);
1049 match (comm_op.op_type, deserialized.op_type) {
1050 (CollectiveOp::AllReduce, CollectiveOp::AllReduce) => {}
1051 _ => panic!("Serialization failed"),
1052 }
1053 }
1054
1055 #[test]
1056 fn test_execution_plan_creation() {
1057 let config = DistributedConfig {
1058 world_size: 2,
1059 rank: 0,
1060 ..Default::default()
1061 };
1062
1063 let mut tracer = ModuleTracer::new();
1064 tracer.add_input("x");
1065 tracer.add_call("relu", vec!["x".to_string()]);
1066 tracer.add_output("node_0");
1067 let graph = tracer.finalize();
1068
1069 let result = create_execution_plan(&graph, config, DistributionStrategy::DataParallel);
1070 assert!(result.is_ok());
1071 }
1072
1073 #[test]
1074 fn test_distributed_execution_single_rank() {
1075 let config = DistributedConfig::default();
1076
1077 let mut tracer = ModuleTracer::new();
1078 tracer.add_input("x");
1079 tracer.add_call("relu", vec!["x".to_string()]);
1080 tracer.add_output("node_0");
1081 let graph = tracer.finalize();
1082
1083 let mut inputs = HashMap::new();
1084 inputs.insert("x".to_string(), ones(&[2, 3]).unwrap());
1085
1086 let result =
1087 execute_distributed(&graph, inputs, config, DistributionStrategy::DataParallel);
1088 match result {
1090 Ok(outputs) => {
1091 assert!(!outputs.is_empty());
1092 }
1093 Err(_) => {
1094 }
1096 }
1097 }
1098}