1use crate::{DatasetError, DatasetStatistics, ValidationReport};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11pub type Result<T> = std::result::Result<T, DatasetError>;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct DatasetMetadata {
17 pub name: String,
19 pub version: String,
21 pub description: Option<String>,
23 pub total_samples: usize,
25 pub total_duration: f32,
27 pub languages: Vec<String>,
29 pub speakers: Vec<String>,
31 pub license: Option<String>,
33 pub metadata: HashMap<String, serde_json::Value>,
35}
36
37pub use crate::splits::SplitConfig;
39
40#[derive(Debug)]
42pub struct DatasetSplit<T> {
43 pub train: T,
45 pub validation: T,
47 pub test: T,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct DatasetSplitIndices {
55 pub train: Vec<usize>,
57 pub validation: Vec<usize>,
59 pub test: Vec<usize>,
61}
62
63impl DatasetSplitIndices {
64 pub fn new(train: Vec<usize>, validation: Vec<usize>, test: Vec<usize>) -> Self {
66 Self {
67 train,
68 validation,
69 test,
70 }
71 }
72
73 pub fn total_samples(&self) -> usize {
75 self.train.len() + self.validation.len() + self.test.len()
76 }
77
78 pub fn ratios(&self) -> (f32, f32, f32) {
80 let total = self.total_samples() as f32;
81 if total == 0.0 {
82 return (0.0, 0.0, 0.0);
83 }
84 (
85 self.train.len() as f32 / total,
86 self.validation.len() as f32 / total,
87 self.test.len() as f32 / total,
88 )
89 }
90
91 pub fn validate(&self, dataset_size: usize) -> Result<()> {
93 let mut all_indices = Vec::new();
94 all_indices.extend(&self.train);
95 all_indices.extend(&self.validation);
96 all_indices.extend(&self.test);
97
98 for &index in &all_indices {
100 if index >= dataset_size {
101 return Err(crate::DatasetError::IndexError(index));
102 }
103 }
104
105 all_indices.sort_unstable();
107 for window in all_indices.windows(2) {
108 if window[0] == window[1] {
109 return Err(crate::DatasetError::ValidationError(format!(
110 "Duplicate index {} found in splits",
111 window[0]
112 )));
113 }
114 }
115
116 Ok(())
117 }
118}
119
120#[async_trait]
122pub trait Dataset: Send + Sync {
123 type Sample: DatasetSample;
125
126 fn len(&self) -> usize;
128
129 fn is_empty(&self) -> bool {
131 self.len() == 0
132 }
133
134 async fn get(&self, index: usize) -> Result<Self::Sample>;
136
137 async fn get_batch(&self, indices: &[usize]) -> Result<Vec<Self::Sample>> {
139 let mut samples = Vec::with_capacity(indices.len());
140 for &index in indices {
141 samples.push(self.get(index).await?);
142 }
143 Ok(samples)
144 }
145
146 fn iter(&self) -> DatasetIterator<Self::Sample> {
148 DatasetIterator::new(self.len())
149 }
150
151 fn metadata(&self) -> &DatasetMetadata;
153
154 async fn split_indices(&self, config: SplitConfig) -> Result<DatasetSplitIndices> {
157 use scirs2_core::random::seq::SliceRandom;
158 use scirs2_core::random::SeedableRng;
159
160 let dataset_size = self.len();
161 if dataset_size == 0 {
162 return Err(crate::DatasetError::SplitError(String::from(
163 "Cannot split empty dataset",
164 )));
165 }
166
167 let sum = config.train_ratio + config.val_ratio + config.test_ratio;
169 if (sum - 1.0).abs() > 1e-6 {
170 return Err(crate::DatasetError::SplitError(format!(
171 "Split ratios must sum to 1.0, got {sum}"
172 )));
173 }
174
175 let mut indices: Vec<usize> = (0..dataset_size).collect();
176
177 match config.strategy {
179 crate::splits::SplitStrategy::Random => {
180 let mut rng = if let Some(seed) = config.seed {
182 scirs2_core::random::Random::seed(seed)
183 } else {
184 {
185 let seed = std::time::SystemTime::now()
186 .duration_since(std::time::UNIX_EPOCH)
187 .map(|d| d.as_secs())
188 .unwrap_or(0);
189 scirs2_core::random::Random::seed(seed)
190 }
191 };
192 indices.shuffle(&mut rng);
193 }
194 crate::splits::SplitStrategy::Stratified => {
195 let mut rng = if let Some(seed) = config.seed {
198 scirs2_core::random::Random::seed(seed)
199 } else {
200 {
201 let seed = std::time::SystemTime::now()
202 .duration_since(std::time::UNIX_EPOCH)
203 .map(|d| d.as_secs())
204 .unwrap_or(0);
205 scirs2_core::random::Random::seed(seed)
206 }
207 };
208 indices.shuffle(&mut rng);
209 }
210 _ => {
211 let mut rng = if let Some(seed) = config.seed {
214 scirs2_core::random::Random::seed(seed)
215 } else {
216 {
217 let seed = std::time::SystemTime::now()
218 .duration_since(std::time::UNIX_EPOCH)
219 .map(|d| d.as_secs())
220 .unwrap_or(0);
221 scirs2_core::random::Random::seed(seed)
222 }
223 };
224 indices.shuffle(&mut rng);
225 }
226 }
227
228 let train_size = (dataset_size as f32 * config.train_ratio).round() as usize;
230 let val_size = (dataset_size as f32 * config.val_ratio).round() as usize;
231 let _test_size = dataset_size - train_size - val_size;
232
233 let train_indices = indices[0..train_size].to_vec();
235 let val_indices = indices[train_size..train_size + val_size].to_vec();
236 let test_indices = indices[train_size + val_size..].to_vec();
237
238 let split_indices = DatasetSplitIndices::new(train_indices, val_indices, test_indices);
239 split_indices.validate(self.len())?;
240 Ok(split_indices)
241 }
242
243 async fn statistics(&self) -> Result<DatasetStatistics>;
245
246 async fn validate(&self) -> Result<ValidationReport>;
248
249 async fn get_random(&self) -> Result<Self::Sample> {
266 if self.is_empty() {
267 return Err(DatasetError::IndexError(0));
268 }
269 let len = self.len();
270 let index = {
271 use scirs2_core::random::{thread_rng, Rng};
272 let mut rng = thread_rng();
273 rng.gen_range(0..len)
274 };
275 self.get(index).await
276 }
277}
278
279pub struct DatasetIterator<T> {
281 current: usize,
282 total: usize,
283 _phantom: std::marker::PhantomData<T>,
284}
285
286impl<T> DatasetIterator<T> {
287 pub fn new(total: usize) -> Self {
288 Self {
289 current: 0,
290 total,
291 _phantom: std::marker::PhantomData,
292 }
293 }
294}
295
296impl<T> Iterator for DatasetIterator<T> {
297 type Item = usize;
298
299 fn next(&mut self) -> Option<Self::Item> {
300 if self.current < self.total {
301 let index = self.current;
302 self.current += 1;
303 Some(index)
304 } else {
305 None
306 }
307 }
308}
309
310pub trait DatasetSample: Clone + Send + Sync {
312 fn id(&self) -> &str;
314
315 fn text(&self) -> &str;
317
318 fn duration(&self) -> f32;
320
321 fn language(&self) -> &str;
323
324 fn speaker_id(&self) -> Option<&str>;
326}
327
328impl DatasetSample for crate::DatasetSample {
330 fn id(&self) -> &str {
331 &self.id
332 }
333
334 fn text(&self) -> &str {
335 &self.text
336 }
337
338 fn duration(&self) -> f32 {
339 self.audio.duration()
340 }
341
342 fn language(&self) -> &str {
343 self.language.as_str()
344 }
345
346 fn speaker_id(&self) -> Option<&str> {
347 self.speaker.as_ref().map(|s| s.id.as_str())
348 }
349}