Skip to main content

voirs_cli/commands/train/
mod.rs

1//! Training command implementations
2//!
3//! This module provides CLI commands for training VoiRS models:
4//! - Vocoder training (HiFi-GAN, DiffWave)
5//! - Acoustic model training (VITS, FastSpeech2)
6//! - G2P model training
7//! - Training progress monitoring and visualization
8
9pub mod acoustic;
10pub mod data_loader;
11pub mod g2p;
12pub mod progress;
13pub mod vocoder;
14
15use crate::GlobalOptions;
16use clap::Subcommand;
17use std::path::PathBuf;
18use voirs_sdk::Result;
19
20/// Training subcommands
21#[derive(Debug, Clone, Subcommand)]
22pub enum TrainCommands {
23    /// Train vocoder model (HiFi-GAN, DiffWave)
24    Vocoder {
25        /// Model type (hifigan, diffwave)
26        #[arg(long, default_value = "diffwave")]
27        model_type: String,
28
29        /// Training data directory
30        #[arg(long)]
31        data: PathBuf,
32
33        /// Output directory for checkpoints
34        #[arg(short, long, default_value = "checkpoints/vocoder")]
35        output: PathBuf,
36
37        /// Training config file (TOML/JSON)
38        #[arg(short, long)]
39        config: Option<PathBuf>,
40
41        /// Number of epochs
42        #[arg(long, default_value = "1000")]
43        epochs: usize,
44
45        /// Batch size
46        #[arg(long, default_value = "16")]
47        batch_size: usize,
48
49        /// Learning rate
50        #[arg(long, default_value = "0.0002")]
51        lr: f64,
52
53        /// Learning rate scheduler (none, step, cosine, exponential, onecycle)
54        #[arg(long, default_value = "none")]
55        lr_scheduler: String,
56
57        /// LR scheduler step size (for step scheduler)
58        #[arg(long, default_value = "100")]
59        lr_step_size: usize,
60
61        /// LR scheduler gamma (decay factor)
62        #[arg(long, default_value = "0.1")]
63        lr_gamma: f64,
64
65        /// Enable early stopping
66        #[arg(long)]
67        early_stopping: bool,
68
69        /// Early stopping patience (epochs)
70        #[arg(long, default_value = "50")]
71        patience: usize,
72
73        /// Early stopping minimum delta
74        #[arg(long, default_value = "0.0001")]
75        min_delta: f64,
76
77        /// Validation frequency (epochs)
78        #[arg(long, default_value = "5")]
79        val_frequency: usize,
80
81        /// Warmup steps
82        #[arg(long, default_value = "0")]
83        warmup_steps: usize,
84
85        /// Gradient clipping value (0 = disabled)
86        #[arg(long, default_value = "1.0")]
87        grad_clip: f64,
88
89        /// Save checkpoint every N epochs
90        #[arg(long, default_value = "10")]
91        save_frequency: usize,
92
93        /// Resume from checkpoint
94        #[arg(long)]
95        resume: Option<PathBuf>,
96
97        /// Use GPU if available
98        #[arg(long)]
99        gpu: bool,
100    },
101
102    /// Train acoustic model (VITS, FastSpeech2)
103    Acoustic {
104        /// Model type (vits, fastspeech2)
105        #[arg(long, default_value = "vits")]
106        model_type: String,
107
108        /// Training data directory
109        #[arg(long)]
110        data: PathBuf,
111
112        /// Output directory for checkpoints
113        #[arg(short, long, default_value = "checkpoints/acoustic")]
114        output: PathBuf,
115
116        /// Training config file (TOML/JSON)
117        #[arg(short, long)]
118        config: Option<PathBuf>,
119
120        /// Number of epochs
121        #[arg(long, default_value = "500")]
122        epochs: usize,
123
124        /// Batch size
125        #[arg(long, default_value = "32")]
126        batch_size: usize,
127
128        /// Learning rate
129        #[arg(long, default_value = "0.0001")]
130        lr: f64,
131
132        /// Resume from checkpoint
133        #[arg(long)]
134        resume: Option<PathBuf>,
135
136        /// Use GPU if available
137        #[arg(long)]
138        gpu: bool,
139    },
140
141    /// Train G2P model
142    G2p {
143        /// Language code (en, ja, etc.)
144        #[arg(long, default_value = "en")]
145        language: String,
146
147        /// Dictionary file (pronunciation dictionary)
148        #[arg(long)]
149        dictionary: PathBuf,
150
151        /// Output model path
152        #[arg(short, long, default_value = "models/g2p.safetensors")]
153        output: PathBuf,
154
155        /// Training config file (TOML/JSON)
156        #[arg(short, long)]
157        config: Option<PathBuf>,
158
159        /// Number of epochs
160        #[arg(long, default_value = "100")]
161        epochs: usize,
162
163        /// Learning rate
164        #[arg(long, default_value = "0.001")]
165        lr: f64,
166    },
167}
168
169/// Training configuration for advanced options
170#[derive(Debug, Clone)]
171pub struct TrainingConfig {
172    pub lr_scheduler: String,
173    pub lr_step_size: usize,
174    pub lr_gamma: f64,
175    pub early_stopping: bool,
176    pub patience: usize,
177    pub min_delta: f64,
178    pub val_frequency: usize,
179    pub warmup_steps: usize,
180    pub grad_clip: f64,
181    pub save_frequency: usize,
182}
183
184impl Default for TrainingConfig {
185    fn default() -> Self {
186        Self {
187            lr_scheduler: "none".to_string(),
188            lr_step_size: 100,
189            lr_gamma: 0.1,
190            early_stopping: false,
191            patience: 50,
192            min_delta: 0.0001,
193            val_frequency: 5,
194            warmup_steps: 0,
195            grad_clip: 1.0,
196            save_frequency: 10,
197        }
198    }
199}
200
201/// Execute training command
202pub async fn execute_train_command(command: TrainCommands, global: &GlobalOptions) -> Result<()> {
203    match command {
204        TrainCommands::Vocoder {
205            model_type,
206            data,
207            output,
208            config,
209            epochs,
210            batch_size,
211            lr,
212            lr_scheduler,
213            lr_step_size,
214            lr_gamma,
215            early_stopping,
216            patience,
217            min_delta,
218            val_frequency,
219            warmup_steps,
220            grad_clip,
221            save_frequency,
222            resume,
223            gpu,
224        } => {
225            let training_config = TrainingConfig {
226                lr_scheduler,
227                lr_step_size,
228                lr_gamma,
229                early_stopping,
230                patience,
231                min_delta,
232                val_frequency,
233                warmup_steps,
234                grad_clip,
235                save_frequency,
236            };
237
238            let args = vocoder::VocoderTrainingArgs {
239                model_type,
240                data,
241                output,
242                config,
243                epochs,
244                batch_size,
245                lr,
246                resume,
247                use_gpu: gpu || global.gpu,
248                training_config,
249            };
250
251            vocoder::run_train_vocoder(args, global).await
252        }
253        TrainCommands::Acoustic {
254            model_type,
255            data,
256            output,
257            config,
258            epochs,
259            batch_size,
260            lr,
261            resume,
262            gpu,
263        } => {
264            let args = acoustic::AcousticModelTrainingArgs {
265                model_type: model_type.clone(),
266                data: data.clone(),
267                output: output.clone(),
268                config: config.clone(),
269                epochs,
270                batch_size,
271                lr,
272                resume: resume.clone(),
273                use_gpu: gpu || global.gpu,
274            };
275            acoustic::run_train_acoustic(args, global).await
276        }
277        TrainCommands::G2p {
278            language,
279            dictionary,
280            output,
281            config,
282            epochs,
283            lr,
284        } => g2p::run_train_g2p(language, dictionary, output, config, epochs, lr, global).await,
285    }
286}