Skip to main content

DistributedDataParallel

Struct DistributedDataParallel 

Source
pub struct DistributedDataParallel<M: Module> { /* private fields */ }
Expand description

Distributed Data Parallel wrapper for models

Implementations§

Source§

impl<M: Module> DistributedDataParallel<M>

Source

pub fn new( module: M, process_group: Arc<ProcessGroup>, device_ids: Vec<usize>, output_device: Option<usize>, broadcast_buffers: bool, bucket_cap_mb: f32, ) -> TorshResult<Self>

Create a new DDP wrapper

Source

pub fn new_with_bucket_config( module: M, process_group: Arc<ProcessGroup>, device_ids: Vec<usize>, output_device: Option<usize>, broadcast_buffers: bool, bucket_config: BucketConfig, ) -> TorshResult<Self>

Create a new DDP wrapper with custom bucket configuration

Source

pub fn new_with_configs( module: M, process_group: Arc<ProcessGroup>, device_ids: Vec<usize>, output_device: Option<usize>, broadcast_buffers: bool, bucket_config: BucketConfig, overlap_config: OverlapConfig, ) -> TorshResult<Self>

Create a new DDP wrapper with custom configurations

Source

pub async fn sync_gradients(&mut self) -> TorshResult<()>

Synchronize gradients across all processes

Source

pub fn register_gradient_hooks(&self) -> TorshResult<()>

Register gradient synchronization hooks This should be called during the backward pass to automatically sync gradients

Source

pub fn register_gradient_async( &self, param_name: &str, gradient: Tensor, ) -> TorshResult<()>

Register a gradient for asynchronous synchronization (overlap mode) This should be called when a gradient becomes available during backward pass

Source

pub fn check_unused_parameters(&self) -> TorshResult<Vec<String>>

Check for unused parameters and issue warnings

Source

pub fn start_iteration(&self) -> TorshResult<()>

Start a new iteration (reset unused parameter tracking)

Source

pub fn get_overlap_stats(&self) -> HashMap<String, Value>

Get overlap computation statistics

Source

pub fn has_gradients(&self) -> bool

Check if any parameters have gradients

Source

pub fn zero_grad(&mut self) -> TorshResult<()>

Zero all gradients

Source

pub fn get_sync_stats(&self) -> GradientSyncStats

Get gradient synchronization statistics

Source

pub fn set_bucketing_enabled(&mut self, enabled: bool) -> TorshResult<()>

Enable/disable gradient bucketing at runtime

Source

pub fn get_bucket_info(&self) -> Vec<BucketInfo>

Get bucket information for debugging

Source

pub async fn check_gradient_consistency(&self) -> TorshResult<bool>

Perform a gradient consistency check across all processes This is useful for debugging distributed training issues

Trait Implementations§

Source§

impl<M: Module> Module for DistributedDataParallel<M>

Source§

fn forward(&self, input: &Tensor) -> Result<Tensor>

Forward pass through the module Read more
Source§

fn parameters(&self) -> HashMap<String, Parameter>

Get all parameters in the module (non-recursive) Read more
Source§

fn named_parameters(&self) -> HashMap<String, Parameter>

Get named parameters (non-recursive) Read more
Source§

fn training(&self) -> bool

Check if in training mode Read more
Source§

fn train(&mut self)

Set training mode Read more
Source§

fn eval(&mut self)

Set evaluation mode Read more
Source§

fn to_device(&mut self, device: DeviceType) -> Result<()>

Move module to device Read more
Source§

fn all_parameters(&self) -> HashMap<String, Parameter>

Get all parameters recursively including submodules Read more
Source§

fn all_named_parameters(&self) -> HashMap<String, Parameter>

Get all named parameters recursively with module prefixes Read more
Source§

fn set_training(&mut self, _training: bool)

Set training mode (internal implementation) Read more
Source§

fn load_state_dict( &mut self, state_dict: &HashMap<String, Tensor>, strict: bool, ) -> Result<(), TorshError>

Load state dictionary into the module Read more
Source§

fn load_state_dict_strict( &mut self, state_dict: &HashMap<String, Tensor>, ) -> Result<(), TorshError>

Load state dictionary with default strict=true
Source§

fn state_dict(&self) -> HashMap<String, Tensor>

Save state dictionary from the module
Source§

fn name(&self) -> Option<&str>

Get the module name (optional, for debugging and serialization)
Source§

fn buffers(&self) -> Vec<Arc<RwLock<RawRwLock, Tensor>>>

Get all buffers (non-trainable parameters)
Source§

fn named_buffers(&self) -> HashMap<String, Arc<RwLock<RawRwLock, Tensor>>>

Get named buffers
Source§

fn children(&self) -> Vec<&dyn Module>

Get all direct child modules Read more
Source§

fn named_children(&self) -> Vec<(String, &dyn Module)>

Get all direct child modules with names Read more
Source§

fn modules(&self) -> Vec<&dyn Module>
where Self: Sized,

Get all modules recursively (depth-first traversal)
Source§

fn named_modules(&self) -> Vec<(String, &dyn Module)>
where Self: Sized,

Get all modules recursively with hierarchical names
Source§

fn zero_grad(&mut self)

Zero all gradients recursively Read more
Source§

fn num_parameters(&self) -> usize

Count total number of parameters
Source§

fn num_trainable_parameters(&self) -> usize

Count trainable parameters
Source§

fn memory_usage(&self) -> usize

Get memory usage estimate in bytes
Source§

fn freeze(&mut self)

Freeze all parameters (set requires_grad = false) Read more
Source§

fn unfreeze(&mut self)

Unfreeze all parameters (set requires_grad = true) Read more
Source§

fn extra_repr(&self) -> String

Get string representation
Source§

fn register_hook( &mut self, _hook_type: HookType, _callback: Box<dyn Fn(&dyn Module, &Tensor, Option<&Tensor>) -> Result<(), TorshError> + Send + Sync>, ) -> Option<HookHandle>

Register a hook for this module (default implementation does nothing)
Source§

fn remove_hook(&mut self, _hook_type: HookType, _handle: HookHandle) -> bool

Remove a hook by handle (default implementation does nothing)
Source§

fn execute_hooks( &self, _hook_type: HookType, _input: &Tensor, _output: Option<&Tensor>, ) -> Result<(), TorshError>

Execute hooks of a specific type (default implementation does nothing)
Source§

fn forward_with_hooks(&self, input: &Tensor) -> Result<Tensor, TorshError>

Forward pass with hooks support
Source§

fn has_hooks(&self, _hook_type: HookType) -> bool

Check if module has hooks registered
Source§

fn call(&self, input: &Tensor) -> Result<Tensor, TorshError>

Convenient method to call forward and handle common patterns Read more
Source§

fn apply(&self, input: &Tensor) -> Result<Tensor, TorshError>

Apply the module to input (alias for forward) Read more
Source§

fn has_parameters(&self) -> bool

Check if the module has any parameters
Source§

fn has_children(&self) -> bool

Check if the module has any child modules
Source§

fn parameter_count(&self) -> usize

Get parameter count (convenience method)
Source§

fn trainable_parameter_count(&self) -> usize

Get trainable parameter count (convenience method)
Source§

fn memory_usage_mb(&self) -> f64

Get memory usage in MB (convenience method)
Source§

fn toggle_training(&mut self)

Toggle training mode (convenience method)
Source§

fn eval_mode(&self) -> bool

Check if module is in evaluation mode
Source§

fn sequential_forward( modules: &[&dyn Module], input: Tensor, ) -> Result<Tensor, TorshError>
where Self: Sized,

Sequential forward pass through multiple modules Read more
Source§

fn batch_forward(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>, TorshError>

Apply module multiple times with different inputs (batch processing) Read more
Source§

fn conditional_forward( &self, input: &Tensor, condition: bool, ) -> Result<Tensor, TorshError>

Forward with condition - only apply if condition is true Read more
Source§

fn residual_forward(&self, input: &Tensor) -> Result<Tensor, TorshError>

Forward with residual connection Read more
Source§

fn module_info(&self) -> ModuleInfo

Get detailed module information for debugging Read more
Source§

fn check_training_readiness(&self) -> Result<(), TorshError>

Check if module is ready for training Read more
Source§

fn parameter_names_matching(&self, pattern: &str) -> Vec<String>

Get parameter names matching a pattern Read more
Source§

fn parameters_by_type(&self, param_type: &str) -> HashMap<String, Parameter>

Get parameters by layer type (e.g., “weight”, “bias”) Read more
Source§

fn clone_parameters(&self) -> HashMap<String, Tensor>

Clone module parameters (for creating copies or checkpoints) Read more
Source§

fn diagnose(&self) -> ModuleDiagnostics

Quick diagnostic check of module health Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T> Instrument for T

Source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> ModuleApply for T
where T: Module,

Source§

fn apply<F>(&mut self, f: &F) -> Result<(), TorshError>
where F: Fn(&mut dyn Module) -> Result<(), TorshError>,

Apply a function to all submodules recursively
Source§

fn apply_to_parameters<F>(&mut self, _f: &F) -> Result<(), TorshError>
where F: Fn(&mut Parameter) -> Result<(), TorshError>,

Apply function to all parameters recursively
Source§

fn apply_to_modules<F>(&mut self, _f: &F) -> Result<(), TorshError>
where F: Fn(&mut dyn Module) -> Result<(), TorshError>,

Apply function to all modules recursively
Source§

impl<T> ModuleComposition for T
where T: Module + 'static,

Source§

fn then<Other>(self, other: Other) -> ComposedModule<T, Other>
where Other: Module + 'static,

Compose this module with another module sequentially Read more
Source§

fn parallel<Other>(self, other: Other) -> ParallelModule<T, Other>
where Other: Module + 'static,

Compose this module with another module in parallel Read more
Source§

fn residual(self) -> ResidualModule<T>

Add a residual connection Read more
Source§

fn conditional<F>(self, condition_fn: F) -> ConditionalModule<T, F>
where F: Fn() -> bool + Send + Sync,

Add conditional execution Read more
Source§

impl<T> ModuleExt for T
where T: Module + ?Sized,

Source§

fn and_then<F>(&self, input: &Tensor, f: F) -> Result<Tensor, TorshError>

Chain forward pass with a transformation function Read more
Source§

fn map<F>(&self, input: &Tensor, f: F) -> Result<Tensor, TorshError>
where F: FnOnce(Tensor) -> Tensor,

Apply module and map the output with a function Read more
Source§

fn with_input<F>(&self, input: &Tensor, f: F) -> Result<Tensor, TorshError>

Forward pass with input transformation Read more
Source§

fn summary(&self) -> String

Get human-readable summary of the module Read more
Source§

fn print_summary(&self)

Print module summary to stdout
Source§

fn parameter_stats(&self) -> ParameterStats

Get parameter statistics Read more
Source§

fn has_finite_parameters(&self) -> bool

Check if module has NaN or Inf in parameters Read more
Source§

fn parameter_names(&self) -> Vec<String>

Get list of parameter names Read more
Source§

fn get_parameter(&self, name: &str) -> Option<Parameter>

Get parameter by name Read more
Source§

fn freeze_matching(&mut self, pattern: &str) -> usize

Freeze specific parameters by name pattern Read more
Source§

fn unfreeze_matching(&mut self, pattern: &str) -> usize

Unfreeze specific parameters by name pattern Read more
Source§

fn frozen_parameters(&self) -> Vec<String>

Get list of frozen parameters Read more
Source§

fn trainable_parameters(&self) -> Vec<String>

Get list of trainable parameters Read more
Source§

fn clone_state_dict(&self) -> HashMap<String, Tensor>

Clone module parameters into a new state dict Read more
Source§

fn apply_to_parameters<F>(&self, f: F)
where F: FnMut(&str, &Parameter),

Apply a function to all parameters Read more
Source§

fn parameters_by_type(&self) -> HashMap<String, usize>

Count parameters by layer type Read more
Source§

fn validate(&self) -> Result<ValidationReport, TorshError>

Validate module configuration Read more
Source§

fn device(&self) -> Option<DeviceType>

Get device of parameters (if consistent) Read more
Source§

fn is_cpu(&self) -> bool

Check if all parameters are on CPU Read more
Source§

fn is_cuda(&self) -> bool

Check if all parameters are on CUDA device Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> Same for T

Source§

type Output = T

Should always be Self
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T> WithSubscriber for T

Source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more