pub struct GradientAccumulationCallback { /* private fields */ }Expand description
Gradient Accumulation callback with advanced features.
Simulates larger batch sizes by accumulating gradients over multiple mini-batches before updating parameters. This is useful when GPU memory is limited but you want to train with effectively larger batches.
Effective batch size = mini_batch_size * accumulation_steps
§Features
- Memory-efficient in-place accumulation
- Multiple scaling strategies
- Gradient overflow detection
- Memory usage tracking
- Automatic gradient zeroing
§Example
ⓘ
use tensorlogic_train::{GradientAccumulationCallback, GradientScalingStrategy};
let mut grad_accum = GradientAccumulationCallback::new(
4, // accumulate over 4 mini-batches
GradientScalingStrategy::Average,
).unwrap();Implementations§
Source§impl GradientAccumulationCallback
impl GradientAccumulationCallback
Sourcepub fn new(accumulation_steps: usize) -> TrainResult<Self>
pub fn new(accumulation_steps: usize) -> TrainResult<Self>
Create a new Gradient Accumulation callback with default average scaling.
§Arguments
accumulation_steps- Number of mini-batches to accumulate (e.g., 4, 8, 16)
Sourcepub fn with_strategy(
accumulation_steps: usize,
scaling_strategy: GradientScalingStrategy,
) -> TrainResult<Self>
pub fn with_strategy( accumulation_steps: usize, scaling_strategy: GradientScalingStrategy, ) -> TrainResult<Self>
Create a new Gradient Accumulation callback with specified scaling strategy.
§Arguments
accumulation_steps- Number of mini-batches to accumulatescaling_strategy- How to scale accumulated gradients
Sourcepub fn with_grad_clipping(self, max_norm: f64) -> Self
pub fn with_grad_clipping(self, max_norm: f64) -> Self
Enable gradient clipping during accumulation.
§Arguments
max_norm- Maximum gradient norm before clipping
Sourcepub fn accumulate(
&mut self,
gradients: &HashMap<String, Array<f64, Ix2>>,
) -> TrainResult<()>
pub fn accumulate( &mut self, gradients: &HashMap<String, Array<f64, Ix2>>, ) -> TrainResult<()>
Accumulate gradients with optional clipping and overflow detection.
Sourcepub fn should_update(&self) -> bool
pub fn should_update(&self) -> bool
Check if we should perform an optimizer step.
Sourcepub fn get_and_reset(&mut self) -> HashMap<String, Array<f64, Ix2>>
pub fn get_and_reset(&mut self) -> HashMap<String, Array<f64, Ix2>>
Get scaled accumulated gradients and reset state.
Sourcepub fn get_stats(&self) -> GradientAccumulationStats
pub fn get_stats(&self) -> GradientAccumulationStats
Get statistics about gradient accumulation.
Trait Implementations§
Source§impl Callback for GradientAccumulationCallback
impl Callback for GradientAccumulationCallback
Source§fn on_epoch_begin(
&mut self,
_epoch: usize,
_state: &TrainingState,
) -> TrainResult<()>
fn on_epoch_begin( &mut self, _epoch: usize, _state: &TrainingState, ) -> TrainResult<()>
Called at the beginning of an epoch.
Source§fn on_train_begin(&mut self, _state: &TrainingState) -> TrainResult<()>
fn on_train_begin(&mut self, _state: &TrainingState) -> TrainResult<()>
Called at the beginning of training.
Source§fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()>
fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()>
Called at the end of training.
Source§fn on_epoch_end(
&mut self,
_epoch: usize,
_state: &TrainingState,
) -> TrainResult<()>
fn on_epoch_end( &mut self, _epoch: usize, _state: &TrainingState, ) -> TrainResult<()>
Called at the end of an epoch.
Source§fn on_batch_begin(
&mut self,
_batch: usize,
_state: &TrainingState,
) -> TrainResult<()>
fn on_batch_begin( &mut self, _batch: usize, _state: &TrainingState, ) -> TrainResult<()>
Called at the beginning of a batch.
Source§fn on_batch_end(
&mut self,
_batch: usize,
_state: &TrainingState,
) -> TrainResult<()>
fn on_batch_end( &mut self, _batch: usize, _state: &TrainingState, ) -> TrainResult<()>
Called at the end of a batch.
Source§fn on_validation_end(&mut self, _state: &TrainingState) -> TrainResult<()>
fn on_validation_end(&mut self, _state: &TrainingState) -> TrainResult<()>
Called after validation.
Source§fn should_stop(&self) -> bool
fn should_stop(&self) -> bool
Check if training should stop early.
Auto Trait Implementations§
impl Freeze for GradientAccumulationCallback
impl RefUnwindSafe for GradientAccumulationCallback
impl Send for GradientAccumulationCallback
impl Sync for GradientAccumulationCallback
impl Unpin for GradientAccumulationCallback
impl UnwindSafe for GradientAccumulationCallback
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more