stable_diffusion_trainer/trainer/parameters/
mod.rs

1//! Trainer's parameters.
2
3use crate::prelude::*;
4use crate::{Network, Output, Prompt, Training, TrainingDataSet};
5
6/// The parameters structure.
7#[derive(Debug, Serialize, Deserialize)]
8pub struct Parameters {
9    /// The prompt to use for the training process.
10    pub prompt: Prompt,
11    /// The dataset to use for the training process.
12    pub dataset: TrainingDataSet,
13    /// The output to use for the training process.
14    pub output: Output,
15    /// The network to use for the training process.
16    pub network: Network,
17    /// The training to use for the training process.
18    pub training: Training
19}
20
21impl Parameters {
22    /// Get the parameters from a file.
23    pub fn from_file(path: impl Into<std::path::PathBuf>) -> anyhow::Result<Self> {
24        use path_slash::*;
25
26        let path = path.into().canonicalize()?;
27        let parent = path.parent().unwrap();
28        let file = std::fs::File::open(path.clone())?;
29        let reader = std::io::BufReader::new(file);
30        let mut parameters: Parameters = serde_json::from_reader(reader)?;
31
32        // TODO: Simplify this. Wrap it in a function.
33        if parameters.dataset.training.path().is_relative() {
34            let path = std::path::PathBuf::from_slash(parent.to_slash().unwrap()).join(parameters.dataset.training.path());
35            parameters.dataset.training.set_path(path);
36        }
37        if parameters.dataset.regularization.is_some() && parameters.dataset.regularization.as_ref().unwrap().path().is_relative() {
38            let path = std::path::PathBuf::from_slash(parent.to_slash().unwrap()).join(parameters.dataset.regularization.as_ref().unwrap().path());
39            parameters.dataset.regularization.as_mut().unwrap().set_path(path);
40        }
41        if parameters.output.directory.is_relative() {
42            let path = std::path::PathBuf::from_slash(parent.to_slash().unwrap()).join(parameters.output.directory);
43            parameters.output.directory = path;
44        }
45        parameters.output.name = parameters.output.name
46            .replace("{network.dimension}", &parameters.network.dimension.to_string())
47            .replace("{network.alpha}", &parameters.network.alpha.to_string())
48            .replace("{prompt.instance}", &parameters.prompt.instance)
49            .replace("{prompt.class}", &parameters.prompt.class);
50
51        Ok(parameters)
52    }
53
54    /// Create a new parameters structure.
55    pub fn new(prompt: Prompt, dataset: TrainingDataSet, output: Output) -> Self {
56        let network = Default::default();
57        let training = Default::default();
58        Parameters { prompt, dataset, output, network, training }
59    }
60
61    /// Set the network configuration to use for the training process.
62    pub fn with_network(mut self, network: Network) -> Self {
63        self.network = network;
64        self
65    }
66
67    /// Set the training configuration to use for the training process.
68    pub fn with_training(mut self, training: Training) -> Self {
69        self.training = training;
70        self
71    }
72}