pub struct DistributedDataParallel<M: Module> { /* private fields */ }Expand description
Distributed Data Parallel wrapper for models
Implementations§
Source§impl<M: Module> DistributedDataParallel<M>
impl<M: Module> DistributedDataParallel<M>
Sourcepub fn new(
module: M,
process_group: Arc<ProcessGroup>,
device_ids: Vec<usize>,
output_device: Option<usize>,
broadcast_buffers: bool,
bucket_cap_mb: f32,
) -> TorshResult<Self>
pub fn new( module: M, process_group: Arc<ProcessGroup>, device_ids: Vec<usize>, output_device: Option<usize>, broadcast_buffers: bool, bucket_cap_mb: f32, ) -> TorshResult<Self>
Create a new DDP wrapper
Sourcepub fn new_with_bucket_config(
module: M,
process_group: Arc<ProcessGroup>,
device_ids: Vec<usize>,
output_device: Option<usize>,
broadcast_buffers: bool,
bucket_config: BucketConfig,
) -> TorshResult<Self>
pub fn new_with_bucket_config( module: M, process_group: Arc<ProcessGroup>, device_ids: Vec<usize>, output_device: Option<usize>, broadcast_buffers: bool, bucket_config: BucketConfig, ) -> TorshResult<Self>
Create a new DDP wrapper with custom bucket configuration
Sourcepub fn new_with_configs(
module: M,
process_group: Arc<ProcessGroup>,
device_ids: Vec<usize>,
output_device: Option<usize>,
broadcast_buffers: bool,
bucket_config: BucketConfig,
overlap_config: OverlapConfig,
) -> TorshResult<Self>
pub fn new_with_configs( module: M, process_group: Arc<ProcessGroup>, device_ids: Vec<usize>, output_device: Option<usize>, broadcast_buffers: bool, bucket_config: BucketConfig, overlap_config: OverlapConfig, ) -> TorshResult<Self>
Create a new DDP wrapper with custom configurations
Sourcepub async fn sync_gradients(&mut self) -> TorshResult<()>
pub async fn sync_gradients(&mut self) -> TorshResult<()>
Synchronize gradients across all processes
Sourcepub fn register_gradient_hooks(&self) -> TorshResult<()>
pub fn register_gradient_hooks(&self) -> TorshResult<()>
Register gradient synchronization hooks This should be called during the backward pass to automatically sync gradients
Sourcepub fn register_gradient_async(
&self,
param_name: &str,
gradient: Tensor,
) -> TorshResult<()>
pub fn register_gradient_async( &self, param_name: &str, gradient: Tensor, ) -> TorshResult<()>
Register a gradient for asynchronous synchronization (overlap mode) This should be called when a gradient becomes available during backward pass
Sourcepub fn check_unused_parameters(&self) -> TorshResult<Vec<String>>
pub fn check_unused_parameters(&self) -> TorshResult<Vec<String>>
Check for unused parameters and issue warnings
Sourcepub fn start_iteration(&self) -> TorshResult<()>
pub fn start_iteration(&self) -> TorshResult<()>
Start a new iteration (reset unused parameter tracking)
Sourcepub fn get_overlap_stats(&self) -> HashMap<String, Value>
pub fn get_overlap_stats(&self) -> HashMap<String, Value>
Get overlap computation statistics
Sourcepub fn has_gradients(&self) -> bool
pub fn has_gradients(&self) -> bool
Check if any parameters have gradients
Sourcepub fn zero_grad(&mut self) -> TorshResult<()>
pub fn zero_grad(&mut self) -> TorshResult<()>
Zero all gradients
Sourcepub fn get_sync_stats(&self) -> GradientSyncStats
pub fn get_sync_stats(&self) -> GradientSyncStats
Get gradient synchronization statistics
Sourcepub fn set_bucketing_enabled(&mut self, enabled: bool) -> TorshResult<()>
pub fn set_bucketing_enabled(&mut self, enabled: bool) -> TorshResult<()>
Enable/disable gradient bucketing at runtime
Sourcepub fn get_bucket_info(&self) -> Vec<BucketInfo>
pub fn get_bucket_info(&self) -> Vec<BucketInfo>
Get bucket information for debugging
Sourcepub async fn check_gradient_consistency(&self) -> TorshResult<bool>
pub async fn check_gradient_consistency(&self) -> TorshResult<bool>
Perform a gradient consistency check across all processes This is useful for debugging distributed training issues
Trait Implementations§
Source§impl<M: Module> Module for DistributedDataParallel<M>
impl<M: Module> Module for DistributedDataParallel<M>
Source§fn parameters(&self) -> HashMap<String, Parameter>
fn parameters(&self) -> HashMap<String, Parameter>
Source§fn named_parameters(&self) -> HashMap<String, Parameter>
fn named_parameters(&self) -> HashMap<String, Parameter>
Source§fn all_parameters(&self) -> HashMap<String, Parameter>
fn all_parameters(&self) -> HashMap<String, Parameter>
Source§fn all_named_parameters(&self) -> HashMap<String, Parameter>
fn all_named_parameters(&self) -> HashMap<String, Parameter>
Source§fn set_training(&mut self, _training: bool)
fn set_training(&mut self, _training: bool)
Source§fn load_state_dict(
&mut self,
state_dict: &HashMap<String, Tensor>,
strict: bool,
) -> Result<(), TorshError>
fn load_state_dict( &mut self, state_dict: &HashMap<String, Tensor>, strict: bool, ) -> Result<(), TorshError>
Source§fn load_state_dict_strict(
&mut self,
state_dict: &HashMap<String, Tensor>,
) -> Result<(), TorshError>
fn load_state_dict_strict( &mut self, state_dict: &HashMap<String, Tensor>, ) -> Result<(), TorshError>
Source§fn name(&self) -> Option<&str>
fn name(&self) -> Option<&str>
Source§fn buffers(&self) -> Vec<Arc<RwLock<RawRwLock, Tensor>>>
fn buffers(&self) -> Vec<Arc<RwLock<RawRwLock, Tensor>>>
Source§fn named_children(&self) -> Vec<(String, &dyn Module)>
fn named_children(&self) -> Vec<(String, &dyn Module)>
Source§fn modules(&self) -> Vec<&dyn Module>where
Self: Sized,
fn modules(&self) -> Vec<&dyn Module>where
Self: Sized,
Source§fn named_modules(&self) -> Vec<(String, &dyn Module)>where
Self: Sized,
fn named_modules(&self) -> Vec<(String, &dyn Module)>where
Self: Sized,
Source§fn num_parameters(&self) -> usize
fn num_parameters(&self) -> usize
Source§fn num_trainable_parameters(&self) -> usize
fn num_trainable_parameters(&self) -> usize
Source§fn memory_usage(&self) -> usize
fn memory_usage(&self) -> usize
Source§fn extra_repr(&self) -> String
fn extra_repr(&self) -> String
Source§fn register_hook(
&mut self,
_hook_type: HookType,
_callback: Box<dyn Fn(&dyn Module, &Tensor, Option<&Tensor>) -> Result<(), TorshError> + Send + Sync>,
) -> Option<HookHandle>
fn register_hook( &mut self, _hook_type: HookType, _callback: Box<dyn Fn(&dyn Module, &Tensor, Option<&Tensor>) -> Result<(), TorshError> + Send + Sync>, ) -> Option<HookHandle>
Source§fn remove_hook(&mut self, _hook_type: HookType, _handle: HookHandle) -> bool
fn remove_hook(&mut self, _hook_type: HookType, _handle: HookHandle) -> bool
Source§fn execute_hooks(
&self,
_hook_type: HookType,
_input: &Tensor,
_output: Option<&Tensor>,
) -> Result<(), TorshError>
fn execute_hooks( &self, _hook_type: HookType, _input: &Tensor, _output: Option<&Tensor>, ) -> Result<(), TorshError>
Source§fn forward_with_hooks(&self, input: &Tensor) -> Result<Tensor, TorshError>
fn forward_with_hooks(&self, input: &Tensor) -> Result<Tensor, TorshError>
Source§fn call(&self, input: &Tensor) -> Result<Tensor, TorshError>
fn call(&self, input: &Tensor) -> Result<Tensor, TorshError>
Source§fn apply(&self, input: &Tensor) -> Result<Tensor, TorshError>
fn apply(&self, input: &Tensor) -> Result<Tensor, TorshError>
Source§fn has_parameters(&self) -> bool
fn has_parameters(&self) -> bool
Source§fn has_children(&self) -> bool
fn has_children(&self) -> bool
Source§fn parameter_count(&self) -> usize
fn parameter_count(&self) -> usize
Source§fn trainable_parameter_count(&self) -> usize
fn trainable_parameter_count(&self) -> usize
Source§fn memory_usage_mb(&self) -> f64
fn memory_usage_mb(&self) -> f64
Source§fn toggle_training(&mut self)
fn toggle_training(&mut self)
Source§fn sequential_forward(
modules: &[&dyn Module],
input: Tensor,
) -> Result<Tensor, TorshError>where
Self: Sized,
fn sequential_forward(
modules: &[&dyn Module],
input: Tensor,
) -> Result<Tensor, TorshError>where
Self: Sized,
Source§fn batch_forward(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>, TorshError>
fn batch_forward(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>, TorshError>
Source§fn conditional_forward(
&self,
input: &Tensor,
condition: bool,
) -> Result<Tensor, TorshError>
fn conditional_forward( &self, input: &Tensor, condition: bool, ) -> Result<Tensor, TorshError>
Source§fn residual_forward(&self, input: &Tensor) -> Result<Tensor, TorshError>
fn residual_forward(&self, input: &Tensor) -> Result<Tensor, TorshError>
Source§fn module_info(&self) -> ModuleInfo
fn module_info(&self) -> ModuleInfo
Source§fn check_training_readiness(&self) -> Result<(), TorshError>
fn check_training_readiness(&self) -> Result<(), TorshError>
Source§fn parameter_names_matching(&self, pattern: &str) -> Vec<String>
fn parameter_names_matching(&self, pattern: &str) -> Vec<String>
Source§fn parameters_by_type(&self, param_type: &str) -> HashMap<String, Parameter>
fn parameters_by_type(&self, param_type: &str) -> HashMap<String, Parameter>
Source§fn clone_parameters(&self) -> HashMap<String, Tensor>
fn clone_parameters(&self) -> HashMap<String, Tensor>
Source§fn diagnose(&self) -> ModuleDiagnostics
fn diagnose(&self) -> ModuleDiagnostics
Auto Trait Implementations§
impl<M> Freeze for DistributedDataParallel<M>where
M: Freeze,
impl<M> !RefUnwindSafe for DistributedDataParallel<M>
impl<M> Send for DistributedDataParallel<M>
impl<M> Sync for DistributedDataParallel<M>
impl<M> Unpin for DistributedDataParallel<M>where
M: Unpin,
impl<M> UnsafeUnpin for DistributedDataParallel<M>where
M: UnsafeUnpin,
impl<M> !UnwindSafe for DistributedDataParallel<M>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more