Skip to main content

torsh_data/dataset/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use torsh_core::error::Result;
6
7use super::types::{FeatureStats, Subset, TensorDataset};
8
9/// A map-style dataset
10///
11/// Represents a dataset that supports random access with a known length.
12pub trait Dataset: Send + Sync {
13    /// The type of items returned by the dataset
14    type Item;
15    /// Returns the number of items in the dataset
16    fn len(&self) -> usize;
17    /// Returns true if the dataset is empty
18    fn is_empty(&self) -> bool {
19        self.len() == 0
20    }
21    /// Get a single item from the dataset
22    fn get(&self, index: usize) -> Result<Self::Item>;
23}
24/// An iterable-style dataset
25///
26/// Represents a dataset that can be iterated over but may not support
27/// random access or have a known length.
28pub trait IterableDataset: Send + Sync {
29    /// The type of items returned by the dataset
30    type Item;
31    /// The iterator type
32    type Iter: Iterator<Item = Result<Self::Item>> + Send;
33    /// Create an iterator over the dataset
34    fn iter(&self) -> Self::Iter;
35}
36/// Split a dataset into train and validation sets
37pub fn random_split<D>(
38    dataset: D,
39    lengths: &[usize],
40    generator: Option<u64>,
41) -> Result<Vec<Subset<D>>>
42where
43    D: Dataset + Clone,
44{
45    let total_length: usize = lengths.iter().sum();
46    if total_length != dataset.len() {
47        return Err(torsh_core::error::TorshError::InvalidArgument(format!(
48            "Sum of lengths {} does not equal dataset length {}",
49            total_length,
50            dataset.len()
51        )));
52    }
53    let mut indices: Vec<usize> = (0..dataset.len()).collect();
54    if let Some(_seed) = generator {
55        use scirs2_core::random::prelude::*;
56        use scirs2_core::random::seq::ScientificSliceRandom;
57        let mut rng = thread_rng();
58        indices.scientific_shuffle(&mut rng);
59    }
60    let mut subsets = Vec::with_capacity(lengths.len());
61    let mut offset = 0;
62    for &length in lengths {
63        let subset_indices = indices[offset..offset + length].to_vec();
64        subsets.push(Subset::new(dataset.clone(), subset_indices));
65        offset += length;
66    }
67    Ok(subsets)
68}
69/// Streaming dataset interface for real-time data processing
70///
71/// This trait represents datasets that can continuously produce data,
72/// potentially from real-time sources or infinite data generators.
73pub trait StreamingDataset: Send + Sync {
74    /// The type of items returned by the dataset
75    type Item;
76    /// The streaming iterator type
77    type Stream: Iterator<Item = Result<Self::Item>> + Send;
78    /// Create a stream over the dataset
79    fn stream(&self) -> Self::Stream;
80    /// Check if the stream has more data available
81    fn has_more(&self) -> bool {
82        true
83    }
84    /// Reset the stream to the beginning (if supported)
85    fn reset(&self) -> Result<()> {
86        Ok(())
87    }
88}
89/// Compute statistics for a tensor dataset
90///
91/// Returns feature statistics for each feature dimension in the dataset.
92/// Only works with `TensorDataset<f32>` where the first tensor contains the features.
93pub fn dataset_statistics(dataset: &TensorDataset<f32>) -> Result<Vec<FeatureStats>> {
94    if dataset.len() == 0 {
95        return Ok(Vec::new());
96    }
97    let first_item = dataset.get(0)?;
98    if first_item.is_empty() {
99        return Ok(Vec::new());
100    }
101    let features_tensor = &first_item[0];
102    let n_features = features_tensor.numel();
103    let mut feature_data: Vec<Vec<f32>> = vec![Vec::with_capacity(dataset.len()); n_features];
104    for i in 0..dataset.len() {
105        let item = dataset.get(i)?;
106        if item.is_empty() {
107            continue;
108        }
109        let features = &item[0];
110        for feat_idx in 0..n_features.min(features.numel()) {
111            if let Ok(indices) = torsh_tensor::Tensor::from_vec(vec![feat_idx as i64], &[1]) {
112                if let Ok(value_tensor) = features.index_select(0, &indices) {
113                    if let Ok(value) = value_tensor.item() {
114                        feature_data[feat_idx].push(value);
115                    }
116                }
117            }
118        }
119    }
120    Ok(feature_data
121        .iter()
122        .map(|data| FeatureStats::from_data(data))
123        .collect())
124}
125/// Stratified split that preserves class distribution
126///
127/// Splits data into train/val/test sets while maintaining the same class distribution
128/// in each split as in the original dataset.
129pub fn stratified_split<D>(
130    dataset: D,
131    labels: &[usize],
132    train_ratio: f32,
133    val_ratio: Option<f32>,
134    random_seed: Option<u64>,
135) -> Result<(Subset<D>, Subset<D>, Option<Subset<D>>)>
136where
137    D: Dataset + Clone,
138{
139    if train_ratio <= 0.0 || train_ratio >= 1.0 {
140        return Err(torsh_core::error::TorshError::InvalidArgument(
141            "train_ratio must be between 0 and 1".to_string(),
142        ));
143    }
144    let has_val = val_ratio.is_some();
145    let val_r = val_ratio.unwrap_or(0.0);
146    if has_val && (train_ratio + val_r >= 1.0) {
147        return Err(torsh_core::error::TorshError::InvalidArgument(
148            "train_ratio + val_ratio must be less than 1".to_string(),
149        ));
150    }
151    if labels.len() != dataset.len() {
152        return Err(torsh_core::error::TorshError::InvalidArgument(
153            "labels length must equal dataset length".to_string(),
154        ));
155    }
156    let mut class_indices: std::collections::HashMap<usize, Vec<usize>> =
157        std::collections::HashMap::new();
158    for (idx, &label) in labels.iter().enumerate() {
159        class_indices.entry(label).or_default().push(idx);
160    }
161    use scirs2_core::random::prelude::*;
162    use scirs2_core::random::seq::ScientificSliceRandom;
163    use scirs2_core::random::SeedableRng;
164    let mut rng = if let Some(seed) = random_seed {
165        StdRng::seed_from_u64(seed)
166    } else {
167        use std::time::SystemTime;
168        let seed = SystemTime::now()
169            .duration_since(SystemTime::UNIX_EPOCH)
170            .expect("time should be after UNIX_EPOCH")
171            .as_secs();
172        StdRng::seed_from_u64(seed)
173    };
174    let mut train_indices = Vec::new();
175    let mut val_indices = Vec::new();
176    let mut test_indices = Vec::new();
177    for (_class, mut indices) in class_indices {
178        indices.scientific_shuffle(&mut rng);
179        let n_train = (indices.len() as f32 * train_ratio).round() as usize;
180        let n_val = if has_val {
181            (indices.len() as f32 * val_r).round() as usize
182        } else {
183            0
184        };
185        train_indices.extend_from_slice(&indices[0..n_train]);
186        if has_val {
187            val_indices.extend_from_slice(&indices[n_train..n_train + n_val]);
188            test_indices.extend_from_slice(&indices[n_train + n_val..]);
189        } else {
190            test_indices.extend_from_slice(&indices[n_train..]);
191        }
192    }
193    let train_subset = Subset::new(dataset.clone(), train_indices);
194    let test_subset = Subset::new(dataset.clone(), test_indices);
195    let val_subset = if has_val {
196        Some(Subset::new(dataset, val_indices))
197    } else {
198        None
199    };
200    Ok((train_subset, test_subset, val_subset))
201}
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use crate::dataset::types::*;
206    use torsh_tensor::creation::*;
207    #[test]
208    fn test_tensor_dataset() {
209        let data = ones::<f32>(&[10, 3]).unwrap();
210        let labels = zeros::<f32>(&[10]).unwrap();
211        let dataset = TensorDataset::from_tensors(vec![data, labels]);
212        assert_eq!(dataset.len(), 10);
213        let item = dataset.get(0).unwrap();
214        assert_eq!(item.len(), 2);
215    }
216    #[test]
217    fn test_concat_dataset() {
218        let ds1 = TensorDataset::from_tensor(ones::<f32>(&[5, 3]).unwrap());
219        let ds2 = TensorDataset::from_tensor(zeros::<f32>(&[3, 3]).unwrap());
220        let concat = ConcatDataset::new(vec![ds1, ds2]);
221        assert_eq!(concat.len(), 8);
222        assert_eq!(concat.dataset_idx(0), Some((0, 0)));
223        assert_eq!(concat.dataset_idx(4), Some((0, 4)));
224        assert_eq!(concat.dataset_idx(5), Some((1, 0)));
225        assert_eq!(concat.dataset_idx(7), Some((1, 2)));
226        assert_eq!(concat.dataset_idx(8), None);
227    }
228    #[test]
229    fn test_subset() {
230        let dataset = TensorDataset::from_tensor(ones::<f32>(&[10, 3]).unwrap());
231        let subset = Subset::new(dataset, vec![0, 2, 4, 6, 8]);
232        assert_eq!(subset.len(), 5);
233        assert!(subset.get(0).is_ok());
234        assert!(subset.get(5).is_err());
235    }
236    #[derive(Clone)]
237    struct SimpleIterableDataset {
238        data: Vec<i32>,
239    }
240    impl IterableDataset for SimpleIterableDataset {
241        type Item = i32;
242        type Iter = std::iter::Map<std::vec::IntoIter<i32>, fn(i32) -> Result<i32>>;
243        fn iter(&self) -> Self::Iter {
244            self.data.clone().into_iter().map(|x| Ok(x) as Result<i32>)
245        }
246    }
247    #[test]
248    fn test_chain_dataset() {
249        let ds1 = SimpleIterableDataset {
250            data: vec![1, 2, 3],
251        };
252        let ds2 = SimpleIterableDataset {
253            data: vec![4, 5, 6],
254        };
255        let ds3 = SimpleIterableDataset {
256            data: vec![7, 8, 9],
257        };
258        let chain = ChainDataset::new(vec![ds1, ds2, ds3]);
259        let collected: Result<Vec<_>> = chain.iter().collect();
260        assert!(collected.is_ok());
261        let values = collected.unwrap();
262        assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
263    }
264    #[test]
265    fn test_chain_dataset_empty() {
266        let chain: ChainDataset<SimpleIterableDataset> = ChainDataset::new(vec![]);
267        let collected: Result<Vec<_>> = chain.iter().collect();
268        assert!(collected.is_ok());
269        let values = collected.unwrap();
270        assert_eq!(values, Vec::<i32>::new());
271    }
272    #[test]
273    fn test_chain_dataset_with_empty_datasets() {
274        let ds1 = SimpleIterableDataset { data: vec![] };
275        let ds2 = SimpleIterableDataset {
276            data: vec![1, 2, 3],
277        };
278        let ds3 = SimpleIterableDataset { data: vec![] };
279        let ds4 = SimpleIterableDataset { data: vec![4, 5] };
280        let chain = ChainDataset::new(vec![ds1, ds2, ds3, ds4]);
281        let collected: Result<Vec<_>> = chain.iter().collect();
282        assert!(collected.is_ok());
283        let values = collected.unwrap();
284        assert_eq!(values, vec![1, 2, 3, 4, 5]);
285    }
286    #[test]
287    fn test_infinite_dataset() {
288        use std::sync::atomic::{AtomicUsize, Ordering};
289        use std::sync::Arc;
290        let counter = Arc::new(AtomicUsize::new(0));
291        let counter_clone = counter.clone();
292        let dataset = InfiniteDataset::new(move || {
293            let val = counter_clone.fetch_add(1, Ordering::SeqCst);
294            Ok(val)
295        });
296        assert!(dataset.has_more());
297        let mut stream = dataset.stream();
298        assert_eq!(stream.next().unwrap().unwrap(), 0);
299        assert_eq!(stream.next().unwrap().unwrap(), 1);
300        assert_eq!(stream.next().unwrap().unwrap(), 2);
301    }
302    #[test]
303    fn test_buffered_streaming_dataset() {
304        let dataset = InfiniteDataset::new(|| Ok(42i32));
305        let buffered = BufferedStreamingDataset::new(dataset, 5).with_prefetch(true);
306        assert!(buffered.has_more());
307        let mut stream = buffered.stream();
308        for _ in 0..10 {
309            assert_eq!(stream.next().unwrap().unwrap(), 42);
310        }
311    }
312    #[test]
313    fn test_data_pipeline() {
314        let pipeline = DataPipeline::new()
315            .add_transform(|x: i32| Ok(x * 2))
316            .add_transform(|x: i32| Ok(x + 1));
317        let result = pipeline.apply(5).unwrap();
318        assert_eq!(result, 11);
319    }
320    #[test]
321    fn test_pipeline_streaming_dataset() {
322        let dataset = InfiniteDataset::new(|| Ok(5i32));
323        let pipeline = DataPipeline::new()
324            .add_transform(|x: i32| Ok(x * 2))
325            .add_transform(|x: i32| Ok(x + 1));
326        let pipeline_dataset = PipelineStreamingDataset::new(dataset, pipeline);
327        assert!(pipeline_dataset.has_more());
328        let mut stream = pipeline_dataset.stream();
329        for _ in 0..5 {
330            assert_eq!(stream.next().unwrap().unwrap(), 11);
331        }
332    }
333    #[test]
334    fn test_real_time_dataset() {
335        let (dataset, _receiver) = RealTimeDataset::<i32>::new();
336        let sender = dataset.sender();
337        {
338            let sender_lock = sender.lock().expect("lock should not be poisoned");
339            sender_lock.send(1).unwrap();
340            sender_lock.send(2).unwrap();
341            sender_lock.send(3).unwrap();
342        }
343        assert!(dataset.has_more());
344        let _stream = dataset.stream();
345    }
346    #[test]
347    fn test_dataset_to_streaming() {
348        let tensor = ones::<f32>(&[5, 3]).unwrap();
349        let dataset = TensorDataset::from_tensor(tensor);
350        let streaming = DatasetToStreaming::new(dataset);
351        assert!(streaming.has_more());
352        let stream = streaming.stream();
353        let mut count = 0;
354        for result in stream {
355            assert!(result.is_ok());
356            count += 1;
357            if count >= 5 {
358                break;
359            }
360        }
361        assert_eq!(count, 5);
362    }
363    #[test]
364    fn test_dataset_to_streaming_repeat() {
365        let tensor = ones::<f32>(&[3, 2]).unwrap();
366        let dataset = TensorDataset::from_tensor(tensor);
367        let streaming = DatasetToStreaming::new(dataset).repeat();
368        assert!(streaming.has_more());
369        let stream = streaming.stream();
370        let mut count = 0;
371        for result in stream {
372            assert!(result.is_ok());
373            count += 1;
374            if count >= 10 {
375                break;
376            }
377        }
378        assert_eq!(count, 10);
379    }
380    #[test]
381    fn test_streaming_dataset_reset() {
382        let dataset = InfiniteDataset::new(|| Ok(42i32));
383        let buffered = BufferedStreamingDataset::new(dataset, 3);
384        assert!(buffered.reset().is_ok());
385    }
386    #[test]
387    #[cfg(feature = "std")]
388    fn test_dataset_profiler_sequential_access() {
389        use std::thread;
390        use std::time::Duration;
391        let tensor = ones::<f32>(&[10, 2]).unwrap();
392        let dataset = TensorDataset::from_tensor(tensor);
393        let profiled = ProfiledDataset::new(dataset);
394        for i in 0..10 {
395            let _ = profiled.get(i).unwrap();
396            thread::sleep(Duration::from_micros(100));
397        }
398        let stats = profiled.stats();
399        assert_eq!(stats.total_accesses, 10);
400        assert_eq!(stats.sequential_accesses, 9);
401        assert!(stats.sequential_ratio > 0.8);
402        assert!(stats.avg_access_time_us > 0.0);
403        assert!(stats.throughput_accesses_per_sec > 0.0);
404    }
405    #[test]
406    #[cfg(feature = "std")]
407    fn test_dataset_profiler_random_access() {
408        let tensor = ones::<f32>(&[10, 2]).unwrap();
409        let dataset = TensorDataset::from_tensor(tensor);
410        let profiled = ProfiledDataset::new(dataset);
411        let indices = [0, 5, 2, 8, 1];
412        for &i in &indices {
413            let _ = profiled.get(i).unwrap();
414        }
415        let stats = profiled.stats();
416        assert_eq!(stats.total_accesses, 5);
417        assert_eq!(stats.sequential_accesses, 0);
418        assert_eq!(stats.sequential_ratio, 0.0);
419    }
420    #[test]
421    #[cfg(feature = "std")]
422    fn test_dataset_profiler_hints() {
423        let tensor = ones::<f32>(&[100, 2]).unwrap();
424        let dataset = TensorDataset::from_tensor(tensor);
425        let profiled = ProfiledDataset::new(dataset);
426        for i in 0..20 {
427            let _ = profiled.get(i).unwrap();
428        }
429        let hints = profiled.hints();
430        assert!(!hints.is_empty());
431        assert!(hints
432            .iter()
433            .any(|h| h.contains("sequential") || h.contains("good")));
434    }
435    #[test]
436    #[cfg(feature = "std")]
437    fn test_dataset_profiler_reset() {
438        let tensor = ones::<f32>(&[10, 2]).unwrap();
439        let dataset = TensorDataset::from_tensor(tensor);
440        let profiled = ProfiledDataset::new(dataset);
441        for i in 0..5 {
442            let _ = profiled.get(i).unwrap();
443        }
444        assert_eq!(profiled.stats().total_accesses, 5);
445        profiled.profiler().reset();
446        assert_eq!(profiled.stats().total_accesses, 0);
447    }
448    #[test]
449    #[cfg(feature = "std")]
450    fn test_dataset_profiler_display() {
451        let tensor = ones::<f32>(&[10, 2]).unwrap();
452        let dataset = TensorDataset::from_tensor(tensor);
453        let profiled = ProfiledDataset::new(dataset);
454        for i in 0..5 {
455            let _ = profiled.get(i).unwrap();
456        }
457        let stats_string = format!("{}", profiled.stats());
458        assert!(stats_string.contains("Dataset Profile Statistics"));
459        assert!(stats_string.contains("Total Accesses: 5"));
460    }
461    #[test]
462    fn test_feature_stats() {
463        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
464        let stats = FeatureStats::from_data(&data);
465        assert_eq!(stats.count, 5);
466        assert_eq!(stats.mean, 3.0);
467        assert_eq!(stats.min, 1.0);
468        assert_eq!(stats.max, 5.0);
469        assert!((stats.std - 1.4142).abs() < 0.01);
470    }
471    #[test]
472    fn test_feature_stats_empty() {
473        let data: Vec<f32> = vec![];
474        let stats = FeatureStats::from_data(&data);
475        assert_eq!(stats.count, 0);
476        assert_eq!(stats.mean, 0.0);
477        assert_eq!(stats.std, 0.0);
478    }
479    #[test]
480    fn test_dataset_statistics() {
481        let data = torsh_tensor::creation::randn::<f32>(&[10, 3]).unwrap();
482        let dataset = TensorDataset::from_tensor(data);
483        let stats = dataset_statistics(&dataset).unwrap();
484        assert_eq!(stats.len(), 3);
485        for stat in &stats {
486            assert_eq!(stat.count, 10);
487            assert!(stat.min <= stat.mean);
488            assert!(stat.mean <= stat.max);
489            assert!(stat.std >= 0.0);
490        }
491    }
492    #[test]
493    fn test_dataset_statistics_empty() {
494        let data = torsh_tensor::creation::zeros::<f32>(&[0, 3]).unwrap();
495        let dataset = TensorDataset::from_tensor(data);
496        let stats = dataset_statistics(&dataset).unwrap();
497        assert_eq!(stats.len(), 0);
498    }
499    #[test]
500    fn test_kfold_basic() {
501        let kfold = KFold::new(5, false, Some(42));
502        let folds = kfold.split(100);
503        assert_eq!(folds.len(), 5);
504        for (fold_idx, (train_indices, val_indices)) in folds.iter().enumerate() {
505            assert_eq!(val_indices.len(), 20);
506            assert_eq!(train_indices.len(), 80);
507            for &val_idx in val_indices {
508                assert!(!train_indices.contains(&val_idx));
509            }
510            for &idx in train_indices.iter().chain(val_indices.iter()) {
511                assert!(idx < 100);
512            }
513            println!(
514                "Fold {}: train={}, val={}",
515                fold_idx,
516                train_indices.len(),
517                val_indices.len()
518            );
519        }
520    }
521    #[test]
522    fn test_kfold_shuffle() {
523        let kfold_shuffled = KFold::new(3, true, Some(42));
524        let kfold_unshuffled = KFold::new(3, false, None);
525        let folds_shuffled = kfold_shuffled.split(30);
526        let folds_unshuffled = kfold_unshuffled.split(30);
527        assert_eq!(folds_shuffled.len(), folds_unshuffled.len());
528        let shuffled_val = &folds_shuffled[0].1;
529        let unshuffled_val = &folds_unshuffled[0].1;
530        assert_eq!(unshuffled_val, &(0..10).collect::<Vec<_>>());
531        assert_ne!(shuffled_val, unshuffled_val);
532    }
533    #[test]
534    fn test_kfold_uneven_split() {
535        let kfold = KFold::new(3, false, None);
536        let folds = kfold.split(10);
537        assert_eq!(folds.len(), 3);
538        assert_eq!(folds[0].1.len(), 3);
539        assert_eq!(folds[1].1.len(), 3);
540        assert_eq!(folds[2].1.len(), 4);
541        let all_val_samples: usize = folds.iter().map(|(_, val)| val.len()).sum();
542        assert_eq!(all_val_samples, 10);
543    }
544    #[test]
545    #[should_panic(expected = "n_splits must be at least 2")]
546    fn test_kfold_invalid_splits() {
547        KFold::new(1, false, None);
548    }
549    #[test]
550    fn test_stratified_split_binary() {
551        let data = ones::<f32>(&[100, 5]).unwrap();
552        let dataset = TensorDataset::from_tensor(data);
553        let labels: Vec<usize> = (0..100).map(|i| if i < 50 { 0 } else { 1 }).collect();
554        let (train, test, val) =
555            stratified_split(dataset, &labels, 0.6, Some(0.2), Some(42)).unwrap();
556        assert_eq!(train.len(), 60);
557        assert!(val.is_some());
558        assert_eq!(val.as_ref().unwrap().len(), 20);
559        assert_eq!(test.len(), 20);
560        println!(
561            "Stratified split: train={}, val={}, test={}",
562            train.len(),
563            val.as_ref().unwrap().len(),
564            test.len()
565        );
566    }
567    #[test]
568    fn test_stratified_split_multi_class() {
569        let data = ones::<f32>(&[90, 5]).unwrap();
570        let dataset = TensorDataset::from_tensor(data);
571        let labels: Vec<usize> = (0..90).map(|i| i / 30).collect();
572        let (train, test, _val) = stratified_split(dataset, &labels, 0.7, None, Some(42)).unwrap();
573        assert_eq!(train.len(), 63);
574        assert_eq!(test.len(), 27);
575        println!(
576            "Multi-class split: train={}, test={}",
577            train.len(),
578            test.len()
579        );
580    }
581    #[test]
582    fn test_stratified_split_no_val() {
583        let data = ones::<f32>(&[50, 3]).unwrap();
584        let dataset = TensorDataset::from_tensor(data);
585        let labels: Vec<usize> = (0..50).map(|i| i % 2).collect();
586        let (train, test, val) = stratified_split(dataset, &labels, 0.8, None, Some(42)).unwrap();
587        assert_eq!(train.len(), 40);
588        assert_eq!(test.len(), 10);
589        assert!(val.is_none());
590    }
591    #[test]
592    fn test_stratified_split_invalid_ratio() {
593        let data = ones::<f32>(&[50, 3]).unwrap();
594        let dataset = TensorDataset::from_tensor(data);
595        let labels: Vec<usize> = (0..50).map(|i| i % 2).collect();
596        let result = stratified_split(dataset.clone(), &labels, 1.0, None, None);
597        assert!(result.is_err());
598        let result = stratified_split(dataset, &labels, 0.7, Some(0.4), None);
599        assert!(result.is_err());
600    }
601    #[test]
602    fn test_stratified_split_mismatched_labels() {
603        let data = ones::<f32>(&[50, 3]).unwrap();
604        let dataset = TensorDataset::from_tensor(data);
605        let labels: Vec<usize> = vec![0, 1];
606        let result = stratified_split(dataset, &labels, 0.8, None, None);
607        assert!(result.is_err());
608    }
609    #[test]
610    fn test_kfold_reproducibility() {
611        let kfold1 = KFold::new(5, true, Some(42));
612        let kfold2 = KFold::new(5, true, Some(42));
613        let folds1 = kfold1.split(50);
614        let folds2 = kfold2.split(50);
615        for (f1, f2) in folds1.iter().zip(folds2.iter()) {
616            assert_eq!(f1.0, f2.0);
617            assert_eq!(f1.1, f2.1);
618        }
619    }
620    #[test]
621    fn test_stratified_split_reproducibility() {
622        let data = ones::<f32>(&[100, 5]).unwrap();
623        let labels: Vec<usize> = (0..100).map(|i| i % 3).collect();
624        let (train1, test1, _) = stratified_split(
625            TensorDataset::from_tensor(data.clone()),
626            &labels,
627            0.7,
628            None,
629            Some(42),
630        )
631        .unwrap();
632        let (train2, test2, _) = stratified_split(
633            TensorDataset::from_tensor(data),
634            &labels,
635            0.7,
636            None,
637            Some(42),
638        )
639        .unwrap();
640        assert_eq!(train1.len(), train2.len());
641        assert_eq!(test1.len(), test2.len());
642    }
643}