Skip to main content

mixed_precision_train_step

Function mixed_precision_train_step 

Source
pub fn mixed_precision_train_step(
    graph: &mut Graph,
    model: &SequentialModel,
    input: &Tensor,
    target: &Tensor,
    _config: &MixedPrecisionConfig,
    scaler: &mut DynamicLossScaler,
) -> Result<(f32, bool), ModelError>
Expand description

Runs a mixed-precision forward+backward step.

  1. Cast input to forward_dtype, then back to F32 for graph computation
  2. Run forward pass, compute loss
  3. Scale loss, backprop
  4. Check gradients for overflow
  5. Update scaler state

Returns (loss_value, step_applied).