pub struct TensorParallel { /* private fields */ }Expand description
Tensor parallel wrapper for modules
Implementations§
Source§impl TensorParallel
impl TensorParallel
Sourcepub fn new(
module: Box<dyn Module>,
tp_group: Arc<ProcessGroup>,
config: TensorParallelConfig,
layer_info: TensorParallelLayer,
) -> TorshResult<Self>
pub fn new( module: Box<dyn Module>, tp_group: Arc<ProcessGroup>, config: TensorParallelConfig, layer_info: TensorParallelLayer, ) -> TorshResult<Self>
Create a new tensor parallel wrapper
Sourcepub fn tp_world_size(&self) -> usize
pub fn tp_world_size(&self) -> usize
Get tensor parallel world size
Sourcepub fn get_shard_info(&self, param_name: &str) -> Option<&ShardInfo>
pub fn get_shard_info(&self, param_name: &str) -> Option<&ShardInfo>
Get sharding information for a parameter
Sourcepub fn uses_sequence_parallel(&self) -> bool
pub fn uses_sequence_parallel(&self) -> bool
Check if layer uses sequence parallelism
Sourcepub fn memory_stats(&self) -> TensorParallelStats
pub fn memory_stats(&self) -> TensorParallelStats
Get memory usage statistics
Trait Implementations§
Source§impl Module for TensorParallel
impl Module for TensorParallel
Source§fn parameters(&self) -> HashMap<String, Parameter>
fn parameters(&self) -> HashMap<String, Parameter>
Get all parameters in the module (non-recursive) Read more
Source§fn named_parameters(&self) -> HashMap<String, Parameter>
fn named_parameters(&self) -> HashMap<String, Parameter>
Get named parameters (non-recursive) Read more
Source§fn all_parameters(&self) -> HashMap<String, Parameter>
fn all_parameters(&self) -> HashMap<String, Parameter>
Get all parameters recursively including submodules Read more
Source§fn all_named_parameters(&self) -> HashMap<String, Parameter>
fn all_named_parameters(&self) -> HashMap<String, Parameter>
Get all named parameters recursively with module prefixes Read more
Source§fn set_training(&mut self, _training: bool)
fn set_training(&mut self, _training: bool)
Set training mode (internal implementation) Read more
Source§fn load_state_dict(
&mut self,
state_dict: &HashMap<String, Tensor>,
strict: bool,
) -> Result<(), TorshError>
fn load_state_dict( &mut self, state_dict: &HashMap<String, Tensor>, strict: bool, ) -> Result<(), TorshError>
Load state dictionary into the module Read more
Source§fn load_state_dict_strict(
&mut self,
state_dict: &HashMap<String, Tensor>,
) -> Result<(), TorshError>
fn load_state_dict_strict( &mut self, state_dict: &HashMap<String, Tensor>, ) -> Result<(), TorshError>
Load state dictionary with default strict=true
Source§fn name(&self) -> Option<&str>
fn name(&self) -> Option<&str>
Get the module name (optional, for debugging and serialization)
Source§fn buffers(&self) -> Vec<Arc<RwLock<RawRwLock, Tensor>>>
fn buffers(&self) -> Vec<Arc<RwLock<RawRwLock, Tensor>>>
Get all buffers (non-trainable parameters)
Source§fn named_children(&self) -> Vec<(String, &dyn Module)>
fn named_children(&self) -> Vec<(String, &dyn Module)>
Get all direct child modules with names Read more
Source§fn modules(&self) -> Vec<&dyn Module>where
Self: Sized,
fn modules(&self) -> Vec<&dyn Module>where
Self: Sized,
Get all modules recursively (depth-first traversal)
Source§fn named_modules(&self) -> Vec<(String, &dyn Module)>where
Self: Sized,
fn named_modules(&self) -> Vec<(String, &dyn Module)>where
Self: Sized,
Get all modules recursively with hierarchical names
Source§fn num_parameters(&self) -> usize
fn num_parameters(&self) -> usize
Count total number of parameters
Source§fn num_trainable_parameters(&self) -> usize
fn num_trainable_parameters(&self) -> usize
Count trainable parameters
Source§fn memory_usage(&self) -> usize
fn memory_usage(&self) -> usize
Get memory usage estimate in bytes
Source§fn extra_repr(&self) -> String
fn extra_repr(&self) -> String
Get string representation
Source§fn register_hook(
&mut self,
_hook_type: HookType,
_callback: Box<dyn Fn(&dyn Module, &Tensor, Option<&Tensor>) -> Result<(), TorshError> + Send + Sync>,
) -> Option<HookHandle>
fn register_hook( &mut self, _hook_type: HookType, _callback: Box<dyn Fn(&dyn Module, &Tensor, Option<&Tensor>) -> Result<(), TorshError> + Send + Sync>, ) -> Option<HookHandle>
Register a hook for this module (default implementation does nothing)
Source§fn remove_hook(&mut self, _hook_type: HookType, _handle: HookHandle) -> bool
fn remove_hook(&mut self, _hook_type: HookType, _handle: HookHandle) -> bool
Remove a hook by handle (default implementation does nothing)
Source§fn execute_hooks(
&self,
_hook_type: HookType,
_input: &Tensor,
_output: Option<&Tensor>,
) -> Result<(), TorshError>
fn execute_hooks( &self, _hook_type: HookType, _input: &Tensor, _output: Option<&Tensor>, ) -> Result<(), TorshError>
Execute hooks of a specific type (default implementation does nothing)
Source§fn forward_with_hooks(&self, input: &Tensor) -> Result<Tensor, TorshError>
fn forward_with_hooks(&self, input: &Tensor) -> Result<Tensor, TorshError>
Forward pass with hooks support
Source§fn call(&self, input: &Tensor) -> Result<Tensor, TorshError>
fn call(&self, input: &Tensor) -> Result<Tensor, TorshError>
Convenient method to call forward and handle common patterns Read more
Source§fn apply(&self, input: &Tensor) -> Result<Tensor, TorshError>
fn apply(&self, input: &Tensor) -> Result<Tensor, TorshError>
Apply the module to input (alias for forward) Read more
Source§fn has_parameters(&self) -> bool
fn has_parameters(&self) -> bool
Check if the module has any parameters
Source§fn has_children(&self) -> bool
fn has_children(&self) -> bool
Check if the module has any child modules
Source§fn parameter_count(&self) -> usize
fn parameter_count(&self) -> usize
Get parameter count (convenience method)
Source§fn trainable_parameter_count(&self) -> usize
fn trainable_parameter_count(&self) -> usize
Get trainable parameter count (convenience method)
Source§fn memory_usage_mb(&self) -> f64
fn memory_usage_mb(&self) -> f64
Get memory usage in MB (convenience method)
Source§fn toggle_training(&mut self)
fn toggle_training(&mut self)
Toggle training mode (convenience method)
Source§fn sequential_forward(
modules: &[&dyn Module],
input: Tensor,
) -> Result<Tensor, TorshError>where
Self: Sized,
fn sequential_forward(
modules: &[&dyn Module],
input: Tensor,
) -> Result<Tensor, TorshError>where
Self: Sized,
Sequential forward pass through multiple modules Read more
Source§fn batch_forward(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>, TorshError>
fn batch_forward(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>, TorshError>
Apply module multiple times with different inputs (batch processing) Read more
Source§fn conditional_forward(
&self,
input: &Tensor,
condition: bool,
) -> Result<Tensor, TorshError>
fn conditional_forward( &self, input: &Tensor, condition: bool, ) -> Result<Tensor, TorshError>
Forward with condition - only apply if condition is true Read more
Source§fn residual_forward(&self, input: &Tensor) -> Result<Tensor, TorshError>
fn residual_forward(&self, input: &Tensor) -> Result<Tensor, TorshError>
Forward with residual connection Read more
Source§fn module_info(&self) -> ModuleInfo
fn module_info(&self) -> ModuleInfo
Get detailed module information for debugging Read more
Source§fn check_training_readiness(&self) -> Result<(), TorshError>
fn check_training_readiness(&self) -> Result<(), TorshError>
Check if module is ready for training Read more
Source§fn parameter_names_matching(&self, pattern: &str) -> Vec<String>
fn parameter_names_matching(&self, pattern: &str) -> Vec<String>
Get parameter names matching a pattern Read more
Source§fn parameters_by_type(&self, param_type: &str) -> HashMap<String, Parameter>
fn parameters_by_type(&self, param_type: &str) -> HashMap<String, Parameter>
Get parameters by layer type (e.g., “weight”, “bias”) Read more
Source§fn clone_parameters(&self) -> HashMap<String, Tensor>
fn clone_parameters(&self) -> HashMap<String, Tensor>
Clone module parameters (for creating copies or checkpoints) Read more
Source§fn diagnose(&self) -> ModuleDiagnostics
fn diagnose(&self) -> ModuleDiagnostics
Quick diagnostic check of module health Read more
Auto Trait Implementations§
impl Freeze for TensorParallel
impl !RefUnwindSafe for TensorParallel
impl Send for TensorParallel
impl Sync for TensorParallel
impl Unpin for TensorParallel
impl UnsafeUnpin for TensorParallel
impl !UnwindSafe for TensorParallel
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> ModuleApply for Twhere
T: Module,
impl<T> ModuleApply for Twhere
T: Module,
Source§fn apply<F>(&mut self, f: &F) -> Result<(), TorshError>
fn apply<F>(&mut self, f: &F) -> Result<(), TorshError>
Apply a function to all submodules recursively
Source§fn apply_to_parameters<F>(&mut self, _f: &F) -> Result<(), TorshError>
fn apply_to_parameters<F>(&mut self, _f: &F) -> Result<(), TorshError>
Apply function to all parameters recursively
Source§fn apply_to_modules<F>(&mut self, _f: &F) -> Result<(), TorshError>
fn apply_to_modules<F>(&mut self, _f: &F) -> Result<(), TorshError>
Apply function to all modules recursively
Source§impl<T> ModuleComposition for Twhere
T: Module + 'static,
impl<T> ModuleComposition for Twhere
T: Module + 'static,
Source§fn then<Other>(self, other: Other) -> ComposedModule<T, Other>where
Other: Module + 'static,
fn then<Other>(self, other: Other) -> ComposedModule<T, Other>where
Other: Module + 'static,
Compose this module with another module sequentially Read more
Source§fn parallel<Other>(self, other: Other) -> ParallelModule<T, Other>where
Other: Module + 'static,
fn parallel<Other>(self, other: Other) -> ParallelModule<T, Other>where
Other: Module + 'static,
Compose this module with another module in parallel Read more
Source§fn residual(self) -> ResidualModule<T>
fn residual(self) -> ResidualModule<T>
Add a residual connection Read more
Source§fn conditional<F>(self, condition_fn: F) -> ConditionalModule<T, F>
fn conditional<F>(self, condition_fn: F) -> ConditionalModule<T, F>
Add conditional execution Read more
Source§impl<T> ModuleExt for T
impl<T> ModuleExt for T
Source§fn and_then<F>(&self, input: &Tensor, f: F) -> Result<Tensor, TorshError>
fn and_then<F>(&self, input: &Tensor, f: F) -> Result<Tensor, TorshError>
Chain forward pass with a transformation function Read more
Source§fn map<F>(&self, input: &Tensor, f: F) -> Result<Tensor, TorshError>
fn map<F>(&self, input: &Tensor, f: F) -> Result<Tensor, TorshError>
Apply module and map the output with a function Read more
Source§fn with_input<F>(&self, input: &Tensor, f: F) -> Result<Tensor, TorshError>
fn with_input<F>(&self, input: &Tensor, f: F) -> Result<Tensor, TorshError>
Forward pass with input transformation Read more
Source§fn print_summary(&self)
fn print_summary(&self)
Print module summary to stdout
Source§fn parameter_stats(&self) -> ParameterStats
fn parameter_stats(&self) -> ParameterStats
Get parameter statistics Read more
Source§fn has_finite_parameters(&self) -> bool
fn has_finite_parameters(&self) -> bool
Check if module has NaN or Inf in parameters Read more
Source§fn freeze_matching(&mut self, pattern: &str) -> usize
fn freeze_matching(&mut self, pattern: &str) -> usize
Freeze specific parameters by name pattern Read more
Source§fn unfreeze_matching(&mut self, pattern: &str) -> usize
fn unfreeze_matching(&mut self, pattern: &str) -> usize
Unfreeze specific parameters by name pattern Read more
Source§fn clone_state_dict(&self) -> HashMap<String, Tensor>
fn clone_state_dict(&self) -> HashMap<String, Tensor>
Clone module parameters into a new state dict Read more
Source§fn apply_to_parameters<F>(&self, f: F)
fn apply_to_parameters<F>(&self, f: F)
Apply a function to all parameters Read more
Source§fn parameters_by_type(&self) -> HashMap<String, usize>
fn parameters_by_type(&self) -> HashMap<String, usize>
Count parameters by layer type Read more
Source§fn validate(&self) -> Result<ValidationReport, TorshError>
fn validate(&self) -> Result<ValidationReport, TorshError>
Validate module configuration Read more