Skip to main content

voirs_dataset/
lib.rs

1//! # VoiRS Dataset Utilities
2//!
3//! Dataset loading, preprocessing, and management utilities for training
4//! and evaluation of VoiRS speech synthesis models.
5
6// Allow pedantic lints that are acceptable for audio/DSP processing code
7#![allow(clippy::cast_precision_loss)] // Acceptable for audio sample conversions
8#![allow(clippy::cast_possible_truncation)] // Controlled truncation in audio processing
9#![allow(clippy::cast_sign_loss)] // Intentional in index calculations
10#![allow(clippy::missing_errors_doc)] // Many internal functions with self-documenting error types
11#![allow(clippy::missing_panics_doc)] // Panics are documented where relevant
12#![allow(clippy::unused_self)] // Some trait implementations require &self for consistency
13#![allow(clippy::must_use_candidate)] // Not all return values need must_use annotation
14#![allow(clippy::doc_markdown)] // Technical terms don't all need backticks
15#![allow(clippy::unnecessary_wraps)] // Result wrappers maintained for API consistency
16#![allow(clippy::float_cmp)] // Exact float comparisons are intentional in some contexts
17#![allow(clippy::match_same_arms)] // Pattern matching clarity sometimes requires duplication
18#![allow(clippy::module_name_repetitions)] // Type names often repeat module names
19#![allow(clippy::struct_excessive_bools)] // Config structs naturally have many boolean flags
20#![allow(clippy::too_many_lines)] // Some functions are inherently complex
21#![allow(clippy::needless_pass_by_value)] // Some functions designed for ownership transfer
22#![allow(clippy::similar_names)] // Many similar variable names in algorithms
23#![allow(clippy::unused_async)] // Public API functions may need async for consistency
24#![allow(clippy::needless_range_loop)] // Range loops sometimes clearer than iterators
25#![allow(clippy::uninlined_format_args)] // Explicit argument names can improve clarity
26#![allow(clippy::manual_clamp)] // Manual clamping sometimes clearer
27#![allow(clippy::return_self_not_must_use)] // Not all builder methods need must_use
28#![allow(clippy::cast_possible_wrap)] // Controlled wrapping in processing code
29#![allow(clippy::cast_lossless)] // Explicit casts preferred for clarity
30#![allow(clippy::wildcard_imports)] // Prelude imports are convenient and standard
31#![allow(clippy::format_push_string)] // Sometimes more readable than alternative
32#![allow(clippy::redundant_closure_for_method_calls)] // Closures sometimes needed for type inference
33
34use serde::{Deserialize, Serialize};
35use std::collections::HashMap;
36use thiserror::Error;
37
38/// Result type for dataset operations
39pub type Result<T> = std::result::Result<T, DatasetError>;
40
41/// Dataset-specific error types
42#[derive(Error, Debug)]
43pub enum DatasetError {
44    #[error("IO error: {0}")]
45    IoError(#[from] std::io::Error),
46
47    #[error("Dataset loading failed: {0}")]
48    LoadError(String),
49
50    #[error("Invalid format: {0}")]
51    FormatError(String),
52
53    #[error("Configuration error: {0}")]
54    ConfigError(String),
55
56    #[error("Audio processing error: {0}")]
57    AudioError(String),
58
59    #[error("Network error: {0}")]
60    NetworkError(String),
61
62    #[error("Validation error: {0}")]
63    ValidationError(String),
64
65    #[error("Preprocessing error: {0}")]
66    PreprocessingError(String),
67
68    #[error("Index out of bounds: {0}")]
69    IndexError(usize),
70
71    #[error("CSV error: {0}")]
72    CsvError(#[from] csv::Error),
73
74    #[error("Audio file error: {0}")]
75    HoundError(#[from] hound::Error),
76
77    #[error("JSON serialization error: {0}")]
78    JsonError(#[from] serde_json::Error),
79
80    #[error("Dataset split error: {0}")]
81    SplitError(String),
82
83    #[error("Processing error: {0}")]
84    ProcessingError(String),
85
86    #[error("Memory error: {0}")]
87    MemoryError(String),
88
89    #[error("Cloud storage error: {0}")]
90    CloudStorage(String),
91
92    #[error("Git error: {0}")]
93    Git(String),
94
95    #[error("MLOps error: {0}")]
96    MLOps(String),
97
98    #[error("Configuration error: {0}")]
99    Configuration(String),
100}
101
102/// Language codes supported by VoiRS
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
104pub enum LanguageCode {
105    /// English (US)
106    EnUs,
107    /// English (UK)
108    EnGb,
109    /// Japanese
110    Ja,
111    /// Mandarin Chinese
112    ZhCn,
113    /// Korean
114    Ko,
115    /// German
116    De,
117    /// French
118    Fr,
119    /// Spanish
120    Es,
121}
122
123impl LanguageCode {
124    /// Get string representation
125    pub fn as_str(&self) -> &'static str {
126        match self {
127            LanguageCode::EnUs => "en-US",
128            LanguageCode::EnGb => "en-GB",
129            LanguageCode::Ja => "ja",
130            LanguageCode::ZhCn => "zh-CN",
131            LanguageCode::Ko => "ko",
132            LanguageCode::De => "de",
133            LanguageCode::Fr => "fr",
134            LanguageCode::Es => "es",
135        }
136    }
137}
138
139/// A phoneme with its symbol and optional features
140#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
141pub struct Phoneme {
142    /// Phoneme symbol (IPA or language-specific)
143    pub symbol: String,
144    /// Optional phoneme features
145    pub features: Option<HashMap<String, String>>,
146    /// Duration in seconds (if available)
147    pub duration: Option<f32>,
148}
149
150impl Phoneme {
151    /// Create new phoneme
152    pub fn new<S: Into<String>>(symbol: S) -> Self {
153        Self {
154            symbol: symbol.into(),
155            features: None,
156            duration: None,
157        }
158    }
159}
160
161/// Audio data structure with efficient processing capabilities
162#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct AudioData {
164    /// Audio samples (interleaved for multi-channel)
165    samples: Vec<f32>,
166    /// Sample rate in Hz
167    sample_rate: u32,
168    /// Number of channels
169    channels: u32,
170    /// Optional metadata
171    metadata: HashMap<String, String>,
172}
173
174/// Audio buffer for holding PCM audio data (legacy compatibility)
175pub type AudioBuffer = AudioData;
176
177impl AudioData {
178    /// Create new audio data
179    pub fn new(samples: Vec<f32>, sample_rate: u32, channels: u32) -> Self {
180        Self {
181            samples,
182            sample_rate,
183            channels,
184            metadata: HashMap::new(),
185        }
186    }
187
188    /// Create silence
189    pub fn silence(duration: f32, sample_rate: u32, channels: u32) -> Self {
190        let num_samples = (duration * sample_rate as f32 * channels as f32) as usize;
191        Self::new(vec![0.0; num_samples], sample_rate, channels)
192    }
193
194    /// Get duration in seconds
195    pub fn duration(&self) -> f32 {
196        self.samples.len() as f32 / (self.sample_rate * self.channels) as f32
197    }
198
199    /// Check if buffer is empty
200    pub fn is_empty(&self) -> bool {
201        self.samples.is_empty()
202    }
203
204    /// Get sample rate
205    pub fn sample_rate(&self) -> u32 {
206        self.sample_rate
207    }
208
209    /// Get number of channels
210    pub fn channels(&self) -> u32 {
211        self.channels
212    }
213
214    /// Get samples
215    pub fn samples(&self) -> &[f32] {
216        &self.samples
217    }
218
219    /// Get mutable samples
220    pub fn samples_mut(&mut self) -> &mut [f32] {
221        &mut self.samples
222    }
223
224    /// Get metadata
225    pub fn metadata(&self) -> &HashMap<String, String> {
226        &self.metadata
227    }
228
229    /// Add metadata
230    pub fn add_metadata(&mut self, key: String, value: String) {
231        self.metadata.insert(key, value);
232    }
233
234    /// Resample audio to new sample rate using high-quality linear interpolation
235    pub fn resample(&self, new_sample_rate: u32) -> Result<AudioData> {
236        if new_sample_rate == self.sample_rate {
237            return Ok(self.clone());
238        }
239
240        if self.samples.is_empty() {
241            return Ok(AudioData::new(vec![], new_sample_rate, self.channels));
242        }
243
244        // High-quality linear interpolation resampling
245        let ratio = self.sample_rate as f64 / new_sample_rate as f64;
246        let new_length = (self.samples.len() as f64 / ratio) as usize;
247        let mut new_samples = Vec::with_capacity(new_length);
248
249        for i in 0..new_length {
250            let src_index = i as f64 * ratio;
251            let index_floor = src_index.floor() as usize;
252            let index_ceil = (index_floor + 1).min(self.samples.len() - 1);
253            let fraction = src_index - index_floor as f64;
254
255            if index_floor >= self.samples.len() {
256                new_samples.push(0.0);
257            } else if index_floor == index_ceil {
258                // At the boundary, use the last sample
259                new_samples.push(self.samples[index_floor]);
260            } else {
261                // Linear interpolation between adjacent samples
262                let sample1 = self.samples[index_floor];
263                let sample2 = self.samples[index_ceil];
264                let interpolated = sample1 + (sample2 - sample1) * fraction as f32;
265                new_samples.push(interpolated);
266            }
267        }
268
269        Ok(AudioData::new(new_samples, new_sample_rate, self.channels))
270    }
271
272    /// High-quality resampling using windowed sinc interpolation
273    pub fn resample_windowed_sinc(&self, new_sample_rate: u32) -> Result<AudioData> {
274        if new_sample_rate == self.sample_rate {
275            return Ok(self.clone());
276        }
277
278        if self.samples.is_empty() {
279            return Ok(AudioData::new(vec![], new_sample_rate, self.channels));
280        }
281
282        // Windowed sinc resampling parameters
283        const FILTER_LENGTH: usize = 128;
284        const KAISER_BETA: f64 = 8.6;
285
286        let ratio = new_sample_rate as f64 / self.sample_rate as f64;
287        let new_length = (self.samples.len() as f64 * ratio) as usize;
288        let mut new_samples = Vec::with_capacity(new_length);
289
290        // Precompute Kaiser window coefficients
291        let kaiser_window = Self::kaiser_window(FILTER_LENGTH, KAISER_BETA);
292
293        // Calculate filter cutoff frequency
294        let cutoff = if ratio < 1.0 { ratio } else { 1.0 };
295
296        for i in 0..new_length {
297            let src_index = i as f64 / ratio;
298            let mut sample = 0.0f64;
299
300            // Apply windowed sinc filter
301            for (j, &window_coeff) in kaiser_window.iter().enumerate().take(FILTER_LENGTH) {
302                let filter_index = j as i32 - (FILTER_LENGTH as i32 / 2);
303                let sample_index = src_index + filter_index as f64;
304
305                if sample_index >= 0.0 && sample_index < self.samples.len() as f64 {
306                    let t = sample_index - sample_index.floor();
307                    let src_sample = if t == 0.0 {
308                        self.samples[sample_index as usize] as f64
309                    } else {
310                        // Linear interpolation between samples
311                        let idx = sample_index.floor() as usize;
312                        let next_idx = (idx + 1).min(self.samples.len() - 1);
313                        let s1 = self.samples[idx] as f64;
314                        let s2 = self.samples[next_idx] as f64;
315                        s1 + (s2 - s1) * t
316                    };
317
318                    // Windowed sinc coefficient
319                    let x = (filter_index as f64 - (src_index - src_index.floor())) * cutoff;
320                    let sinc_val = if x.abs() < 1e-10 {
321                        cutoff
322                    } else {
323                        let pi_x = std::f64::consts::PI * x;
324                        (pi_x.sin() / pi_x) * cutoff
325                    };
326
327                    sample += src_sample * sinc_val * window_coeff;
328                }
329            }
330
331            new_samples.push(sample.clamp(-1.0, 1.0) as f32);
332        }
333
334        Ok(AudioData::new(new_samples, new_sample_rate, self.channels))
335    }
336
337    /// Generate Kaiser window coefficients
338    fn kaiser_window(length: usize, beta: f64) -> Vec<f64> {
339        let mut window = Vec::with_capacity(length);
340        let alpha = (length - 1) as f64 / 2.0;
341        let i0_beta = Self::modified_bessel_i0(beta);
342
343        for i in 0..length {
344            let x = (i as f64 - alpha) / alpha;
345            let arg = beta * (1.0 - x * x).sqrt();
346            window.push(Self::modified_bessel_i0(arg) / i0_beta);
347        }
348
349        window
350    }
351
352    /// Modified Bessel function of the first kind (I0)
353    fn modified_bessel_i0(x: f64) -> f64 {
354        let mut sum = 1.0;
355        let mut term = 1.0;
356        let x_squared = x * x;
357
358        for k in 1..=50 {
359            term *= x_squared / (4.0 * k as f64 * k as f64);
360            sum += term;
361            if term < 1e-15 * sum {
362                break;
363            }
364        }
365
366        sum
367    }
368
369    /// Normalize audio amplitude
370    pub fn normalize(&mut self) -> Result<()> {
371        if self.samples.is_empty() {
372            return Ok(());
373        }
374
375        use crate::audio::simd::SimdAudioProcessor;
376        let max_amplitude = SimdAudioProcessor::find_peak(&self.samples);
377
378        if max_amplitude > 0.0 {
379            let scale = 1.0 / max_amplitude;
380            SimdAudioProcessor::apply_gain(&mut self.samples, scale);
381        }
382
383        Ok(())
384    }
385
386    /// Calculate RMS (Root Mean Square) of the audio
387    pub fn rms(&self) -> Option<f32> {
388        if self.samples.is_empty() {
389            return None;
390        }
391
392        use crate::audio::simd::SimdAudioProcessor;
393        let rms = SimdAudioProcessor::calculate_rms(&self.samples);
394        Some(rms)
395    }
396
397    /// Calculate peak amplitude of the audio
398    pub fn peak(&self) -> Option<f32> {
399        if self.samples.is_empty() {
400            return None;
401        }
402
403        use crate::audio::simd::SimdAudioProcessor;
404        let peak = SimdAudioProcessor::find_peak(&self.samples);
405        Some(peak)
406    }
407
408    /// Calculate LUFS (Loudness Units Full Scale) of the audio
409    /// This is a perceptual loudness measurement following ITU-R BS.1770-4
410    pub fn lufs(&self) -> Option<f32> {
411        if self.samples.is_empty() {
412            return None;
413        }
414
415        let loudness = self.calculate_integrated_loudness();
416        Some(loudness)
417    }
418
419    /// Calculate integrated loudness following ITU-R BS.1770-4
420    fn calculate_integrated_loudness(&self) -> f32 {
421        // For simplicity, we'll implement a basic LUFS calculation
422        // Full implementation would include K-weighting filter and gating
423
424        // Apply basic pre-filter (approximation of K-weighting)
425        let filtered_samples = self.apply_k_weighting_approximation();
426
427        // Calculate mean square with gating
428        let mean_square = self.calculate_gated_mean_square(&filtered_samples);
429
430        // Convert to LUFS
431        if mean_square > 0.0 {
432            -0.691 + 10.0 * mean_square.log10()
433        } else {
434            -70.0 // Minimum practical LUFS value
435        }
436    }
437
438    /// Apply K-weighting filter approximation
439    fn apply_k_weighting_approximation(&self) -> Vec<f32> {
440        // Simplified K-weighting using high-shelf filter approximation
441        // Full implementation would use proper biquad filters
442        let mut filtered = self.samples.clone();
443
444        // Simple high-frequency emphasis (approximating K-weighting)
445        for i in 1..filtered.len() {
446            filtered[i] = filtered[i] + 0.1 * (filtered[i] - filtered[i - 1]);
447        }
448
449        filtered
450    }
451
452    /// Calculate gated mean square for LUFS measurement
453    fn calculate_gated_mean_square(&self, samples: &[f32]) -> f32 {
454        // Block size for gating (400ms at typical sample rates)
455        let block_size = (0.4 * self.sample_rate as f32) as usize;
456        if block_size == 0 || samples.len() < block_size {
457            // Fallback for short audio
458            return samples.iter().map(|&x| x * x).sum::<f32>() / samples.len() as f32;
459        }
460
461        let mut block_powers = Vec::new();
462
463        // Calculate power for each block
464        for chunk in samples.chunks(block_size) {
465            let power = chunk.iter().map(|&x| x * x).sum::<f32>() / chunk.len() as f32;
466            block_powers.push(power);
467        }
468
469        // Apply relative gate (-70 LUFS)
470        let relative_threshold = block_powers.iter().sum::<f32>() / block_powers.len() as f32 * 0.1; // -10dB relative
471
472        let gated_powers: Vec<f32> = block_powers
473            .into_iter()
474            .filter(|&power| power >= relative_threshold)
475            .collect();
476
477        if gated_powers.is_empty() {
478            relative_threshold
479        } else {
480            gated_powers.iter().sum::<f32>() / gated_powers.len() as f32
481        }
482    }
483
484    /// Normalize audio to target RMS level
485    pub fn normalize_rms(&mut self, target_rms: f32) -> Result<()> {
486        if let Some(current_rms) = self.rms() {
487            if current_rms > 0.0 {
488                let scale = target_rms / current_rms;
489                use crate::audio::simd::SimdAudioProcessor;
490                SimdAudioProcessor::apply_gain(&mut self.samples, scale);
491            }
492        }
493        Ok(())
494    }
495
496    /// Normalize audio to target peak level
497    pub fn normalize_peak(&mut self, target_peak: f32) -> Result<()> {
498        if let Some(current_peak) = self.peak() {
499            if current_peak > 0.0 {
500                let scale = target_peak / current_peak;
501                use crate::audio::simd::SimdAudioProcessor;
502                SimdAudioProcessor::apply_gain(&mut self.samples, scale);
503            }
504        }
505        Ok(())
506    }
507
508    /// Normalize audio to target LUFS level
509    pub fn normalize_lufs(&mut self, target_lufs: f32) -> Result<()> {
510        if let Some(current_lufs) = self.lufs() {
511            let lufs_difference = target_lufs - current_lufs;
512            let scale = 10.0_f32.powf(lufs_difference / 20.0); // Convert dB to linear scale
513            use crate::audio::simd::SimdAudioProcessor;
514            SimdAudioProcessor::apply_gain(&mut self.samples, scale);
515        }
516        Ok(())
517    }
518
519    /// Comprehensive normalization with multiple options
520    pub fn normalize_comprehensive(&mut self, config: NormalizationConfig) -> Result<()> {
521        match config.method {
522            NormalizationMethod::Peak => {
523                self.normalize_peak(config.target_level)?;
524            }
525            NormalizationMethod::Rms => {
526                self.normalize_rms(config.target_level)?;
527            }
528            NormalizationMethod::Lufs => {
529                self.normalize_lufs(config.target_level)?;
530            }
531        }
532
533        // Apply optional limiting to prevent clipping
534        if config.apply_limiting {
535            self.apply_soft_limiter(config.limiter_threshold)?;
536        }
537
538        Ok(())
539    }
540
541    /// Apply soft limiting to prevent clipping
542    fn apply_soft_limiter(&mut self, threshold: f32) -> Result<()> {
543        for sample in &mut self.samples {
544            let abs_sample = sample.abs();
545            if abs_sample > threshold {
546                // Soft limiting using tanh compression
547                let sign = sample.signum();
548                let compressed = threshold * (abs_sample / threshold).tanh();
549                *sample = sign * compressed;
550            }
551        }
552        Ok(())
553    }
554}
555
556/// Audio file formats supported by VoiRS
557#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
558pub enum AudioFormat {
559    /// WAV format
560    Wav,
561    /// FLAC format
562    Flac,
563    /// MP3 format
564    Mp3,
565    /// OGG Vorbis format
566    Ogg,
567    /// OPUS format
568    Opus,
569}
570
571pub mod audio;
572pub mod augmentation;
573pub mod cache;
574pub mod datasets;
575pub mod error;
576pub mod export;
577pub mod integration;
578pub mod metadata;
579pub mod ml;
580pub mod parallel;
581pub mod performance;
582pub mod processing;
583pub mod profiling;
584pub mod quality;
585pub mod research;
586pub mod sampling;
587pub mod streaming;
588pub mod traits;
589pub mod utils;
590pub mod versioning;
591
592// Legacy modules for backward compatibility
593pub mod formats;
594pub mod loaders;
595pub mod preprocessors;
596pub mod splits;
597pub mod validation;
598
599// Re-export split types for convenience
600pub use splits::{DatasetSplit, DatasetSplits, SplitConfig, SplitStatistics, SplitStrategy};
601
602/// Speaker information structure
603#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
604pub struct SpeakerInfo {
605    /// Speaker identifier
606    pub id: String,
607    /// Speaker name (if available)
608    pub name: Option<String>,
609    /// Speaker gender
610    pub gender: Option<String>,
611    /// Speaker age
612    pub age: Option<u32>,
613    /// Speaker accent/region
614    pub accent: Option<String>,
615    /// Additional speaker metadata
616    pub metadata: HashMap<String, String>,
617}
618
619/// Quality metrics for audio samples
620#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
621pub struct QualityMetrics {
622    /// Signal-to-noise ratio in dB
623    pub snr: Option<f32>,
624    /// Clipping percentage
625    pub clipping: Option<f32>,
626    /// Dynamic range in dB
627    pub dynamic_range: Option<f32>,
628    /// Spectral quality score (0-1)
629    pub spectral_quality: Option<f32>,
630    /// Overall quality score (0-1)
631    pub overall_quality: Option<f32>,
632}
633
634/// Audio normalization method
635#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
636pub enum NormalizationMethod {
637    /// Normalize to peak amplitude
638    Peak,
639    /// Normalize to RMS level
640    Rms,
641    /// Normalize to LUFS level (perceptual loudness)
642    Lufs,
643}
644
645/// Audio normalization configuration
646#[derive(Debug, Clone, Serialize, Deserialize)]
647pub struct NormalizationConfig {
648    /// Normalization method to use
649    pub method: NormalizationMethod,
650    /// Target level (peak: 0.0-1.0, RMS: 0.0-1.0, LUFS: -70.0 to 0.0 dB)
651    pub target_level: f32,
652    /// Apply soft limiting to prevent clipping
653    pub apply_limiting: bool,
654    /// Limiter threshold (0.0-1.0)
655    pub limiter_threshold: f32,
656}
657
658impl Default for NormalizationConfig {
659    fn default() -> Self {
660        Self {
661            method: NormalizationMethod::Peak,
662            target_level: 0.9,
663            apply_limiting: true,
664            limiter_threshold: 0.95,
665        }
666    }
667}
668
669impl NormalizationConfig {
670    /// Create configuration for peak normalization
671    pub fn peak(target_level: f32) -> Self {
672        Self {
673            method: NormalizationMethod::Peak,
674            target_level,
675            apply_limiting: true,
676            limiter_threshold: 0.95,
677        }
678    }
679
680    /// Create configuration for RMS normalization
681    pub fn rms(target_level: f32) -> Self {
682        Self {
683            method: NormalizationMethod::Rms,
684            target_level,
685            apply_limiting: true,
686            limiter_threshold: 0.95,
687        }
688    }
689
690    /// Create configuration for LUFS normalization
691    pub fn lufs(target_lufs: f32) -> Self {
692        Self {
693            method: NormalizationMethod::Lufs,
694            target_level: target_lufs,
695            apply_limiting: true,
696            limiter_threshold: 0.95,
697        }
698    }
699}
700
701/// Dataset sample with comprehensive metadata
702#[derive(Debug, Clone, Serialize, Deserialize)]
703pub struct DatasetSample {
704    /// Unique identifier for this sample
705    pub id: String,
706
707    /// Original text
708    pub text: String,
709
710    /// Audio data
711    pub audio: AudioData,
712
713    /// Speaker information (if available)
714    pub speaker: Option<SpeakerInfo>,
715
716    /// Language of the text
717    pub language: LanguageCode,
718
719    /// Quality metrics
720    pub quality: QualityMetrics,
721
722    /// Phoneme sequence (if available)
723    pub phonemes: Option<Vec<Phoneme>>,
724
725    /// Additional metadata
726    pub metadata: HashMap<String, serde_json::Value>,
727}
728
729/// Dataset item containing text, phonemes, and audio (legacy compatibility)
730pub type DatasetItem = DatasetSample;
731
732impl DatasetSample {
733    /// Create new dataset sample
734    pub fn new(id: String, text: String, audio: AudioData, language: LanguageCode) -> Self {
735        Self {
736            id,
737            text,
738            audio,
739            speaker: None,
740            language,
741            quality: QualityMetrics {
742                snr: None,
743                clipping: None,
744                dynamic_range: None,
745                spectral_quality: None,
746                overall_quality: None,
747            },
748            phonemes: None,
749            metadata: HashMap::new(),
750        }
751    }
752
753    /// Set phonemes
754    pub fn with_phonemes(mut self, phonemes: Vec<Phoneme>) -> Self {
755        self.phonemes = Some(phonemes);
756        self
757    }
758
759    /// Set speaker information
760    pub fn with_speaker(mut self, speaker: SpeakerInfo) -> Self {
761        self.speaker = Some(speaker);
762        self
763    }
764
765    /// Set quality metrics
766    pub fn with_quality(mut self, quality: QualityMetrics) -> Self {
767        self.quality = quality;
768        self
769    }
770
771    /// Add metadata
772    pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
773        self.metadata.insert(key, value);
774        self
775    }
776
777    /// Get duration in seconds
778    pub fn duration(&self) -> f32 {
779        self.audio.duration()
780    }
781
782    /// Get speaker ID (for backward compatibility)
783    pub fn speaker_id(&self) -> Option<&str> {
784        self.speaker.as_ref().map(|s| s.id.as_str())
785    }
786}
787
788/// Dataset trait for different dataset formats
789pub trait Dataset {
790    /// Get dataset name
791    fn name(&self) -> &str;
792
793    /// Get number of items in dataset
794    fn len(&self) -> usize;
795
796    /// Check if dataset is empty
797    fn is_empty(&self) -> bool {
798        self.len() == 0
799    }
800
801    /// Get item by index
802    fn get_item(&self, index: usize) -> Result<DatasetItem>;
803
804    /// Get all items (for small datasets)
805    fn get_all_items(&self) -> Result<Vec<DatasetItem>> {
806        (0..self.len()).map(|i| self.get_item(i)).collect()
807    }
808
809    /// Get dataset statistics
810    fn statistics(&self) -> DatasetStatistics;
811
812    /// Validate dataset
813    fn validate(&self) -> Result<ValidationReport>;
814}
815
816/// Dataset statistics
817#[derive(Debug, Clone, Serialize, Deserialize)]
818pub struct DatasetStatistics {
819    /// Total number of items
820    pub total_items: usize,
821
822    /// Total duration in seconds
823    pub total_duration: f32,
824
825    /// Average duration per item
826    pub average_duration: f32,
827
828    /// Language distribution
829    pub language_distribution: std::collections::HashMap<LanguageCode, usize>,
830
831    /// Speaker distribution (if applicable)
832    pub speaker_distribution: std::collections::HashMap<String, usize>,
833
834    /// Text length statistics
835    pub text_length_stats: LengthStatistics,
836
837    /// Audio duration statistics
838    pub duration_stats: DurationStatistics,
839}
840
841/// Length statistics for text
842#[derive(Debug, Clone, Serialize, Deserialize)]
843pub struct LengthStatistics {
844    pub min: usize,
845    pub max: usize,
846    pub mean: f32,
847    pub median: usize,
848    pub std_dev: f32,
849}
850
851/// Duration statistics for audio
852#[derive(Debug, Clone, Serialize, Deserialize)]
853pub struct DurationStatistics {
854    pub min: f32,
855    pub max: f32,
856    pub mean: f32,
857    pub median: f32,
858    pub std_dev: f32,
859}
860
861/// Dataset validation report
862#[derive(Debug, Clone, Serialize, Deserialize)]
863pub struct ValidationReport {
864    /// Whether the dataset is valid
865    pub is_valid: bool,
866
867    /// List of errors found
868    pub errors: Vec<String>,
869
870    /// List of warnings
871    pub warnings: Vec<String>,
872
873    /// Number of items validated
874    pub items_validated: usize,
875}
876
877/// In-memory dataset implementation
878pub struct MemoryDataset {
879    name: String,
880    items: Vec<DatasetItem>,
881}
882
883impl MemoryDataset {
884    /// Create new in-memory dataset
885    pub fn new(name: String) -> Self {
886        Self {
887            name,
888            items: Vec::new(),
889        }
890    }
891
892    /// Add item to dataset
893    pub fn add_item(&mut self, item: DatasetItem) {
894        self.items.push(item);
895    }
896
897    /// Add multiple items
898    pub fn add_items(&mut self, items: Vec<DatasetItem>) {
899        self.items.extend(items);
900    }
901
902    /// Clear all items
903    pub fn clear(&mut self) {
904        self.items.clear();
905    }
906}
907
908impl Dataset for MemoryDataset {
909    fn name(&self) -> &str {
910        &self.name
911    }
912
913    fn len(&self) -> usize {
914        self.items.len()
915    }
916
917    fn get_item(&self, index: usize) -> Result<DatasetItem> {
918        self.items.get(index).cloned().ok_or_else(|| {
919            DatasetError::ConfigError(format!("Dataset index {index} out of bounds"))
920        })
921    }
922
923    fn statistics(&self) -> DatasetStatistics {
924        if self.items.is_empty() {
925            return DatasetStatistics {
926                total_items: 0,
927                total_duration: 0.0,
928                average_duration: 0.0,
929                language_distribution: std::collections::HashMap::new(),
930                speaker_distribution: std::collections::HashMap::new(),
931                text_length_stats: LengthStatistics {
932                    min: 0,
933                    max: 0,
934                    mean: 0.0,
935                    median: 0,
936                    std_dev: 0.0,
937                },
938                duration_stats: DurationStatistics {
939                    min: 0.0,
940                    max: 0.0,
941                    mean: 0.0,
942                    median: 0.0,
943                    std_dev: 0.0,
944                },
945            };
946        }
947
948        let total_items = self.items.len();
949        let total_duration: f32 = self.items.iter().map(DatasetSample::duration).sum();
950        let average_duration = total_duration / total_items as f32;
951
952        // Language distribution
953        let mut language_distribution = std::collections::HashMap::new();
954        for item in &self.items {
955            *language_distribution.entry(item.language).or_insert(0) += 1;
956        }
957
958        // Speaker distribution
959        let mut speaker_distribution = std::collections::HashMap::new();
960        for item in &self.items {
961            if let Some(speaker) = item.speaker_id() {
962                *speaker_distribution.entry(speaker.to_string()).or_insert(0) += 1;
963            }
964        }
965
966        // Text length statistics
967        let text_lengths: Vec<usize> = self.items.iter().map(|item| item.text.len()).collect();
968        let text_length_stats = calculate_length_stats(&text_lengths);
969
970        // Duration statistics
971        let durations: Vec<f32> = self.items.iter().map(DatasetSample::duration).collect();
972        let duration_stats = calculate_duration_stats(&durations);
973
974        DatasetStatistics {
975            total_items,
976            total_duration,
977            average_duration,
978            language_distribution,
979            speaker_distribution,
980            text_length_stats,
981            duration_stats,
982        }
983    }
984
985    fn validate(&self) -> Result<ValidationReport> {
986        let mut errors = Vec::new();
987        let mut warnings = Vec::new();
988
989        for (i, item) in self.items.iter().enumerate() {
990            // Check for empty text
991            if item.text.trim().is_empty() {
992                errors.push(format!("Item {i}: Empty text"));
993            }
994
995            // Check for very short audio
996            if item.duration() < 0.1 {
997                warnings.push(format!(
998                    "Item {}: Very short audio ({:.3}s)",
999                    i,
1000                    item.duration()
1001                ));
1002            }
1003
1004            // Check for very long audio
1005            if item.duration() > 30.0 {
1006                warnings.push(format!(
1007                    "Item {}: Very long audio ({:.1}s)",
1008                    i,
1009                    item.duration()
1010                ));
1011            }
1012
1013            // Check for empty audio
1014            if item.audio.is_empty() {
1015                errors.push(format!("Item {i}: Empty audio"));
1016            }
1017        }
1018
1019        Ok(ValidationReport {
1020            is_valid: errors.is_empty(),
1021            errors,
1022            warnings,
1023            items_validated: self.items.len(),
1024        })
1025    }
1026}
1027
1028/// Calculate length statistics
1029fn calculate_length_stats(values: &[usize]) -> LengthStatistics {
1030    if values.is_empty() {
1031        return LengthStatistics {
1032            min: 0,
1033            max: 0,
1034            mean: 0.0,
1035            median: 0,
1036            std_dev: 0.0,
1037        };
1038    }
1039
1040    let mut sorted = values.to_vec();
1041    sorted.sort_unstable();
1042
1043    let min = sorted[0];
1044    let max = sorted[sorted.len() - 1];
1045    let sum: usize = values.iter().sum();
1046    let mean = sum as f32 / values.len() as f32;
1047    let median = sorted[sorted.len() / 2];
1048
1049    let variance: f32 = values
1050        .iter()
1051        .map(|&x| (x as f32 - mean).powi(2))
1052        .sum::<f32>()
1053        / values.len() as f32;
1054    let std_dev = variance.sqrt();
1055
1056    LengthStatistics {
1057        min,
1058        max,
1059        mean,
1060        median,
1061        std_dev,
1062    }
1063}
1064
1065/// Calculate duration statistics
1066fn calculate_duration_stats(values: &[f32]) -> DurationStatistics {
1067    if values.is_empty() {
1068        return DurationStatistics {
1069            min: 0.0,
1070            max: 0.0,
1071            mean: 0.0,
1072            median: 0.0,
1073            std_dev: 0.0,
1074        };
1075    }
1076
1077    let mut sorted = values.to_vec();
1078    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1079
1080    let min = sorted[0];
1081    let max = sorted[sorted.len() - 1];
1082    let sum: f32 = values.iter().sum();
1083    let mean = sum / values.len() as f32;
1084    let median = sorted[sorted.len() / 2];
1085
1086    let variance: f32 =
1087        values.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
1088    let std_dev = variance.sqrt();
1089
1090    DurationStatistics {
1091        min,
1092        max,
1093        mean,
1094        median,
1095        std_dev,
1096    }
1097}
1098
1099#[cfg(test)]
1100mod tests {
1101    use super::*;
1102    use crate::LanguageCode;
1103
1104    #[test]
1105    fn test_dataset_item_creation() {
1106        let audio = AudioBuffer::silence(1.0, 22050, 1);
1107        let item = DatasetItem::new(
1108            "test-001".to_string(),
1109            "Hello, world!".to_string(),
1110            audio,
1111            LanguageCode::EnUs,
1112        );
1113
1114        assert_eq!(item.id, "test-001");
1115        assert_eq!(item.text, "Hello, world!");
1116        assert_eq!(item.language, LanguageCode::EnUs);
1117        assert!(item.phonemes.is_none());
1118        assert!(item.speaker_id().is_none());
1119    }
1120
1121    #[test]
1122    fn test_memory_dataset() {
1123        let mut dataset = MemoryDataset::new("test-dataset".to_string());
1124
1125        // Add test items
1126        for i in 0..3 {
1127            let audio = AudioBuffer::silence(1.0, 22050, 1);
1128            let item = DatasetItem::new(
1129                format!("item-{i:03}"),
1130                format!("Text number {i}"),
1131                audio,
1132                LanguageCode::EnUs,
1133            );
1134            dataset.add_item(item);
1135        }
1136
1137        assert_eq!(dataset.name(), "test-dataset");
1138        assert_eq!(dataset.len(), 3);
1139        assert!(!dataset.is_empty());
1140
1141        // Test item retrieval
1142        let item = dataset.get_item(1).unwrap();
1143        assert_eq!(item.id, "item-001");
1144        assert_eq!(item.text, "Text number 1");
1145
1146        // Test statistics
1147        let stats = dataset.statistics();
1148        assert_eq!(stats.total_items, 3);
1149        assert!(stats.total_duration > 0.0);
1150        assert_eq!(stats.language_distribution[&LanguageCode::EnUs], 3);
1151
1152        // Test validation
1153        let report = dataset.validate().unwrap();
1154        assert!(report.is_valid);
1155        assert_eq!(report.items_validated, 3);
1156    }
1157
1158    #[test]
1159    fn test_windowed_sinc_resampling() {
1160        // Test with sine wave to verify frequency preservation
1161        let sample_rate = 44100;
1162        let new_sample_rate = 22050;
1163        let frequency = 1000.0; // 1kHz test tone
1164        let duration = 0.1; // 100ms
1165
1166        let mut samples = Vec::new();
1167        let num_samples = (sample_rate as f32 * duration) as usize;
1168
1169        // Generate sine wave
1170        for i in 0..num_samples {
1171            let t = i as f32 / sample_rate as f32;
1172            let sample = (2.0 * std::f32::consts::PI * frequency * t).sin();
1173            samples.push(sample);
1174        }
1175
1176        let original_audio = AudioData::new(samples, sample_rate, 1);
1177        let resampled = original_audio
1178            .resample_windowed_sinc(new_sample_rate)
1179            .unwrap();
1180
1181        // Check that the resampled audio has the correct sample rate and length
1182        assert_eq!(resampled.sample_rate(), new_sample_rate);
1183        let expected_length = (num_samples * new_sample_rate as usize) / sample_rate as usize;
1184        assert!((resampled.samples().len() as i32 - expected_length as i32).abs() <= 1);
1185
1186        // Verify that the frequency content is preserved (basic check)
1187        let resampled_samples = resampled.samples();
1188        assert!(!resampled_samples.is_empty());
1189
1190        // Check that the signal hasn't been completely distorted
1191        let original_rms = original_audio.rms().unwrap();
1192        let resampled_rms = resampled.rms().unwrap();
1193        assert!((original_rms - resampled_rms).abs() < 0.1);
1194    }
1195
1196    #[test]
1197    fn test_windowed_sinc_resampling_same_rate() {
1198        let samples = vec![1.0, -1.0, 1.0, -1.0];
1199        let audio = AudioData::new(samples.clone(), 44100, 1);
1200
1201        let result = audio.resample_windowed_sinc(44100).unwrap();
1202
1203        assert_eq!(result.sample_rate(), 44100);
1204        assert_eq!(result.samples(), &samples);
1205    }
1206
1207    #[test]
1208    fn test_windowed_sinc_resampling_empty() {
1209        let audio = AudioData::new(vec![], 44100, 1);
1210        let result = audio.resample_windowed_sinc(22050).unwrap();
1211
1212        assert_eq!(result.sample_rate(), 22050);
1213        assert!(result.samples().is_empty());
1214    }
1215
1216    #[test]
1217    fn test_windowed_sinc_upsampling() {
1218        let samples = vec![1.0, 0.0, -1.0, 0.0];
1219        let audio = AudioData::new(samples, 22050, 1);
1220
1221        let result = audio.resample_windowed_sinc(44100).unwrap();
1222
1223        assert_eq!(result.sample_rate(), 44100);
1224        assert_eq!(result.samples().len(), 8); // Double the length
1225
1226        // Check that the signal quality is maintained
1227        let rms = result.rms().unwrap();
1228        assert!(rms > 0.0);
1229    }
1230
1231    #[test]
1232    fn test_windowed_sinc_downsampling() {
1233        let mut samples = Vec::new();
1234        for i in 0..88 {
1235            samples.push((i as f32 / 88.0).sin());
1236        }
1237        let audio = AudioData::new(samples, 44100, 1);
1238
1239        let result = audio.resample_windowed_sinc(22050).unwrap();
1240
1241        assert_eq!(result.sample_rate(), 22050);
1242        assert_eq!(result.samples().len(), 44); // Half the length
1243
1244        // Check that the signal quality is maintained
1245        let rms = result.rms().unwrap();
1246        assert!(rms > 0.0);
1247    }
1248
1249    #[test]
1250    fn test_kaiser_window_properties() {
1251        let window = AudioData::kaiser_window(64, 8.6);
1252
1253        // Kaiser window should have symmetric properties
1254        assert_eq!(window.len(), 64);
1255        assert!((window[0] - window[63]).abs() < 1e-10);
1256        assert!((window[16] - window[47]).abs() < 1e-10);
1257
1258        // Maximum should be at the center
1259        let max_val = window.iter().fold(0.0f64, |a, &b| a.max(b));
1260        assert!((window[31] - max_val).abs() < 1e-10);
1261    }
1262
1263    #[test]
1264    fn test_modified_bessel_i0_known_values() {
1265        // Test known values of modified Bessel function I0
1266        assert!((AudioData::modified_bessel_i0(0.0) - 1.0).abs() < 1e-10);
1267        assert!((AudioData::modified_bessel_i0(1.0) - 1.2660658777520084).abs() < 1e-10);
1268        assert!((AudioData::modified_bessel_i0(2.0) - 2.2795853023360673).abs() < 1e-10);
1269    }
1270}