tenflowers_neural/distributed/
data_parallel.rs1use parking_lot::RwLock;
4use std::sync::Arc;
5use tenflowers_core::{Device, Result, Tensor, TensorError};
6
7use super::types::{
8 BackendConfig, CollectiveOp, CollectiveResult, CommunicationBackend, CommunicationGroup,
9 CommunicationRuntime, ReductionOp,
10};
11use crate::Model;
12
13pub struct DataParallel {
15 pub(crate) base_model: Arc<RwLock<Box<dyn Model<f32>>>>,
17 pub(crate) device_replicas: Vec<Device>,
19 pub(crate) comm_runtime: Arc<RwLock<CommunicationRuntime>>,
21 pub(crate) is_training: bool,
23 pub(crate) sync_mode: SynchronizationMode,
25}
26
27pub struct DistributedDataParallel {
29 pub(crate) base_model: Arc<RwLock<Box<dyn Model<f32>>>>,
31 pub(crate) process_group: Arc<CommunicationGroup>,
33 pub(crate) comm_runtime: Arc<RwLock<CommunicationRuntime>>,
35 pub(crate) device: Device,
37 pub(crate) broadcast_buffers: bool,
39 pub(crate) is_training: bool,
41 pub(crate) bucket_size: usize,
43 pub(crate) ddp_config: DDPConfig,
45}
46
47#[derive(Debug, Clone)]
49pub struct DDPConfig {
50 pub find_unused_parameters: bool,
52 pub gradient_as_bucket_view: bool,
54 pub static_graph: bool,
56 pub delay_all_reduce: bool,
58}
59
60#[derive(Debug, Clone, Copy)]
62pub enum SynchronizationMode {
63 Synchronous,
65 Asynchronous,
67 BoundedStaleness { max_staleness: u32 },
69}
70
71impl Default for DDPConfig {
72 fn default() -> Self {
73 Self {
74 find_unused_parameters: false,
75 gradient_as_bucket_view: false,
76 static_graph: false,
77 delay_all_reduce: true,
78 }
79 }
80}
81
82impl DataParallel {
83 pub fn new(
85 model: Box<dyn Model<f32>>,
86 devices: Vec<Device>,
87 comm_runtime: Arc<RwLock<CommunicationRuntime>>,
88 ) -> Result<Self> {
89 if devices.is_empty() {
90 return Err(TensorError::invalid_argument_op(
91 "DataParallel::new",
92 "No devices provided",
93 ));
94 }
95
96 #[allow(clippy::arc_with_non_send_sync)]
97 let base_model = Arc::new(RwLock::new(model));
98
99 Self::replicate_parameters(&base_model, &devices)?;
100
101 Ok(Self {
102 base_model,
103 device_replicas: devices,
104 comm_runtime,
105 is_training: true,
106 sync_mode: SynchronizationMode::Synchronous,
107 })
108 }
109
110 fn replicate_parameters(
112 model: &Arc<RwLock<Box<dyn Model<f32>>>>,
113 devices: &[Device],
114 ) -> Result<()> {
115 let model_read = model.read();
116 let parameters = model_read.parameters();
117
118 for param in parameters {
119 for device in devices {
120 if *param.device() != *device {
121 param.to(device.clone())?;
122 }
123 }
124 }
125
126 Ok(())
127 }
128
129 pub fn forward_parallel(&self, inputs: &[Tensor<f32>]) -> Result<Vec<Tensor<f32>>> {
131 if inputs.len() != self.device_replicas.len() {
132 return Err(TensorError::invalid_argument_op(
133 "forward_parallel",
134 &format!(
135 "Expected {} inputs for {} devices",
136 self.device_replicas.len(),
137 inputs.len()
138 ),
139 ));
140 }
141
142 let model = self.base_model.read();
143 let mut outputs = Vec::with_capacity(inputs.len());
144
145 for (input, device) in inputs.iter().zip(&self.device_replicas) {
146 let input_on_device = if *input.device() != *device {
147 input.to(device.clone())?
148 } else {
149 input.clone()
150 };
151
152 let output = model.forward(&input_on_device)?;
153 outputs.push(output);
154 }
155
156 Ok(outputs)
157 }
158
159 pub fn sync_gradients(&mut self) -> Result<()> {
161 if !self.is_training {
162 return Ok(());
163 }
164
165 let mut model = self.base_model.write();
166 let mut parameters = model.parameters_mut();
167
168 for param in parameters.iter_mut() {
169 if let Some(grad) = param.grad() {
170 let comm_runtime = self.comm_runtime.read();
171
172 let op = CollectiveOp::AllReduce {
173 reduction_op: ReductionOp::Average,
174 };
175
176 if let Ok(CollectiveResult::Tensor(synced_grad)) =
177 comm_runtime.collective_op_f32(op, grad, None)
178 {
179 param.set_grad(Some(synced_grad));
180 }
181 }
182 }
183
184 Ok(())
185 }
186
187 pub fn set_sync_mode(&mut self, mode: SynchronizationMode) {
189 self.sync_mode = mode;
190 }
191
192 pub fn devices(&self) -> &[Device] {
194 &self.device_replicas
195 }
196}
197
198impl Model<f32> for DataParallel {
199 fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
200 let inputs: Vec<Tensor<f32>> = self
201 .device_replicas
202 .iter()
203 .map(|device| input.to(device.clone()))
204 .collect::<Result<Vec<_>>>()?;
205
206 let outputs = self.forward_parallel(&inputs)?;
207
208 let primary_device = &self.device_replicas[0];
209 let gathered_outputs: Vec<Tensor<f32>> = outputs
210 .into_iter()
211 .map(|output| output.to(primary_device.clone()))
212 .collect::<Result<Vec<_>>>()?;
213
214 let mut result = gathered_outputs[0].clone();
215 for output in &gathered_outputs[1..] {
216 result = result.add(output)?;
217 }
218
219 let num_devices = gathered_outputs.len() as f32;
220 let divisor = Tensor::from_scalar(num_devices);
221 result.div(&divisor)
222 }
223
224 fn parameters(&self) -> Vec<&Tensor<f32>> {
225 vec![]
226 }
227
228 fn parameters_mut(&mut self) -> Vec<&mut Tensor<f32>> {
229 vec![]
230 }
231
232 fn set_training(&mut self, training: bool) {
233 self.is_training = training;
234 self.base_model.write().set_training(training);
235 }
236
237 fn zero_grad(&mut self) {
238 self.base_model.write().zero_grad();
239 }
240
241 fn as_any(&self) -> &dyn std::any::Any {
242 self
243 }
244}
245
246impl DistributedDataParallel {
247 pub fn new(
249 model: Box<dyn Model<f32>>,
250 device: Device,
251 process_group: Arc<CommunicationGroup>,
252 comm_runtime: Arc<RwLock<CommunicationRuntime>>,
253 config: DDPConfig,
254 ) -> Result<Self> {
255 #[allow(clippy::arc_with_non_send_sync)]
256 let base_model = Arc::new(RwLock::new(model));
257
258 let mut ddp = Self {
259 base_model,
260 process_group,
261 comm_runtime,
262 device,
263 broadcast_buffers: true,
264 is_training: true,
265 bucket_size: 25 * 1024 * 1024, ddp_config: config,
267 };
268
269 if ddp.broadcast_buffers {
270 ddp.broadcast_parameters()?;
271 }
272
273 Ok(ddp)
274 }
275
276 fn broadcast_parameters(&mut self) -> Result<()> {
278 let model = self.base_model.read();
279 let parameters = model.parameters();
280
281 let comm_runtime = self.comm_runtime.read();
282
283 for param in parameters {
284 let op = CollectiveOp::Broadcast { root_rank: 0 };
285
286 if let Ok(CollectiveResult::Tensor(_synced_param)) =
287 comm_runtime.collective_op_f32(op, param, Some(&self.process_group.group_id))
288 {
289 }
292 }
293
294 Ok(())
295 }
296
297 pub fn sync_gradients(&mut self) -> Result<()> {
299 if !self.is_training {
300 return Ok(());
301 }
302
303 let mut model = self.base_model.write();
304 let mut parameters = model.parameters_mut();
305
306 let mut gradient_buckets = self.create_gradient_buckets(&mut parameters)?;
307
308 let comm_runtime = self.comm_runtime.read();
309
310 for bucket in &gradient_buckets {
311 for grad_tensor in bucket {
312 let op = CollectiveOp::AllReduce {
313 reduction_op: ReductionOp::Average,
314 };
315
316 if let Ok(CollectiveResult::Tensor(_synced_grad)) = comm_runtime.collective_op_f32(
317 op,
318 grad_tensor,
319 Some(&self.process_group.group_id),
320 ) {
321 }
323 }
324 }
325
326 Ok(())
327 }
328
329 fn create_gradient_buckets<'a>(
331 &self,
332 parameters: &'a mut [&'a mut Tensor<f32>],
333 ) -> Result<Vec<Vec<&'a Tensor<f32>>>> {
334 let mut buckets = Vec::new();
335 let mut current_bucket = Vec::new();
336 let mut current_bucket_size = 0;
337
338 for param in parameters {
339 if let Some(grad) = param.grad() {
340 let grad_size = grad.shape().size() * std::mem::size_of::<f32>();
341
342 if current_bucket_size + grad_size > self.bucket_size && !current_bucket.is_empty()
343 {
344 buckets.push(std::mem::take(&mut current_bucket));
345 current_bucket_size = 0;
346 }
347
348 current_bucket.push(grad);
349 current_bucket_size += grad_size;
350 }
351 }
352
353 if !current_bucket.is_empty() {
354 buckets.push(current_bucket);
355 }
356
357 Ok(buckets)
358 }
359
360 pub fn process_group(&self) -> &CommunicationGroup {
362 &self.process_group
363 }
364
365 pub fn local_rank(&self) -> usize {
367 self.process_group.rank
368 }
369
370 pub fn world_size(&self) -> usize {
372 self.process_group.world_size
373 }
374
375 pub fn set_bucket_size(&mut self, size: usize) {
377 self.bucket_size = size;
378 }
379}
380
381impl Model<f32> for DistributedDataParallel {
382 fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
383 let input_on_device = if *input.device() != self.device {
384 input.to(self.device.clone())?
385 } else {
386 input.clone()
387 };
388
389 self.base_model.read().forward(&input_on_device)
390 }
391
392 fn parameters(&self) -> Vec<&Tensor<f32>> {
393 vec![]
394 }
395
396 fn parameters_mut(&mut self) -> Vec<&mut Tensor<f32>> {
397 vec![]
398 }
399
400 fn set_training(&mut self, training: bool) {
401 self.is_training = training;
402 self.base_model.write().set_training(training);
403 }
404
405 fn zero_grad(&mut self) {
406 self.base_model.write().zero_grad();
407 }
408
409 fn as_any(&self) -> &dyn std::any::Any {
410 self
411 }
412}
413
414pub mod utils {
416 use super::super::types::CommunicationBackend;
417 use super::*;
418
419 pub fn init_process_group(
421 backend: CommunicationBackend,
422 rank: usize,
423 world_size: usize,
424 ) -> Result<(Arc<RwLock<CommunicationRuntime>>, Arc<CommunicationGroup>)> {
425 let mut comm_runtime = CommunicationRuntime::new();
426
427 match backend {
428 CommunicationBackend::Thread => {
429 comm_runtime.register_backend(
430 CommunicationBackend::Thread,
431 Box::new(crate::backends::thread::ThreadBackend::new()),
432 );
433 }
434 #[cfg(feature = "nccl")]
435 CommunicationBackend::Nccl => {
436 comm_runtime.register_backend(
437 CommunicationBackend::Nccl,
438 Box::new(crate::backends::nccl::NcclBackend::new()),
439 );
440 }
441 _ => {
442 return Err(TensorError::unsupported_operation_simple(format!(
443 "Backend {backend:?} not supported"
444 )));
445 }
446 }
447
448 let config = BackendConfig::default();
449 comm_runtime.initialize(&config)?;
450
451 let devices = super::super::auto_detect_available_devices();
452 let process_group = Arc::new(CommunicationGroup {
453 group_id: "ddp_main".to_string(),
454 rank,
455 world_size,
456 devices,
457 backend,
458 });
459
460 let runtime = Arc::new(RwLock::new(comm_runtime));
461 runtime.write().create_group((*process_group).clone())?;
462
463 Ok((runtime, process_group))
464 }
465
466 pub fn create_data_parallel(model: Box<dyn Model<f32>>) -> Result<DataParallel> {
468 let devices = super::super::auto_detect_available_devices();
469 let comm_runtime = super::super::utils::init_distributed(0, devices.len(), None)?;
470 let runtime = Arc::new(RwLock::new(comm_runtime));
471
472 DataParallel::new(model, devices, runtime)
473 }
474
475 pub fn create_distributed_data_parallel(
477 model: Box<dyn Model<f32>>,
478 device: Device,
479 backend: CommunicationBackend,
480 rank: usize,
481 world_size: usize,
482 ) -> Result<DistributedDataParallel> {
483 let (comm_runtime, process_group) = init_process_group(backend, rank, world_size)?;
484 let config = DDPConfig::default();
485
486 DistributedDataParallel::new(model, device, process_group, comm_runtime, config)
487 }
488}