stable_diffusion_trainer/trainer/scheduler/
mod.rs1use crate::prelude::*;
4use std::fmt::Display;
5
6#[derive(Debug, Serialize, Deserialize)]
8pub struct LearningRate {
9 #[serde(default = "default_amount")]
11 pub amount: f32,
12 pub scheduler: LearningRateScheduler
14}
15
16fn default_amount() -> f32 { 0.001 }
17
18impl Default for LearningRate {
19 fn default() -> Self {
20 let amount = default_amount();
21 let scheduler = LearningRateScheduler::Constant;
22 LearningRate { amount, scheduler }
23 }
24}
25
26impl LearningRate {
27 pub fn new() -> Self {
29 Default::default()
30 }
31
32 pub fn with_amount(mut self, amount: f32) -> Self {
34 self.amount = amount;
35 self
36 }
37
38 pub fn with_scheduler(mut self, scheduler: LearningRateScheduler) -> Self {
40 self.scheduler = scheduler;
41 self
42 }
43}
44
45#[derive(Debug, Serialize, Deserialize)]
47pub enum LearningRateScheduler {
48 Adafactor,
50 Constant,
52 ConstantWithWarmup,
54 Cosine,
56 CosineWithRestarts,
58 Linear,
60 Polynomial
62}
63
64impl Display for LearningRateScheduler {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 match self {
67 LearningRateScheduler::Adafactor => write!(f, "adafactor"),
68 LearningRateScheduler::Constant => write!(f, "constant"),
69 LearningRateScheduler::ConstantWithWarmup => write!(f, "constant_with_warmup"),
70 LearningRateScheduler::Cosine => write!(f, "cosine"),
71 LearningRateScheduler::CosineWithRestarts => write!(f, "cosine_with_restarts"),
72 LearningRateScheduler::Linear => write!(f, "linear"),
73 LearningRateScheduler::Polynomial => write!(f, "polynomial")
74 }
75 }
76}