stable_diffusion_trainer/trainer/
mod.rs

1//! The Trainer module contains the training configuration and the training process.
2
3use std::{path::PathBuf, process::Command};
4
5pub mod training;
6pub mod output;
7pub mod optimizer;
8pub mod scheduler;
9pub mod parameters;
10
11pub use training::*;
12pub use output::*;
13pub use optimizer::*;
14pub use scheduler::*;
15pub use parameters::*;
16
17use crate::{environment::Environment, model_file_format::ModelFileFormat, precision::FloatPrecision};
18
19/// The Trainer structure.
20pub struct Trainer {
21    /// The environment to use for the training process.
22    pub environment: Environment,
23    /// The number of times to repeat the training images.
24    pub training_images_repeat: usize,
25    /// The number of times to repeat the regularization images.
26    pub regularization_images_repeat: usize,
27    /// The maximum resolution of the images to use for the training process.
28    pub resolution: (usize, usize),
29    /// The format to save the model as.
30    pub save_model_as: ModelFileFormat,
31    /// The module to use for the network.
32    pub network_module: String,
33    /// The learning rate for the text encoder.
34    pub text_encoder_lr: f32,
35    /// The learning rate for the unet.
36    pub unet_lr: f32,
37    /// The number of cycles for the learning rate scheduler.
38    pub lr_scheduler_num_cycles: usize,
39    /// The learning rate for the training process.
40    pub learning_rate: f32,
41    /// The number of warmup steps for the learning rate.
42    pub lr_warmup_steps: usize,
43    /// The batch size for the training process.
44    pub train_batch_size: usize,
45    /// The maximum number of training steps.
46    pub max_train_steps: usize,
47    /// The frequency to 
48    pub save_every_n_epochs: usize,
49    /// The precision to use for mixed precision training.
50    pub mixed_precision: FloatPrecision,
51    /// The precision to use for saving the model.
52    pub save_precision: FloatPrecision,
53    /// The maximum gradient norm.
54    pub max_grad_norm: f32,
55    /// The maximum number of data loader workers.
56    pub max_data_loader_n_workers: usize,
57    /// The number of steps for the bucket resolution.
58    pub bucket_reso_steps: usize,
59    /// The noise offset.
60    pub noise_offset: f32,
61}
62
63impl Default for Trainer {
64    fn default() -> Self {
65        Trainer {
66            environment: Default::default(),
67            training_images_repeat: 40,
68            regularization_images_repeat: 1,
69            resolution: (1024,1024),
70            save_model_as: ModelFileFormat::Safetensors,
71            network_module: "networks.lora".to_string(),
72            text_encoder_lr: 5e-05,
73            unet_lr: 0.0001,
74            lr_scheduler_num_cycles: 1,
75            learning_rate: 0.0001,
76            lr_warmup_steps: 48,
77            train_batch_size: 1,
78            max_train_steps: 480,
79            save_every_n_epochs: 1,
80            mixed_precision: FloatPrecision::F16,
81            save_precision: FloatPrecision::F16,
82            max_grad_norm: 1.0,
83            max_data_loader_n_workers: 0,
84            bucket_reso_steps: 64,
85            noise_offset: 0.0,
86        }
87    }
88}
89
90impl Trainer {
91    /// Create a new Trainer.
92    pub fn new() -> Self {
93        Default::default()
94    }
95
96    /// Set the environment for the training process.
97    pub fn with_environment(mut self, environment: Environment) -> Self {
98        self.environment = environment;
99        self
100    }
101
102    fn training_dir() -> PathBuf {
103        if let Some(path) = std::env::var_os("TRAINING_DIR") {
104            PathBuf::from(path)
105        } else {
106            std::env::temp_dir().join(uuid::Uuid::new_v4().to_string())
107        }
108    }
109
110    /// Start the training process.
111    pub fn start(&mut self, parameters: &Parameters) {
112        let training_dir = Self::training_dir();
113        self.prepare(parameters, &training_dir);
114        self.activate();
115        self.caption(parameters, &training_dir);
116        self.train(parameters, &training_dir);
117        self.deactivate();
118    }
119
120    fn image_dir(training_dir: &PathBuf) -> PathBuf {
121        training_dir.join("img")
122    }
123
124    fn reg_dir(training_dir: &PathBuf) -> PathBuf {
125        training_dir.join("reg")
126    }
127
128    fn subject_dir(&self, parameters: &Parameters, training_dir: &PathBuf) -> PathBuf {
129        Self::image_dir(training_dir).join(format!("{}_{} {}", self.training_images_repeat, parameters.prompt.instance, parameters.prompt.class))
130    }
131
132    fn activate(&mut self) {
133        self.environment.activate();
134    }
135
136    fn deactivate(&mut self) {
137        self.environment.deactivate();
138    }
139
140    fn prepare(&self, parameters: &Parameters, training_dir: &PathBuf) {
141        let image_dir = self.subject_dir(parameters, training_dir);
142        let class_dir = Self::reg_dir(training_dir).join(format!("{}_{}", self.regularization_images_repeat, parameters.prompt.class));
143        std::fs::create_dir_all(training_dir.join("log")).unwrap();
144        std::fs::create_dir_all(training_dir.join("model")).unwrap();
145        std::fs::create_dir_all(&image_dir).unwrap();
146        std::fs::create_dir_all(&class_dir).unwrap();
147        println!("{}", parameters.dataset.training.path().display());
148        for file in parameters.dataset.training.path().read_dir().unwrap() {
149            let file = file.unwrap().path();
150            let file_name = file.file_name().unwrap();
151            std::fs::copy(&file, image_dir.join(file_name)).unwrap();
152        }
153
154        if let Some(regularization) = &parameters.dataset.regularization {
155            for file in regularization.path().read_dir().unwrap() {
156                let file = file.unwrap().path();
157                let file_name = file.file_name().unwrap();
158                std::fs::copy(&file, class_dir.join(file_name)).unwrap();
159            }
160        }
161    }
162
163    fn caption(&self, parameters: &Parameters, training_dir: &PathBuf) {
164        let image_dir = self.subject_dir(parameters, training_dir);
165        let python_executable = self.environment.python_executable_path();
166        Command::new(python_executable)
167        .arg(self.environment.kohya_ss().join("finetune").join("make_captions.py"))
168        .args(["--batch_size", "1"])
169        .args(["--num_beams", "1"])
170        .args(["--top_p", "0.9"])
171        .args(["--max_length", "75"])
172        .args(["--min_length", "5"])
173        .arg("--beam_search")
174        .args(["--caption_extension", ".txt"])
175        .arg(&image_dir)
176        .args(["--caption_weights", "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth"])
177        .status()
178        .expect("Failed to execute command");
179        for txt in image_dir.read_dir().unwrap() {
180            let txt = txt.unwrap().path();
181            if txt.extension().unwrap() == "txt" {
182                let content = format!("{} {} {}", parameters.prompt.instance, parameters.prompt.class, std::fs::read_to_string(&txt).unwrap());
183                std::fs::write(txt, content).expect("Failed to update txt file");
184            }
185        }
186    }
187
188    fn train(&self, parameters: &Parameters, training_dir: &PathBuf) {
189        Command::new("accelerate")
190        .arg("launch")
191        .arg("--num_cpu_threads_per_process=2")
192        .arg(self.environment.kohya_ss().join("sdxl_train_network.py"))
193        .args(["--train_data_dir", &Self::image_dir(training_dir).display().to_string()])
194        .args(["--reg_data_dir", &Self::reg_dir(training_dir).display().to_string()])
195        .args(["--output_dir", &parameters.output.directory.display().to_string()])
196        .args(["--output_name", &parameters.output.name])
197        .args(["--pretrained_model_name_or_path", &parameters.training.pretrained_model])
198        .args(["--resolution", &format!("{},{}", self.resolution.0, self.resolution.1)])
199        .args(["--save_model_as", &self.save_model_as.to_string()])
200        .args(["--network_alpha", &parameters.network.alpha.to_string()])
201        .args(["--network_module", &self.network_module])
202        .args(["--network_dim", &parameters.network.dimension.to_string()])
203        .args(["--text_encoder_lr", &self.text_encoder_lr.to_string()])
204        .args(["--unet_lr", &self.unet_lr.to_string()])
205        .args(["--lr_scheduler_num_cycles", &self.lr_scheduler_num_cycles.to_string()])
206        .arg("--no_half_vae")
207        .args(["--learning_rate", &self.learning_rate.to_string()])
208        .args(["--lr_scheduler", &parameters.training.learning_rate.scheduler.to_string()])
209        // .args(["--lr_warmup_steps", &self.lr_warmup_steps.to_string()])
210        .args(["--train_batch_size", &self.train_batch_size.to_string()])
211        // .args(["--max_train_steps", &self.max_train_steps.to_string()])
212        .args(["--save_every_n_epochs", &self.save_every_n_epochs.to_string()])
213        .args(["--mixed_precision", &self.mixed_precision.to_string()])
214        .args(["--save_precision", &self.save_precision.to_string()])
215        .args(["--optimizer_type", &parameters.training.optimizer.to_string()])
216        .args(["--max_grad_norm", &self.max_grad_norm.to_string()])
217        .args(["--max_data_loader_n_workers", &self.max_data_loader_n_workers.to_string()])
218        
219        // Move it to Adafactor
220        // .args(["--optimizer_args", "scale_parameter=False", "relative_step=False", "warmup_init=False"])
221        
222        .arg("--xformers")
223        .arg("--enable_bucket")
224        .args(["--min_bucket_reso", "256"])
225        .args(["--max_bucket_reso", "2048"])
226        .args(["--bucket_reso_steps", &self.bucket_reso_steps.to_string()])
227        .arg("--bucket_no_upscale")
228        .args(["--noise_offset", &self.noise_offset.to_string()])
229        .status()
230        .expect("Failed to execute command");
231    }
232}