pub fn validate_weight_decay(weight_decay: f32) -> PyResult<()>
Validate that weight decay is non-negative