1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
//! Learning rate scheduler module.

use crate::prelude::*;
use std::fmt::Display;

/// The learning rate scheduler structure.
#[derive(Debug, Serialize, Deserialize)]
pub struct LearningRate {
    /// The amount of the learning rate.
    #[serde(default = "default_amount")]
    pub amount: f32,
    /// The learning rate scheduler.
    pub scheduler: LearningRateScheduler
}

fn default_amount() -> f32 { 0.001 }

impl Default for LearningRate {
    fn default() -> Self {
        let amount = default_amount();
        let scheduler = LearningRateScheduler::Constant;
        LearningRate { amount, scheduler }
    }
}

impl LearningRate {
    /// Create a new learning rate structure.
    pub fn new() -> Self {
        Default::default()
    }

    /// Set the amount of the learning rate.
    pub fn with_amount(mut self, amount: f32) -> Self {
        self.amount = amount;
        self
    }

    /// Set the learning rate scheduler.
    pub fn with_scheduler(mut self, scheduler: LearningRateScheduler) -> Self {
        self.scheduler = scheduler;
        self
    }
}

/// The learning rate scheduler enumeration.
#[derive(Debug, Serialize, Deserialize)]
pub enum LearningRateScheduler {
    /// Adafactor learning rate scheduler.
    Adafactor,
    /// Constant learning rate scheduler.
    Constant,
    /// Constant with warmup learning rate scheduler.
    ConstantWithWarmup,
    /// Cosine learning rate scheduler.
    Cosine,
    /// Cosine with restarts learning rate scheduler.
    CosineWithRestarts,
    /// Linear learning rate scheduler.
    Linear,
    /// Polynomial learning rate scheduler.
    Polynomial
}

impl Display for LearningRateScheduler {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            LearningRateScheduler::Adafactor => write!(f, "adafactor"),
            LearningRateScheduler::Constant => write!(f, "constant"),
            LearningRateScheduler::ConstantWithWarmup => write!(f, "constant_with_warmup"),
            LearningRateScheduler::Cosine => write!(f, "cosine"),
            LearningRateScheduler::CosineWithRestarts => write!(f, "cosine_with_restarts"),
            LearningRateScheduler::Linear => write!(f, "linear"),
            LearningRateScheduler::Polynomial => write!(f, "polynomial")
        }
    }
}