stable_diffusion_trainer/trainer/
mod.rs1use std::{path::PathBuf, process::Command};
4
5pub mod training;
6pub mod output;
7pub mod optimizer;
8pub mod scheduler;
9pub mod parameters;
10
11pub use training::*;
12pub use output::*;
13pub use optimizer::*;
14pub use scheduler::*;
15pub use parameters::*;
16
17use crate::{environment::Environment, model_file_format::ModelFileFormat, precision::FloatPrecision};
18
19pub struct Trainer {
21 pub environment: Environment,
23 pub training_images_repeat: usize,
25 pub regularization_images_repeat: usize,
27 pub resolution: (usize, usize),
29 pub save_model_as: ModelFileFormat,
31 pub network_module: String,
33 pub text_encoder_lr: f32,
35 pub unet_lr: f32,
37 pub lr_scheduler_num_cycles: usize,
39 pub learning_rate: f32,
41 pub lr_warmup_steps: usize,
43 pub train_batch_size: usize,
45 pub max_train_steps: usize,
47 pub save_every_n_epochs: usize,
49 pub mixed_precision: FloatPrecision,
51 pub save_precision: FloatPrecision,
53 pub max_grad_norm: f32,
55 pub max_data_loader_n_workers: usize,
57 pub bucket_reso_steps: usize,
59 pub noise_offset: f32,
61}
62
63impl Default for Trainer {
64 fn default() -> Self {
65 Trainer {
66 environment: Default::default(),
67 training_images_repeat: 40,
68 regularization_images_repeat: 1,
69 resolution: (1024,1024),
70 save_model_as: ModelFileFormat::Safetensors,
71 network_module: "networks.lora".to_string(),
72 text_encoder_lr: 5e-05,
73 unet_lr: 0.0001,
74 lr_scheduler_num_cycles: 1,
75 learning_rate: 0.0001,
76 lr_warmup_steps: 48,
77 train_batch_size: 1,
78 max_train_steps: 480,
79 save_every_n_epochs: 1,
80 mixed_precision: FloatPrecision::F16,
81 save_precision: FloatPrecision::F16,
82 max_grad_norm: 1.0,
83 max_data_loader_n_workers: 0,
84 bucket_reso_steps: 64,
85 noise_offset: 0.0,
86 }
87 }
88}
89
90impl Trainer {
91 pub fn new() -> Self {
93 Default::default()
94 }
95
96 pub fn with_environment(mut self, environment: Environment) -> Self {
98 self.environment = environment;
99 self
100 }
101
102 fn training_dir() -> PathBuf {
103 if let Some(path) = std::env::var_os("TRAINING_DIR") {
104 PathBuf::from(path)
105 } else {
106 std::env::temp_dir().join(uuid::Uuid::new_v4().to_string())
107 }
108 }
109
110 pub fn start(&mut self, parameters: &Parameters) {
112 let training_dir = Self::training_dir();
113 self.prepare(parameters, &training_dir);
114 self.activate();
115 self.caption(parameters, &training_dir);
116 self.train(parameters, &training_dir);
117 self.deactivate();
118 }
119
120 fn image_dir(training_dir: &PathBuf) -> PathBuf {
121 training_dir.join("img")
122 }
123
124 fn reg_dir(training_dir: &PathBuf) -> PathBuf {
125 training_dir.join("reg")
126 }
127
128 fn subject_dir(&self, parameters: &Parameters, training_dir: &PathBuf) -> PathBuf {
129 Self::image_dir(training_dir).join(format!("{}_{} {}", self.training_images_repeat, parameters.prompt.instance, parameters.prompt.class))
130 }
131
132 fn activate(&mut self) {
133 self.environment.activate();
134 }
135
136 fn deactivate(&mut self) {
137 self.environment.deactivate();
138 }
139
140 fn prepare(&self, parameters: &Parameters, training_dir: &PathBuf) {
141 let image_dir = self.subject_dir(parameters, training_dir);
142 let class_dir = Self::reg_dir(training_dir).join(format!("{}_{}", self.regularization_images_repeat, parameters.prompt.class));
143 std::fs::create_dir_all(training_dir.join("log")).unwrap();
144 std::fs::create_dir_all(training_dir.join("model")).unwrap();
145 std::fs::create_dir_all(&image_dir).unwrap();
146 std::fs::create_dir_all(&class_dir).unwrap();
147 println!("{}", parameters.dataset.training.path().display());
148 for file in parameters.dataset.training.path().read_dir().unwrap() {
149 let file = file.unwrap().path();
150 let file_name = file.file_name().unwrap();
151 std::fs::copy(&file, image_dir.join(file_name)).unwrap();
152 }
153
154 if let Some(regularization) = ¶meters.dataset.regularization {
155 for file in regularization.path().read_dir().unwrap() {
156 let file = file.unwrap().path();
157 let file_name = file.file_name().unwrap();
158 std::fs::copy(&file, class_dir.join(file_name)).unwrap();
159 }
160 }
161 }
162
163 fn caption(&self, parameters: &Parameters, training_dir: &PathBuf) {
164 let image_dir = self.subject_dir(parameters, training_dir);
165 let python_executable = self.environment.python_executable_path();
166 Command::new(python_executable)
167 .arg(self.environment.kohya_ss().join("finetune").join("make_captions.py"))
168 .args(["--batch_size", "1"])
169 .args(["--num_beams", "1"])
170 .args(["--top_p", "0.9"])
171 .args(["--max_length", "75"])
172 .args(["--min_length", "5"])
173 .arg("--beam_search")
174 .args(["--caption_extension", ".txt"])
175 .arg(&image_dir)
176 .args(["--caption_weights", "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth"])
177 .status()
178 .expect("Failed to execute command");
179 for txt in image_dir.read_dir().unwrap() {
180 let txt = txt.unwrap().path();
181 if txt.extension().unwrap() == "txt" {
182 let content = format!("{} {} {}", parameters.prompt.instance, parameters.prompt.class, std::fs::read_to_string(&txt).unwrap());
183 std::fs::write(txt, content).expect("Failed to update txt file");
184 }
185 }
186 }
187
188 fn train(&self, parameters: &Parameters, training_dir: &PathBuf) {
189 Command::new("accelerate")
190 .arg("launch")
191 .arg("--num_cpu_threads_per_process=2")
192 .arg(self.environment.kohya_ss().join("sdxl_train_network.py"))
193 .args(["--train_data_dir", &Self::image_dir(training_dir).display().to_string()])
194 .args(["--reg_data_dir", &Self::reg_dir(training_dir).display().to_string()])
195 .args(["--output_dir", ¶meters.output.directory.display().to_string()])
196 .args(["--output_name", ¶meters.output.name])
197 .args(["--pretrained_model_name_or_path", ¶meters.training.pretrained_model])
198 .args(["--resolution", &format!("{},{}", self.resolution.0, self.resolution.1)])
199 .args(["--save_model_as", &self.save_model_as.to_string()])
200 .args(["--network_alpha", ¶meters.network.alpha.to_string()])
201 .args(["--network_module", &self.network_module])
202 .args(["--network_dim", ¶meters.network.dimension.to_string()])
203 .args(["--text_encoder_lr", &self.text_encoder_lr.to_string()])
204 .args(["--unet_lr", &self.unet_lr.to_string()])
205 .args(["--lr_scheduler_num_cycles", &self.lr_scheduler_num_cycles.to_string()])
206 .arg("--no_half_vae")
207 .args(["--learning_rate", &self.learning_rate.to_string()])
208 .args(["--lr_scheduler", ¶meters.training.learning_rate.scheduler.to_string()])
209 .args(["--train_batch_size", &self.train_batch_size.to_string()])
211 .args(["--save_every_n_epochs", &self.save_every_n_epochs.to_string()])
213 .args(["--mixed_precision", &self.mixed_precision.to_string()])
214 .args(["--save_precision", &self.save_precision.to_string()])
215 .args(["--optimizer_type", ¶meters.training.optimizer.to_string()])
216 .args(["--max_grad_norm", &self.max_grad_norm.to_string()])
217 .args(["--max_data_loader_n_workers", &self.max_data_loader_n_workers.to_string()])
218
219 .arg("--xformers")
223 .arg("--enable_bucket")
224 .args(["--min_bucket_reso", "256"])
225 .args(["--max_bucket_reso", "2048"])
226 .args(["--bucket_reso_steps", &self.bucket_reso_steps.to_string()])
227 .arg("--bucket_no_upscale")
228 .args(["--noise_offset", &self.noise_offset.to_string()])
229 .status()
230 .expect("Failed to execute command");
231 }
232}