torsh_distributed/
process_group.rs1#![allow(unexpected_cfgs)]
4
5use crate::backend::{Backend, BackendConfig, BackendType, MockBackend};
6use crate::{TorshDistributedError, TorshResult};
7use parking_lot::RwLock;
8use std::sync::Arc;
9
10pub type Rank = u32;
12
13pub type WorldSize = u32;
15
16pub struct ProcessGroup {
18 backend: Arc<RwLock<Box<dyn Backend>>>,
19 rank: Rank,
20 world_size: WorldSize,
21 #[allow(dead_code)]
22 master_addr: String,
23 #[allow(dead_code)]
24 master_port: u16,
25}
26
27impl ProcessGroup {
28 pub async fn new(
30 backend_type: BackendType,
31 rank: Rank,
32 world_size: WorldSize,
33 master_addr: &str,
34 master_port: u16,
35 ) -> TorshResult<Self> {
36 let mut backend = create_backend(backend_type, rank, world_size)?;
37
38 let config = BackendConfig::default();
40 backend.init(config).await?;
41
42 let pg = Self {
43 backend: Arc::new(RwLock::new(backend)),
44 rank,
45 world_size,
46 master_addr: master_addr.to_string(),
47 master_port,
48 };
49
50 Ok(pg)
51 }
52
53 pub fn rank(&self) -> Rank {
55 self.rank
56 }
57
58 pub fn world_size(&self) -> WorldSize {
60 self.world_size
61 }
62
63 pub fn backend_type(&self) -> BackendType {
65 self.backend.read().backend_type()
66 }
67
68 pub fn backend(&self) -> &Arc<RwLock<Box<dyn Backend>>> {
70 &self.backend
71 }
72}
73
74fn create_backend(
76 backend_type: BackendType,
77 rank: Rank,
78 world_size: WorldSize,
79) -> TorshResult<Box<dyn Backend>> {
80 match backend_type {
81 #[cfg(feature = "nccl")]
82 BackendType::Nccl => {
83 Ok(Box::new(MockBackend::new(rank, world_size)))
85 }
86 #[cfg(not(feature = "nccl"))]
87 BackendType::Nccl => Err(TorshDistributedError::feature_not_available(
88 "NCCL backend",
89 "nccl",
90 )),
91 #[cfg(feature = "mpi")]
92 BackendType::Mpi => {
93 Ok(Box::new(MockBackend::new(rank, world_size)))
95 }
96 #[cfg(not(feature = "mpi"))]
97 BackendType::Mpi => Err(TorshDistributedError::feature_not_available(
98 "MPI backend",
99 "mpi",
100 )),
101 BackendType::Gloo => {
102 Ok(Box::new(MockBackend::new(rank, world_size)))
104 }
105 BackendType::Custom(name) => Err(TorshDistributedError::feature_not_available(
106 format!("Custom backend: {}", name),
107 "custom backend implementation",
108 )),
109 }
110}