stable_diffusion_trainer/trainer/training/
mod.rs

1//! The training configuration for the training process.
2
3use crate::{prelude::*, LearningRate, Optimizer};
4
5fn default_pretrained_model() -> String {
6    "stabilityai/stable-diffusion-xl-base-1.0".to_string()
7}
8
9/// The training configuration for the training process.
10#[derive(Debug, Serialize, Deserialize)]
11pub struct Training {
12    /// The name or path of the pretrained model to use for the training process.
13    #[serde(default = "default_pretrained_model")]
14    pub pretrained_model: String,
15    /// The optimizer to use for the training process.
16    pub optimizer: Optimizer,
17    /// The learning rate to use for the training process.
18    pub learning_rate: LearningRate
19}
20
21impl Default for Training {
22    fn default() -> Self {
23        let optimizer = Optimizer::Adafactor;
24        let learning_rate = LearningRate::default();
25        let pretrained_model = default_pretrained_model();
26        Training { optimizer, learning_rate, pretrained_model }
27    }
28}
29
30impl Training {
31    /// Create a new training configuration.
32    pub fn new() -> Self {
33        Default::default()
34    }
35
36    /// Set the pretrained model for the training process.
37    pub fn with_pretrained_model(mut self, pretrained_model: impl Into<String>) -> Self {
38        self.pretrained_model = pretrained_model.into();
39        self
40    }
41
42    /// Set the optimizer for the training process.
43    pub fn with_optimizer(mut self, optimizer: Optimizer) -> Self {
44        self.optimizer = optimizer;
45        self
46    }
47
48    /// Set the learning rate for the training process.
49    pub fn with_learning_rate(mut self, learning_rate: LearningRate) -> Self {
50        self.learning_rate = learning_rate;
51        self
52    }
53}