voirs_cli/commands/train/
data_loader.rs1use std::path::{Path, PathBuf};
34use voirs_dataset::{
35 loaders::LjSpeechLoader,
36 processing::features::{extract_mel_spectrogram, MelSpectrogramConfig},
37 traits::Dataset,
38 DatasetSample,
39};
40use voirs_sdk::Result;
41
42pub struct VocoderDataLoader {
44 samples: Vec<DatasetSample>,
46 mel_config: MelSpectrogramConfig,
48 current_index: usize,
50}
51
52impl VocoderDataLoader {
53 pub async fn load<P: AsRef<Path>>(data_dir: P) -> Result<Self> {
55 let is_valid = LjSpeechLoader::is_valid_dataset(data_dir.as_ref());
57
58 let dataset = if is_valid {
59 LjSpeechLoader::load(data_dir).await.map_err(|e| {
60 voirs_sdk::VoirsError::config_error(format!("Failed to load dataset: {}", e))
61 })?
62 } else {
63 return Err(voirs_sdk::VoirsError::config_error(format!(
64 "Unsupported dataset format at {:?}. Currently only LJSpeech is supported.",
65 data_dir.as_ref()
66 )));
67 };
68
69 let num_samples = dataset.len();
71 let mut samples = Vec::with_capacity(num_samples);
72
73 for i in 0..num_samples {
74 match dataset.get(i).await {
75 Ok(sample) => samples.push(sample),
76 Err(e) => {
77 eprintln!("Warning: Failed to load sample {}: {}", i, e);
79 }
80 }
81 }
82
83 if samples.is_empty() {
84 return Err(voirs_sdk::VoirsError::config_error(
85 "No valid samples found in dataset".to_string(),
86 ));
87 }
88
89 Ok(Self {
90 samples,
91 mel_config: MelSpectrogramConfig::default(),
92 current_index: 0,
93 })
94 }
95
96 pub fn len(&self) -> usize {
98 self.samples.len()
99 }
100
101 pub fn is_empty(&self) -> bool {
103 self.samples.is_empty()
104 }
105
106 pub fn get_batch(&mut self, batch_size: usize) -> Result<VocoderBatch> {
108 let mut batch_audio = Vec::new();
109 let mut batch_mels = Vec::new();
110
111 for _ in 0..batch_size {
112 if self.current_index >= self.samples.len() {
113 self.current_index = 0;
115 }
116
117 let sample = &self.samples[self.current_index];
118 self.current_index += 1;
119
120 let mel_result = extract_mel_spectrogram(
122 &sample.audio,
123 self.mel_config.n_mels,
124 self.mel_config.n_fft,
125 self.mel_config.hop_length,
126 )
127 .map_err(|e| {
128 voirs_sdk::VoirsError::config_error(format!(
129 "Failed to extract mel spectrogram: {}",
130 e
131 ))
132 })?;
133
134 let mel_matrix = mel_result.as_matrix();
136
137 batch_audio.push(sample.audio.samples().to_vec());
138 batch_mels.push(mel_matrix);
139 }
140
141 Ok(VocoderBatch {
142 audio: batch_audio,
143 mels: batch_mels,
144 })
145 }
146
147 pub fn reset(&mut self) {
149 self.current_index = 0;
150 }
151
152 pub fn current_index(&self) -> usize {
154 self.current_index
155 }
156
157 pub fn set_index(&mut self, index: usize) {
159 self.current_index = index.min(self.samples.len());
160 }
161}
162
163pub struct VocoderBatch {
165 pub audio: Vec<Vec<f32>>,
167 pub mels: Vec<Vec<Vec<f32>>>,
169}
170
171impl VocoderBatch {
172 pub fn len(&self) -> usize {
174 self.audio.len()
175 }
176
177 pub fn is_empty(&self) -> bool {
179 self.audio.is_empty()
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use std::env;
187
188 fn resolve_ljspeech_path() -> PathBuf {
189 env::var("LJSPEECH_PATH")
190 .map(PathBuf::from)
191 .unwrap_or_else(|_| {
192 std::env::temp_dir()
193 .join("voirs")
194 .join("datasets")
195 .join("LJSpeech-1.1")
196 })
197 }
198
199 #[tokio::test]
200 async fn test_vocoder_data_loader_basic() {
201 let ljspeech_path = resolve_ljspeech_path();
203
204 if !ljspeech_path.exists() {
205 eprintln!(
206 "Skipping test: LJSpeech dataset not found at {}",
207 ljspeech_path.display()
208 );
209 return;
210 }
211
212 let loader = VocoderDataLoader::load(&ljspeech_path).await;
213 assert!(loader.is_ok(), "Failed to load dataset");
214
215 let loader = loader.unwrap();
216 assert!(loader.len() > 0, "Dataset should not be empty");
217 assert!(!loader.is_empty(), "Dataset should not be empty");
218 }
219
220 #[tokio::test]
221 async fn test_batch_generation() {
222 let ljspeech_path = resolve_ljspeech_path();
223
224 if !ljspeech_path.exists() {
225 eprintln!("Skipping test: LJSpeech dataset not found");
226 return;
227 }
228
229 let mut loader = VocoderDataLoader::load(&ljspeech_path).await.unwrap();
230 let batch_size = 4;
231 let batch = loader.get_batch(batch_size).unwrap();
232
233 assert_eq!(batch.len(), batch_size, "Batch size should match");
234 assert_eq!(
235 batch.audio.len(),
236 batch_size,
237 "Audio batch size should match"
238 );
239 assert_eq!(batch.mels.len(), batch_size, "Mel batch size should match");
240
241 for mel in &batch.mels {
243 assert!(!mel.is_empty(), "Mel spectrogram should not be empty");
244 assert!(mel[0].len() > 0, "Mel spectrogram should have features");
245 }
246 }
247
248 #[tokio::test]
249 async fn test_batch_wraparound() {
250 let ljspeech_path = resolve_ljspeech_path();
251
252 if !ljspeech_path.exists() {
253 eprintln!("Skipping test: LJSpeech dataset not found");
254 return;
255 }
256
257 let mut loader = VocoderDataLoader::load(&ljspeech_path).await.unwrap();
258 let total_samples = loader.len();
259
260 let batch_size = 4;
262 let num_batches = (total_samples / batch_size) + 2;
263
264 for i in 0..num_batches {
265 let batch = loader.get_batch(batch_size);
266 assert!(batch.is_ok(), "Batch generation failed at iteration {}", i);
267 assert_eq!(batch.unwrap().len(), batch_size);
268 }
269 }
270
271 #[test]
272 fn test_vocoder_batch_properties() {
273 let batch = VocoderBatch {
274 audio: vec![vec![0.0; 100]; 4],
275 mels: vec![vec![vec![0.0; 80]; 10]; 4],
276 };
277
278 assert_eq!(batch.len(), 4);
279 assert!(!batch.is_empty());
280
281 let empty_batch = VocoderBatch {
282 audio: vec![],
283 mels: vec![],
284 };
285
286 assert_eq!(empty_batch.len(), 0);
287 assert!(empty_batch.is_empty());
288 }
289
290 #[tokio::test]
291 async fn test_invalid_dataset_path() {
292 let invalid_path = "/nonexistent/path/to/dataset";
293 let result = VocoderDataLoader::load(invalid_path).await;
294
295 assert!(result.is_err(), "Should fail with invalid path");
296 }
297
298 #[tokio::test]
299 async fn test_mel_spectrogram_shape() {
300 let ljspeech_path = resolve_ljspeech_path();
301
302 if !ljspeech_path.exists() {
303 eprintln!("Skipping test: LJSpeech dataset not found");
304 return;
305 }
306
307 let mut loader = VocoderDataLoader::load(&ljspeech_path).await.unwrap();
308 let batch = loader.get_batch(1).unwrap();
309
310 assert_eq!(batch.mels.len(), 1);
311
312 let mel = &batch.mels[0];
313 assert!(!mel.is_empty(), "Mel spectrogram should have frames");
314
315 for frame in mel {
317 assert_eq!(frame.len(), 80, "Each frame should have 80 mel bins");
318 }
319 }
320}