tenflowers_neural/distributed/
types.rs1use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5use std::time::Duration;
6use tenflowers_core::{Device, Result, Tensor, TensorError};
7
8#[derive(Debug, Clone, PartialEq, Eq, Hash)]
10pub enum CommunicationBackend {
11 #[cfg(feature = "nccl")]
13 Nccl,
14 Gloo,
16 #[cfg(feature = "mpi")]
18 Mpi,
19 Thread,
21 Custom(String),
23}
24
25#[derive(Debug, Clone)]
27pub struct CommunicationGroup {
28 pub group_id: String,
30 pub rank: usize,
32 pub world_size: usize,
34 pub devices: Vec<Device>,
36 pub backend: CommunicationBackend,
38}
39
40#[derive(Debug, Clone)]
42pub enum CollectiveOp {
43 AllReduce { reduction_op: ReductionOp },
45 AllGather,
47 ReduceScatter { reduction_op: ReductionOp },
49 Broadcast { root_rank: usize },
51 Send { dest_rank: usize },
53 Recv { src_rank: usize },
55}
56
57#[derive(Debug, Clone, Copy)]
59pub enum ReductionOp {
60 Sum,
61 Average,
62 Min,
63 Max,
64 Product,
65}
66
67#[derive(Debug, Default)]
69pub struct CommunicationMetrics {
70 pub total_bytes: u64,
72 pub operation_count: u64,
74 pub total_time: Duration,
76 pub avg_bandwidth: f64,
78 pub operation_metrics: HashMap<String, OperationMetrics>,
80}
81
82#[derive(Debug, Default, Clone)]
84pub struct OperationMetrics {
85 pub count: u64,
86 pub total_time: Duration,
87 pub total_bytes: u64,
88 pub avg_latency: Duration,
89}
90
91impl Clone for CommunicationMetrics {
92 fn clone(&self) -> Self {
93 Self {
94 total_bytes: self.total_bytes,
95 operation_count: self.operation_count,
96 total_time: self.total_time,
97 avg_bandwidth: self.avg_bandwidth,
98 operation_metrics: self.operation_metrics.clone(),
99 }
100 }
101}
102
103pub trait CommunicationBackendImpl: Send + Sync {
107 fn initialize(&mut self, config: &BackendConfig) -> Result<()>;
109
110 fn create_group(&mut self, group: &CommunicationGroup) -> Result<()>;
112
113 fn all_reduce_f32(
115 &self,
116 tensor: &Tensor<f32>,
117 group: &CommunicationGroup,
118 op: ReductionOp,
119 ) -> Result<Tensor<f32>>;
120
121 fn all_gather_f32(
123 &self,
124 tensor: &Tensor<f32>,
125 group: &CommunicationGroup,
126 ) -> Result<Vec<Tensor<f32>>>;
127
128 fn broadcast_f32(
130 &self,
131 tensor: &Tensor<f32>,
132 root_rank: usize,
133 group: &CommunicationGroup,
134 ) -> Result<Tensor<f32>>;
135
136 fn send_f32(
138 &self,
139 tensor: &Tensor<f32>,
140 dest_rank: usize,
141 group: &CommunicationGroup,
142 ) -> Result<()>;
143
144 fn recv_f32(
146 &self,
147 shape: &[usize],
148 src_rank: usize,
149 group: &CommunicationGroup,
150 ) -> Result<Tensor<f32>>;
151
152 fn finalize(&mut self) -> Result<()>;
154
155 fn name(&self) -> &str;
157}
158
159#[derive(Debug, Clone)]
161pub struct BackendConfig {
162 pub options: HashMap<String, String>,
164 pub timeout: Duration,
166 pub compression: bool,
168 pub compression_algo: CompressionAlgorithm,
170}
171
172impl Default for BackendConfig {
173 fn default() -> Self {
174 Self {
175 options: HashMap::new(),
176 timeout: Duration::from_secs(30),
177 compression: false,
178 compression_algo: CompressionAlgorithm::None,
179 }
180 }
181}
182
183#[derive(Debug, Clone)]
185pub enum CompressionAlgorithm {
186 None,
187 TopK {
189 k: usize,
190 },
191 Random {
193 ratio: f32,
194 },
195 Quantization {
197 bits: u8,
198 },
199 Custom(String),
201}
202
203#[derive(Debug)]
205pub enum CollectiveResult<T> {
206 Tensor(Tensor<T>),
208 TensorList(Vec<Tensor<T>>),
210 None,
212}
213
214pub struct CommunicationRuntime {
216 pub(super) groups: HashMap<String, CommunicationGroup>,
218 pub(super) default_group: Option<String>,
220 pub(super) backends: HashMap<CommunicationBackend, Box<dyn CommunicationBackendImpl>>,
222 pub(super) metrics: Arc<Mutex<CommunicationMetrics>>,
224}