use crate::{prelude::*, LearningRate, Optimizer};
fn default_pretrained_model() -> String {
"stabilityai/stable-diffusion-xl-base-1.0".to_string()
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Training {
#[serde(default = "default_pretrained_model")]
pub pretrained_model: String,
pub optimizer: Optimizer,
pub learning_rate: LearningRate
}
impl Default for Training {
fn default() -> Self {
let optimizer = Optimizer::Adafactor;
let learning_rate = LearningRate::default();
let pretrained_model = default_pretrained_model();
Training { optimizer, learning_rate, pretrained_model }
}
}
impl Training {
pub fn new() -> Self {
Default::default()
}
pub fn with_pretrained_model(mut self, pretrained_model: impl Into<String>) -> Self {
self.pretrained_model = pretrained_model.into();
self
}
pub fn with_optimizer(mut self, optimizer: Optimizer) -> Self {
self.optimizer = optimizer;
self
}
pub fn with_learning_rate(mut self, learning_rate: LearningRate) -> Self {
self.learning_rate = learning_rate;
self
}
}