pub struct MixedPrecisionTrainer { /* private fields */ }Expand description
Mixed precision trainer
Implementations§
Source§impl MixedPrecisionTrainer
impl MixedPrecisionTrainer
Sourcepub fn new(config: MixedPrecisionConfig) -> Self
pub fn new(config: MixedPrecisionConfig) -> Self
Create new mixed precision trainer
Sourcepub fn to_mixed_precision(
&self,
array: &Array2<Float>,
operation_name: &str,
) -> MixedPrecisionArray
pub fn to_mixed_precision( &self, array: &Array2<Float>, operation_name: &str, ) -> MixedPrecisionArray
Convert array to mixed precision format
Sourcepub fn to_full_precision(&self, array: &MixedPrecisionArray) -> Array2<Float>
pub fn to_full_precision(&self, array: &MixedPrecisionArray) -> Array2<Float>
Convert mixed precision array back to FP32
Sourcepub fn scale_gradients(&self, gradients: &mut Array2<Float>)
pub fn scale_gradients(&self, gradients: &mut Array2<Float>)
Scale gradients to prevent underflow
Sourcepub fn unscale_gradients(&self, gradients: &mut Array2<Float>) -> bool
pub fn unscale_gradients(&self, gradients: &mut Array2<Float>) -> bool
Unscale gradients after backward pass
Sourcepub fn update_scale(&mut self, overflow_detected: bool)
pub fn update_scale(&mut self, overflow_detected: bool)
Update loss scale based on overflow detection
Sourcepub fn should_skip_step(&self) -> bool
pub fn should_skip_step(&self) -> bool
Check if current step should be skipped due to overflow
Sourcepub fn get_loss_scale(&self) -> Float
pub fn get_loss_scale(&self) -> Float
Get current loss scale
Sourcepub fn train_ensemble_mixed_precision<F>(
&mut self,
x: &Array2<Float>,
y: &Array1<Int>,
n_estimators: usize,
train_fn: F,
) -> Result<Vec<Array1<Float>>>
pub fn train_ensemble_mixed_precision<F>( &mut self, x: &Array2<Float>, y: &Array1<Int>, n_estimators: usize, train_fn: F, ) -> Result<Vec<Array1<Float>>>
Train ensemble with mixed precision
Sourcepub fn scaler_state(&self) -> &ScalerState
pub fn scaler_state(&self) -> &ScalerState
Get scaler state
Sourcepub fn reset_scaler(&mut self)
pub fn reset_scaler(&mut self)
Reset scaler state
Auto Trait Implementations§
impl Freeze for MixedPrecisionTrainer
impl RefUnwindSafe for MixedPrecisionTrainer
impl Send for MixedPrecisionTrainer
impl Sync for MixedPrecisionTrainer
impl Unpin for MixedPrecisionTrainer
impl UnwindSafe for MixedPrecisionTrainer
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more