stable_diffusion_trainer/trainer/training/
mod.rs1use crate::{prelude::*, LearningRate, Optimizer};
4
5fn default_pretrained_model() -> String {
6 "stabilityai/stable-diffusion-xl-base-1.0".to_string()
7}
8
9#[derive(Debug, Serialize, Deserialize)]
11pub struct Training {
12 #[serde(default = "default_pretrained_model")]
14 pub pretrained_model: String,
15 pub optimizer: Optimizer,
17 pub learning_rate: LearningRate
19}
20
21impl Default for Training {
22 fn default() -> Self {
23 let optimizer = Optimizer::Adafactor;
24 let learning_rate = LearningRate::default();
25 let pretrained_model = default_pretrained_model();
26 Training { optimizer, learning_rate, pretrained_model }
27 }
28}
29
30impl Training {
31 pub fn new() -> Self {
33 Default::default()
34 }
35
36 pub fn with_pretrained_model(mut self, pretrained_model: impl Into<String>) -> Self {
38 self.pretrained_model = pretrained_model.into();
39 self
40 }
41
42 pub fn with_optimizer(mut self, optimizer: Optimizer) -> Self {
44 self.optimizer = optimizer;
45 self
46 }
47
48 pub fn with_learning_rate(mut self, learning_rate: LearningRate) -> Self {
50 self.learning_rate = learning_rate;
51 self
52 }
53}