stable_diffusion_trainer/trainer/scheduler/
mod.rs

1//! Learning rate scheduler module.
2
3use crate::prelude::*;
4use std::fmt::Display;
5
6/// The learning rate scheduler structure.
7#[derive(Debug, Serialize, Deserialize)]
8pub struct LearningRate {
9    /// The amount of the learning rate.
10    #[serde(default = "default_amount")]
11    pub amount: f32,
12    /// The learning rate scheduler.
13    pub scheduler: LearningRateScheduler
14}
15
16fn default_amount() -> f32 { 0.001 }
17
18impl Default for LearningRate {
19    fn default() -> Self {
20        let amount = default_amount();
21        let scheduler = LearningRateScheduler::Constant;
22        LearningRate { amount, scheduler }
23    }
24}
25
26impl LearningRate {
27    /// Create a new learning rate structure.
28    pub fn new() -> Self {
29        Default::default()
30    }
31
32    /// Set the amount of the learning rate.
33    pub fn with_amount(mut self, amount: f32) -> Self {
34        self.amount = amount;
35        self
36    }
37
38    /// Set the learning rate scheduler.
39    pub fn with_scheduler(mut self, scheduler: LearningRateScheduler) -> Self {
40        self.scheduler = scheduler;
41        self
42    }
43}
44
45/// The learning rate scheduler enumeration.
46#[derive(Debug, Serialize, Deserialize)]
47pub enum LearningRateScheduler {
48    /// Adafactor learning rate scheduler.
49    Adafactor,
50    /// Constant learning rate scheduler.
51    Constant,
52    /// Constant with warmup learning rate scheduler.
53    ConstantWithWarmup,
54    /// Cosine learning rate scheduler.
55    Cosine,
56    /// Cosine with restarts learning rate scheduler.
57    CosineWithRestarts,
58    /// Linear learning rate scheduler.
59    Linear,
60    /// Polynomial learning rate scheduler.
61    Polynomial
62}
63
64impl Display for LearningRateScheduler {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        match self {
67            LearningRateScheduler::Adafactor => write!(f, "adafactor"),
68            LearningRateScheduler::Constant => write!(f, "constant"),
69            LearningRateScheduler::ConstantWithWarmup => write!(f, "constant_with_warmup"),
70            LearningRateScheduler::Cosine => write!(f, "cosine"),
71            LearningRateScheduler::CosineWithRestarts => write!(f, "cosine_with_restarts"),
72            LearningRateScheduler::Linear => write!(f, "linear"),
73            LearningRateScheduler::Polynomial => write!(f, "polynomial")
74        }
75    }
76}