stable_diffusion_trainer/trainer/optimizer/
mod.rs

1//! Optimizer module for the trainer.
2
3use crate::prelude::*;
4use std::fmt::Display;
5
6/// The optimizer to use for the training process.
7#[derive(Debug, Serialize, Deserialize)]
8pub enum Optimizer {
9    /// AdamW optimizer.
10    AdamW,
11    /// AdamW 8-bit optimizer.
12    AdamW8bit,
13    /// Adafactor optimizer.
14    Adafactor,
15    /// DAdaptation optimizer.
16    DAdaptation,
17    /// DAdaptationGrad optimizer.
18    DAdaptationGrad,
19    /// DAdaptAdam optimizer.
20    DAdaptAdam,
21    /// DAdaptAdan optimizer.
22    DAdaptAdan,
23    /// DAdaptAdamIP optimizer.
24    DAdaptAdamIP,
25    /// DAdaptAdamReprint optimizer.
26    DAdaptAdamReprint,
27    /// DAdaptLion optimizer.
28    DAdaptLion,
29    /// DAdaptSGD optimizer.
30    DAdaptSGD,
31    /// Lion optimizer.
32    Lion,
33    /// Lion 8-bit optimizer.
34    Lion8bit,
35    /// PagedAdamW 8-bit optimizer.
36    PagedAdamW8bit,
37    /// PagedAdamW 32-bit optimizer.
38    PagedAdamW32bit,
39    /// PagedLion 8-bit optimizer.
40    PagedLion8bit,
41    /// Prodigy optimizer.
42    Prodigy,
43    /// SGDNesterov optimizer.
44    SGDNesterov,
45    /// SGDNesterov 8-bit optimizer.
46    SGDNesterov8bit
47}
48
49impl Display for Optimizer {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        match self {
52            Optimizer::AdamW => write!(f, "AdamW"),
53            Optimizer::AdamW8bit => write!(f, "AdamW8bit"),
54            Optimizer::Adafactor => write!(f, "Adafactor"),
55            Optimizer::DAdaptation => write!(f, "DAdaptation"),
56            Optimizer::DAdaptationGrad => write!(f, "DAdaptationGrad"),
57            Optimizer::DAdaptAdam => write!(f, "DAdaptAdam"),
58            Optimizer::DAdaptAdan => write!(f, "DAdaptAdan"),
59            Optimizer::DAdaptAdamIP => write!(f, "DAdaptAdamIP"),
60            Optimizer::DAdaptAdamReprint => write!(f, "DAdaptAdamReprint"),
61            Optimizer::DAdaptLion => write!(f, "DAdaptLion"),
62            Optimizer::DAdaptSGD => write!(f, "DAdaptSGD"),
63            Optimizer::Lion => write!(f, "Lion"),
64            Optimizer::Lion8bit => write!(f, "Lion8bit"),
65            Optimizer::PagedAdamW8bit => write!(f, "PagedAdamW8bit"),
66            Optimizer::PagedAdamW32bit => write!(f, "PagedAdamW32bit"),
67            Optimizer::PagedLion8bit => write!(f, "PagedLion8bit"),
68            Optimizer::Prodigy => write!(f, "Prodigy"),
69            Optimizer::SGDNesterov => write!(f, "SGDNesterov"),
70            Optimizer::SGDNesterov8bit => write!(f, "SGDNesterov8bit")
71        }
72    }
73}