voirs_cli/commands/train/
mod.rs1pub mod acoustic;
10pub mod data_loader;
11pub mod g2p;
12pub mod progress;
13pub mod vocoder;
14
15use crate::GlobalOptions;
16use clap::Subcommand;
17use std::path::PathBuf;
18use voirs_sdk::Result;
19
20#[derive(Debug, Clone, Subcommand)]
22pub enum TrainCommands {
23 Vocoder {
25 #[arg(long, default_value = "diffwave")]
27 model_type: String,
28
29 #[arg(long)]
31 data: PathBuf,
32
33 #[arg(short, long, default_value = "checkpoints/vocoder")]
35 output: PathBuf,
36
37 #[arg(short, long)]
39 config: Option<PathBuf>,
40
41 #[arg(long, default_value = "1000")]
43 epochs: usize,
44
45 #[arg(long, default_value = "16")]
47 batch_size: usize,
48
49 #[arg(long, default_value = "0.0002")]
51 lr: f64,
52
53 #[arg(long, default_value = "none")]
55 lr_scheduler: String,
56
57 #[arg(long, default_value = "100")]
59 lr_step_size: usize,
60
61 #[arg(long, default_value = "0.1")]
63 lr_gamma: f64,
64
65 #[arg(long)]
67 early_stopping: bool,
68
69 #[arg(long, default_value = "50")]
71 patience: usize,
72
73 #[arg(long, default_value = "0.0001")]
75 min_delta: f64,
76
77 #[arg(long, default_value = "5")]
79 val_frequency: usize,
80
81 #[arg(long, default_value = "0")]
83 warmup_steps: usize,
84
85 #[arg(long, default_value = "1.0")]
87 grad_clip: f64,
88
89 #[arg(long, default_value = "10")]
91 save_frequency: usize,
92
93 #[arg(long)]
95 resume: Option<PathBuf>,
96
97 #[arg(long)]
99 gpu: bool,
100 },
101
102 Acoustic {
104 #[arg(long, default_value = "vits")]
106 model_type: String,
107
108 #[arg(long)]
110 data: PathBuf,
111
112 #[arg(short, long, default_value = "checkpoints/acoustic")]
114 output: PathBuf,
115
116 #[arg(short, long)]
118 config: Option<PathBuf>,
119
120 #[arg(long, default_value = "500")]
122 epochs: usize,
123
124 #[arg(long, default_value = "32")]
126 batch_size: usize,
127
128 #[arg(long, default_value = "0.0001")]
130 lr: f64,
131
132 #[arg(long)]
134 resume: Option<PathBuf>,
135
136 #[arg(long)]
138 gpu: bool,
139 },
140
141 G2p {
143 #[arg(long, default_value = "en")]
145 language: String,
146
147 #[arg(long)]
149 dictionary: PathBuf,
150
151 #[arg(short, long, default_value = "models/g2p.safetensors")]
153 output: PathBuf,
154
155 #[arg(short, long)]
157 config: Option<PathBuf>,
158
159 #[arg(long, default_value = "100")]
161 epochs: usize,
162
163 #[arg(long, default_value = "0.001")]
165 lr: f64,
166 },
167}
168
169#[derive(Debug, Clone)]
171pub struct TrainingConfig {
172 pub lr_scheduler: String,
173 pub lr_step_size: usize,
174 pub lr_gamma: f64,
175 pub early_stopping: bool,
176 pub patience: usize,
177 pub min_delta: f64,
178 pub val_frequency: usize,
179 pub warmup_steps: usize,
180 pub grad_clip: f64,
181 pub save_frequency: usize,
182}
183
184impl Default for TrainingConfig {
185 fn default() -> Self {
186 Self {
187 lr_scheduler: "none".to_string(),
188 lr_step_size: 100,
189 lr_gamma: 0.1,
190 early_stopping: false,
191 patience: 50,
192 min_delta: 0.0001,
193 val_frequency: 5,
194 warmup_steps: 0,
195 grad_clip: 1.0,
196 save_frequency: 10,
197 }
198 }
199}
200
201pub async fn execute_train_command(command: TrainCommands, global: &GlobalOptions) -> Result<()> {
203 match command {
204 TrainCommands::Vocoder {
205 model_type,
206 data,
207 output,
208 config,
209 epochs,
210 batch_size,
211 lr,
212 lr_scheduler,
213 lr_step_size,
214 lr_gamma,
215 early_stopping,
216 patience,
217 min_delta,
218 val_frequency,
219 warmup_steps,
220 grad_clip,
221 save_frequency,
222 resume,
223 gpu,
224 } => {
225 let training_config = TrainingConfig {
226 lr_scheduler,
227 lr_step_size,
228 lr_gamma,
229 early_stopping,
230 patience,
231 min_delta,
232 val_frequency,
233 warmup_steps,
234 grad_clip,
235 save_frequency,
236 };
237
238 let args = vocoder::VocoderTrainingArgs {
239 model_type,
240 data,
241 output,
242 config,
243 epochs,
244 batch_size,
245 lr,
246 resume,
247 use_gpu: gpu || global.gpu,
248 training_config,
249 };
250
251 vocoder::run_train_vocoder(args, global).await
252 }
253 TrainCommands::Acoustic {
254 model_type,
255 data,
256 output,
257 config,
258 epochs,
259 batch_size,
260 lr,
261 resume,
262 gpu,
263 } => {
264 let args = acoustic::AcousticModelTrainingArgs {
265 model_type: model_type.clone(),
266 data: data.clone(),
267 output: output.clone(),
268 config: config.clone(),
269 epochs,
270 batch_size,
271 lr,
272 resume: resume.clone(),
273 use_gpu: gpu || global.gpu,
274 };
275 acoustic::run_train_acoustic(args, global).await
276 }
277 TrainCommands::G2p {
278 language,
279 dictionary,
280 output,
281 config,
282 epochs,
283 lr,
284 } => g2p::run_train_g2p(language, dictionary, output, config, epochs, lr, global).await,
285 }
286}