Skip to main content

torsh_distributed/
process_group.rs

1//! Process group management for distributed training
2
3#![allow(unexpected_cfgs)]
4
5use crate::backend::{Backend, BackendConfig, BackendType, MockBackend};
6use crate::{TorshDistributedError, TorshResult};
7use parking_lot::RwLock;
8use std::sync::Arc;
9
10/// Process rank type
11pub type Rank = u32;
12
13/// World size type
14pub type WorldSize = u32;
15
16/// Process group for distributed communication
17pub struct ProcessGroup {
18    backend: Arc<RwLock<Box<dyn Backend>>>,
19    rank: Rank,
20    world_size: WorldSize,
21    #[allow(dead_code)]
22    master_addr: String,
23    #[allow(dead_code)]
24    master_port: u16,
25}
26
27impl ProcessGroup {
28    /// Create a new process group
29    pub async fn new(
30        backend_type: BackendType,
31        rank: Rank,
32        world_size: WorldSize,
33        master_addr: &str,
34        master_port: u16,
35    ) -> TorshResult<Self> {
36        let mut backend = create_backend(backend_type, rank, world_size)?;
37
38        // Initialize the backend with default config
39        let config = BackendConfig::default();
40        backend.init(config).await?;
41
42        let pg = Self {
43            backend: Arc::new(RwLock::new(backend)),
44            rank,
45            world_size,
46            master_addr: master_addr.to_string(),
47            master_port,
48        };
49
50        Ok(pg)
51    }
52
53    /// Get the rank of this process
54    pub fn rank(&self) -> Rank {
55        self.rank
56    }
57
58    /// Get the world size
59    pub fn world_size(&self) -> WorldSize {
60        self.world_size
61    }
62
63    /// Get the backend type
64    pub fn backend_type(&self) -> BackendType {
65        self.backend.read().backend_type()
66    }
67
68    /// Get a reference to the backend
69    pub fn backend(&self) -> &Arc<RwLock<Box<dyn Backend>>> {
70        &self.backend
71    }
72}
73
74/// Create a backend based on the type
75fn create_backend(
76    backend_type: BackendType,
77    rank: Rank,
78    world_size: WorldSize,
79) -> TorshResult<Box<dyn Backend>> {
80    match backend_type {
81        #[cfg(feature = "nccl")]
82        BackendType::Nccl => {
83            // For now, use mock backend - NCCL backend needs implementation
84            Ok(Box::new(MockBackend::new(rank, world_size)))
85        }
86        #[cfg(not(feature = "nccl"))]
87        BackendType::Nccl => Err(TorshDistributedError::feature_not_available(
88            "NCCL backend",
89            "nccl",
90        )),
91        #[cfg(feature = "mpi")]
92        BackendType::Mpi => {
93            // For now, use mock backend - MPI backend needs implementation
94            Ok(Box::new(MockBackend::new(rank, world_size)))
95        }
96        #[cfg(not(feature = "mpi"))]
97        BackendType::Mpi => Err(TorshDistributedError::feature_not_available(
98            "MPI backend",
99            "mpi",
100        )),
101        BackendType::Gloo => {
102            // Use mock backend for now
103            Ok(Box::new(MockBackend::new(rank, world_size)))
104        }
105        BackendType::Custom(name) => Err(TorshDistributedError::feature_not_available(
106            format!("Custom backend: {}", name),
107            "custom backend implementation",
108        )),
109    }
110}