Skip to main content

tenflowers_neural/distributed/
types.rs

1//! Distributed training types: structs, enums, error types
2
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5use std::time::Duration;
6use tenflowers_core::{Device, Result, Tensor, TensorError};
7
8/// Communication backend types for distributed training
9#[derive(Debug, Clone, PartialEq, Eq, Hash)]
10pub enum CommunicationBackend {
11    /// NCCL backend for NVIDIA GPUs
12    #[cfg(feature = "nccl")]
13    Nccl,
14    /// Gloo backend for general CPU/GPU communication
15    Gloo,
16    /// MPI backend for HPC environments
17    #[cfg(feature = "mpi")]
18    Mpi,
19    /// Thread-based backend for single-node multi-GPU
20    Thread,
21    /// Custom user-defined backend
22    Custom(String),
23}
24
25/// Communication group for collective operations
26#[derive(Debug, Clone)]
27pub struct CommunicationGroup {
28    /// Group identifier
29    pub group_id: String,
30    /// Rank of this process in the group
31    pub rank: usize,
32    /// Total number of processes in the group
33    pub world_size: usize,
34    /// Devices participating in this group
35    pub devices: Vec<Device>,
36    /// Backend used for communication
37    pub backend: CommunicationBackend,
38}
39
40/// Collective operation types
41#[derive(Debug, Clone)]
42pub enum CollectiveOp {
43    /// All-reduce operation (sum, average, min, max)
44    AllReduce { reduction_op: ReductionOp },
45    /// All-gather operation
46    AllGather,
47    /// Reduce-scatter operation
48    ReduceScatter { reduction_op: ReductionOp },
49    /// Broadcast from root rank
50    Broadcast { root_rank: usize },
51    /// Point-to-point send
52    Send { dest_rank: usize },
53    /// Point-to-point receive
54    Recv { src_rank: usize },
55}
56
57/// Reduction operations for collective ops
58#[derive(Debug, Clone, Copy)]
59pub enum ReductionOp {
60    Sum,
61    Average,
62    Min,
63    Max,
64    Product,
65}
66
67/// Communication performance metrics
68#[derive(Debug, Default)]
69pub struct CommunicationMetrics {
70    /// Total bytes communicated
71    pub total_bytes: u64,
72    /// Number of operations performed
73    pub operation_count: u64,
74    /// Total communication time
75    pub total_time: Duration,
76    /// Average bandwidth (bytes/second)
77    pub avg_bandwidth: f64,
78    /// Per-operation metrics
79    pub operation_metrics: HashMap<String, OperationMetrics>,
80}
81
82/// Metrics for specific operation types
83#[derive(Debug, Default, Clone)]
84pub struct OperationMetrics {
85    pub count: u64,
86    pub total_time: Duration,
87    pub total_bytes: u64,
88    pub avg_latency: Duration,
89}
90
91impl Clone for CommunicationMetrics {
92    fn clone(&self) -> Self {
93        Self {
94            total_bytes: self.total_bytes,
95            operation_count: self.operation_count,
96            total_time: self.total_time,
97            avg_bandwidth: self.avg_bandwidth,
98            operation_metrics: self.operation_metrics.clone(),
99        }
100    }
101}
102
103/// Trait for communication backend implementations
104/// Note: For simplicity, we use f32 tensors for now. This can be extended to support
105/// multiple types using enum dispatch or other type erasure techniques.
106pub trait CommunicationBackendImpl: Send + Sync {
107    /// Initialize the backend
108    fn initialize(&mut self, config: &BackendConfig) -> Result<()>;
109
110    /// Create a communication group
111    fn create_group(&mut self, group: &CommunicationGroup) -> Result<()>;
112
113    /// Perform all-reduce operation with f32 tensors
114    fn all_reduce_f32(
115        &self,
116        tensor: &Tensor<f32>,
117        group: &CommunicationGroup,
118        op: ReductionOp,
119    ) -> Result<Tensor<f32>>;
120
121    /// Perform all-gather operation with f32 tensors
122    fn all_gather_f32(
123        &self,
124        tensor: &Tensor<f32>,
125        group: &CommunicationGroup,
126    ) -> Result<Vec<Tensor<f32>>>;
127
128    /// Perform broadcast operation with f32 tensors
129    fn broadcast_f32(
130        &self,
131        tensor: &Tensor<f32>,
132        root_rank: usize,
133        group: &CommunicationGroup,
134    ) -> Result<Tensor<f32>>;
135
136    /// Send f32 tensor to specific rank
137    fn send_f32(
138        &self,
139        tensor: &Tensor<f32>,
140        dest_rank: usize,
141        group: &CommunicationGroup,
142    ) -> Result<()>;
143
144    /// Receive f32 tensor from specific rank
145    fn recv_f32(
146        &self,
147        shape: &[usize],
148        src_rank: usize,
149        group: &CommunicationGroup,
150    ) -> Result<Tensor<f32>>;
151
152    /// Finalize the backend
153    fn finalize(&mut self) -> Result<()>;
154
155    /// Get backend name
156    fn name(&self) -> &str;
157}
158
159/// Configuration for communication backends
160#[derive(Debug, Clone)]
161pub struct BackendConfig {
162    /// Backend-specific options
163    pub options: HashMap<String, String>,
164    /// Timeout for operations
165    pub timeout: Duration,
166    /// Enable compression
167    pub compression: bool,
168    /// Compression algorithm if enabled
169    pub compression_algo: CompressionAlgorithm,
170}
171
172impl Default for BackendConfig {
173    fn default() -> Self {
174        Self {
175            options: HashMap::new(),
176            timeout: Duration::from_secs(30),
177            compression: false,
178            compression_algo: CompressionAlgorithm::None,
179        }
180    }
181}
182
183/// Compression algorithms for communication
184#[derive(Debug, Clone)]
185pub enum CompressionAlgorithm {
186    None,
187    /// Top-k sparsification
188    TopK {
189        k: usize,
190    },
191    /// Random sparsification
192    Random {
193        ratio: f32,
194    },
195    /// Quantization
196    Quantization {
197        bits: u8,
198    },
199    /// Custom compression
200    Custom(String),
201}
202
203/// Result of collective operations with f32 tensors
204#[derive(Debug)]
205pub enum CollectiveResult<T> {
206    /// Single tensor result
207    Tensor(Tensor<T>),
208    /// Multiple tensor result (e.g., from all-gather)
209    TensorList(Vec<Tensor<T>>),
210    /// No result (e.g., from send)
211    None,
212}
213
214/// Communication runtime for managing distributed operations
215pub struct CommunicationRuntime {
216    /// Active communication groups
217    pub(super) groups: HashMap<String, CommunicationGroup>,
218    /// Default communication group
219    pub(super) default_group: Option<String>,
220    /// Backend implementations
221    pub(super) backends: HashMap<CommunicationBackend, Box<dyn CommunicationBackendImpl>>,
222    /// Performance metrics
223    pub(super) metrics: Arc<Mutex<CommunicationMetrics>>,
224}