Skip to main content

voirs_dataset/
traits.rs

1//! Core traits for dataset handling
2//!
3//! This module defines the main Dataset trait and related abstractions
4//! for working with speech synthesis datasets.
5
6use crate::{DatasetError, DatasetStatistics, ValidationReport};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11/// Result type for dataset operations
12pub type Result<T> = std::result::Result<T, DatasetError>;
13
14/// Dataset metadata containing information about the dataset
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct DatasetMetadata {
17    /// Dataset name
18    pub name: String,
19    /// Dataset version
20    pub version: String,
21    /// Dataset description
22    pub description: Option<String>,
23    /// Total number of samples
24    pub total_samples: usize,
25    /// Total duration in seconds
26    pub total_duration: f32,
27    /// Languages present in the dataset
28    pub languages: Vec<String>,
29    /// Speakers present in the dataset
30    pub speakers: Vec<String>,
31    /// License information
32    pub license: Option<String>,
33    /// Additional metadata
34    pub metadata: HashMap<String, serde_json::Value>,
35}
36
37// Re-export SplitConfig from splits module to maintain compatibility
38pub use crate::splits::SplitConfig;
39
40/// Dataset split containing train, validation, and test sets
41#[derive(Debug)]
42pub struct DatasetSplit<T> {
43    /// Training dataset
44    pub train: T,
45    /// Validation dataset
46    pub validation: T,
47    /// Test dataset
48    pub test: T,
49}
50
51/// Dataset split containing indices for train, validation, and test sets
52/// This is object-safe and can be used with trait objects
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct DatasetSplitIndices {
55    /// Training set indices
56    pub train: Vec<usize>,
57    /// Validation set indices
58    pub validation: Vec<usize>,
59    /// Test set indices
60    pub test: Vec<usize>,
61}
62
63impl DatasetSplitIndices {
64    /// Create a new dataset split with given indices
65    pub fn new(train: Vec<usize>, validation: Vec<usize>, test: Vec<usize>) -> Self {
66        Self {
67            train,
68            validation,
69            test,
70        }
71    }
72
73    /// Get the total number of samples across all splits
74    pub fn total_samples(&self) -> usize {
75        self.train.len() + self.validation.len() + self.test.len()
76    }
77
78    /// Get the ratio of each split
79    pub fn ratios(&self) -> (f32, f32, f32) {
80        let total = self.total_samples() as f32;
81        if total == 0.0 {
82            return (0.0, 0.0, 0.0);
83        }
84        (
85            self.train.len() as f32 / total,
86            self.validation.len() as f32 / total,
87            self.test.len() as f32 / total,
88        )
89    }
90
91    /// Validate that all indices are unique and within bounds
92    pub fn validate(&self, dataset_size: usize) -> Result<()> {
93        let mut all_indices = Vec::new();
94        all_indices.extend(&self.train);
95        all_indices.extend(&self.validation);
96        all_indices.extend(&self.test);
97
98        // Check for out-of-bounds indices
99        for &index in &all_indices {
100            if index >= dataset_size {
101                return Err(crate::DatasetError::IndexError(index));
102            }
103        }
104
105        // Check for duplicates
106        all_indices.sort_unstable();
107        for window in all_indices.windows(2) {
108            if window[0] == window[1] {
109                return Err(crate::DatasetError::ValidationError(format!(
110                    "Duplicate index {} found in splits",
111                    window[0]
112                )));
113            }
114        }
115
116        Ok(())
117    }
118}
119
120/// Main Dataset trait for different dataset implementations
121#[async_trait]
122pub trait Dataset: Send + Sync {
123    /// Dataset sample type
124    type Sample: DatasetSample;
125
126    /// Get the number of samples in the dataset
127    fn len(&self) -> usize;
128
129    /// Check if the dataset is empty
130    fn is_empty(&self) -> bool {
131        self.len() == 0
132    }
133
134    /// Get a sample by index
135    async fn get(&self, index: usize) -> Result<Self::Sample>;
136
137    /// Get multiple samples by indices
138    async fn get_batch(&self, indices: &[usize]) -> Result<Vec<Self::Sample>> {
139        let mut samples = Vec::with_capacity(indices.len());
140        for &index in indices {
141            samples.push(self.get(index).await?);
142        }
143        Ok(samples)
144    }
145
146    /// Create an iterator over all samples
147    fn iter(&self) -> DatasetIterator<Self::Sample> {
148        DatasetIterator::new(self.len())
149    }
150
151    /// Get dataset metadata
152    fn metadata(&self) -> &DatasetMetadata;
153
154    /// Split the dataset into train/validation/test index sets (object-safe)
155    /// Returns indices for each split rather than new dataset objects
156    async fn split_indices(&self, config: SplitConfig) -> Result<DatasetSplitIndices> {
157        use scirs2_core::random::seq::SliceRandom;
158        use scirs2_core::random::SeedableRng;
159
160        let dataset_size = self.len();
161        if dataset_size == 0 {
162            return Err(crate::DatasetError::SplitError(String::from(
163                "Cannot split empty dataset",
164            )));
165        }
166
167        // Validate split ratios
168        let sum = config.train_ratio + config.val_ratio + config.test_ratio;
169        if (sum - 1.0).abs() > 1e-6 {
170            return Err(crate::DatasetError::SplitError(format!(
171                "Split ratios must sum to 1.0, got {sum}"
172            )));
173        }
174
175        let mut indices: Vec<usize> = (0..dataset_size).collect();
176
177        // Handle different splitting strategies
178        match config.strategy {
179            crate::splits::SplitStrategy::Random => {
180                // Random shuffle
181                let mut rng = if let Some(seed) = config.seed {
182                    scirs2_core::random::Random::seed(seed)
183                } else {
184                    {
185                        let seed = std::time::SystemTime::now()
186                            .duration_since(std::time::UNIX_EPOCH)
187                            .map(|d| d.as_secs())
188                            .unwrap_or(0);
189                        scirs2_core::random::Random::seed(seed)
190                    }
191                };
192                indices.shuffle(&mut rng);
193            }
194            crate::splits::SplitStrategy::Stratified => {
195                // For stratified splitting, we need speaker information
196                // This is a simplified version - ideally we'd group by speaker first
197                let mut rng = if let Some(seed) = config.seed {
198                    scirs2_core::random::Random::seed(seed)
199                } else {
200                    {
201                        let seed = std::time::SystemTime::now()
202                            .duration_since(std::time::UNIX_EPOCH)
203                            .map(|d| d.as_secs())
204                            .unwrap_or(0);
205                        scirs2_core::random::Random::seed(seed)
206                    }
207                };
208                indices.shuffle(&mut rng);
209            }
210            _ => {
211                // For duration and text length based splitting, we would need to
212                // collect samples and sort them, but for now use random as fallback
213                let mut rng = if let Some(seed) = config.seed {
214                    scirs2_core::random::Random::seed(seed)
215                } else {
216                    {
217                        let seed = std::time::SystemTime::now()
218                            .duration_since(std::time::UNIX_EPOCH)
219                            .map(|d| d.as_secs())
220                            .unwrap_or(0);
221                        scirs2_core::random::Random::seed(seed)
222                    }
223                };
224                indices.shuffle(&mut rng);
225            }
226        }
227
228        // Calculate split sizes
229        let train_size = (dataset_size as f32 * config.train_ratio).round() as usize;
230        let val_size = (dataset_size as f32 * config.val_ratio).round() as usize;
231        let _test_size = dataset_size - train_size - val_size;
232
233        // Split indices
234        let train_indices = indices[0..train_size].to_vec();
235        let val_indices = indices[train_size..train_size + val_size].to_vec();
236        let test_indices = indices[train_size + val_size..].to_vec();
237
238        let split_indices = DatasetSplitIndices::new(train_indices, val_indices, test_indices);
239        split_indices.validate(self.len())?;
240        Ok(split_indices)
241    }
242
243    /// Get dataset statistics
244    async fn statistics(&self) -> Result<DatasetStatistics>;
245
246    /// Validate the dataset
247    async fn validate(&self) -> Result<ValidationReport>;
248
249    // Filter samples based on a predicate (not object-safe due to generics)
250    // async fn filter<F>(&self, predicate: F) -> Result<Vec<usize>>
251    // where
252    //     F: Fn(&Self::Sample) -> bool + Send + Sync,
253    // {
254    //     let mut filtered_indices = Vec::new();
255    //     for i in 0..self.len() {
256    //         let sample = self.get(i).await?;
257    //         if predicate(&sample) {
258    //             filtered_indices.push(i);
259    //         }
260    //     }
261    //     Ok(filtered_indices)
262    // }
263
264    /// Get a random sample
265    async fn get_random(&self) -> Result<Self::Sample> {
266        if self.is_empty() {
267            return Err(DatasetError::IndexError(0));
268        }
269        let len = self.len();
270        let index = {
271            use scirs2_core::random::{thread_rng, Rng};
272            let mut rng = thread_rng();
273            rng.gen_range(0..len)
274        };
275        self.get(index).await
276    }
277}
278
279/// Dataset iterator for efficient sample access
280pub struct DatasetIterator<T> {
281    current: usize,
282    total: usize,
283    _phantom: std::marker::PhantomData<T>,
284}
285
286impl<T> DatasetIterator<T> {
287    pub fn new(total: usize) -> Self {
288        Self {
289            current: 0,
290            total,
291            _phantom: std::marker::PhantomData,
292        }
293    }
294}
295
296impl<T> Iterator for DatasetIterator<T> {
297    type Item = usize;
298
299    fn next(&mut self) -> Option<Self::Item> {
300        if self.current < self.total {
301            let index = self.current;
302            self.current += 1;
303            Some(index)
304        } else {
305            None
306        }
307    }
308}
309
310/// Trait for dataset samples
311pub trait DatasetSample: Clone + Send + Sync {
312    /// Get the sample's unique identifier
313    fn id(&self) -> &str;
314
315    /// Get the sample's text content
316    fn text(&self) -> &str;
317
318    /// Get the sample's duration in seconds
319    fn duration(&self) -> f32;
320
321    /// Get the sample's language
322    fn language(&self) -> &str;
323
324    /// Get the sample's speaker ID (if available)
325    fn speaker_id(&self) -> Option<&str>;
326}
327
328// Implement DatasetSample for our DatasetSample struct
329impl DatasetSample for crate::DatasetSample {
330    fn id(&self) -> &str {
331        &self.id
332    }
333
334    fn text(&self) -> &str {
335        &self.text
336    }
337
338    fn duration(&self) -> f32 {
339        self.audio.duration()
340    }
341
342    fn language(&self) -> &str {
343        self.language.as_str()
344    }
345
346    fn speaker_id(&self) -> Option<&str> {
347        self.speaker.as_ref().map(|s| s.id.as_str())
348    }
349}