Skip to main content

voirs_cli/commands/train/
data_loader.rs

1//! Training data loader for vocoder training
2//!
3//! This module provides data loading capabilities for neural vocoder training.
4//! It supports the LJSpeech dataset format and can be extended to support
5//! other TTS datasets.
6//!
7//! # Features
8//!
9//! - Async dataset loading with error handling
10//! - Real-time mel spectrogram extraction using FFT
11//! - Efficient batch generation with sample wraparound
12//! - Support for LJSpeech format with automatic validation
13//!
14//! # Example
15//!
16//! ```no_run
17//! use voirs_cli::commands::train::data_loader::VocoderDataLoader;
18//!
19//! #[tokio::main]
20//! async fn main() -> voirs_sdk::Result<()> {
21//!     let mut loader = VocoderDataLoader::load("./data/LJSpeech-1.1").await?;
22//!
23//!     println!("Loaded {} samples", loader.len());
24//!
25//!     let batch = loader.get_batch(4)?;
26//!     println!("Batch contains {} audio samples and {} mel spectrograms",
27//!              batch.audio.len(), batch.mels.len());
28//!
29//!     Ok(())
30//! }
31//! ```
32
33use std::path::{Path, PathBuf};
34use voirs_dataset::{
35    loaders::LjSpeechLoader,
36    processing::features::{extract_mel_spectrogram, MelSpectrogramConfig},
37    traits::Dataset,
38    DatasetSample,
39};
40use voirs_sdk::Result;
41
42/// Vocoder training data loader
43pub struct VocoderDataLoader {
44    /// Dataset samples
45    samples: Vec<DatasetSample>,
46    /// Mel spectrogram configuration
47    mel_config: MelSpectrogramConfig,
48    /// Current index for iteration
49    current_index: usize,
50}
51
52impl VocoderDataLoader {
53    /// Load dataset from directory
54    pub async fn load<P: AsRef<Path>>(data_dir: P) -> Result<Self> {
55        // Try to load as LJSpeech dataset
56        let is_valid = LjSpeechLoader::is_valid_dataset(data_dir.as_ref());
57
58        let dataset = if is_valid {
59            LjSpeechLoader::load(data_dir).await.map_err(|e| {
60                voirs_sdk::VoirsError::config_error(format!("Failed to load dataset: {}", e))
61            })?
62        } else {
63            return Err(voirs_sdk::VoirsError::config_error(format!(
64                "Unsupported dataset format at {:?}. Currently only LJSpeech is supported.",
65                data_dir.as_ref()
66            )));
67        };
68
69        // Get all samples
70        let num_samples = dataset.len();
71        let mut samples = Vec::with_capacity(num_samples);
72
73        for i in 0..num_samples {
74            match dataset.get(i).await {
75                Ok(sample) => samples.push(sample),
76                Err(e) => {
77                    // Log warning and continue with other samples
78                    eprintln!("Warning: Failed to load sample {}: {}", i, e);
79                }
80            }
81        }
82
83        if samples.is_empty() {
84            return Err(voirs_sdk::VoirsError::config_error(
85                "No valid samples found in dataset".to_string(),
86            ));
87        }
88
89        Ok(Self {
90            samples,
91            mel_config: MelSpectrogramConfig::default(),
92            current_index: 0,
93        })
94    }
95
96    /// Get total number of samples
97    pub fn len(&self) -> usize {
98        self.samples.len()
99    }
100
101    /// Check if dataset is empty
102    pub fn is_empty(&self) -> bool {
103        self.samples.is_empty()
104    }
105
106    /// Get batch of audio samples with mel spectrograms
107    pub fn get_batch(&mut self, batch_size: usize) -> Result<VocoderBatch> {
108        let mut batch_audio = Vec::new();
109        let mut batch_mels = Vec::new();
110
111        for _ in 0..batch_size {
112            if self.current_index >= self.samples.len() {
113                // Wrap around to beginning (epoch completed)
114                self.current_index = 0;
115            }
116
117            let sample = &self.samples[self.current_index];
118            self.current_index += 1;
119
120            // Extract mel spectrogram
121            let mel_result = extract_mel_spectrogram(
122                &sample.audio,
123                self.mel_config.n_mels,
124                self.mel_config.n_fft,
125                self.mel_config.hop_length,
126            )
127            .map_err(|e| {
128                voirs_sdk::VoirsError::config_error(format!(
129                    "Failed to extract mel spectrogram: {}",
130                    e
131                ))
132            })?;
133
134            // Convert to Vec<Vec<f32>> (frames x mels)
135            let mel_matrix = mel_result.as_matrix();
136
137            batch_audio.push(sample.audio.samples().to_vec());
138            batch_mels.push(mel_matrix);
139        }
140
141        Ok(VocoderBatch {
142            audio: batch_audio,
143            mels: batch_mels,
144        })
145    }
146
147    /// Reset iterator to beginning
148    pub fn reset(&mut self) {
149        self.current_index = 0;
150    }
151
152    /// Get current index in dataset
153    pub fn current_index(&self) -> usize {
154        self.current_index
155    }
156
157    /// Set current index in dataset
158    pub fn set_index(&mut self, index: usize) {
159        self.current_index = index.min(self.samples.len());
160    }
161}
162
163/// Batch of vocoder training data
164pub struct VocoderBatch {
165    /// Audio waveforms (batch_size x samples)
166    pub audio: Vec<Vec<f32>>,
167    /// Mel spectrograms (batch_size x frames x n_mels)
168    pub mels: Vec<Vec<Vec<f32>>>,
169}
170
171impl VocoderBatch {
172    /// Get batch size
173    pub fn len(&self) -> usize {
174        self.audio.len()
175    }
176
177    /// Check if batch is empty
178    pub fn is_empty(&self) -> bool {
179        self.audio.is_empty()
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use std::env;
187
188    fn resolve_ljspeech_path() -> PathBuf {
189        env::var("LJSPEECH_PATH")
190            .map(PathBuf::from)
191            .unwrap_or_else(|_| {
192                std::env::temp_dir()
193                    .join("voirs")
194                    .join("datasets")
195                    .join("LJSpeech-1.1")
196            })
197    }
198
199    #[tokio::test]
200    async fn test_vocoder_data_loader_basic() {
201        // Test with LJSpeech dataset if available
202        let ljspeech_path = resolve_ljspeech_path();
203
204        if !ljspeech_path.exists() {
205            eprintln!(
206                "Skipping test: LJSpeech dataset not found at {}",
207                ljspeech_path.display()
208            );
209            return;
210        }
211
212        let loader = VocoderDataLoader::load(&ljspeech_path).await;
213        assert!(loader.is_ok(), "Failed to load dataset");
214
215        let loader = loader.unwrap();
216        assert!(loader.len() > 0, "Dataset should not be empty");
217        assert!(!loader.is_empty(), "Dataset should not be empty");
218    }
219
220    #[tokio::test]
221    async fn test_batch_generation() {
222        let ljspeech_path = resolve_ljspeech_path();
223
224        if !ljspeech_path.exists() {
225            eprintln!("Skipping test: LJSpeech dataset not found");
226            return;
227        }
228
229        let mut loader = VocoderDataLoader::load(&ljspeech_path).await.unwrap();
230        let batch_size = 4;
231        let batch = loader.get_batch(batch_size).unwrap();
232
233        assert_eq!(batch.len(), batch_size, "Batch size should match");
234        assert_eq!(
235            batch.audio.len(),
236            batch_size,
237            "Audio batch size should match"
238        );
239        assert_eq!(batch.mels.len(), batch_size, "Mel batch size should match");
240
241        // Verify mel spectrogram dimensions
242        for mel in &batch.mels {
243            assert!(!mel.is_empty(), "Mel spectrogram should not be empty");
244            assert!(mel[0].len() > 0, "Mel spectrogram should have features");
245        }
246    }
247
248    #[tokio::test]
249    async fn test_batch_wraparound() {
250        let ljspeech_path = resolve_ljspeech_path();
251
252        if !ljspeech_path.exists() {
253            eprintln!("Skipping test: LJSpeech dataset not found");
254            return;
255        }
256
257        let mut loader = VocoderDataLoader::load(&ljspeech_path).await.unwrap();
258        let total_samples = loader.len();
259
260        // Consume all samples plus some to test wraparound
261        let batch_size = 4;
262        let num_batches = (total_samples / batch_size) + 2;
263
264        for i in 0..num_batches {
265            let batch = loader.get_batch(batch_size);
266            assert!(batch.is_ok(), "Batch generation failed at iteration {}", i);
267            assert_eq!(batch.unwrap().len(), batch_size);
268        }
269    }
270
271    #[test]
272    fn test_vocoder_batch_properties() {
273        let batch = VocoderBatch {
274            audio: vec![vec![0.0; 100]; 4],
275            mels: vec![vec![vec![0.0; 80]; 10]; 4],
276        };
277
278        assert_eq!(batch.len(), 4);
279        assert!(!batch.is_empty());
280
281        let empty_batch = VocoderBatch {
282            audio: vec![],
283            mels: vec![],
284        };
285
286        assert_eq!(empty_batch.len(), 0);
287        assert!(empty_batch.is_empty());
288    }
289
290    #[tokio::test]
291    async fn test_invalid_dataset_path() {
292        let invalid_path = "/nonexistent/path/to/dataset";
293        let result = VocoderDataLoader::load(invalid_path).await;
294
295        assert!(result.is_err(), "Should fail with invalid path");
296    }
297
298    #[tokio::test]
299    async fn test_mel_spectrogram_shape() {
300        let ljspeech_path = resolve_ljspeech_path();
301
302        if !ljspeech_path.exists() {
303            eprintln!("Skipping test: LJSpeech dataset not found");
304            return;
305        }
306
307        let mut loader = VocoderDataLoader::load(&ljspeech_path).await.unwrap();
308        let batch = loader.get_batch(1).unwrap();
309
310        assert_eq!(batch.mels.len(), 1);
311
312        let mel = &batch.mels[0];
313        assert!(!mel.is_empty(), "Mel spectrogram should have frames");
314
315        // Check each frame has correct number of mel bins (80)
316        for frame in mel {
317            assert_eq!(frame.len(), 80, "Each frame should have 80 mel bins");
318        }
319    }
320}