Skip to main content

torsh_data/dataset/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use torsh_core::error::Result;
6
7use super::types::{FeatureStats, Subset, TensorDataset};
8
9/// A map-style dataset
10///
11/// Represents a dataset that supports random access with a known length.
12pub trait Dataset: Send + Sync {
13    /// The type of items returned by the dataset
14    type Item;
15    /// Returns the number of items in the dataset
16    fn len(&self) -> usize;
17    /// Returns true if the dataset is empty
18    fn is_empty(&self) -> bool {
19        self.len() == 0
20    }
21    /// Get a single item from the dataset
22    fn get(&self, index: usize) -> Result<Self::Item>;
23}
24/// An iterable-style dataset
25///
26/// Represents a dataset that can be iterated over but may not support
27/// random access or have a known length.
28pub trait IterableDataset: Send + Sync {
29    /// The type of items returned by the dataset
30    type Item;
31    /// The iterator type
32    type Iter: Iterator<Item = Result<Self::Item>> + Send;
33    /// Create an iterator over the dataset
34    fn iter(&self) -> Self::Iter;
35}
36/// Split a dataset into train and validation sets
37pub fn random_split<D>(
38    dataset: D,
39    lengths: &[usize],
40    generator: Option<u64>,
41) -> Result<Vec<Subset<D>>>
42where
43    D: Dataset + Clone,
44{
45    let total_length: usize = lengths.iter().sum();
46    if total_length != dataset.len() {
47        return Err(torsh_core::error::TorshError::InvalidArgument(format!(
48            "Sum of lengths {} does not equal dataset length {}",
49            total_length,
50            dataset.len()
51        )));
52    }
53    let mut indices: Vec<usize> = (0..dataset.len()).collect();
54    if let Some(_seed) = generator {
55        use scirs2_core::random::prelude::*;
56        use scirs2_core::random::seq::ScientificSliceRandom;
57        let mut rng = thread_rng();
58        indices.scientific_shuffle(&mut rng);
59    }
60    let mut subsets = Vec::with_capacity(lengths.len());
61    let mut offset = 0;
62    for &length in lengths {
63        let subset_indices = indices[offset..offset + length].to_vec();
64        subsets.push(Subset::new(dataset.clone(), subset_indices));
65        offset += length;
66    }
67    Ok(subsets)
68}
69/// Streaming dataset interface for real-time data processing
70///
71/// This trait represents datasets that can continuously produce data,
72/// potentially from real-time sources or infinite data generators.
73pub trait StreamingDataset: Send + Sync {
74    /// The type of items returned by the dataset
75    type Item;
76    /// The streaming iterator type
77    type Stream: Iterator<Item = Result<Self::Item>> + Send;
78    /// Create a stream over the dataset
79    fn stream(&self) -> Self::Stream;
80    /// Check if the stream has more data available
81    fn has_more(&self) -> bool {
82        true
83    }
84    /// Reset the stream to the beginning (if supported)
85    fn reset(&self) -> Result<()> {
86        Ok(())
87    }
88}
89/// Compute statistics for a tensor dataset
90///
91/// Returns feature statistics for each feature dimension in the dataset.
92/// Only works with `TensorDataset<f32>` where the first tensor contains the features.
93pub fn dataset_statistics(dataset: &TensorDataset<f32>) -> Result<Vec<FeatureStats>> {
94    if dataset.len() == 0 {
95        return Ok(Vec::new());
96    }
97    let first_item = dataset.get(0)?;
98    if first_item.is_empty() {
99        return Ok(Vec::new());
100    }
101    let features_tensor = &first_item[0];
102    let n_features = features_tensor.numel();
103    let mut feature_data: Vec<Vec<f32>> = vec![Vec::with_capacity(dataset.len()); n_features];
104    for i in 0..dataset.len() {
105        let item = dataset.get(i)?;
106        if item.is_empty() {
107            continue;
108        }
109        let features = &item[0];
110        for feat_idx in 0..n_features.min(features.numel()) {
111            if let Ok(indices) = torsh_tensor::Tensor::from_vec(vec![feat_idx as i64], &[1]) {
112                if let Ok(value_tensor) = features.index_select(0, &indices) {
113                    if let Ok(value) = value_tensor.item() {
114                        feature_data[feat_idx].push(value);
115                    }
116                }
117            }
118        }
119    }
120    Ok(feature_data
121        .iter()
122        .map(|data| FeatureStats::from_data(data))
123        .collect())
124}
125/// Stratified split that preserves class distribution
126///
127/// Splits data into train/val/test sets while maintaining the same class distribution
128/// in each split as in the original dataset.
129pub fn stratified_split<D>(
130    dataset: D,
131    labels: &[usize],
132    train_ratio: f32,
133    val_ratio: Option<f32>,
134    random_seed: Option<u64>,
135) -> Result<(Subset<D>, Subset<D>, Option<Subset<D>>)>
136where
137    D: Dataset + Clone,
138{
139    if train_ratio <= 0.0 || train_ratio >= 1.0 {
140        return Err(torsh_core::error::TorshError::InvalidArgument(
141            "train_ratio must be between 0 and 1".to_string(),
142        ));
143    }
144    let has_val = val_ratio.is_some();
145    let val_r = val_ratio.unwrap_or(0.0);
146    if has_val && (train_ratio + val_r >= 1.0) {
147        return Err(torsh_core::error::TorshError::InvalidArgument(
148            "train_ratio + val_ratio must be less than 1".to_string(),
149        ));
150    }
151    if labels.len() != dataset.len() {
152        return Err(torsh_core::error::TorshError::InvalidArgument(
153            "labels length must equal dataset length".to_string(),
154        ));
155    }
156    let mut class_indices: std::collections::HashMap<usize, Vec<usize>> =
157        std::collections::HashMap::new();
158    for (idx, &label) in labels.iter().enumerate() {
159        class_indices.entry(label).or_default().push(idx);
160    }
161    use scirs2_core::random::prelude::*;
162    use scirs2_core::random::seq::ScientificSliceRandom;
163    use scirs2_core::random::SeedableRng;
164    let mut rng = if let Some(seed) = random_seed {
165        StdRng::seed_from_u64(seed)
166    } else {
167        use std::time::SystemTime;
168        let seed = SystemTime::now()
169            .duration_since(SystemTime::UNIX_EPOCH)
170            .expect("time should be after UNIX_EPOCH")
171            .as_secs();
172        StdRng::seed_from_u64(seed)
173    };
174    let mut train_indices = Vec::new();
175    let mut val_indices = Vec::new();
176    let mut test_indices = Vec::new();
177    for (_class, mut indices) in class_indices {
178        indices.scientific_shuffle(&mut rng);
179        let n_train = (indices.len() as f32 * train_ratio).round() as usize;
180        let n_val = if has_val {
181            (indices.len() as f32 * val_r).round() as usize
182        } else {
183            0
184        };
185        train_indices.extend_from_slice(&indices[0..n_train]);
186        if has_val {
187            val_indices.extend_from_slice(&indices[n_train..n_train + n_val]);
188            test_indices.extend_from_slice(&indices[n_train + n_val..]);
189        } else {
190            test_indices.extend_from_slice(&indices[n_train..]);
191        }
192    }
193    let train_subset = Subset::new(dataset.clone(), train_indices);
194    let test_subset = Subset::new(dataset.clone(), test_indices);
195    let val_subset = if has_val {
196        Some(Subset::new(dataset, val_indices))
197    } else {
198        None
199    };
200    Ok((train_subset, test_subset, val_subset))
201}
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use crate::dataset::types::*;
206    use torsh_tensor::creation::*;
207    #[test]
208    fn test_tensor_dataset() {
209        let data = ones::<f32>(&[10, 3]).expect("operation should succeed");
210        let labels = zeros::<f32>(&[10]).expect("operation should succeed");
211        let dataset = TensorDataset::from_tensors(vec![data, labels]);
212        assert_eq!(dataset.len(), 10);
213        let item = dataset
214            .get(0)
215            .expect("element retrieval should succeed for valid index");
216        assert_eq!(item.len(), 2);
217    }
218    #[test]
219    fn test_concat_dataset() {
220        let ds1 = TensorDataset::from_tensor(
221            ones::<f32>(&[5, 3]).expect("Tensor Dataset should succeed"),
222        );
223        let ds2 = TensorDataset::from_tensor(
224            zeros::<f32>(&[3, 3]).expect("Tensor Dataset should succeed"),
225        );
226        let concat = ConcatDataset::new(vec![ds1, ds2]);
227        assert_eq!(concat.len(), 8);
228        assert_eq!(concat.dataset_idx(0), Some((0, 0)));
229        assert_eq!(concat.dataset_idx(4), Some((0, 4)));
230        assert_eq!(concat.dataset_idx(5), Some((1, 0)));
231        assert_eq!(concat.dataset_idx(7), Some((1, 2)));
232        assert_eq!(concat.dataset_idx(8), None);
233    }
234    #[test]
235    fn test_subset() {
236        let dataset = TensorDataset::from_tensor(
237            ones::<f32>(&[10, 3]).expect("Tensor Dataset should succeed"),
238        );
239        let subset = Subset::new(dataset, vec![0, 2, 4, 6, 8]);
240        assert_eq!(subset.len(), 5);
241        assert!(subset.get(0).is_ok());
242        assert!(subset.get(5).is_err());
243    }
244    #[derive(Clone)]
245    struct SimpleIterableDataset {
246        data: Vec<i32>,
247    }
248    impl IterableDataset for SimpleIterableDataset {
249        type Item = i32;
250        type Iter = std::iter::Map<std::vec::IntoIter<i32>, fn(i32) -> Result<i32>>;
251        fn iter(&self) -> Self::Iter {
252            self.data.clone().into_iter().map(|x| Ok(x) as Result<i32>)
253        }
254    }
255    #[test]
256    fn test_chain_dataset() {
257        let ds1 = SimpleIterableDataset {
258            data: vec![1, 2, 3],
259        };
260        let ds2 = SimpleIterableDataset {
261            data: vec![4, 5, 6],
262        };
263        let ds3 = SimpleIterableDataset {
264            data: vec![7, 8, 9],
265        };
266        let chain = ChainDataset::new(vec![ds1, ds2, ds3]);
267        let collected: Result<Vec<_>> = chain.iter().collect();
268        assert!(collected.is_ok());
269        let values = collected.expect("operation should succeed");
270        assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
271    }
272    #[test]
273    fn test_chain_dataset_empty() {
274        let chain: ChainDataset<SimpleIterableDataset> = ChainDataset::new(vec![]);
275        let collected: Result<Vec<_>> = chain.iter().collect();
276        assert!(collected.is_ok());
277        let values = collected.expect("operation should succeed");
278        assert_eq!(values, Vec::<i32>::new());
279    }
280    #[test]
281    fn test_chain_dataset_with_empty_datasets() {
282        let ds1 = SimpleIterableDataset { data: vec![] };
283        let ds2 = SimpleIterableDataset {
284            data: vec![1, 2, 3],
285        };
286        let ds3 = SimpleIterableDataset { data: vec![] };
287        let ds4 = SimpleIterableDataset { data: vec![4, 5] };
288        let chain = ChainDataset::new(vec![ds1, ds2, ds3, ds4]);
289        let collected: Result<Vec<_>> = chain.iter().collect();
290        assert!(collected.is_ok());
291        let values = collected.expect("operation should succeed");
292        assert_eq!(values, vec![1, 2, 3, 4, 5]);
293    }
294    #[test]
295    fn test_infinite_dataset() {
296        use std::sync::atomic::{AtomicUsize, Ordering};
297        use std::sync::Arc;
298        let counter = Arc::new(AtomicUsize::new(0));
299        let counter_clone = counter.clone();
300        let dataset = InfiniteDataset::new(move || {
301            let val = counter_clone.fetch_add(1, Ordering::SeqCst);
302            Ok(val)
303        });
304        assert!(dataset.has_more());
305        let mut stream = dataset.stream();
306        assert_eq!(
307            stream
308                .next()
309                .expect("iterator should have a next element")
310                .expect("operation should succeed"),
311            0
312        );
313        assert_eq!(
314            stream
315                .next()
316                .expect("iterator should have a next element")
317                .expect("operation should succeed"),
318            1
319        );
320        assert_eq!(
321            stream
322                .next()
323                .expect("iterator should have a next element")
324                .expect("operation should succeed"),
325            2
326        );
327    }
328    #[test]
329    fn test_buffered_streaming_dataset() {
330        let dataset = InfiniteDataset::new(|| Ok(42i32));
331        let buffered = BufferedStreamingDataset::new(dataset, 5).with_prefetch(true);
332        assert!(buffered.has_more());
333        let mut stream = buffered.stream();
334        for _ in 0..10 {
335            assert_eq!(
336                stream
337                    .next()
338                    .expect("iterator should have a next element")
339                    .expect("operation should succeed"),
340                42
341            );
342        }
343    }
344    #[test]
345    fn test_data_pipeline() {
346        let pipeline = DataPipeline::new()
347            .add_transform(|x: i32| Ok(x * 2))
348            .add_transform(|x: i32| Ok(x + 1));
349        let result = pipeline.apply(5).expect("apply operation should succeed");
350        assert_eq!(result, 11);
351    }
352    #[test]
353    fn test_pipeline_streaming_dataset() {
354        let dataset = InfiniteDataset::new(|| Ok(5i32));
355        let pipeline = DataPipeline::new()
356            .add_transform(|x: i32| Ok(x * 2))
357            .add_transform(|x: i32| Ok(x + 1));
358        let pipeline_dataset = PipelineStreamingDataset::new(dataset, pipeline);
359        assert!(pipeline_dataset.has_more());
360        let mut stream = pipeline_dataset.stream();
361        for _ in 0..5 {
362            assert_eq!(
363                stream
364                    .next()
365                    .expect("iterator should have a next element")
366                    .expect("operation should succeed"),
367                11
368            );
369        }
370    }
371    #[test]
372    fn test_real_time_dataset() {
373        let (dataset, _receiver) = RealTimeDataset::<i32>::new();
374        let sender = dataset.sender();
375        {
376            let sender_lock = sender.lock().expect("lock should not be poisoned");
377            sender_lock.send(1).expect("channel send should succeed");
378            sender_lock.send(2).expect("channel send should succeed");
379            sender_lock.send(3).expect("channel send should succeed");
380        }
381        assert!(dataset.has_more());
382        let _stream = dataset.stream();
383    }
384    #[test]
385    fn test_dataset_to_streaming() {
386        let tensor = ones::<f32>(&[5, 3]).expect("operation should succeed");
387        let dataset = TensorDataset::from_tensor(tensor);
388        let streaming = DatasetToStreaming::new(dataset);
389        assert!(streaming.has_more());
390        let stream = streaming.stream();
391        let mut count = 0;
392        for result in stream {
393            assert!(result.is_ok());
394            count += 1;
395            if count >= 5 {
396                break;
397            }
398        }
399        assert_eq!(count, 5);
400    }
401    #[test]
402    fn test_dataset_to_streaming_repeat() {
403        let tensor = ones::<f32>(&[3, 2]).expect("operation should succeed");
404        let dataset = TensorDataset::from_tensor(tensor);
405        let streaming = DatasetToStreaming::new(dataset).repeat();
406        assert!(streaming.has_more());
407        let stream = streaming.stream();
408        let mut count = 0;
409        for result in stream {
410            assert!(result.is_ok());
411            count += 1;
412            if count >= 10 {
413                break;
414            }
415        }
416        assert_eq!(count, 10);
417    }
418    #[test]
419    fn test_streaming_dataset_reset() {
420        let dataset = InfiniteDataset::new(|| Ok(42i32));
421        let buffered = BufferedStreamingDataset::new(dataset, 3);
422        assert!(buffered.reset().is_ok());
423    }
424    #[test]
425    #[cfg(feature = "std")]
426    fn test_dataset_profiler_sequential_access() {
427        use std::thread;
428        use std::time::Duration;
429        let tensor = ones::<f32>(&[10, 2]).expect("operation should succeed");
430        let dataset = TensorDataset::from_tensor(tensor);
431        let profiled = ProfiledDataset::new(dataset);
432        for i in 0..10 {
433            let _ = profiled
434                .get(i)
435                .expect("element retrieval should succeed for valid index");
436            thread::sleep(Duration::from_micros(100));
437        }
438        let stats = profiled.stats();
439        assert_eq!(stats.total_accesses, 10);
440        assert_eq!(stats.sequential_accesses, 9);
441        assert!(stats.sequential_ratio > 0.8);
442        assert!(stats.avg_access_time_us > 0.0);
443        assert!(stats.throughput_accesses_per_sec > 0.0);
444    }
445    #[test]
446    #[cfg(feature = "std")]
447    fn test_dataset_profiler_random_access() {
448        let tensor = ones::<f32>(&[10, 2]).expect("operation should succeed");
449        let dataset = TensorDataset::from_tensor(tensor);
450        let profiled = ProfiledDataset::new(dataset);
451        let indices = [0, 5, 2, 8, 1];
452        for &i in &indices {
453            let _ = profiled
454                .get(i)
455                .expect("element retrieval should succeed for valid index");
456        }
457        let stats = profiled.stats();
458        assert_eq!(stats.total_accesses, 5);
459        assert_eq!(stats.sequential_accesses, 0);
460        assert_eq!(stats.sequential_ratio, 0.0);
461    }
462    #[test]
463    #[cfg(feature = "std")]
464    fn test_dataset_profiler_hints() {
465        let tensor = ones::<f32>(&[100, 2]).expect("operation should succeed");
466        let dataset = TensorDataset::from_tensor(tensor);
467        let profiled = ProfiledDataset::new(dataset);
468        for i in 0..20 {
469            let _ = profiled
470                .get(i)
471                .expect("element retrieval should succeed for valid index");
472        }
473        let hints = profiled.hints();
474        assert!(!hints.is_empty());
475        assert!(hints
476            .iter()
477            .any(|h| h.contains("sequential") || h.contains("good")));
478    }
479    #[test]
480    #[cfg(feature = "std")]
481    fn test_dataset_profiler_reset() {
482        let tensor = ones::<f32>(&[10, 2]).expect("operation should succeed");
483        let dataset = TensorDataset::from_tensor(tensor);
484        let profiled = ProfiledDataset::new(dataset);
485        for i in 0..5 {
486            let _ = profiled
487                .get(i)
488                .expect("element retrieval should succeed for valid index");
489        }
490        assert_eq!(profiled.stats().total_accesses, 5);
491        profiled.profiler().reset();
492        assert_eq!(profiled.stats().total_accesses, 0);
493    }
494    #[test]
495    #[cfg(feature = "std")]
496    fn test_dataset_profiler_display() {
497        let tensor = ones::<f32>(&[10, 2]).expect("operation should succeed");
498        let dataset = TensorDataset::from_tensor(tensor);
499        let profiled = ProfiledDataset::new(dataset);
500        for i in 0..5 {
501            let _ = profiled
502                .get(i)
503                .expect("element retrieval should succeed for valid index");
504        }
505        let stats_string = format!("{}", profiled.stats());
506        assert!(stats_string.contains("Dataset Profile Statistics"));
507        assert!(stats_string.contains("Total Accesses: 5"));
508    }
509    #[test]
510    fn test_feature_stats() {
511        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
512        let stats = FeatureStats::from_data(&data);
513        assert_eq!(stats.count, 5);
514        assert_eq!(stats.mean, 3.0);
515        assert_eq!(stats.min, 1.0);
516        assert_eq!(stats.max, 5.0);
517        assert!((stats.std - 1.4142).abs() < 0.01);
518    }
519    #[test]
520    fn test_feature_stats_empty() {
521        let data: Vec<f32> = vec![];
522        let stats = FeatureStats::from_data(&data);
523        assert_eq!(stats.count, 0);
524        assert_eq!(stats.mean, 0.0);
525        assert_eq!(stats.std, 0.0);
526    }
527    #[test]
528    fn test_dataset_statistics() {
529        let data =
530            torsh_tensor::creation::randn::<f32>(&[10, 3]).expect("operation should succeed");
531        let dataset = TensorDataset::from_tensor(data);
532        let stats = dataset_statistics(&dataset).expect("dataset statistics should succeed");
533        assert_eq!(stats.len(), 3);
534        for stat in &stats {
535            assert_eq!(stat.count, 10);
536            assert!(stat.min <= stat.mean);
537            assert!(stat.mean <= stat.max);
538            assert!(stat.std >= 0.0);
539        }
540    }
541    #[test]
542    fn test_dataset_statistics_empty() {
543        let data = torsh_tensor::creation::zeros::<f32>(&[0, 3]).expect("operation should succeed");
544        let dataset = TensorDataset::from_tensor(data);
545        let stats = dataset_statistics(&dataset).expect("dataset statistics should succeed");
546        assert_eq!(stats.len(), 0);
547    }
548    #[test]
549    fn test_kfold_basic() {
550        let kfold = KFold::new(5, false, Some(42));
551        let folds = kfold.split(100);
552        assert_eq!(folds.len(), 5);
553        for (fold_idx, (train_indices, val_indices)) in folds.iter().enumerate() {
554            assert_eq!(val_indices.len(), 20);
555            assert_eq!(train_indices.len(), 80);
556            for &val_idx in val_indices {
557                assert!(!train_indices.contains(&val_idx));
558            }
559            for &idx in train_indices.iter().chain(val_indices.iter()) {
560                assert!(idx < 100);
561            }
562            println!(
563                "Fold {}: train={}, val={}",
564                fold_idx,
565                train_indices.len(),
566                val_indices.len()
567            );
568        }
569    }
570    #[test]
571    fn test_kfold_shuffle() {
572        let kfold_shuffled = KFold::new(3, true, Some(42));
573        let kfold_unshuffled = KFold::new(3, false, None);
574        let folds_shuffled = kfold_shuffled.split(30);
575        let folds_unshuffled = kfold_unshuffled.split(30);
576        assert_eq!(folds_shuffled.len(), folds_unshuffled.len());
577        let shuffled_val = &folds_shuffled[0].1;
578        let unshuffled_val = &folds_unshuffled[0].1;
579        assert_eq!(unshuffled_val, &(0..10).collect::<Vec<_>>());
580        assert_ne!(shuffled_val, unshuffled_val);
581    }
582    #[test]
583    fn test_kfold_uneven_split() {
584        let kfold = KFold::new(3, false, None);
585        let folds = kfold.split(10);
586        assert_eq!(folds.len(), 3);
587        assert_eq!(folds[0].1.len(), 3);
588        assert_eq!(folds[1].1.len(), 3);
589        assert_eq!(folds[2].1.len(), 4);
590        let all_val_samples: usize = folds.iter().map(|(_, val)| val.len()).sum();
591        assert_eq!(all_val_samples, 10);
592    }
593    #[test]
594    #[should_panic(expected = "n_splits must be at least 2")]
595    fn test_kfold_invalid_splits() {
596        KFold::new(1, false, None);
597    }
598    #[test]
599    fn test_stratified_split_binary() {
600        let data = ones::<f32>(&[100, 5]).expect("operation should succeed");
601        let dataset = TensorDataset::from_tensor(data);
602        let labels: Vec<usize> = (0..100).map(|i| if i < 50 { 0 } else { 1 }).collect();
603        let (train, test, val) = stratified_split(dataset, &labels, 0.6, Some(0.2), Some(42))
604            .expect("operation should succeed");
605        assert_eq!(train.len(), 60);
606        assert!(val.is_some());
607        assert_eq!(val.as_ref().expect("value should be available").len(), 20);
608        assert_eq!(test.len(), 20);
609        println!(
610            "Stratified split: train={}, val={}, test={}",
611            train.len(),
612            val.as_ref().expect("value should be available").len(),
613            test.len()
614        );
615    }
616    #[test]
617    fn test_stratified_split_multi_class() {
618        let data = ones::<f32>(&[90, 5]).expect("operation should succeed");
619        let dataset = TensorDataset::from_tensor(data);
620        let labels: Vec<usize> = (0..90).map(|i| i / 30).collect();
621        let (train, test, _val) = stratified_split(dataset, &labels, 0.7, None, Some(42))
622            .expect("operation should succeed");
623        assert_eq!(train.len(), 63);
624        assert_eq!(test.len(), 27);
625        println!(
626            "Multi-class split: train={}, test={}",
627            train.len(),
628            test.len()
629        );
630    }
631    #[test]
632    fn test_stratified_split_no_val() {
633        let data = ones::<f32>(&[50, 3]).expect("operation should succeed");
634        let dataset = TensorDataset::from_tensor(data);
635        let labels: Vec<usize> = (0..50).map(|i| i % 2).collect();
636        let (train, test, val) = stratified_split(dataset, &labels, 0.8, None, Some(42))
637            .expect("operation should succeed");
638        assert_eq!(train.len(), 40);
639        assert_eq!(test.len(), 10);
640        assert!(val.is_none());
641    }
642    #[test]
643    fn test_stratified_split_invalid_ratio() {
644        let data = ones::<f32>(&[50, 3]).expect("operation should succeed");
645        let dataset = TensorDataset::from_tensor(data);
646        let labels: Vec<usize> = (0..50).map(|i| i % 2).collect();
647        let result = stratified_split(dataset.clone(), &labels, 1.0, None, None);
648        assert!(result.is_err());
649        let result = stratified_split(dataset, &labels, 0.7, Some(0.4), None);
650        assert!(result.is_err());
651    }
652    #[test]
653    fn test_stratified_split_mismatched_labels() {
654        let data = ones::<f32>(&[50, 3]).expect("operation should succeed");
655        let dataset = TensorDataset::from_tensor(data);
656        let labels: Vec<usize> = vec![0, 1];
657        let result = stratified_split(dataset, &labels, 0.8, None, None);
658        assert!(result.is_err());
659    }
660    #[test]
661    fn test_kfold_reproducibility() {
662        let kfold1 = KFold::new(5, true, Some(42));
663        let kfold2 = KFold::new(5, true, Some(42));
664        let folds1 = kfold1.split(50);
665        let folds2 = kfold2.split(50);
666        for (f1, f2) in folds1.iter().zip(folds2.iter()) {
667            assert_eq!(f1.0, f2.0);
668            assert_eq!(f1.1, f2.1);
669        }
670    }
671    #[test]
672    fn test_stratified_split_reproducibility() {
673        let data = ones::<f32>(&[100, 5]).expect("operation should succeed");
674        let labels: Vec<usize> = (0..100).map(|i| i % 3).collect();
675        let (train1, test1, _) = stratified_split(
676            TensorDataset::from_tensor(data.clone()),
677            &labels,
678            0.7,
679            None,
680            Some(42),
681        )
682        .expect("operation should succeed");
683        let (train2, test2, _) = stratified_split(
684            TensorDataset::from_tensor(data),
685            &labels,
686            0.7,
687            None,
688            Some(42),
689        )
690        .expect("operation should succeed");
691        assert_eq!(train1.len(), train2.len());
692        assert_eq!(test1.len(), test2.len());
693    }
694}