Skip to main content

unsloth_rs/
training.rs

1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tyler Zervas
3
4//! Training utilities.
5//!
6//! This module provides training utilities including:
7//! - Mixed precision training support (FP32, FP16, BF16)
8//! - Gradient scaling for numerical stability
9//! - Gradient checkpointing configuration
10
11use candle_core::{DType, Tensor};
12
13use crate::error::{Result, UnslothError};
14use crate::memory::CheckpointConfig;
15
16/// Precision mode for training.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum PrecisionMode {
19    /// Full precision (FP32)
20    Full,
21    /// Half precision (FP16)
22    Half,
23    /// Brain float 16 (BF16)
24    BFloat16,
25}
26
27impl PrecisionMode {
28    /// Convert precision mode to Candle `DType`.
29    #[must_use]
30    pub fn to_dtype(&self) -> DType {
31        match self {
32            Self::Full => DType::F32,
33            Self::Half => DType::F16,
34            Self::BFloat16 => DType::BF16,
35        }
36    }
37
38    /// Get the precision mode from a Candle `DType`.
39    ///
40    /// # Errors
41    /// Returns error if dtype is not a supported floating point type.
42    pub fn from_dtype(dtype: DType) -> Result<Self> {
43        match dtype {
44            DType::F32 => Ok(Self::Full),
45            DType::F16 => Ok(Self::Half),
46            DType::BF16 => Ok(Self::BFloat16),
47            _ => Err(UnslothError::InvalidConfig(format!(
48                "Unsupported dtype for mixed precision: {dtype:?}"
49            ))),
50        }
51    }
52}
53
54/// Mixed precision training configuration.
55#[derive(Debug, Clone)]
56pub struct MixedPrecisionConfig {
57    /// Precision mode for computation
58    pub compute_precision: PrecisionMode,
59    /// Precision mode for master weights (usually FP32)
60    pub master_precision: PrecisionMode,
61    /// Loss scale factor to prevent gradient underflow
62    pub loss_scale: f32,
63    /// Enable dynamic loss scaling
64    pub dynamic_loss_scale: bool,
65    /// Minimum loss scale for dynamic scaling
66    pub min_loss_scale: f32,
67    /// Maximum loss scale for dynamic scaling
68    pub max_loss_scale: f32,
69    /// Growth factor for dynamic loss scaling
70    pub scale_growth_factor: f32,
71    /// Backoff factor for dynamic loss scaling
72    pub scale_backoff_factor: f32,
73    /// Number of consecutive non-overflow steps before increasing scale
74    pub scale_growth_interval: usize,
75}
76
77impl Default for MixedPrecisionConfig {
78    fn default() -> Self {
79        Self {
80            compute_precision: PrecisionMode::Half,
81            master_precision: PrecisionMode::Full,
82            loss_scale: 65536.0, // 2^16
83            dynamic_loss_scale: true,
84            min_loss_scale: 1.0,
85            max_loss_scale: 2_147_483_648.0, // 2^31
86            scale_growth_factor: 2.0,
87            scale_backoff_factor: 0.5,
88            scale_growth_interval: 2000,
89        }
90    }
91}
92
93impl MixedPrecisionConfig {
94    /// Create a new mixed precision configuration.
95    #[must_use]
96    pub fn new(compute_precision: PrecisionMode) -> Self {
97        Self {
98            compute_precision,
99            ..Default::default()
100        }
101    }
102
103    /// Create a configuration for FP16 training.
104    #[must_use]
105    pub fn fp16() -> Self {
106        Self::new(PrecisionMode::Half)
107    }
108
109    /// Create a configuration for BF16 training.
110    #[must_use]
111    pub fn bf16() -> Self {
112        Self::new(PrecisionMode::BFloat16)
113    }
114
115    /// Create a configuration for FP32 training (no mixed precision).
116    #[must_use]
117    pub fn fp32() -> Self {
118        Self {
119            compute_precision: PrecisionMode::Full,
120            master_precision: PrecisionMode::Full,
121            dynamic_loss_scale: false,
122            loss_scale: 1.0,
123            ..Default::default()
124        }
125    }
126}
127
128/// Training configuration.
129#[derive(Debug, Clone)]
130pub struct TrainingConfig {
131    /// Batch size
132    pub batch_size: usize,
133    /// Maximum sequence length
134    pub max_seq_len: usize,
135    /// Gradient accumulation steps
136    pub gradient_accumulation_steps: usize,
137    /// Mixed precision configuration (None = FP32)
138    pub mixed_precision: Option<MixedPrecisionConfig>,
139    /// Gradient checkpointing
140    pub checkpoint_config: CheckpointConfig,
141}
142
143impl Default for TrainingConfig {
144    fn default() -> Self {
145        Self {
146            batch_size: 4,
147            max_seq_len: 2048,
148            gradient_accumulation_steps: 4,
149            mixed_precision: Some(MixedPrecisionConfig::default()),
150            checkpoint_config: CheckpointConfig::default(),
151        }
152    }
153}
154
155/// Convert tensor to specified precision.
156///
157/// # Arguments
158/// * `tensor` - Input tensor to convert
159/// * `precision` - Target precision mode (FP32, FP16, or BF16)
160///
161/// # Returns
162/// Tensor converted to target precision
163///
164/// # Errors
165/// Returns an error if the dtype conversion fails.
166pub fn convert_precision(tensor: &Tensor, precision: PrecisionMode) -> Result<Tensor> {
167    let target_dtype = precision.to_dtype();
168    if tensor.dtype() == target_dtype {
169        Ok(tensor.clone())
170    } else {
171        Ok(tensor.to_dtype(target_dtype)?)
172    }
173}
174
175/// Scale loss for mixed precision training.
176///
177/// Scales the loss by the loss scale factor to prevent gradient underflow
178/// in lower precision formats.
179///
180/// # Arguments
181/// * `loss` - Original loss tensor to scale
182/// * `config` - Mixed precision configuration containing the loss scale factor
183///
184/// # Returns
185/// Scaled loss tensor
186///
187/// # Errors
188/// Returns an error if tensor multiplication fails.
189pub fn scale_loss(loss: &Tensor, config: &MixedPrecisionConfig) -> Result<Tensor> {
190    if (config.loss_scale - 1.0).abs() < f32::EPSILON {
191        Ok(loss.clone())
192    } else {
193        Ok((loss * f64::from(config.loss_scale))?)
194    }
195}
196
197/// Unscale gradients after backward pass.
198///
199/// Divides gradients by the loss scale factor to get the true gradient values.
200///
201/// # Arguments
202/// * `gradients` - Scaled gradients from backward pass
203/// * `config` - Mixed precision configuration
204///
205/// # Returns
206/// Unscaled gradients
207///
208/// # Errors
209///
210/// Returns an error if tensor operations fail.
211pub fn unscale_gradients(
212    gradients: &[Tensor],
213    config: &MixedPrecisionConfig,
214) -> Result<Vec<Tensor>> {
215    if (config.loss_scale - 1.0).abs() < f32::EPSILON {
216        Ok(gradients.to_vec())
217    } else {
218        let scale = 1.0 / f64::from(config.loss_scale);
219        gradients
220            .iter()
221            .map(|g| (g * scale).map_err(Into::into))
222            .collect()
223    }
224}
225
226/// Check if gradients contain NaN or Inf values.
227///
228/// Used to detect gradient overflow in mixed precision training.
229///
230/// # Arguments
231/// * `gradients` - Slice of gradient tensors to check for numerical instability
232///
233/// # Returns
234/// `true` if any gradient contains NaN or Inf, `false` otherwise
235///
236/// # Errors
237/// Returns an error if tensor dtype conversion or flattening fails.
238pub fn has_inf_or_nan(gradients: &[Tensor]) -> Result<bool> {
239    for grad in gradients {
240        let grad_f32 = grad.to_dtype(DType::F32)?;
241        let values: Vec<f32> = grad_f32.flatten_all()?.to_vec1()?;
242
243        for &val in &values {
244            if val.is_nan() || val.is_infinite() {
245                return Ok(true);
246            }
247        }
248    }
249    Ok(false)
250}
251
252/// Update loss scale based on gradient overflow status.
253///
254/// Implements dynamic loss scaling to automatically adjust the loss scale
255/// based on whether gradients overflow.
256///
257/// # Arguments
258/// * `config` - Mixed precision configuration (will be modified)
259/// * `has_overflow` - Whether gradients overflowed in this step
260/// * `steps_since_overflow` - Number of steps since last overflow
261///
262/// # Returns
263/// New loss scale value
264#[allow(clippy::cast_possible_truncation)]
265#[allow(clippy::cast_sign_loss)]
266pub fn update_loss_scale(
267    config: &mut MixedPrecisionConfig,
268    has_overflow: bool,
269    steps_since_overflow: usize,
270) -> f32 {
271    if !config.dynamic_loss_scale {
272        return config.loss_scale;
273    }
274
275    if has_overflow {
276        // Reduce loss scale on overflow
277        config.loss_scale =
278            (config.loss_scale * config.scale_backoff_factor).max(config.min_loss_scale);
279    } else if steps_since_overflow >= config.scale_growth_interval {
280        // Increase loss scale after many successful steps
281        config.loss_scale =
282            (config.loss_scale * config.scale_growth_factor).min(config.max_loss_scale);
283    }
284
285    config.loss_scale
286}
287
288/// Compute gradient with optional checkpointing.
289///
290/// This function performs gradient computation with activation checkpointing,
291/// which trades compute for memory by recomputing activations during the backward pass
292/// instead of storing them in memory.
293///
294/// # Arguments
295/// * `_input` - Input tensor for the forward pass
296/// * `_forward_fn` - Function that computes the forward pass
297/// * `_config` - Checkpoint configuration specifying checkpointing strategy
298///
299/// # Returns
300/// Computed gradient tensor
301///
302/// # Errors
303/// Returns an error if gradient computation fails.
304///
305/// # Note
306/// This is currently unimplemented and will return an error.
307/// Gradient checkpointing is planned for a future release.
308pub fn compute_gradient_checkpointed<F>(
309    _input: &Tensor,
310    _forward_fn: F,
311    _config: &CheckpointConfig,
312) -> Result<Tensor>
313where
314    F: Fn(&Tensor) -> Result<Tensor>,
315{
316    // TODO: Implement gradient checkpointing
317    // This would recompute forward pass during backward instead of storing activations
318    Err(UnslothError::InvalidConfig(
319        "Gradient checkpointing is not yet implemented. This feature is planned for a future release.".to_string()
320    ))
321}
322
323/// Scale gradients for mixed precision training.
324pub fn scale_gradients(gradients: &[Tensor], scale: f32) -> Result<Vec<Tensor>> {
325    gradients
326        .iter()
327        .map(|g| (g * f64::from(scale)).map_err(Into::into))
328        .collect()
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334    use candle_core::Device;
335
336    #[test]
337    fn test_training_config_default() {
338        let config = TrainingConfig::default();
339        assert_eq!(config.batch_size, 4);
340        assert!(config.mixed_precision.is_some());
341    }
342
343    #[test]
344    fn test_precision_mode_to_dtype() {
345        assert_eq!(PrecisionMode::Full.to_dtype(), DType::F32);
346        assert_eq!(PrecisionMode::Half.to_dtype(), DType::F16);
347        assert_eq!(PrecisionMode::BFloat16.to_dtype(), DType::BF16);
348    }
349
350    #[test]
351    fn test_precision_mode_from_dtype() {
352        assert_eq!(
353            PrecisionMode::from_dtype(DType::F32).unwrap(),
354            PrecisionMode::Full
355        );
356        assert_eq!(
357            PrecisionMode::from_dtype(DType::F16).unwrap(),
358            PrecisionMode::Half
359        );
360        assert_eq!(
361            PrecisionMode::from_dtype(DType::BF16).unwrap(),
362            PrecisionMode::BFloat16
363        );
364
365        // Test unsupported dtype
366        assert!(PrecisionMode::from_dtype(DType::U8).is_err());
367    }
368
369    #[test]
370    fn test_mixed_precision_config_defaults() {
371        let config = MixedPrecisionConfig::default();
372        assert_eq!(config.compute_precision, PrecisionMode::Half);
373        assert_eq!(config.master_precision, PrecisionMode::Full);
374        assert_eq!(config.loss_scale, 65536.0);
375        assert!(config.dynamic_loss_scale);
376    }
377
378    #[test]
379    fn test_mixed_precision_config_fp16() {
380        let config = MixedPrecisionConfig::fp16();
381        assert_eq!(config.compute_precision, PrecisionMode::Half);
382        assert_eq!(config.master_precision, PrecisionMode::Full);
383    }
384
385    #[test]
386    fn test_mixed_precision_config_bf16() {
387        let config = MixedPrecisionConfig::bf16();
388        assert_eq!(config.compute_precision, PrecisionMode::BFloat16);
389    }
390
391    #[test]
392    fn test_mixed_precision_config_fp32() {
393        let config = MixedPrecisionConfig::fp32();
394        assert_eq!(config.compute_precision, PrecisionMode::Full);
395        assert_eq!(config.master_precision, PrecisionMode::Full);
396        assert!(!config.dynamic_loss_scale);
397        assert_eq!(config.loss_scale, 1.0);
398    }
399
400    #[test]
401    fn test_convert_precision() {
402        let device = Device::Cpu;
403        let tensor = Tensor::ones((2, 3), DType::F32, &device).unwrap();
404
405        // Convert to FP16
406        let fp16 = convert_precision(&tensor, PrecisionMode::Half).unwrap();
407        assert_eq!(fp16.dtype(), DType::F16);
408
409        // Convert to BF16
410        let bf16 = convert_precision(&tensor, PrecisionMode::BFloat16).unwrap();
411        assert_eq!(bf16.dtype(), DType::BF16);
412
413        // Convert to same precision should work
414        let same = convert_precision(&tensor, PrecisionMode::Full).unwrap();
415        assert_eq!(same.dtype(), DType::F32);
416    }
417
418    #[test]
419    fn test_scale_loss() {
420        let device = Device::Cpu;
421        let loss = Tensor::full(2.0f32, (), &device).unwrap(); // scalar tensor
422
423        let mut config = MixedPrecisionConfig::default();
424        config.loss_scale = 4.0;
425
426        let scaled = scale_loss(&loss, &config).unwrap();
427        let value: f32 = scaled.to_scalar().unwrap();
428
429        assert!((value - 8.0).abs() < 1e-5);
430    }
431
432    #[test]
433    fn test_unscale_gradients() {
434        let device = Device::Cpu;
435        let grad1 = Tensor::full(8.0f32, (2, 2), &device).unwrap();
436        let grad2 = Tensor::full(16.0f32, (2, 2), &device).unwrap();
437
438        let gradients = vec![grad1, grad2];
439
440        let mut config = MixedPrecisionConfig::default();
441        config.loss_scale = 4.0;
442
443        let unscaled = unscale_gradients(&gradients, &config).unwrap();
444
445        // Check first gradient: 8.0 / 4.0 = 2.0
446        let vals1: Vec<f32> = unscaled[0].flatten_all().unwrap().to_vec1().unwrap();
447        for val in vals1 {
448            assert!((val - 2.0).abs() < 1e-5);
449        }
450
451        // Check second gradient: 16.0 / 4.0 = 4.0
452        let vals2: Vec<f32> = unscaled[1].flatten_all().unwrap().to_vec1().unwrap();
453        for val in vals2 {
454            assert!((val - 4.0).abs() < 1e-5);
455        }
456    }
457
458    #[test]
459    fn test_has_inf_or_nan() {
460        let device = Device::Cpu;
461
462        // Test normal gradients
463        let grad1 = Tensor::ones((2, 2), DType::F32, &device).unwrap();
464        let grad2 = Tensor::full(2.0f32, (2, 2), &device).unwrap();
465        assert!(!has_inf_or_nan(&[grad1, grad2]).unwrap());
466
467        // Test with NaN
468        let nan_grad = Tensor::full(f32::NAN, (2, 2), &device).unwrap();
469        assert!(has_inf_or_nan(&[nan_grad]).unwrap());
470
471        // Test with Inf
472        let inf_grad = Tensor::full(f32::INFINITY, (2, 2), &device).unwrap();
473        assert!(has_inf_or_nan(&[inf_grad]).unwrap());
474    }
475
476    #[test]
477    fn test_update_loss_scale_on_overflow() {
478        let mut config = MixedPrecisionConfig {
479            loss_scale: 1000.0,
480            scale_backoff_factor: 0.5,
481            ..Default::default()
482        };
483
484        // Test backoff on overflow
485        let new_scale = update_loss_scale(&mut config, true, 0);
486        assert_eq!(new_scale, 500.0);
487        assert_eq!(config.loss_scale, 500.0);
488    }
489
490    #[test]
491    fn test_update_loss_scale_growth() {
492        let mut config = MixedPrecisionConfig {
493            loss_scale: 100.0,
494            scale_growth_factor: 2.0,
495            scale_growth_interval: 100,
496            ..Default::default()
497        };
498
499        // Test growth after many successful steps
500        let new_scale = update_loss_scale(&mut config, false, 100);
501        assert_eq!(new_scale, 200.0);
502        assert_eq!(config.loss_scale, 200.0);
503    }
504
505    #[test]
506    fn test_update_loss_scale_no_change() {
507        let mut config = MixedPrecisionConfig::default();
508        config.loss_scale = 100.0;
509
510        // No change if not enough steps and no overflow
511        let new_scale = update_loss_scale(&mut config, false, 10);
512        assert_eq!(new_scale, 100.0);
513    }
514
515    #[test]
516    fn test_update_loss_scale_bounds() {
517        let mut config = MixedPrecisionConfig {
518            min_loss_scale: 1.0,
519            max_loss_scale: 1000.0,
520            loss_scale: 2.0,
521            scale_backoff_factor: 0.5,
522            ..Default::default()
523        };
524
525        // Test min bound
526        update_loss_scale(&mut config, true, 0);
527        assert!((config.loss_scale - 1.0).abs() < f32::EPSILON); // Should hit min
528
529        // Test max bound
530        config.loss_scale = 600.0;
531        config.scale_growth_factor = 2.0;
532        config.scale_growth_interval = 10;
533        update_loss_scale(&mut config, false, 10);
534        assert!((config.loss_scale - 1000.0).abs() < f32::EPSILON); // Should hit max
535    }
536
537    #[test]
538    fn test_scale_gradients() {
539        let device = Device::Cpu;
540        let grad1 = Tensor::ones((2, 3), DType::F32, &device).unwrap();
541        let grad2 = Tensor::full(2.0f32, (2, 3), &device).unwrap();
542
543        let gradients = vec![grad1, grad2];
544        let scale = 0.5;
545
546        let scaled = scale_gradients(&gradients, scale).unwrap();
547
548        // Check first gradient: 1.0 * 0.5 = 0.5
549        let vals1: Vec<f32> = scaled[0].flatten_all().unwrap().to_vec1().unwrap();
550        for val in vals1 {
551            assert!((val - 0.5).abs() < 1e-5);
552        }
553
554        // Check second gradient: 2.0 * 0.5 = 1.0
555        let vals2: Vec<f32> = scaled[1].flatten_all().unwrap().to_vec1().unwrap();
556        for val in vals2 {
557            assert!((val - 1.0).abs() < 1e-5);
558        }
559    }
560}