1use torsh_core::error::Result;
6
7use super::types::{FeatureStats, Subset, TensorDataset};
8
9pub trait Dataset: Send + Sync {
13 type Item;
15 fn len(&self) -> usize;
17 fn is_empty(&self) -> bool {
19 self.len() == 0
20 }
21 fn get(&self, index: usize) -> Result<Self::Item>;
23}
24pub trait IterableDataset: Send + Sync {
29 type Item;
31 type Iter: Iterator<Item = Result<Self::Item>> + Send;
33 fn iter(&self) -> Self::Iter;
35}
36pub fn random_split<D>(
38 dataset: D,
39 lengths: &[usize],
40 generator: Option<u64>,
41) -> Result<Vec<Subset<D>>>
42where
43 D: Dataset + Clone,
44{
45 let total_length: usize = lengths.iter().sum();
46 if total_length != dataset.len() {
47 return Err(torsh_core::error::TorshError::InvalidArgument(format!(
48 "Sum of lengths {} does not equal dataset length {}",
49 total_length,
50 dataset.len()
51 )));
52 }
53 let mut indices: Vec<usize> = (0..dataset.len()).collect();
54 if let Some(_seed) = generator {
55 use scirs2_core::random::prelude::*;
56 use scirs2_core::random::seq::ScientificSliceRandom;
57 let mut rng = thread_rng();
58 indices.scientific_shuffle(&mut rng);
59 }
60 let mut subsets = Vec::with_capacity(lengths.len());
61 let mut offset = 0;
62 for &length in lengths {
63 let subset_indices = indices[offset..offset + length].to_vec();
64 subsets.push(Subset::new(dataset.clone(), subset_indices));
65 offset += length;
66 }
67 Ok(subsets)
68}
69pub trait StreamingDataset: Send + Sync {
74 type Item;
76 type Stream: Iterator<Item = Result<Self::Item>> + Send;
78 fn stream(&self) -> Self::Stream;
80 fn has_more(&self) -> bool {
82 true
83 }
84 fn reset(&self) -> Result<()> {
86 Ok(())
87 }
88}
89pub fn dataset_statistics(dataset: &TensorDataset<f32>) -> Result<Vec<FeatureStats>> {
94 if dataset.len() == 0 {
95 return Ok(Vec::new());
96 }
97 let first_item = dataset.get(0)?;
98 if first_item.is_empty() {
99 return Ok(Vec::new());
100 }
101 let features_tensor = &first_item[0];
102 let n_features = features_tensor.numel();
103 let mut feature_data: Vec<Vec<f32>> = vec![Vec::with_capacity(dataset.len()); n_features];
104 for i in 0..dataset.len() {
105 let item = dataset.get(i)?;
106 if item.is_empty() {
107 continue;
108 }
109 let features = &item[0];
110 for feat_idx in 0..n_features.min(features.numel()) {
111 if let Ok(indices) = torsh_tensor::Tensor::from_vec(vec![feat_idx as i64], &[1]) {
112 if let Ok(value_tensor) = features.index_select(0, &indices) {
113 if let Ok(value) = value_tensor.item() {
114 feature_data[feat_idx].push(value);
115 }
116 }
117 }
118 }
119 }
120 Ok(feature_data
121 .iter()
122 .map(|data| FeatureStats::from_data(data))
123 .collect())
124}
125pub fn stratified_split<D>(
130 dataset: D,
131 labels: &[usize],
132 train_ratio: f32,
133 val_ratio: Option<f32>,
134 random_seed: Option<u64>,
135) -> Result<(Subset<D>, Subset<D>, Option<Subset<D>>)>
136where
137 D: Dataset + Clone,
138{
139 if train_ratio <= 0.0 || train_ratio >= 1.0 {
140 return Err(torsh_core::error::TorshError::InvalidArgument(
141 "train_ratio must be between 0 and 1".to_string(),
142 ));
143 }
144 let has_val = val_ratio.is_some();
145 let val_r = val_ratio.unwrap_or(0.0);
146 if has_val && (train_ratio + val_r >= 1.0) {
147 return Err(torsh_core::error::TorshError::InvalidArgument(
148 "train_ratio + val_ratio must be less than 1".to_string(),
149 ));
150 }
151 if labels.len() != dataset.len() {
152 return Err(torsh_core::error::TorshError::InvalidArgument(
153 "labels length must equal dataset length".to_string(),
154 ));
155 }
156 let mut class_indices: std::collections::HashMap<usize, Vec<usize>> =
157 std::collections::HashMap::new();
158 for (idx, &label) in labels.iter().enumerate() {
159 class_indices.entry(label).or_default().push(idx);
160 }
161 use scirs2_core::random::prelude::*;
162 use scirs2_core::random::seq::ScientificSliceRandom;
163 use scirs2_core::random::SeedableRng;
164 let mut rng = if let Some(seed) = random_seed {
165 StdRng::seed_from_u64(seed)
166 } else {
167 use std::time::SystemTime;
168 let seed = SystemTime::now()
169 .duration_since(SystemTime::UNIX_EPOCH)
170 .expect("time should be after UNIX_EPOCH")
171 .as_secs();
172 StdRng::seed_from_u64(seed)
173 };
174 let mut train_indices = Vec::new();
175 let mut val_indices = Vec::new();
176 let mut test_indices = Vec::new();
177 for (_class, mut indices) in class_indices {
178 indices.scientific_shuffle(&mut rng);
179 let n_train = (indices.len() as f32 * train_ratio).round() as usize;
180 let n_val = if has_val {
181 (indices.len() as f32 * val_r).round() as usize
182 } else {
183 0
184 };
185 train_indices.extend_from_slice(&indices[0..n_train]);
186 if has_val {
187 val_indices.extend_from_slice(&indices[n_train..n_train + n_val]);
188 test_indices.extend_from_slice(&indices[n_train + n_val..]);
189 } else {
190 test_indices.extend_from_slice(&indices[n_train..]);
191 }
192 }
193 let train_subset = Subset::new(dataset.clone(), train_indices);
194 let test_subset = Subset::new(dataset.clone(), test_indices);
195 let val_subset = if has_val {
196 Some(Subset::new(dataset, val_indices))
197 } else {
198 None
199 };
200 Ok((train_subset, test_subset, val_subset))
201}
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::dataset::types::*;
206 use torsh_tensor::creation::*;
207 #[test]
208 fn test_tensor_dataset() {
209 let data = ones::<f32>(&[10, 3]).expect("operation should succeed");
210 let labels = zeros::<f32>(&[10]).expect("operation should succeed");
211 let dataset = TensorDataset::from_tensors(vec![data, labels]);
212 assert_eq!(dataset.len(), 10);
213 let item = dataset
214 .get(0)
215 .expect("element retrieval should succeed for valid index");
216 assert_eq!(item.len(), 2);
217 }
218 #[test]
219 fn test_concat_dataset() {
220 let ds1 = TensorDataset::from_tensor(
221 ones::<f32>(&[5, 3]).expect("Tensor Dataset should succeed"),
222 );
223 let ds2 = TensorDataset::from_tensor(
224 zeros::<f32>(&[3, 3]).expect("Tensor Dataset should succeed"),
225 );
226 let concat = ConcatDataset::new(vec![ds1, ds2]);
227 assert_eq!(concat.len(), 8);
228 assert_eq!(concat.dataset_idx(0), Some((0, 0)));
229 assert_eq!(concat.dataset_idx(4), Some((0, 4)));
230 assert_eq!(concat.dataset_idx(5), Some((1, 0)));
231 assert_eq!(concat.dataset_idx(7), Some((1, 2)));
232 assert_eq!(concat.dataset_idx(8), None);
233 }
234 #[test]
235 fn test_subset() {
236 let dataset = TensorDataset::from_tensor(
237 ones::<f32>(&[10, 3]).expect("Tensor Dataset should succeed"),
238 );
239 let subset = Subset::new(dataset, vec![0, 2, 4, 6, 8]);
240 assert_eq!(subset.len(), 5);
241 assert!(subset.get(0).is_ok());
242 assert!(subset.get(5).is_err());
243 }
244 #[derive(Clone)]
245 struct SimpleIterableDataset {
246 data: Vec<i32>,
247 }
248 impl IterableDataset for SimpleIterableDataset {
249 type Item = i32;
250 type Iter = std::iter::Map<std::vec::IntoIter<i32>, fn(i32) -> Result<i32>>;
251 fn iter(&self) -> Self::Iter {
252 self.data.clone().into_iter().map(|x| Ok(x) as Result<i32>)
253 }
254 }
255 #[test]
256 fn test_chain_dataset() {
257 let ds1 = SimpleIterableDataset {
258 data: vec![1, 2, 3],
259 };
260 let ds2 = SimpleIterableDataset {
261 data: vec![4, 5, 6],
262 };
263 let ds3 = SimpleIterableDataset {
264 data: vec![7, 8, 9],
265 };
266 let chain = ChainDataset::new(vec![ds1, ds2, ds3]);
267 let collected: Result<Vec<_>> = chain.iter().collect();
268 assert!(collected.is_ok());
269 let values = collected.expect("operation should succeed");
270 assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
271 }
272 #[test]
273 fn test_chain_dataset_empty() {
274 let chain: ChainDataset<SimpleIterableDataset> = ChainDataset::new(vec![]);
275 let collected: Result<Vec<_>> = chain.iter().collect();
276 assert!(collected.is_ok());
277 let values = collected.expect("operation should succeed");
278 assert_eq!(values, Vec::<i32>::new());
279 }
280 #[test]
281 fn test_chain_dataset_with_empty_datasets() {
282 let ds1 = SimpleIterableDataset { data: vec![] };
283 let ds2 = SimpleIterableDataset {
284 data: vec![1, 2, 3],
285 };
286 let ds3 = SimpleIterableDataset { data: vec![] };
287 let ds4 = SimpleIterableDataset { data: vec![4, 5] };
288 let chain = ChainDataset::new(vec![ds1, ds2, ds3, ds4]);
289 let collected: Result<Vec<_>> = chain.iter().collect();
290 assert!(collected.is_ok());
291 let values = collected.expect("operation should succeed");
292 assert_eq!(values, vec![1, 2, 3, 4, 5]);
293 }
294 #[test]
295 fn test_infinite_dataset() {
296 use std::sync::atomic::{AtomicUsize, Ordering};
297 use std::sync::Arc;
298 let counter = Arc::new(AtomicUsize::new(0));
299 let counter_clone = counter.clone();
300 let dataset = InfiniteDataset::new(move || {
301 let val = counter_clone.fetch_add(1, Ordering::SeqCst);
302 Ok(val)
303 });
304 assert!(dataset.has_more());
305 let mut stream = dataset.stream();
306 assert_eq!(
307 stream
308 .next()
309 .expect("iterator should have a next element")
310 .expect("operation should succeed"),
311 0
312 );
313 assert_eq!(
314 stream
315 .next()
316 .expect("iterator should have a next element")
317 .expect("operation should succeed"),
318 1
319 );
320 assert_eq!(
321 stream
322 .next()
323 .expect("iterator should have a next element")
324 .expect("operation should succeed"),
325 2
326 );
327 }
328 #[test]
329 fn test_buffered_streaming_dataset() {
330 let dataset = InfiniteDataset::new(|| Ok(42i32));
331 let buffered = BufferedStreamingDataset::new(dataset, 5).with_prefetch(true);
332 assert!(buffered.has_more());
333 let mut stream = buffered.stream();
334 for _ in 0..10 {
335 assert_eq!(
336 stream
337 .next()
338 .expect("iterator should have a next element")
339 .expect("operation should succeed"),
340 42
341 );
342 }
343 }
344 #[test]
345 fn test_data_pipeline() {
346 let pipeline = DataPipeline::new()
347 .add_transform(|x: i32| Ok(x * 2))
348 .add_transform(|x: i32| Ok(x + 1));
349 let result = pipeline.apply(5).expect("apply operation should succeed");
350 assert_eq!(result, 11);
351 }
352 #[test]
353 fn test_pipeline_streaming_dataset() {
354 let dataset = InfiniteDataset::new(|| Ok(5i32));
355 let pipeline = DataPipeline::new()
356 .add_transform(|x: i32| Ok(x * 2))
357 .add_transform(|x: i32| Ok(x + 1));
358 let pipeline_dataset = PipelineStreamingDataset::new(dataset, pipeline);
359 assert!(pipeline_dataset.has_more());
360 let mut stream = pipeline_dataset.stream();
361 for _ in 0..5 {
362 assert_eq!(
363 stream
364 .next()
365 .expect("iterator should have a next element")
366 .expect("operation should succeed"),
367 11
368 );
369 }
370 }
371 #[test]
372 fn test_real_time_dataset() {
373 let (dataset, _receiver) = RealTimeDataset::<i32>::new();
374 let sender = dataset.sender();
375 {
376 let sender_lock = sender.lock().expect("lock should not be poisoned");
377 sender_lock.send(1).expect("channel send should succeed");
378 sender_lock.send(2).expect("channel send should succeed");
379 sender_lock.send(3).expect("channel send should succeed");
380 }
381 assert!(dataset.has_more());
382 let _stream = dataset.stream();
383 }
384 #[test]
385 fn test_dataset_to_streaming() {
386 let tensor = ones::<f32>(&[5, 3]).expect("operation should succeed");
387 let dataset = TensorDataset::from_tensor(tensor);
388 let streaming = DatasetToStreaming::new(dataset);
389 assert!(streaming.has_more());
390 let stream = streaming.stream();
391 let mut count = 0;
392 for result in stream {
393 assert!(result.is_ok());
394 count += 1;
395 if count >= 5 {
396 break;
397 }
398 }
399 assert_eq!(count, 5);
400 }
401 #[test]
402 fn test_dataset_to_streaming_repeat() {
403 let tensor = ones::<f32>(&[3, 2]).expect("operation should succeed");
404 let dataset = TensorDataset::from_tensor(tensor);
405 let streaming = DatasetToStreaming::new(dataset).repeat();
406 assert!(streaming.has_more());
407 let stream = streaming.stream();
408 let mut count = 0;
409 for result in stream {
410 assert!(result.is_ok());
411 count += 1;
412 if count >= 10 {
413 break;
414 }
415 }
416 assert_eq!(count, 10);
417 }
418 #[test]
419 fn test_streaming_dataset_reset() {
420 let dataset = InfiniteDataset::new(|| Ok(42i32));
421 let buffered = BufferedStreamingDataset::new(dataset, 3);
422 assert!(buffered.reset().is_ok());
423 }
424 #[test]
425 #[cfg(feature = "std")]
426 fn test_dataset_profiler_sequential_access() {
427 use std::thread;
428 use std::time::Duration;
429 let tensor = ones::<f32>(&[10, 2]).expect("operation should succeed");
430 let dataset = TensorDataset::from_tensor(tensor);
431 let profiled = ProfiledDataset::new(dataset);
432 for i in 0..10 {
433 let _ = profiled
434 .get(i)
435 .expect("element retrieval should succeed for valid index");
436 thread::sleep(Duration::from_micros(100));
437 }
438 let stats = profiled.stats();
439 assert_eq!(stats.total_accesses, 10);
440 assert_eq!(stats.sequential_accesses, 9);
441 assert!(stats.sequential_ratio > 0.8);
442 assert!(stats.avg_access_time_us > 0.0);
443 assert!(stats.throughput_accesses_per_sec > 0.0);
444 }
445 #[test]
446 #[cfg(feature = "std")]
447 fn test_dataset_profiler_random_access() {
448 let tensor = ones::<f32>(&[10, 2]).expect("operation should succeed");
449 let dataset = TensorDataset::from_tensor(tensor);
450 let profiled = ProfiledDataset::new(dataset);
451 let indices = [0, 5, 2, 8, 1];
452 for &i in &indices {
453 let _ = profiled
454 .get(i)
455 .expect("element retrieval should succeed for valid index");
456 }
457 let stats = profiled.stats();
458 assert_eq!(stats.total_accesses, 5);
459 assert_eq!(stats.sequential_accesses, 0);
460 assert_eq!(stats.sequential_ratio, 0.0);
461 }
462 #[test]
463 #[cfg(feature = "std")]
464 fn test_dataset_profiler_hints() {
465 let tensor = ones::<f32>(&[100, 2]).expect("operation should succeed");
466 let dataset = TensorDataset::from_tensor(tensor);
467 let profiled = ProfiledDataset::new(dataset);
468 for i in 0..20 {
469 let _ = profiled
470 .get(i)
471 .expect("element retrieval should succeed for valid index");
472 }
473 let hints = profiled.hints();
474 assert!(!hints.is_empty());
475 assert!(hints
476 .iter()
477 .any(|h| h.contains("sequential") || h.contains("good")));
478 }
479 #[test]
480 #[cfg(feature = "std")]
481 fn test_dataset_profiler_reset() {
482 let tensor = ones::<f32>(&[10, 2]).expect("operation should succeed");
483 let dataset = TensorDataset::from_tensor(tensor);
484 let profiled = ProfiledDataset::new(dataset);
485 for i in 0..5 {
486 let _ = profiled
487 .get(i)
488 .expect("element retrieval should succeed for valid index");
489 }
490 assert_eq!(profiled.stats().total_accesses, 5);
491 profiled.profiler().reset();
492 assert_eq!(profiled.stats().total_accesses, 0);
493 }
494 #[test]
495 #[cfg(feature = "std")]
496 fn test_dataset_profiler_display() {
497 let tensor = ones::<f32>(&[10, 2]).expect("operation should succeed");
498 let dataset = TensorDataset::from_tensor(tensor);
499 let profiled = ProfiledDataset::new(dataset);
500 for i in 0..5 {
501 let _ = profiled
502 .get(i)
503 .expect("element retrieval should succeed for valid index");
504 }
505 let stats_string = format!("{}", profiled.stats());
506 assert!(stats_string.contains("Dataset Profile Statistics"));
507 assert!(stats_string.contains("Total Accesses: 5"));
508 }
509 #[test]
510 fn test_feature_stats() {
511 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
512 let stats = FeatureStats::from_data(&data);
513 assert_eq!(stats.count, 5);
514 assert_eq!(stats.mean, 3.0);
515 assert_eq!(stats.min, 1.0);
516 assert_eq!(stats.max, 5.0);
517 assert!((stats.std - 1.4142).abs() < 0.01);
518 }
519 #[test]
520 fn test_feature_stats_empty() {
521 let data: Vec<f32> = vec![];
522 let stats = FeatureStats::from_data(&data);
523 assert_eq!(stats.count, 0);
524 assert_eq!(stats.mean, 0.0);
525 assert_eq!(stats.std, 0.0);
526 }
527 #[test]
528 fn test_dataset_statistics() {
529 let data =
530 torsh_tensor::creation::randn::<f32>(&[10, 3]).expect("operation should succeed");
531 let dataset = TensorDataset::from_tensor(data);
532 let stats = dataset_statistics(&dataset).expect("dataset statistics should succeed");
533 assert_eq!(stats.len(), 3);
534 for stat in &stats {
535 assert_eq!(stat.count, 10);
536 assert!(stat.min <= stat.mean);
537 assert!(stat.mean <= stat.max);
538 assert!(stat.std >= 0.0);
539 }
540 }
541 #[test]
542 fn test_dataset_statistics_empty() {
543 let data = torsh_tensor::creation::zeros::<f32>(&[0, 3]).expect("operation should succeed");
544 let dataset = TensorDataset::from_tensor(data);
545 let stats = dataset_statistics(&dataset).expect("dataset statistics should succeed");
546 assert_eq!(stats.len(), 0);
547 }
548 #[test]
549 fn test_kfold_basic() {
550 let kfold = KFold::new(5, false, Some(42));
551 let folds = kfold.split(100);
552 assert_eq!(folds.len(), 5);
553 for (fold_idx, (train_indices, val_indices)) in folds.iter().enumerate() {
554 assert_eq!(val_indices.len(), 20);
555 assert_eq!(train_indices.len(), 80);
556 for &val_idx in val_indices {
557 assert!(!train_indices.contains(&val_idx));
558 }
559 for &idx in train_indices.iter().chain(val_indices.iter()) {
560 assert!(idx < 100);
561 }
562 println!(
563 "Fold {}: train={}, val={}",
564 fold_idx,
565 train_indices.len(),
566 val_indices.len()
567 );
568 }
569 }
570 #[test]
571 fn test_kfold_shuffle() {
572 let kfold_shuffled = KFold::new(3, true, Some(42));
573 let kfold_unshuffled = KFold::new(3, false, None);
574 let folds_shuffled = kfold_shuffled.split(30);
575 let folds_unshuffled = kfold_unshuffled.split(30);
576 assert_eq!(folds_shuffled.len(), folds_unshuffled.len());
577 let shuffled_val = &folds_shuffled[0].1;
578 let unshuffled_val = &folds_unshuffled[0].1;
579 assert_eq!(unshuffled_val, &(0..10).collect::<Vec<_>>());
580 assert_ne!(shuffled_val, unshuffled_val);
581 }
582 #[test]
583 fn test_kfold_uneven_split() {
584 let kfold = KFold::new(3, false, None);
585 let folds = kfold.split(10);
586 assert_eq!(folds.len(), 3);
587 assert_eq!(folds[0].1.len(), 3);
588 assert_eq!(folds[1].1.len(), 3);
589 assert_eq!(folds[2].1.len(), 4);
590 let all_val_samples: usize = folds.iter().map(|(_, val)| val.len()).sum();
591 assert_eq!(all_val_samples, 10);
592 }
593 #[test]
594 #[should_panic(expected = "n_splits must be at least 2")]
595 fn test_kfold_invalid_splits() {
596 KFold::new(1, false, None);
597 }
598 #[test]
599 fn test_stratified_split_binary() {
600 let data = ones::<f32>(&[100, 5]).expect("operation should succeed");
601 let dataset = TensorDataset::from_tensor(data);
602 let labels: Vec<usize> = (0..100).map(|i| if i < 50 { 0 } else { 1 }).collect();
603 let (train, test, val) = stratified_split(dataset, &labels, 0.6, Some(0.2), Some(42))
604 .expect("operation should succeed");
605 assert_eq!(train.len(), 60);
606 assert!(val.is_some());
607 assert_eq!(val.as_ref().expect("value should be available").len(), 20);
608 assert_eq!(test.len(), 20);
609 println!(
610 "Stratified split: train={}, val={}, test={}",
611 train.len(),
612 val.as_ref().expect("value should be available").len(),
613 test.len()
614 );
615 }
616 #[test]
617 fn test_stratified_split_multi_class() {
618 let data = ones::<f32>(&[90, 5]).expect("operation should succeed");
619 let dataset = TensorDataset::from_tensor(data);
620 let labels: Vec<usize> = (0..90).map(|i| i / 30).collect();
621 let (train, test, _val) = stratified_split(dataset, &labels, 0.7, None, Some(42))
622 .expect("operation should succeed");
623 assert_eq!(train.len(), 63);
624 assert_eq!(test.len(), 27);
625 println!(
626 "Multi-class split: train={}, test={}",
627 train.len(),
628 test.len()
629 );
630 }
631 #[test]
632 fn test_stratified_split_no_val() {
633 let data = ones::<f32>(&[50, 3]).expect("operation should succeed");
634 let dataset = TensorDataset::from_tensor(data);
635 let labels: Vec<usize> = (0..50).map(|i| i % 2).collect();
636 let (train, test, val) = stratified_split(dataset, &labels, 0.8, None, Some(42))
637 .expect("operation should succeed");
638 assert_eq!(train.len(), 40);
639 assert_eq!(test.len(), 10);
640 assert!(val.is_none());
641 }
642 #[test]
643 fn test_stratified_split_invalid_ratio() {
644 let data = ones::<f32>(&[50, 3]).expect("operation should succeed");
645 let dataset = TensorDataset::from_tensor(data);
646 let labels: Vec<usize> = (0..50).map(|i| i % 2).collect();
647 let result = stratified_split(dataset.clone(), &labels, 1.0, None, None);
648 assert!(result.is_err());
649 let result = stratified_split(dataset, &labels, 0.7, Some(0.4), None);
650 assert!(result.is_err());
651 }
652 #[test]
653 fn test_stratified_split_mismatched_labels() {
654 let data = ones::<f32>(&[50, 3]).expect("operation should succeed");
655 let dataset = TensorDataset::from_tensor(data);
656 let labels: Vec<usize> = vec![0, 1];
657 let result = stratified_split(dataset, &labels, 0.8, None, None);
658 assert!(result.is_err());
659 }
660 #[test]
661 fn test_kfold_reproducibility() {
662 let kfold1 = KFold::new(5, true, Some(42));
663 let kfold2 = KFold::new(5, true, Some(42));
664 let folds1 = kfold1.split(50);
665 let folds2 = kfold2.split(50);
666 for (f1, f2) in folds1.iter().zip(folds2.iter()) {
667 assert_eq!(f1.0, f2.0);
668 assert_eq!(f1.1, f2.1);
669 }
670 }
671 #[test]
672 fn test_stratified_split_reproducibility() {
673 let data = ones::<f32>(&[100, 5]).expect("operation should succeed");
674 let labels: Vec<usize> = (0..100).map(|i| i % 3).collect();
675 let (train1, test1, _) = stratified_split(
676 TensorDataset::from_tensor(data.clone()),
677 &labels,
678 0.7,
679 None,
680 Some(42),
681 )
682 .expect("operation should succeed");
683 let (train2, test2, _) = stratified_split(
684 TensorDataset::from_tensor(data),
685 &labels,
686 0.7,
687 None,
688 Some(42),
689 )
690 .expect("operation should succeed");
691 assert_eq!(train1.len(), train2.len());
692 assert_eq!(test1.len(), test2.len());
693 }
694}