pub fn mixed_precision_train_step(
graph: &mut Graph,
model: &SequentialModel,
input: &Tensor,
target: &Tensor,
_config: &MixedPrecisionConfig,
scaler: &mut DynamicLossScaler,
) -> Result<(f32, bool), ModelError>Expand description
Runs a mixed-precision forward+backward step.
- Cast input to forward_dtype, then back to F32 for graph computation
- Run forward pass, compute loss
- Scale loss, backprop
- Check gradients for overflow
- Update scaler state
Returns (loss_value, step_applied).