use std::{path::PathBuf, process::Command};
pub mod training;
pub mod output;
pub mod optimizer;
pub mod scheduler;
pub mod parameters;
pub use training::*;
pub use output::*;
pub use optimizer::*;
pub use scheduler::*;
pub use parameters::*;
use crate::{environment::Environment, model_file_format::ModelFileFormat, precision::FloatPrecision};
pub struct Trainer {
pub environment: Environment,
pub training_images_repeat: usize,
pub regularization_images_repeat: usize,
pub pretrained_model_name_or_path: String,
pub resolution: (usize, usize),
pub save_model_as: ModelFileFormat,
pub network_module: String,
pub text_encoder_lr: f32,
pub unet_lr: f32,
pub lr_scheduler_num_cycles: usize,
pub learning_rate: f32,
pub lr_warmup_steps: usize,
pub train_batch_size: usize,
pub max_train_steps: usize,
pub save_every_n_epochs: usize,
pub mixed_precision: FloatPrecision,
pub save_precision: FloatPrecision,
pub max_grad_norm: f32,
pub max_data_loader_n_workers: usize,
pub bucket_reso_steps: usize,
pub noise_offset: f32,
}
impl Default for Trainer {
fn default() -> Self {
Trainer {
environment: Default::default(),
training_images_repeat: 40,
regularization_images_repeat: 1,
pretrained_model_name_or_path: "stabilityai/stable-diffusion-xl-base-1.0".to_string(),
resolution: (1024,1024),
save_model_as: ModelFileFormat::Safetensors,
network_module: "networks.lora".to_string(),
text_encoder_lr: 5e-05,
unet_lr: 0.0001,
lr_scheduler_num_cycles: 1,
learning_rate: 0.0001,
lr_warmup_steps: 48,
train_batch_size: 1,
max_train_steps: 480,
save_every_n_epochs: 1,
mixed_precision: FloatPrecision::F16,
save_precision: FloatPrecision::F16,
max_grad_norm: 1.0,
max_data_loader_n_workers: 0,
bucket_reso_steps: 64,
noise_offset: 0.0,
}
}
}
impl Trainer {
pub fn new() -> Self {
Default::default()
}
fn training_dir() -> PathBuf {
if let Some(path) = std::env::var_os("TRAINING_DIR") {
PathBuf::from(path)
} else {
std::env::temp_dir().join(uuid::Uuid::new_v4().to_string())
}
}
pub fn start(&self, parameters: &Parameters) {
let training_dir = Self::training_dir();
self.activate();
self.prepare(parameters, &training_dir);
self.caption(parameters, &training_dir);
self.train(parameters, &training_dir);
}
fn image_dir(training_dir: &PathBuf) -> PathBuf {
training_dir.join("img")
}
fn reg_dir(training_dir: &PathBuf) -> PathBuf {
training_dir.join("reg")
}
fn subject_dir(&self, parameters: &Parameters, training_dir: &PathBuf) -> PathBuf {
Self::image_dir(training_dir).join(format!("{}_{} {}", self.training_images_repeat, parameters.prompt.instance, parameters.prompt.class))
}
fn activate(&self) {
self.environment.activate();
}
fn prepare(&self, parameters: &Parameters, training_dir: &PathBuf) {
let image_dir = self.subject_dir(parameters, training_dir);
let class_dir = Self::reg_dir(training_dir).join(format!("{}_{}", self.regularization_images_repeat, parameters.prompt.class));
std::fs::create_dir_all(training_dir.join("log")).unwrap();
std::fs::create_dir_all(training_dir.join("model")).unwrap();
std::fs::create_dir_all(&image_dir).unwrap();
std::fs::create_dir_all(&class_dir).unwrap();
for file in parameters.dataset.training.path().read_dir().unwrap() {
let file = file.unwrap().path();
let file_name = file.file_name().unwrap();
std::fs::copy(&file, image_dir.join(file_name)).unwrap();
}
if let Some(regularization) = ¶meters.dataset.regularization {
for file in regularization.path().read_dir().unwrap() {
let file = file.unwrap().path();
let file_name = file.file_name().unwrap();
std::fs::copy(&file, class_dir.join(file_name)).unwrap();
}
}
}
fn caption(&self, parameters: &Parameters, training_dir: &PathBuf) {
let image_dir = self.subject_dir(parameters, training_dir);
let python_executable = self.environment.python_executable_path();
Command::new(python_executable)
.arg(self.environment.kohya_ss().join("finetune").join("make_captions.py"))
.args(["--batch_size", "1"])
.args(["--num_beams", "1"])
.args(["--top_p", "0.9"])
.args(["--max_length", "75"])
.args(["--min_length", "5"])
.arg("--beam_search")
.args(["--caption_extension", ".txt"])
.arg(&image_dir)
.args(["--caption_weights", "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth"])
.status()
.expect("Failed to execute command");
for txt in image_dir.read_dir().unwrap() {
let txt = txt.unwrap().path();
if txt.extension().unwrap() == "txt" {
let content = format!("{} {} {}", parameters.prompt.instance, parameters.prompt.class, std::fs::read_to_string(&txt).unwrap());
std::fs::write(txt, content).expect("Failed to update txt file");
}
}
}
fn train(&self, parameters: &Parameters, training_dir: &PathBuf) {
Command::new("accelerate")
.arg("launch")
.arg("--num_cpu_threads_per_process=2")
.arg(self.environment.kohya_ss().join("sdxl_train_network.py"))
.args(["--train_data_dir", &Self::image_dir(training_dir).display().to_string()])
.args(["--reg_data_dir", &Self::reg_dir(training_dir).display().to_string()])
.args(["--output_dir", ¶meters.output.directory.display().to_string()])
.args(["--output_name", ¶meters.output.name])
.args(["--pretrained_model_name_or_path", &self.pretrained_model_name_or_path])
.args(["--resolution", &format!("{},{}", self.resolution.0, self.resolution.1)])
.args(["--save_model_as", &self.save_model_as.to_string()])
.args(["--network_alpha", ¶meters.network.alpha.to_string()])
.args(["--network_module", &self.network_module])
.args(["--network_dim", ¶meters.network.dimension.to_string()])
.args(["--text_encoder_lr", &self.text_encoder_lr.to_string()])
.args(["--unet_lr", &self.unet_lr.to_string()])
.args(["--lr_scheduler_num_cycles", &self.lr_scheduler_num_cycles.to_string()])
.arg("--no_half_vae")
.args(["--learning_rate", &self.learning_rate.to_string()])
.args(["--lr_scheduler", ¶meters.training.learning_rate.scheduler.to_string()])
.args(["--train_batch_size", &self.train_batch_size.to_string()])
.args(["--save_every_n_epochs", &self.save_every_n_epochs.to_string()])
.args(["--mixed_precision", &self.mixed_precision.to_string()])
.args(["--save_precision", &self.save_precision.to_string()])
.args(["--optimizer_type", ¶meters.training.optimizer.to_string()])
.args(["--max_grad_norm", &self.max_grad_norm.to_string()])
.args(["--max_data_loader_n_workers", &self.max_data_loader_n_workers.to_string()])
.args(["--optimizer_args", "scale_parameter=False", "relative_step=False", "warmup_init=False"])
.arg("--xformers")
.arg("--enable_bucket")
.args(["--min_bucket_reso", "256"])
.args(["--max_bucket_reso", "2048"])
.args(["--bucket_reso_steps", &self.bucket_reso_steps.to_string()])
.arg("--bucket_no_upscale")
.args(["--noise_offset", &self.noise_offset.to_string()])
.status()
.expect("Failed to execute command");
}
}