1#![allow(dead_code)]
5#![allow(clippy::await_holding_lock)]
6use crate::backend::ReduceOp;
7use crate::collectives::all_reduce;
8use crate::{process_group::ProcessGroup, TorshResult};
9use log::info;
10use std::collections::{HashMap, HashSet};
11use std::sync::{Arc, Mutex};
12use std::time::{Duration, Instant};
13use tokio::sync::{mpsc, Semaphore};
14use tokio::task::JoinHandle;
15use torsh_core::{error::Result, DeviceType};
16use torsh_nn::{Module, Parameter};
17use torsh_tensor::Tensor;
18
19#[derive(Debug, Clone)]
21pub struct BucketConfig {
22 pub max_bucket_size_mb: f32,
24 pub enabled: bool,
26 pub min_bucket_size_mb: f32,
28}
29
30impl Default for BucketConfig {
31 fn default() -> Self {
32 Self {
33 max_bucket_size_mb: 25.0,
34 enabled: true,
35 min_bucket_size_mb: 1.0,
36 }
37 }
38}
39
40#[derive(Debug)]
42struct GradientBucket {
43 parameters: Vec<String>,
45 total_size: usize,
47 _ready: bool,
49}
50
51impl GradientBucket {
52 fn new() -> Self {
53 Self {
54 parameters: Vec::new(),
55 total_size: 0,
56 _ready: false,
57 }
58 }
59
60 fn add_parameter(&mut self, name: String, size: usize) {
61 self.parameters.push(name);
62 self.total_size += size;
63 }
64
65 fn size_mb(&self) -> f32 {
66 self.total_size as f32 / (1024.0 * 1024.0)
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct GradientSyncStats {
73 pub total_parameters: usize,
75 pub parameters_with_grad: usize,
77 pub total_gradient_size_mb: f32,
79 pub num_buckets: usize,
81 pub world_size: u32,
83}
84
85#[derive(Debug, Clone)]
87pub struct BucketInfo {
88 pub index: usize,
90 pub size_mb: f32,
92 pub num_parameters: usize,
94 pub parameter_names: Vec<String>,
96}
97
98#[derive(Debug)]
100struct GradientMessage {
101 param_name: String,
103 gradient: Tensor,
105 bucket_index: usize,
107}
108
109#[derive(Debug, Default)]
111struct UnusedParameterTracker {
112 all_parameters: HashSet<String>,
114 used_parameters: HashSet<String>,
116 enabled: bool,
118 iteration: u64,
120}
121
122#[derive(Debug, Clone)]
124pub struct OverlapConfig {
125 pub enabled: bool,
127 pub max_pending_syncs: usize,
129 pub sync_timeout_secs: u64,
131 pub detect_unused_parameters: bool,
133}
134
135impl Default for OverlapConfig {
136 fn default() -> Self {
137 Self {
138 enabled: true,
139 max_pending_syncs: 4,
140 sync_timeout_secs: 30,
141 detect_unused_parameters: true,
142 }
143 }
144}
145
146pub struct DistributedDataParallel<M: Module> {
148 module: M,
149 process_group: Arc<ProcessGroup>,
150 _device_ids: Vec<usize>,
151 _output_device: Option<usize>,
152 _broadcast_buffers: bool,
153 _bucket_cap_mb: f32,
154 bucket_config: BucketConfig,
155 gradient_buckets: Vec<GradientBucket>,
156 param_to_bucket: HashMap<String, usize>,
158 overlap_config: OverlapConfig,
160 gradient_sender: Option<mpsc::UnboundedSender<GradientMessage>>,
162 sync_task_handle: Option<JoinHandle<()>>,
164 sync_semaphore: Arc<Semaphore>,
166 unused_param_tracker: Arc<Mutex<UnusedParameterTracker>>,
168 bucket_ready_count: Arc<Mutex<HashMap<usize, usize>>>,
170}
171
172impl<M: Module> DistributedDataParallel<M> {
173 pub fn new(
175 module: M,
176 process_group: Arc<ProcessGroup>,
177 device_ids: Vec<usize>,
178 output_device: Option<usize>,
179 broadcast_buffers: bool,
180 bucket_cap_mb: f32,
181 ) -> TorshResult<Self> {
182 let bucket_config = BucketConfig {
183 max_bucket_size_mb: bucket_cap_mb,
184 enabled: true,
185 min_bucket_size_mb: 1.0,
186 };
187
188 Self::new_with_configs(
189 module,
190 process_group,
191 device_ids,
192 output_device,
193 broadcast_buffers,
194 bucket_config,
195 OverlapConfig::default(),
196 )
197 }
198
199 pub fn new_with_bucket_config(
201 module: M,
202 process_group: Arc<ProcessGroup>,
203 device_ids: Vec<usize>,
204 output_device: Option<usize>,
205 broadcast_buffers: bool,
206 bucket_config: BucketConfig,
207 ) -> TorshResult<Self> {
208 Self::new_with_configs(
209 module,
210 process_group,
211 device_ids,
212 output_device,
213 broadcast_buffers,
214 bucket_config,
215 OverlapConfig::default(),
216 )
217 }
218
219 pub fn new_with_configs(
221 module: M,
222 process_group: Arc<ProcessGroup>,
223 device_ids: Vec<usize>,
224 output_device: Option<usize>,
225 broadcast_buffers: bool,
226 bucket_config: BucketConfig,
227 overlap_config: OverlapConfig,
228 ) -> TorshResult<Self> {
229 let sync_semaphore = Arc::new(Semaphore::new(overlap_config.max_pending_syncs));
230 let unused_param_tracker = Arc::new(Mutex::new(UnusedParameterTracker::default()));
231 let bucket_ready_count = Arc::new(Mutex::new(HashMap::new()));
232
233 let mut ddp = Self {
234 module,
235 process_group,
236 _device_ids: device_ids,
237 _output_device: output_device,
238 _broadcast_buffers: broadcast_buffers,
239 _bucket_cap_mb: bucket_config.max_bucket_size_mb,
240 bucket_config,
241 gradient_buckets: Vec::new(),
242 param_to_bucket: HashMap::new(),
243 overlap_config,
244 gradient_sender: None,
245 sync_task_handle: None,
246 sync_semaphore,
247 unused_param_tracker,
248 bucket_ready_count,
249 };
250
251 ddp.initialize_buckets()?;
253
254 if ddp.overlap_config.detect_unused_parameters {
256 ddp.initialize_unused_parameter_tracking()?;
257 }
258
259 if ddp.overlap_config.enabled {
261 ddp.start_gradient_sync_worker()?;
262 }
263
264 Ok(ddp)
265 }
266
267 fn initialize_buckets(&mut self) -> TorshResult<()> {
269 if !self.bucket_config.enabled {
270 return Ok(());
271 }
272
273 let parameters = self.module.named_parameters();
274 let mut current_bucket = GradientBucket::new();
275 let mut bucket_index = 0;
276
277 let mut param_sizes: Vec<(String, usize)> = parameters
279 .iter()
280 .map(|(name, param)| {
281 let tensor = param.tensor();
282 let tensor_guard = tensor.read();
283 let size = tensor_guard.numel() * std::mem::size_of::<f32>(); (name.clone(), size)
285 })
286 .collect();
287
288 param_sizes.sort_by(|a, b| b.1.cmp(&a.1)); for (param_name, param_size) in param_sizes {
291 let new_size_mb = (current_bucket.total_size + param_size) as f32 / (1024.0 * 1024.0);
293
294 if new_size_mb > self.bucket_config.max_bucket_size_mb
295 && !current_bucket.parameters.is_empty()
296 {
297 self.gradient_buckets.push(current_bucket);
299 current_bucket = GradientBucket::new();
300 bucket_index += 1;
301 }
302
303 current_bucket.add_parameter(param_name.clone(), param_size);
305 self.param_to_bucket.insert(param_name, bucket_index);
306 }
307
308 if !current_bucket.parameters.is_empty() {
310 self.gradient_buckets.push(current_bucket);
311 }
312
313 info!(
314 "📦 Initialized {} gradient buckets",
315 self.gradient_buckets.len()
316 );
317 for (i, bucket) in self.gradient_buckets.iter().enumerate() {
318 info!(
319 " Bucket {}: {:.2} MB, {} parameters",
320 i,
321 bucket.size_mb(),
322 bucket.parameters.len()
323 );
324 }
325
326 Ok(())
327 }
328
329 fn initialize_unused_parameter_tracking(&mut self) -> TorshResult<()> {
331 let parameters = self.module.named_parameters();
332 let mut tracker = self
333 .unused_param_tracker
334 .lock()
335 .expect("lock should not be poisoned");
336
337 tracker.enabled = true;
338 tracker.all_parameters.clear();
339 tracker.used_parameters.clear();
340
341 for (name, param) in parameters {
343 let tensor = param.tensor();
344 let tensor_guard = tensor.read();
345 if tensor_guard.requires_grad() {
346 tracker.all_parameters.insert(name);
347 }
348 }
349
350 info!(
351 "🔍 Initialized unused parameter detection for {} parameters",
352 tracker.all_parameters.len()
353 );
354
355 Ok(())
356 }
357
358 fn start_gradient_sync_worker(&mut self) -> TorshResult<()> {
360 let (sender, mut receiver) = mpsc::unbounded_channel::<GradientMessage>();
361
362 let process_group = Arc::clone(&self.process_group);
363 let sync_semaphore = Arc::clone(&self.sync_semaphore);
364 let _bucket_ready_count = Arc::clone(&self.bucket_ready_count);
365 let _gradient_buckets_len = self.gradient_buckets.len();
366 let timeout_duration = Duration::from_secs(self.overlap_config.sync_timeout_secs);
367
368 let bucket_param_counts: HashMap<usize, usize> = self
370 .gradient_buckets
371 .iter()
372 .enumerate()
373 .map(|(i, bucket)| (i, bucket.parameters.len()))
374 .collect();
375
376 let handle = tokio::spawn(async move {
377 let mut pending_gradients: HashMap<usize, Vec<(String, Tensor)>> = HashMap::new();
378
379 while let Some(grad_msg) = receiver.recv().await {
380 let _permit =
382 match tokio::time::timeout(timeout_duration, sync_semaphore.acquire()).await {
383 Ok(Ok(permit)) => permit,
384 Ok(Err(_)) => {
385 info!(" Gradient sync semaphore closed, stopping worker");
386 break;
387 }
388 Err(_) => {
389 info!(
390 " Gradient sync timeout, dropping gradient for {}",
391 grad_msg.param_name
392 );
393 continue;
394 }
395 };
396
397 let bucket_index = grad_msg.bucket_index;
398
399 pending_gradients
401 .entry(bucket_index)
402 .or_insert_with(Vec::new)
403 .push((grad_msg.param_name, grad_msg.gradient));
404
405 let expected_count = bucket_param_counts.get(&bucket_index).copied().unwrap_or(0);
407 let current_count = pending_gradients
408 .get(&bucket_index)
409 .map(|v| v.len())
410 .unwrap_or(0);
411
412 if current_count >= expected_count && expected_count > 0 {
413 if let Some(bucket_gradients) = pending_gradients.remove(&bucket_index) {
415 let pg = Arc::clone(&process_group);
417 tokio::spawn(async move {
418 match Self::sync_bucket_gradients(bucket_gradients, &pg).await {
419 Ok(synchronized_gradients) => {
420 info!(
421 " Successfully synchronized bucket {} with {} gradients",
422 bucket_index,
423 synchronized_gradients.len()
424 );
425 }
428 Err(e) => {
429 info!(" Failed to sync bucket {}: {}", bucket_index, e);
430 }
431 }
432 });
433 }
434 }
435 }
436
437 info!(" Gradient synchronization worker stopped");
438 });
439
440 self.gradient_sender = Some(sender);
441 self.sync_task_handle = Some(handle);
442
443 info!(" Started background gradient synchronization worker");
444 Ok(())
445 }
446
447 async fn sync_bucket_gradients(
449 gradients: Vec<(String, Tensor)>,
450 process_group: &ProcessGroup,
451 ) -> TorshResult<Vec<(String, Tensor)>> {
452 let start_time = Instant::now();
453
454 if gradients.is_empty() {
455 return Ok(Vec::new());
456 }
457
458 if gradients.len() > 1 {
460 Self::sync_bucket_gradients_flattened(gradients, process_group).await
462 } else {
463 Self::sync_single_gradient(gradients, process_group).await
465 }
466 .inspect(|_result| {
467 let elapsed = start_time.elapsed();
468 if elapsed > Duration::from_millis(100) {
469 info!(
470 "⏱️ Bucket sync took {:.2}ms",
471 elapsed.as_secs_f32() * 1000.0
472 );
473 }
474 })
475 }
476
477 async fn sync_bucket_gradients_flattened(
479 gradients: Vec<(String, Tensor)>,
480 process_group: &ProcessGroup,
481 ) -> TorshResult<Vec<(String, Tensor)>> {
482 let mut gradient_shapes = Vec::new();
484 let mut gradient_sizes = Vec::new();
485 let mut flattened_data = Vec::new();
486
487 for (param_name, grad) in &gradients {
488 let shape = grad.shape();
489 let numel = grad.numel();
490 gradient_shapes.push((param_name.clone(), shape.dims().to_vec()));
491 gradient_sizes.push(numel);
492
493 let flattened_grad = grad.flatten()?;
495 let grad_data = flattened_grad.data()?;
496 flattened_data.extend_from_slice(&grad_data);
497 }
498
499 let total_size = flattened_data.len();
501 let mut flattened_tensor =
502 Tensor::from_data(flattened_data, vec![total_size], gradients[0].1.device())?;
503
504 all_reduce(&mut flattened_tensor, ReduceOp::Sum, process_group).await?;
506
507 let world_size = process_group.world_size() as f32;
509 flattened_tensor = flattened_tensor.div_scalar(world_size)?;
510
511 let flattened_data = flattened_tensor.data()?;
513 let mut result_gradients = Vec::new();
514 let mut current_offset = 0;
515
516 for ((param_name, original_shape), size) in
517 gradient_shapes.iter().zip(gradient_sizes.iter())
518 {
519 let grad_data = &flattened_data[current_offset..current_offset + size];
521
522 let reconstructed_grad = Tensor::from_data(
524 grad_data.to_vec(),
525 original_shape.clone(),
526 gradients[0].1.device(),
527 )?;
528
529 result_gradients.push((param_name.clone(), reconstructed_grad));
530 current_offset += size;
531 }
532
533 Ok(result_gradients)
534 }
535
536 async fn sync_single_gradient(
538 mut gradients: Vec<(String, Tensor)>,
539 process_group: &ProcessGroup,
540 ) -> TorshResult<Vec<(String, Tensor)>> {
541 if let Some((param_name, mut grad)) = gradients.pop() {
542 all_reduce(&mut grad, ReduceOp::Sum, process_group).await?;
543
544 let world_size = process_group.world_size() as f32;
546 grad = grad.div_scalar(world_size)?;
547
548 Ok(vec![(param_name, grad)])
549 } else {
550 Ok(Vec::new())
551 }
552 }
553
554 pub async fn sync_gradients(&mut self) -> TorshResult<()> {
556 if self.bucket_config.enabled && !self.gradient_buckets.is_empty() {
557 self.sync_gradients_bucketed().await
559 } else {
560 self.sync_gradients_naive().await
562 }
563 }
564
565 async fn sync_gradients_naive(&mut self) -> TorshResult<()> {
567 #[allow(clippy::await_holding_lock)]
568 let parameters = self.module.parameters();
569
570 for (_name, param) in parameters {
571 let tensor = param.tensor();
572 let tensor_guard = tensor.read();
573
574 if tensor_guard.requires_grad() {
576 if let Some(mut grad) = tensor_guard.grad() {
577 all_reduce(&mut grad, ReduceOp::Sum, &self.process_group).await?;
579
580 let world_size = self.process_group.world_size() as f32;
582 grad = grad.div_scalar(world_size)?;
583
584 tensor_guard.set_grad(Some(grad));
586 }
587 }
588 }
589
590 Ok(())
591 }
592
593 async fn sync_gradients_bucketed(&mut self) -> TorshResult<()> {
595 let parameters = self.module.named_parameters();
596
597 for bucket in &self.gradient_buckets {
599 let mut bucket_gradients = Vec::new();
600 let mut bucket_params = Vec::new();
601
602 for param_name in &bucket.parameters {
604 if let Some(param) = parameters.get(param_name) {
605 let tensor = param.tensor();
606 let tensor_guard = tensor.read();
607
608 if tensor_guard.requires_grad() {
609 if let Some(grad) = tensor_guard.grad() {
610 bucket_gradients.push(grad);
611 bucket_params.push(param_name.clone());
612 }
613 }
614 }
615 }
616
617 if !bucket_gradients.is_empty() {
618 let gradients_with_names: Vec<(String, Tensor)> = bucket_gradients
625 .into_iter()
626 .zip(bucket_params.iter())
627 .map(|(grad, param_name)| (param_name.clone(), grad))
628 .collect();
629
630 match Self::sync_bucket_gradients_flattened(
632 gradients_with_names,
633 &self.process_group,
634 )
635 .await
636 {
637 Ok(synchronized_gradients) => {
638 for (param_name, synchronized_grad) in synchronized_gradients {
640 if let Some(param) = parameters.get(¶m_name) {
641 let tensor = param.tensor();
642 let tensor_guard = tensor.read();
643 tensor_guard.set_grad(Some(synchronized_grad));
644 }
645 }
646 }
647 Err(e) => {
648 info!(" Failed to sync bucket gradients efficiently, falling back to individual sync: {}", e);
649
650 #[allow(clippy::await_holding_lock)]
651 for param_name in &bucket_params {
653 if let Some(param) = parameters.get(param_name) {
654 let tensor = param.tensor();
655 let tensor_guard = tensor.read();
656
657 if let Some(mut grad) = tensor_guard.grad() {
658 all_reduce(&mut grad, ReduceOp::Sum, &self.process_group)
659 .await?;
660 let world_size = self.process_group.world_size() as f32;
661 grad = grad.div_scalar(world_size)?;
662 tensor_guard.set_grad(Some(grad));
663 }
664 }
665 }
666 }
667 }
668 }
669 }
670
671 Ok(())
672 }
673
674 pub fn register_gradient_hooks(&self) -> TorshResult<()> {
677 Ok(())
683 }
684
685 pub fn register_gradient_async(&self, param_name: &str, gradient: Tensor) -> TorshResult<()> {
688 if !self.overlap_config.enabled {
689 return Ok(()); }
691
692 if self.overlap_config.detect_unused_parameters {
694 let mut tracker = self
695 .unused_param_tracker
696 .lock()
697 .expect("lock should not be poisoned");
698 if tracker.enabled {
699 tracker.used_parameters.insert(param_name.to_string());
700 }
701 }
702
703 let bucket_index = self.param_to_bucket.get(param_name).copied().unwrap_or(0);
705
706 if let Some(sender) = &self.gradient_sender {
708 let message = GradientMessage {
709 param_name: param_name.to_string(),
710 gradient,
711 bucket_index,
712 };
713
714 if let Err(e) = sender.send(message) {
715 info!(" Failed to send gradient for {}: {}", param_name, e);
716 }
717 }
718
719 Ok(())
720 }
721
722 pub fn check_unused_parameters(&self) -> TorshResult<Vec<String>> {
724 if !self.overlap_config.detect_unused_parameters {
725 return Ok(Vec::new());
726 }
727
728 let tracker = self
729 .unused_param_tracker
730 .lock()
731 .expect("lock should not be poisoned");
732 if !tracker.enabled {
733 return Ok(Vec::new());
734 }
735
736 let unused: Vec<String> = tracker
737 .all_parameters
738 .difference(&tracker.used_parameters)
739 .cloned()
740 .collect();
741
742 if !unused.is_empty() {
743 info!(
744 " Found {} unused parameters in iteration {}:",
745 unused.len(),
746 tracker.iteration
747 );
748 for param in &unused {
749 info!(" - {}", param);
750 }
751 }
752
753 Ok(unused)
754 }
755
756 pub fn start_iteration(&self) -> TorshResult<()> {
758 if self.overlap_config.detect_unused_parameters {
759 let mut tracker = self
760 .unused_param_tracker
761 .lock()
762 .expect("lock should not be poisoned");
763 if tracker.enabled {
764 tracker.used_parameters.clear();
765 tracker.iteration += 1;
766 }
767 }
768 Ok(())
769 }
770
771 pub fn get_overlap_stats(&self) -> HashMap<String, serde_json::Value> {
773 let mut stats = HashMap::new();
774
775 stats.insert(
776 "overlap_enabled".to_string(),
777 serde_json::Value::Bool(self.overlap_config.enabled),
778 );
779 stats.insert(
780 "max_pending_syncs".to_string(),
781 serde_json::Value::Number(serde_json::Number::from(
782 self.overlap_config.max_pending_syncs,
783 )),
784 );
785 stats.insert(
786 "sync_timeout_secs".to_string(),
787 serde_json::Value::Number(serde_json::Number::from(
788 self.overlap_config.sync_timeout_secs,
789 )),
790 );
791 stats.insert(
792 "unused_param_detection".to_string(),
793 serde_json::Value::Bool(self.overlap_config.detect_unused_parameters),
794 );
795
796 if let Ok(tracker) = self.unused_param_tracker.lock() {
797 stats.insert(
798 "total_params".to_string(),
799 serde_json::Value::Number(serde_json::Number::from(tracker.all_parameters.len())),
800 );
801 stats.insert(
802 "used_params".to_string(),
803 serde_json::Value::Number(serde_json::Number::from(tracker.used_parameters.len())),
804 );
805 stats.insert(
806 "current_iteration".to_string(),
807 serde_json::Value::Number(serde_json::Number::from(tracker.iteration)),
808 );
809 }
810
811 let available_permits = self.sync_semaphore.available_permits();
813 stats.insert(
814 "available_sync_permits".to_string(),
815 serde_json::Value::Number(serde_json::Number::from(available_permits)),
816 );
817
818 stats
819 }
820
821 pub fn has_gradients(&self) -> bool {
823 let parameters = self.module.parameters();
824
825 for (_name, param) in parameters {
826 let tensor = param.tensor();
827 let tensor_guard = tensor.read();
828
829 if tensor_guard.requires_grad() && tensor_guard.has_grad() {
830 return true;
831 }
832 }
833
834 false
835 }
836
837 pub fn zero_grad(&mut self) -> TorshResult<()> {
839 let parameters = self.module.parameters();
840
841 for (_name, param) in parameters {
842 let tensor = param.tensor();
843 let tensor_guard = tensor.read();
844
845 if tensor_guard.requires_grad() {
846 tensor_guard.set_grad(None);
847 }
848 }
849
850 Ok(())
851 }
852
853 pub fn get_sync_stats(&self) -> GradientSyncStats {
855 let parameters = self.module.named_parameters();
856 let mut total_parameters = 0;
857 let mut parameters_with_grad = 0;
858 let mut total_gradient_size = 0;
859
860 for (_name, param) in parameters {
861 let tensor = param.tensor();
862 let tensor_guard = tensor.read();
863
864 if tensor_guard.requires_grad() {
865 total_parameters += 1;
866
867 if tensor_guard.has_grad() {
868 parameters_with_grad += 1;
869 total_gradient_size += tensor_guard.numel() * std::mem::size_of::<f32>();
870 }
871 }
872 }
873
874 GradientSyncStats {
875 total_parameters,
876 parameters_with_grad,
877 total_gradient_size_mb: total_gradient_size as f32 / (1024.0 * 1024.0),
878 num_buckets: self.gradient_buckets.len(),
879 world_size: self.process_group.world_size(),
880 }
881 }
882
883 pub fn set_bucketing_enabled(&mut self, enabled: bool) -> TorshResult<()> {
885 self.bucket_config.enabled = enabled;
886
887 if enabled && self.gradient_buckets.is_empty() {
888 self.initialize_buckets()?;
890 }
891
892 Ok(())
893 }
894
895 pub fn get_bucket_info(&self) -> Vec<BucketInfo> {
897 self.gradient_buckets
898 .iter()
899 .enumerate()
900 .map(|(i, bucket)| BucketInfo {
901 index: i,
902 size_mb: bucket.size_mb(),
903 num_parameters: bucket.parameters.len(),
904 parameter_names: bucket.parameters.clone(),
905 })
906 .collect()
907 }
908
909 pub async fn check_gradient_consistency(&self) -> TorshResult<bool> {
912 Ok(true)
920 }
921}
922
923impl<M: Module> Module for DistributedDataParallel<M> {
924 fn forward(&self, input: &Tensor) -> Result<Tensor> {
925 self.module.forward(input)
927 }
928
929 fn parameters(&self) -> HashMap<String, Parameter> {
930 self.module.parameters()
931 }
932
933 fn named_parameters(&self) -> HashMap<String, Parameter> {
934 self.module.named_parameters()
935 }
936
937 fn training(&self) -> bool {
938 self.module.training()
939 }
940
941 fn train(&mut self) {
942 self.module.train()
943 }
944
945 fn eval(&mut self) {
946 self.module.eval()
947 }
948
949 fn to_device(&mut self, device: DeviceType) -> Result<()> {
950 self.module.to_device(device)
951 }
952}