1use torsh_core::error::Result;
6
7use super::types::{FeatureStats, Subset, TensorDataset};
8
9pub trait Dataset: Send + Sync {
13 type Item;
15 fn len(&self) -> usize;
17 fn is_empty(&self) -> bool {
19 self.len() == 0
20 }
21 fn get(&self, index: usize) -> Result<Self::Item>;
23}
24pub trait IterableDataset: Send + Sync {
29 type Item;
31 type Iter: Iterator<Item = Result<Self::Item>> + Send;
33 fn iter(&self) -> Self::Iter;
35}
36pub fn random_split<D>(
38 dataset: D,
39 lengths: &[usize],
40 generator: Option<u64>,
41) -> Result<Vec<Subset<D>>>
42where
43 D: Dataset + Clone,
44{
45 let total_length: usize = lengths.iter().sum();
46 if total_length != dataset.len() {
47 return Err(torsh_core::error::TorshError::InvalidArgument(format!(
48 "Sum of lengths {} does not equal dataset length {}",
49 total_length,
50 dataset.len()
51 )));
52 }
53 let mut indices: Vec<usize> = (0..dataset.len()).collect();
54 if let Some(_seed) = generator {
55 use scirs2_core::random::prelude::*;
56 use scirs2_core::random::seq::ScientificSliceRandom;
57 let mut rng = thread_rng();
58 indices.scientific_shuffle(&mut rng);
59 }
60 let mut subsets = Vec::with_capacity(lengths.len());
61 let mut offset = 0;
62 for &length in lengths {
63 let subset_indices = indices[offset..offset + length].to_vec();
64 subsets.push(Subset::new(dataset.clone(), subset_indices));
65 offset += length;
66 }
67 Ok(subsets)
68}
69pub trait StreamingDataset: Send + Sync {
74 type Item;
76 type Stream: Iterator<Item = Result<Self::Item>> + Send;
78 fn stream(&self) -> Self::Stream;
80 fn has_more(&self) -> bool {
82 true
83 }
84 fn reset(&self) -> Result<()> {
86 Ok(())
87 }
88}
89pub fn dataset_statistics(dataset: &TensorDataset<f32>) -> Result<Vec<FeatureStats>> {
94 if dataset.len() == 0 {
95 return Ok(Vec::new());
96 }
97 let first_item = dataset.get(0)?;
98 if first_item.is_empty() {
99 return Ok(Vec::new());
100 }
101 let features_tensor = &first_item[0];
102 let n_features = features_tensor.numel();
103 let mut feature_data: Vec<Vec<f32>> = vec![Vec::with_capacity(dataset.len()); n_features];
104 for i in 0..dataset.len() {
105 let item = dataset.get(i)?;
106 if item.is_empty() {
107 continue;
108 }
109 let features = &item[0];
110 for feat_idx in 0..n_features.min(features.numel()) {
111 if let Ok(indices) = torsh_tensor::Tensor::from_vec(vec![feat_idx as i64], &[1]) {
112 if let Ok(value_tensor) = features.index_select(0, &indices) {
113 if let Ok(value) = value_tensor.item() {
114 feature_data[feat_idx].push(value);
115 }
116 }
117 }
118 }
119 }
120 Ok(feature_data
121 .iter()
122 .map(|data| FeatureStats::from_data(data))
123 .collect())
124}
125pub fn stratified_split<D>(
130 dataset: D,
131 labels: &[usize],
132 train_ratio: f32,
133 val_ratio: Option<f32>,
134 random_seed: Option<u64>,
135) -> Result<(Subset<D>, Subset<D>, Option<Subset<D>>)>
136where
137 D: Dataset + Clone,
138{
139 if train_ratio <= 0.0 || train_ratio >= 1.0 {
140 return Err(torsh_core::error::TorshError::InvalidArgument(
141 "train_ratio must be between 0 and 1".to_string(),
142 ));
143 }
144 let has_val = val_ratio.is_some();
145 let val_r = val_ratio.unwrap_or(0.0);
146 if has_val && (train_ratio + val_r >= 1.0) {
147 return Err(torsh_core::error::TorshError::InvalidArgument(
148 "train_ratio + val_ratio must be less than 1".to_string(),
149 ));
150 }
151 if labels.len() != dataset.len() {
152 return Err(torsh_core::error::TorshError::InvalidArgument(
153 "labels length must equal dataset length".to_string(),
154 ));
155 }
156 let mut class_indices: std::collections::HashMap<usize, Vec<usize>> =
157 std::collections::HashMap::new();
158 for (idx, &label) in labels.iter().enumerate() {
159 class_indices.entry(label).or_default().push(idx);
160 }
161 use scirs2_core::random::prelude::*;
162 use scirs2_core::random::seq::ScientificSliceRandom;
163 use scirs2_core::random::SeedableRng;
164 let mut rng = if let Some(seed) = random_seed {
165 StdRng::seed_from_u64(seed)
166 } else {
167 use std::time::SystemTime;
168 let seed = SystemTime::now()
169 .duration_since(SystemTime::UNIX_EPOCH)
170 .expect("time should be after UNIX_EPOCH")
171 .as_secs();
172 StdRng::seed_from_u64(seed)
173 };
174 let mut train_indices = Vec::new();
175 let mut val_indices = Vec::new();
176 let mut test_indices = Vec::new();
177 for (_class, mut indices) in class_indices {
178 indices.scientific_shuffle(&mut rng);
179 let n_train = (indices.len() as f32 * train_ratio).round() as usize;
180 let n_val = if has_val {
181 (indices.len() as f32 * val_r).round() as usize
182 } else {
183 0
184 };
185 train_indices.extend_from_slice(&indices[0..n_train]);
186 if has_val {
187 val_indices.extend_from_slice(&indices[n_train..n_train + n_val]);
188 test_indices.extend_from_slice(&indices[n_train + n_val..]);
189 } else {
190 test_indices.extend_from_slice(&indices[n_train..]);
191 }
192 }
193 let train_subset = Subset::new(dataset.clone(), train_indices);
194 let test_subset = Subset::new(dataset.clone(), test_indices);
195 let val_subset = if has_val {
196 Some(Subset::new(dataset, val_indices))
197 } else {
198 None
199 };
200 Ok((train_subset, test_subset, val_subset))
201}
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::dataset::types::*;
206 use torsh_tensor::creation::*;
207 #[test]
208 fn test_tensor_dataset() {
209 let data = ones::<f32>(&[10, 3]).unwrap();
210 let labels = zeros::<f32>(&[10]).unwrap();
211 let dataset = TensorDataset::from_tensors(vec![data, labels]);
212 assert_eq!(dataset.len(), 10);
213 let item = dataset.get(0).unwrap();
214 assert_eq!(item.len(), 2);
215 }
216 #[test]
217 fn test_concat_dataset() {
218 let ds1 = TensorDataset::from_tensor(ones::<f32>(&[5, 3]).unwrap());
219 let ds2 = TensorDataset::from_tensor(zeros::<f32>(&[3, 3]).unwrap());
220 let concat = ConcatDataset::new(vec![ds1, ds2]);
221 assert_eq!(concat.len(), 8);
222 assert_eq!(concat.dataset_idx(0), Some((0, 0)));
223 assert_eq!(concat.dataset_idx(4), Some((0, 4)));
224 assert_eq!(concat.dataset_idx(5), Some((1, 0)));
225 assert_eq!(concat.dataset_idx(7), Some((1, 2)));
226 assert_eq!(concat.dataset_idx(8), None);
227 }
228 #[test]
229 fn test_subset() {
230 let dataset = TensorDataset::from_tensor(ones::<f32>(&[10, 3]).unwrap());
231 let subset = Subset::new(dataset, vec![0, 2, 4, 6, 8]);
232 assert_eq!(subset.len(), 5);
233 assert!(subset.get(0).is_ok());
234 assert!(subset.get(5).is_err());
235 }
236 #[derive(Clone)]
237 struct SimpleIterableDataset {
238 data: Vec<i32>,
239 }
240 impl IterableDataset for SimpleIterableDataset {
241 type Item = i32;
242 type Iter = std::iter::Map<std::vec::IntoIter<i32>, fn(i32) -> Result<i32>>;
243 fn iter(&self) -> Self::Iter {
244 self.data.clone().into_iter().map(|x| Ok(x) as Result<i32>)
245 }
246 }
247 #[test]
248 fn test_chain_dataset() {
249 let ds1 = SimpleIterableDataset {
250 data: vec![1, 2, 3],
251 };
252 let ds2 = SimpleIterableDataset {
253 data: vec![4, 5, 6],
254 };
255 let ds3 = SimpleIterableDataset {
256 data: vec![7, 8, 9],
257 };
258 let chain = ChainDataset::new(vec![ds1, ds2, ds3]);
259 let collected: Result<Vec<_>> = chain.iter().collect();
260 assert!(collected.is_ok());
261 let values = collected.unwrap();
262 assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
263 }
264 #[test]
265 fn test_chain_dataset_empty() {
266 let chain: ChainDataset<SimpleIterableDataset> = ChainDataset::new(vec![]);
267 let collected: Result<Vec<_>> = chain.iter().collect();
268 assert!(collected.is_ok());
269 let values = collected.unwrap();
270 assert_eq!(values, Vec::<i32>::new());
271 }
272 #[test]
273 fn test_chain_dataset_with_empty_datasets() {
274 let ds1 = SimpleIterableDataset { data: vec![] };
275 let ds2 = SimpleIterableDataset {
276 data: vec![1, 2, 3],
277 };
278 let ds3 = SimpleIterableDataset { data: vec![] };
279 let ds4 = SimpleIterableDataset { data: vec![4, 5] };
280 let chain = ChainDataset::new(vec![ds1, ds2, ds3, ds4]);
281 let collected: Result<Vec<_>> = chain.iter().collect();
282 assert!(collected.is_ok());
283 let values = collected.unwrap();
284 assert_eq!(values, vec![1, 2, 3, 4, 5]);
285 }
286 #[test]
287 fn test_infinite_dataset() {
288 use std::sync::atomic::{AtomicUsize, Ordering};
289 use std::sync::Arc;
290 let counter = Arc::new(AtomicUsize::new(0));
291 let counter_clone = counter.clone();
292 let dataset = InfiniteDataset::new(move || {
293 let val = counter_clone.fetch_add(1, Ordering::SeqCst);
294 Ok(val)
295 });
296 assert!(dataset.has_more());
297 let mut stream = dataset.stream();
298 assert_eq!(stream.next().unwrap().unwrap(), 0);
299 assert_eq!(stream.next().unwrap().unwrap(), 1);
300 assert_eq!(stream.next().unwrap().unwrap(), 2);
301 }
302 #[test]
303 fn test_buffered_streaming_dataset() {
304 let dataset = InfiniteDataset::new(|| Ok(42i32));
305 let buffered = BufferedStreamingDataset::new(dataset, 5).with_prefetch(true);
306 assert!(buffered.has_more());
307 let mut stream = buffered.stream();
308 for _ in 0..10 {
309 assert_eq!(stream.next().unwrap().unwrap(), 42);
310 }
311 }
312 #[test]
313 fn test_data_pipeline() {
314 let pipeline = DataPipeline::new()
315 .add_transform(|x: i32| Ok(x * 2))
316 .add_transform(|x: i32| Ok(x + 1));
317 let result = pipeline.apply(5).unwrap();
318 assert_eq!(result, 11);
319 }
320 #[test]
321 fn test_pipeline_streaming_dataset() {
322 let dataset = InfiniteDataset::new(|| Ok(5i32));
323 let pipeline = DataPipeline::new()
324 .add_transform(|x: i32| Ok(x * 2))
325 .add_transform(|x: i32| Ok(x + 1));
326 let pipeline_dataset = PipelineStreamingDataset::new(dataset, pipeline);
327 assert!(pipeline_dataset.has_more());
328 let mut stream = pipeline_dataset.stream();
329 for _ in 0..5 {
330 assert_eq!(stream.next().unwrap().unwrap(), 11);
331 }
332 }
333 #[test]
334 fn test_real_time_dataset() {
335 let (dataset, _receiver) = RealTimeDataset::<i32>::new();
336 let sender = dataset.sender();
337 {
338 let sender_lock = sender.lock().expect("lock should not be poisoned");
339 sender_lock.send(1).unwrap();
340 sender_lock.send(2).unwrap();
341 sender_lock.send(3).unwrap();
342 }
343 assert!(dataset.has_more());
344 let _stream = dataset.stream();
345 }
346 #[test]
347 fn test_dataset_to_streaming() {
348 let tensor = ones::<f32>(&[5, 3]).unwrap();
349 let dataset = TensorDataset::from_tensor(tensor);
350 let streaming = DatasetToStreaming::new(dataset);
351 assert!(streaming.has_more());
352 let stream = streaming.stream();
353 let mut count = 0;
354 for result in stream {
355 assert!(result.is_ok());
356 count += 1;
357 if count >= 5 {
358 break;
359 }
360 }
361 assert_eq!(count, 5);
362 }
363 #[test]
364 fn test_dataset_to_streaming_repeat() {
365 let tensor = ones::<f32>(&[3, 2]).unwrap();
366 let dataset = TensorDataset::from_tensor(tensor);
367 let streaming = DatasetToStreaming::new(dataset).repeat();
368 assert!(streaming.has_more());
369 let stream = streaming.stream();
370 let mut count = 0;
371 for result in stream {
372 assert!(result.is_ok());
373 count += 1;
374 if count >= 10 {
375 break;
376 }
377 }
378 assert_eq!(count, 10);
379 }
380 #[test]
381 fn test_streaming_dataset_reset() {
382 let dataset = InfiniteDataset::new(|| Ok(42i32));
383 let buffered = BufferedStreamingDataset::new(dataset, 3);
384 assert!(buffered.reset().is_ok());
385 }
386 #[test]
387 #[cfg(feature = "std")]
388 fn test_dataset_profiler_sequential_access() {
389 use std::thread;
390 use std::time::Duration;
391 let tensor = ones::<f32>(&[10, 2]).unwrap();
392 let dataset = TensorDataset::from_tensor(tensor);
393 let profiled = ProfiledDataset::new(dataset);
394 for i in 0..10 {
395 let _ = profiled.get(i).unwrap();
396 thread::sleep(Duration::from_micros(100));
397 }
398 let stats = profiled.stats();
399 assert_eq!(stats.total_accesses, 10);
400 assert_eq!(stats.sequential_accesses, 9);
401 assert!(stats.sequential_ratio > 0.8);
402 assert!(stats.avg_access_time_us > 0.0);
403 assert!(stats.throughput_accesses_per_sec > 0.0);
404 }
405 #[test]
406 #[cfg(feature = "std")]
407 fn test_dataset_profiler_random_access() {
408 let tensor = ones::<f32>(&[10, 2]).unwrap();
409 let dataset = TensorDataset::from_tensor(tensor);
410 let profiled = ProfiledDataset::new(dataset);
411 let indices = [0, 5, 2, 8, 1];
412 for &i in &indices {
413 let _ = profiled.get(i).unwrap();
414 }
415 let stats = profiled.stats();
416 assert_eq!(stats.total_accesses, 5);
417 assert_eq!(stats.sequential_accesses, 0);
418 assert_eq!(stats.sequential_ratio, 0.0);
419 }
420 #[test]
421 #[cfg(feature = "std")]
422 fn test_dataset_profiler_hints() {
423 let tensor = ones::<f32>(&[100, 2]).unwrap();
424 let dataset = TensorDataset::from_tensor(tensor);
425 let profiled = ProfiledDataset::new(dataset);
426 for i in 0..20 {
427 let _ = profiled.get(i).unwrap();
428 }
429 let hints = profiled.hints();
430 assert!(!hints.is_empty());
431 assert!(hints
432 .iter()
433 .any(|h| h.contains("sequential") || h.contains("good")));
434 }
435 #[test]
436 #[cfg(feature = "std")]
437 fn test_dataset_profiler_reset() {
438 let tensor = ones::<f32>(&[10, 2]).unwrap();
439 let dataset = TensorDataset::from_tensor(tensor);
440 let profiled = ProfiledDataset::new(dataset);
441 for i in 0..5 {
442 let _ = profiled.get(i).unwrap();
443 }
444 assert_eq!(profiled.stats().total_accesses, 5);
445 profiled.profiler().reset();
446 assert_eq!(profiled.stats().total_accesses, 0);
447 }
448 #[test]
449 #[cfg(feature = "std")]
450 fn test_dataset_profiler_display() {
451 let tensor = ones::<f32>(&[10, 2]).unwrap();
452 let dataset = TensorDataset::from_tensor(tensor);
453 let profiled = ProfiledDataset::new(dataset);
454 for i in 0..5 {
455 let _ = profiled.get(i).unwrap();
456 }
457 let stats_string = format!("{}", profiled.stats());
458 assert!(stats_string.contains("Dataset Profile Statistics"));
459 assert!(stats_string.contains("Total Accesses: 5"));
460 }
461 #[test]
462 fn test_feature_stats() {
463 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
464 let stats = FeatureStats::from_data(&data);
465 assert_eq!(stats.count, 5);
466 assert_eq!(stats.mean, 3.0);
467 assert_eq!(stats.min, 1.0);
468 assert_eq!(stats.max, 5.0);
469 assert!((stats.std - 1.4142).abs() < 0.01);
470 }
471 #[test]
472 fn test_feature_stats_empty() {
473 let data: Vec<f32> = vec![];
474 let stats = FeatureStats::from_data(&data);
475 assert_eq!(stats.count, 0);
476 assert_eq!(stats.mean, 0.0);
477 assert_eq!(stats.std, 0.0);
478 }
479 #[test]
480 fn test_dataset_statistics() {
481 let data = torsh_tensor::creation::randn::<f32>(&[10, 3]).unwrap();
482 let dataset = TensorDataset::from_tensor(data);
483 let stats = dataset_statistics(&dataset).unwrap();
484 assert_eq!(stats.len(), 3);
485 for stat in &stats {
486 assert_eq!(stat.count, 10);
487 assert!(stat.min <= stat.mean);
488 assert!(stat.mean <= stat.max);
489 assert!(stat.std >= 0.0);
490 }
491 }
492 #[test]
493 fn test_dataset_statistics_empty() {
494 let data = torsh_tensor::creation::zeros::<f32>(&[0, 3]).unwrap();
495 let dataset = TensorDataset::from_tensor(data);
496 let stats = dataset_statistics(&dataset).unwrap();
497 assert_eq!(stats.len(), 0);
498 }
499 #[test]
500 fn test_kfold_basic() {
501 let kfold = KFold::new(5, false, Some(42));
502 let folds = kfold.split(100);
503 assert_eq!(folds.len(), 5);
504 for (fold_idx, (train_indices, val_indices)) in folds.iter().enumerate() {
505 assert_eq!(val_indices.len(), 20);
506 assert_eq!(train_indices.len(), 80);
507 for &val_idx in val_indices {
508 assert!(!train_indices.contains(&val_idx));
509 }
510 for &idx in train_indices.iter().chain(val_indices.iter()) {
511 assert!(idx < 100);
512 }
513 println!(
514 "Fold {}: train={}, val={}",
515 fold_idx,
516 train_indices.len(),
517 val_indices.len()
518 );
519 }
520 }
521 #[test]
522 fn test_kfold_shuffle() {
523 let kfold_shuffled = KFold::new(3, true, Some(42));
524 let kfold_unshuffled = KFold::new(3, false, None);
525 let folds_shuffled = kfold_shuffled.split(30);
526 let folds_unshuffled = kfold_unshuffled.split(30);
527 assert_eq!(folds_shuffled.len(), folds_unshuffled.len());
528 let shuffled_val = &folds_shuffled[0].1;
529 let unshuffled_val = &folds_unshuffled[0].1;
530 assert_eq!(unshuffled_val, &(0..10).collect::<Vec<_>>());
531 assert_ne!(shuffled_val, unshuffled_val);
532 }
533 #[test]
534 fn test_kfold_uneven_split() {
535 let kfold = KFold::new(3, false, None);
536 let folds = kfold.split(10);
537 assert_eq!(folds.len(), 3);
538 assert_eq!(folds[0].1.len(), 3);
539 assert_eq!(folds[1].1.len(), 3);
540 assert_eq!(folds[2].1.len(), 4);
541 let all_val_samples: usize = folds.iter().map(|(_, val)| val.len()).sum();
542 assert_eq!(all_val_samples, 10);
543 }
544 #[test]
545 #[should_panic(expected = "n_splits must be at least 2")]
546 fn test_kfold_invalid_splits() {
547 KFold::new(1, false, None);
548 }
549 #[test]
550 fn test_stratified_split_binary() {
551 let data = ones::<f32>(&[100, 5]).unwrap();
552 let dataset = TensorDataset::from_tensor(data);
553 let labels: Vec<usize> = (0..100).map(|i| if i < 50 { 0 } else { 1 }).collect();
554 let (train, test, val) =
555 stratified_split(dataset, &labels, 0.6, Some(0.2), Some(42)).unwrap();
556 assert_eq!(train.len(), 60);
557 assert!(val.is_some());
558 assert_eq!(val.as_ref().unwrap().len(), 20);
559 assert_eq!(test.len(), 20);
560 println!(
561 "Stratified split: train={}, val={}, test={}",
562 train.len(),
563 val.as_ref().unwrap().len(),
564 test.len()
565 );
566 }
567 #[test]
568 fn test_stratified_split_multi_class() {
569 let data = ones::<f32>(&[90, 5]).unwrap();
570 let dataset = TensorDataset::from_tensor(data);
571 let labels: Vec<usize> = (0..90).map(|i| i / 30).collect();
572 let (train, test, _val) = stratified_split(dataset, &labels, 0.7, None, Some(42)).unwrap();
573 assert_eq!(train.len(), 63);
574 assert_eq!(test.len(), 27);
575 println!(
576 "Multi-class split: train={}, test={}",
577 train.len(),
578 test.len()
579 );
580 }
581 #[test]
582 fn test_stratified_split_no_val() {
583 let data = ones::<f32>(&[50, 3]).unwrap();
584 let dataset = TensorDataset::from_tensor(data);
585 let labels: Vec<usize> = (0..50).map(|i| i % 2).collect();
586 let (train, test, val) = stratified_split(dataset, &labels, 0.8, None, Some(42)).unwrap();
587 assert_eq!(train.len(), 40);
588 assert_eq!(test.len(), 10);
589 assert!(val.is_none());
590 }
591 #[test]
592 fn test_stratified_split_invalid_ratio() {
593 let data = ones::<f32>(&[50, 3]).unwrap();
594 let dataset = TensorDataset::from_tensor(data);
595 let labels: Vec<usize> = (0..50).map(|i| i % 2).collect();
596 let result = stratified_split(dataset.clone(), &labels, 1.0, None, None);
597 assert!(result.is_err());
598 let result = stratified_split(dataset, &labels, 0.7, Some(0.4), None);
599 assert!(result.is_err());
600 }
601 #[test]
602 fn test_stratified_split_mismatched_labels() {
603 let data = ones::<f32>(&[50, 3]).unwrap();
604 let dataset = TensorDataset::from_tensor(data);
605 let labels: Vec<usize> = vec![0, 1];
606 let result = stratified_split(dataset, &labels, 0.8, None, None);
607 assert!(result.is_err());
608 }
609 #[test]
610 fn test_kfold_reproducibility() {
611 let kfold1 = KFold::new(5, true, Some(42));
612 let kfold2 = KFold::new(5, true, Some(42));
613 let folds1 = kfold1.split(50);
614 let folds2 = kfold2.split(50);
615 for (f1, f2) in folds1.iter().zip(folds2.iter()) {
616 assert_eq!(f1.0, f2.0);
617 assert_eq!(f1.1, f2.1);
618 }
619 }
620 #[test]
621 fn test_stratified_split_reproducibility() {
622 let data = ones::<f32>(&[100, 5]).unwrap();
623 let labels: Vec<usize> = (0..100).map(|i| i % 3).collect();
624 let (train1, test1, _) = stratified_split(
625 TensorDataset::from_tensor(data.clone()),
626 &labels,
627 0.7,
628 None,
629 Some(42),
630 )
631 .unwrap();
632 let (train2, test2, _) = stratified_split(
633 TensorDataset::from_tensor(data),
634 &labels,
635 0.7,
636 None,
637 Some(42),
638 )
639 .unwrap();
640 assert_eq!(train1.len(), train2.len());
641 assert_eq!(test1.len(), test2.len());
642 }
643}