1use candle_core::{DType, Tensor};
12
13use crate::error::{Result, UnslothError};
14use crate::memory::CheckpointConfig;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum PrecisionMode {
19 Full,
21 Half,
23 BFloat16,
25}
26
27impl PrecisionMode {
28 #[must_use]
30 pub fn to_dtype(&self) -> DType {
31 match self {
32 Self::Full => DType::F32,
33 Self::Half => DType::F16,
34 Self::BFloat16 => DType::BF16,
35 }
36 }
37
38 pub fn from_dtype(dtype: DType) -> Result<Self> {
43 match dtype {
44 DType::F32 => Ok(Self::Full),
45 DType::F16 => Ok(Self::Half),
46 DType::BF16 => Ok(Self::BFloat16),
47 _ => Err(UnslothError::InvalidConfig(format!(
48 "Unsupported dtype for mixed precision: {dtype:?}"
49 ))),
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct MixedPrecisionConfig {
57 pub compute_precision: PrecisionMode,
59 pub master_precision: PrecisionMode,
61 pub loss_scale: f32,
63 pub dynamic_loss_scale: bool,
65 pub min_loss_scale: f32,
67 pub max_loss_scale: f32,
69 pub scale_growth_factor: f32,
71 pub scale_backoff_factor: f32,
73 pub scale_growth_interval: usize,
75}
76
77impl Default for MixedPrecisionConfig {
78 fn default() -> Self {
79 Self {
80 compute_precision: PrecisionMode::Half,
81 master_precision: PrecisionMode::Full,
82 loss_scale: 65536.0, dynamic_loss_scale: true,
84 min_loss_scale: 1.0,
85 max_loss_scale: 2_147_483_648.0, scale_growth_factor: 2.0,
87 scale_backoff_factor: 0.5,
88 scale_growth_interval: 2000,
89 }
90 }
91}
92
93impl MixedPrecisionConfig {
94 #[must_use]
96 pub fn new(compute_precision: PrecisionMode) -> Self {
97 Self {
98 compute_precision,
99 ..Default::default()
100 }
101 }
102
103 #[must_use]
105 pub fn fp16() -> Self {
106 Self::new(PrecisionMode::Half)
107 }
108
109 #[must_use]
111 pub fn bf16() -> Self {
112 Self::new(PrecisionMode::BFloat16)
113 }
114
115 #[must_use]
117 pub fn fp32() -> Self {
118 Self {
119 compute_precision: PrecisionMode::Full,
120 master_precision: PrecisionMode::Full,
121 dynamic_loss_scale: false,
122 loss_scale: 1.0,
123 ..Default::default()
124 }
125 }
126}
127
128#[derive(Debug, Clone)]
130pub struct TrainingConfig {
131 pub batch_size: usize,
133 pub max_seq_len: usize,
135 pub gradient_accumulation_steps: usize,
137 pub mixed_precision: Option<MixedPrecisionConfig>,
139 pub checkpoint_config: CheckpointConfig,
141}
142
143impl Default for TrainingConfig {
144 fn default() -> Self {
145 Self {
146 batch_size: 4,
147 max_seq_len: 2048,
148 gradient_accumulation_steps: 4,
149 mixed_precision: Some(MixedPrecisionConfig::default()),
150 checkpoint_config: CheckpointConfig::default(),
151 }
152 }
153}
154
155pub fn convert_precision(tensor: &Tensor, precision: PrecisionMode) -> Result<Tensor> {
167 let target_dtype = precision.to_dtype();
168 if tensor.dtype() == target_dtype {
169 Ok(tensor.clone())
170 } else {
171 Ok(tensor.to_dtype(target_dtype)?)
172 }
173}
174
175pub fn scale_loss(loss: &Tensor, config: &MixedPrecisionConfig) -> Result<Tensor> {
190 if (config.loss_scale - 1.0).abs() < f32::EPSILON {
191 Ok(loss.clone())
192 } else {
193 Ok((loss * f64::from(config.loss_scale))?)
194 }
195}
196
197pub fn unscale_gradients(
212 gradients: &[Tensor],
213 config: &MixedPrecisionConfig,
214) -> Result<Vec<Tensor>> {
215 if (config.loss_scale - 1.0).abs() < f32::EPSILON {
216 Ok(gradients.to_vec())
217 } else {
218 let scale = 1.0 / f64::from(config.loss_scale);
219 gradients
220 .iter()
221 .map(|g| (g * scale).map_err(Into::into))
222 .collect()
223 }
224}
225
226pub fn has_inf_or_nan(gradients: &[Tensor]) -> Result<bool> {
239 for grad in gradients {
240 let grad_f32 = grad.to_dtype(DType::F32)?;
241 let values: Vec<f32> = grad_f32.flatten_all()?.to_vec1()?;
242
243 for &val in &values {
244 if val.is_nan() || val.is_infinite() {
245 return Ok(true);
246 }
247 }
248 }
249 Ok(false)
250}
251
252#[allow(clippy::cast_possible_truncation)]
265#[allow(clippy::cast_sign_loss)]
266pub fn update_loss_scale(
267 config: &mut MixedPrecisionConfig,
268 has_overflow: bool,
269 steps_since_overflow: usize,
270) -> f32 {
271 if !config.dynamic_loss_scale {
272 return config.loss_scale;
273 }
274
275 if has_overflow {
276 config.loss_scale =
278 (config.loss_scale * config.scale_backoff_factor).max(config.min_loss_scale);
279 } else if steps_since_overflow >= config.scale_growth_interval {
280 config.loss_scale =
282 (config.loss_scale * config.scale_growth_factor).min(config.max_loss_scale);
283 }
284
285 config.loss_scale
286}
287
288pub fn compute_gradient_checkpointed<F>(
309 _input: &Tensor,
310 _forward_fn: F,
311 _config: &CheckpointConfig,
312) -> Result<Tensor>
313where
314 F: Fn(&Tensor) -> Result<Tensor>,
315{
316 Err(UnslothError::InvalidConfig(
319 "Gradient checkpointing is not yet implemented. This feature is planned for a future release.".to_string()
320 ))
321}
322
323pub fn scale_gradients(gradients: &[Tensor], scale: f32) -> Result<Vec<Tensor>> {
325 gradients
326 .iter()
327 .map(|g| (g * f64::from(scale)).map_err(Into::into))
328 .collect()
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334 use candle_core::Device;
335
336 #[test]
337 fn test_training_config_default() {
338 let config = TrainingConfig::default();
339 assert_eq!(config.batch_size, 4);
340 assert!(config.mixed_precision.is_some());
341 }
342
343 #[test]
344 fn test_precision_mode_to_dtype() {
345 assert_eq!(PrecisionMode::Full.to_dtype(), DType::F32);
346 assert_eq!(PrecisionMode::Half.to_dtype(), DType::F16);
347 assert_eq!(PrecisionMode::BFloat16.to_dtype(), DType::BF16);
348 }
349
350 #[test]
351 fn test_precision_mode_from_dtype() {
352 assert_eq!(
353 PrecisionMode::from_dtype(DType::F32).unwrap(),
354 PrecisionMode::Full
355 );
356 assert_eq!(
357 PrecisionMode::from_dtype(DType::F16).unwrap(),
358 PrecisionMode::Half
359 );
360 assert_eq!(
361 PrecisionMode::from_dtype(DType::BF16).unwrap(),
362 PrecisionMode::BFloat16
363 );
364
365 assert!(PrecisionMode::from_dtype(DType::U8).is_err());
367 }
368
369 #[test]
370 fn test_mixed_precision_config_defaults() {
371 let config = MixedPrecisionConfig::default();
372 assert_eq!(config.compute_precision, PrecisionMode::Half);
373 assert_eq!(config.master_precision, PrecisionMode::Full);
374 assert_eq!(config.loss_scale, 65536.0);
375 assert!(config.dynamic_loss_scale);
376 }
377
378 #[test]
379 fn test_mixed_precision_config_fp16() {
380 let config = MixedPrecisionConfig::fp16();
381 assert_eq!(config.compute_precision, PrecisionMode::Half);
382 assert_eq!(config.master_precision, PrecisionMode::Full);
383 }
384
385 #[test]
386 fn test_mixed_precision_config_bf16() {
387 let config = MixedPrecisionConfig::bf16();
388 assert_eq!(config.compute_precision, PrecisionMode::BFloat16);
389 }
390
391 #[test]
392 fn test_mixed_precision_config_fp32() {
393 let config = MixedPrecisionConfig::fp32();
394 assert_eq!(config.compute_precision, PrecisionMode::Full);
395 assert_eq!(config.master_precision, PrecisionMode::Full);
396 assert!(!config.dynamic_loss_scale);
397 assert_eq!(config.loss_scale, 1.0);
398 }
399
400 #[test]
401 fn test_convert_precision() {
402 let device = Device::Cpu;
403 let tensor = Tensor::ones((2, 3), DType::F32, &device).unwrap();
404
405 let fp16 = convert_precision(&tensor, PrecisionMode::Half).unwrap();
407 assert_eq!(fp16.dtype(), DType::F16);
408
409 let bf16 = convert_precision(&tensor, PrecisionMode::BFloat16).unwrap();
411 assert_eq!(bf16.dtype(), DType::BF16);
412
413 let same = convert_precision(&tensor, PrecisionMode::Full).unwrap();
415 assert_eq!(same.dtype(), DType::F32);
416 }
417
418 #[test]
419 fn test_scale_loss() {
420 let device = Device::Cpu;
421 let loss = Tensor::full(2.0f32, (), &device).unwrap(); let mut config = MixedPrecisionConfig::default();
424 config.loss_scale = 4.0;
425
426 let scaled = scale_loss(&loss, &config).unwrap();
427 let value: f32 = scaled.to_scalar().unwrap();
428
429 assert!((value - 8.0).abs() < 1e-5);
430 }
431
432 #[test]
433 fn test_unscale_gradients() {
434 let device = Device::Cpu;
435 let grad1 = Tensor::full(8.0f32, (2, 2), &device).unwrap();
436 let grad2 = Tensor::full(16.0f32, (2, 2), &device).unwrap();
437
438 let gradients = vec![grad1, grad2];
439
440 let mut config = MixedPrecisionConfig::default();
441 config.loss_scale = 4.0;
442
443 let unscaled = unscale_gradients(&gradients, &config).unwrap();
444
445 let vals1: Vec<f32> = unscaled[0].flatten_all().unwrap().to_vec1().unwrap();
447 for val in vals1 {
448 assert!((val - 2.0).abs() < 1e-5);
449 }
450
451 let vals2: Vec<f32> = unscaled[1].flatten_all().unwrap().to_vec1().unwrap();
453 for val in vals2 {
454 assert!((val - 4.0).abs() < 1e-5);
455 }
456 }
457
458 #[test]
459 fn test_has_inf_or_nan() {
460 let device = Device::Cpu;
461
462 let grad1 = Tensor::ones((2, 2), DType::F32, &device).unwrap();
464 let grad2 = Tensor::full(2.0f32, (2, 2), &device).unwrap();
465 assert!(!has_inf_or_nan(&[grad1, grad2]).unwrap());
466
467 let nan_grad = Tensor::full(f32::NAN, (2, 2), &device).unwrap();
469 assert!(has_inf_or_nan(&[nan_grad]).unwrap());
470
471 let inf_grad = Tensor::full(f32::INFINITY, (2, 2), &device).unwrap();
473 assert!(has_inf_or_nan(&[inf_grad]).unwrap());
474 }
475
476 #[test]
477 fn test_update_loss_scale_on_overflow() {
478 let mut config = MixedPrecisionConfig {
479 loss_scale: 1000.0,
480 scale_backoff_factor: 0.5,
481 ..Default::default()
482 };
483
484 let new_scale = update_loss_scale(&mut config, true, 0);
486 assert_eq!(new_scale, 500.0);
487 assert_eq!(config.loss_scale, 500.0);
488 }
489
490 #[test]
491 fn test_update_loss_scale_growth() {
492 let mut config = MixedPrecisionConfig {
493 loss_scale: 100.0,
494 scale_growth_factor: 2.0,
495 scale_growth_interval: 100,
496 ..Default::default()
497 };
498
499 let new_scale = update_loss_scale(&mut config, false, 100);
501 assert_eq!(new_scale, 200.0);
502 assert_eq!(config.loss_scale, 200.0);
503 }
504
505 #[test]
506 fn test_update_loss_scale_no_change() {
507 let mut config = MixedPrecisionConfig::default();
508 config.loss_scale = 100.0;
509
510 let new_scale = update_loss_scale(&mut config, false, 10);
512 assert_eq!(new_scale, 100.0);
513 }
514
515 #[test]
516 fn test_update_loss_scale_bounds() {
517 let mut config = MixedPrecisionConfig {
518 min_loss_scale: 1.0,
519 max_loss_scale: 1000.0,
520 loss_scale: 2.0,
521 scale_backoff_factor: 0.5,
522 ..Default::default()
523 };
524
525 update_loss_scale(&mut config, true, 0);
527 assert!((config.loss_scale - 1.0).abs() < f32::EPSILON); config.loss_scale = 600.0;
531 config.scale_growth_factor = 2.0;
532 config.scale_growth_interval = 10;
533 update_loss_scale(&mut config, false, 10);
534 assert!((config.loss_scale - 1000.0).abs() < f32::EPSILON); }
536
537 #[test]
538 fn test_scale_gradients() {
539 let device = Device::Cpu;
540 let grad1 = Tensor::ones((2, 3), DType::F32, &device).unwrap();
541 let grad2 = Tensor::full(2.0f32, (2, 3), &device).unwrap();
542
543 let gradients = vec![grad1, grad2];
544 let scale = 0.5;
545
546 let scaled = scale_gradients(&gradients, scale).unwrap();
547
548 let vals1: Vec<f32> = scaled[0].flatten_all().unwrap().to_vec1().unwrap();
550 for val in vals1 {
551 assert!((val - 0.5).abs() < 1e-5);
552 }
553
554 let vals2: Vec<f32> = scaled[1].flatten_all().unwrap().to_vec1().unwrap();
556 for val in vals2 {
557 assert!((val - 1.0).abs() < 1e-5);
558 }
559 }
560}