stable_diffusion_trainer/trainer/parameters/
mod.rs1use crate::prelude::*;
4use crate::{Network, Output, Prompt, Training, TrainingDataSet};
5
6#[derive(Debug, Serialize, Deserialize)]
8pub struct Parameters {
9 pub prompt: Prompt,
11 pub dataset: TrainingDataSet,
13 pub output: Output,
15 pub network: Network,
17 pub training: Training
19}
20
21impl Parameters {
22 pub fn from_file(path: impl Into<std::path::PathBuf>) -> anyhow::Result<Self> {
24 use path_slash::*;
25
26 let path = path.into().canonicalize()?;
27 let parent = path.parent().unwrap();
28 let file = std::fs::File::open(path.clone())?;
29 let reader = std::io::BufReader::new(file);
30 let mut parameters: Parameters = serde_json::from_reader(reader)?;
31
32 if parameters.dataset.training.path().is_relative() {
34 let path = std::path::PathBuf::from_slash(parent.to_slash().unwrap()).join(parameters.dataset.training.path());
35 parameters.dataset.training.set_path(path);
36 }
37 if parameters.dataset.regularization.is_some() && parameters.dataset.regularization.as_ref().unwrap().path().is_relative() {
38 let path = std::path::PathBuf::from_slash(parent.to_slash().unwrap()).join(parameters.dataset.regularization.as_ref().unwrap().path());
39 parameters.dataset.regularization.as_mut().unwrap().set_path(path);
40 }
41 if parameters.output.directory.is_relative() {
42 let path = std::path::PathBuf::from_slash(parent.to_slash().unwrap()).join(parameters.output.directory);
43 parameters.output.directory = path;
44 }
45 parameters.output.name = parameters.output.name
46 .replace("{network.dimension}", ¶meters.network.dimension.to_string())
47 .replace("{network.alpha}", ¶meters.network.alpha.to_string())
48 .replace("{prompt.instance}", ¶meters.prompt.instance)
49 .replace("{prompt.class}", ¶meters.prompt.class);
50
51 Ok(parameters)
52 }
53
54 pub fn new(prompt: Prompt, dataset: TrainingDataSet, output: Output) -> Self {
56 let network = Default::default();
57 let training = Default::default();
58 Parameters { prompt, dataset, output, network, training }
59 }
60
61 pub fn with_network(mut self, network: Network) -> Self {
63 self.network = network;
64 self
65 }
66
67 pub fn with_training(mut self, training: Training) -> Self {
69 self.training = training;
70 self
71 }
72}