use crate::prelude::*;
use std::fmt::Display;
#[derive(Debug, Serialize, Deserialize)]
pub struct LearningRate {
#[serde(default = "default_amount")]
pub amount: f32,
pub scheduler: LearningRateScheduler
}
fn default_amount() -> f32 { 0.001 }
impl Default for LearningRate {
fn default() -> Self {
let amount = default_amount();
let scheduler = LearningRateScheduler::Constant;
LearningRate { amount, scheduler }
}
}
impl LearningRate {
pub fn new() -> Self {
Default::default()
}
pub fn with_amount(mut self, amount: f32) -> Self {
self.amount = amount;
self
}
pub fn with_scheduler(mut self, scheduler: LearningRateScheduler) -> Self {
self.scheduler = scheduler;
self
}
}
#[derive(Debug, Serialize, Deserialize)]
pub enum LearningRateScheduler {
Adafactor,
Constant,
ConstantWithWarmup,
Cosine,
CosineWithRestarts,
Linear,
Polynomial
}
impl Display for LearningRateScheduler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LearningRateScheduler::Adafactor => write!(f, "adafactor"),
LearningRateScheduler::Constant => write!(f, "constant"),
LearningRateScheduler::ConstantWithWarmup => write!(f, "constant_with_warmup"),
LearningRateScheduler::Cosine => write!(f, "cosine"),
LearningRateScheduler::CosineWithRestarts => write!(f, "cosine_with_restarts"),
LearningRateScheduler::Linear => write!(f, "linear"),
LearningRateScheduler::Polynomial => write!(f, "polynomial")
}
}
}