1use crate::error::{DatasetsError, Result};
14use scirs2_core::ndarray::{Array1, Array2};
15
16struct Lcg64 {
22 state: u64,
23}
24
25impl Lcg64 {
26 fn new(seed: u64) -> Self {
27 Self {
28 state: seed.wrapping_add(1),
29 }
30 }
31
32 fn next_u64(&mut self) -> u64 {
34 self.state = self
35 .state
36 .wrapping_mul(6_364_136_223_846_793_005)
37 .wrapping_add(1_442_695_040_888_963_407);
38 self.state
39 }
40
41 fn next_usize(&mut self, n: usize) -> usize {
43 if n == 0 {
44 return 0;
45 }
46 (self.next_u64() % n as u64) as usize
47 }
48
49 fn next_f64(&mut self) -> f64 {
51 (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
52 }
53}
54
55#[non_exhaustive]
61#[derive(Debug, Clone, PartialEq, Default)]
62pub enum ShardStrategy {
63 #[default]
65 Index,
66 Hash,
68 Stratified {
70 label_column: String,
72 },
73 Size {
75 shard_size_bytes: usize,
77 },
78}
79
80#[derive(Debug, Clone)]
82pub struct ShardingConfig {
83 pub n_shards: usize,
85 pub strategy: ShardStrategy,
87 pub shuffle: bool,
89 pub seed: u64,
91}
92
93impl Default for ShardingConfig {
94 fn default() -> Self {
95 Self {
96 n_shards: 8,
97 strategy: ShardStrategy::default(),
98 shuffle: true,
99 seed: 42,
100 }
101 }
102}
103
104#[derive(Debug, Clone)]
106pub struct DataShard {
107 pub shard_id: usize,
109 pub n_shards: usize,
111 pub indices: Vec<usize>,
113 pub is_train: bool,
115}
116
117impl DataShard {
118 pub fn new(
125 shard_id: usize,
126 total_shards: usize,
127 n_samples: usize,
128 config: &ShardConfig,
129 ) -> Self {
130 let all_shards = shard_by_index(
131 n_samples,
132 total_shards,
133 config.shuffle,
134 config.seed.unwrap_or(0),
135 );
136 match all_shards.into_iter().find(|s| s.shard_id == shard_id) {
138 Some(s) => Self {
139 shard_id: s.shard_id,
140 n_shards: s.n_shards,
141 indices: s.indices,
142 is_train: s.is_train,
143 },
144 None => Self {
145 shard_id,
146 n_shards: total_shards,
147 indices: Vec::new(),
148 is_train: true,
149 },
150 }
151 }
152
153 pub fn apply_2d<T: Clone + Default>(&self, data: &Array2<T>) -> Array2<T> {
162 let n_cols = data.ncols();
163 let valid_indices: Vec<usize> = self
164 .indices
165 .iter()
166 .copied()
167 .filter(|&i| i < data.nrows())
168 .collect();
169 let n_rows = valid_indices.len();
170 if n_rows == 0 || n_cols == 0 {
171 return Array2::default((0, n_cols));
172 }
173 let mut flat = Vec::with_capacity(n_rows * n_cols);
174 for &row_idx in &valid_indices {
175 flat.extend_from_slice(data.row(row_idx).as_slice().unwrap_or(&[]));
176 }
177 if flat.len() != n_rows * n_cols {
179 flat.clear();
180 for &row_idx in &valid_indices {
181 for col in 0..n_cols {
182 flat.push(data[[row_idx, col]].clone());
183 }
184 }
185 }
186 Array2::from_shape_vec((n_rows, n_cols), flat)
187 .unwrap_or_else(|_| Array2::default((0, n_cols)))
188 }
189
190 pub fn apply_1d<T: Clone>(&self, data: &Array1<T>) -> Array1<T> {
199 let selected: Vec<T> = self
200 .indices
201 .iter()
202 .copied()
203 .filter(|&i| i < data.len())
204 .map(|i| data[i].clone())
205 .collect();
206 Array1::from_vec(selected)
207 }
208
209 pub fn len(&self) -> usize {
211 self.indices.len()
212 }
213
214 pub fn is_empty(&self) -> bool {
216 self.indices.is_empty()
217 }
218}
219
220#[derive(Debug, Clone)]
224pub struct ShardConfig {
225 pub n_shards: usize,
227 pub shuffle: bool,
229 pub seed: Option<u64>,
231}
232
233#[derive(Debug, Clone)]
237pub struct ShardedDataset {
238 pub shards: Vec<DataShard>,
240 pub total_size: usize,
242 pub config: ShardingConfig,
244}
245
246pub fn consistent_shuffle(n: usize, seed: u64) -> Vec<usize> {
254 let mut indices: Vec<usize> = (0..n).collect();
255 let mut rng = Lcg64::new(seed);
256 for i in (1..n).rev() {
258 let j = rng.next_usize(i + 1);
259 indices.swap(i, j);
260 }
261 indices
262}
263
264pub fn shard_by_index(
269 n_samples: usize,
270 n_shards: usize,
271 shuffle: bool,
272 seed: u64,
273) -> Vec<DataShard> {
274 if n_shards == 0 || n_samples == 0 {
275 return Vec::new();
276 }
277
278 let indices = if shuffle {
279 consistent_shuffle(n_samples, seed)
280 } else {
281 (0..n_samples).collect()
282 };
283
284 let base = n_samples / n_shards;
285 let remainder = n_samples % n_shards;
286
287 let mut shards = Vec::with_capacity(n_shards);
288 let mut offset = 0usize;
289
290 for shard_id in 0..n_shards {
291 let extra = if shard_id < remainder { 1 } else { 0 };
292 let size = base + extra;
293 let shard_indices = indices[offset..offset + size].to_vec();
294 shards.push(DataShard {
295 shard_id,
296 n_shards,
297 indices: shard_indices,
298 is_train: true,
299 });
300 offset += size;
301 }
302
303 shards
304}
305
306pub fn shard_by_hash(n_samples: usize, n_shards: usize) -> Vec<DataShard> {
308 if n_shards == 0 || n_samples == 0 {
309 return Vec::new();
310 }
311
312 let mut buckets: Vec<Vec<usize>> = vec![Vec::new(); n_shards];
313 for i in 0..n_samples {
314 buckets[i % n_shards].push(i);
315 }
316
317 buckets
318 .into_iter()
319 .enumerate()
320 .map(|(shard_id, indices)| DataShard {
321 shard_id,
322 n_shards,
323 indices,
324 is_train: true,
325 })
326 .collect()
327}
328
329pub fn shard_stratified(
334 labels: &[usize],
335 n_shards: usize,
336 shuffle: bool,
337 seed: u64,
338) -> Vec<DataShard> {
339 if n_shards == 0 || labels.is_empty() {
340 return Vec::new();
341 }
342
343 let max_class = labels.iter().copied().max().unwrap_or(0);
345 let mut class_indices: Vec<Vec<usize>> = vec![Vec::new(); max_class + 1];
346 for (i, &label) in labels.iter().enumerate() {
347 class_indices[label].push(i);
348 }
349
350 if shuffle {
352 for (cls, indices) in class_indices.iter_mut().enumerate() {
353 let class_seed = seed.wrapping_add(cls as u64 * 0x9e37_79b9_7f4a_7c15);
354 let shuffled = consistent_shuffle(indices.len(), class_seed);
355 let original = indices.clone();
356 for (new_pos, &old_pos) in shuffled.iter().enumerate() {
357 indices[new_pos] = original[old_pos];
358 }
359 }
360 }
361
362 let mut buckets: Vec<Vec<usize>> = vec![Vec::new(); n_shards];
364 for class_idx in class_indices {
365 for (pos, sample_idx) in class_idx.into_iter().enumerate() {
367 buckets[pos % n_shards].push(sample_idx);
368 }
369 }
370
371 buckets
372 .into_iter()
373 .enumerate()
374 .map(|(shard_id, indices)| DataShard {
375 shard_id,
376 n_shards,
377 indices,
378 is_train: true,
379 })
380 .collect()
381}
382
383impl ShardedDataset {
388 pub fn new(n_samples: usize, config: ShardingConfig) -> Result<Self> {
394 if config.n_shards == 0 {
395 return Err(DatasetsError::InvalidFormat("n_shards must be >= 1".into()));
396 }
397 if n_samples == 0 {
398 return Err(DatasetsError::InvalidFormat(
399 "n_samples must be >= 1".into(),
400 ));
401 }
402
403 let shards = match &config.strategy {
404 ShardStrategy::Index => {
405 shard_by_index(n_samples, config.n_shards, config.shuffle, config.seed)
406 }
407 ShardStrategy::Hash => shard_by_hash(n_samples, config.n_shards),
408 ShardStrategy::Stratified { .. } => {
409 return Err(DatasetsError::InvalidFormat(
410 "Use ShardedDataset::new_stratified for Stratified strategy".into(),
411 ));
412 }
413 ShardStrategy::Size { shard_size_bytes } => {
414 let _ = shard_size_bytes; shard_by_index(n_samples, config.n_shards, config.shuffle, config.seed)
420 }
421 };
422
423 Ok(Self {
424 shards,
425 total_size: n_samples,
426 config,
427 })
428 }
429
430 pub fn new_stratified(labels: &[usize], config: ShardingConfig) -> Result<Self> {
432 if config.n_shards == 0 {
433 return Err(DatasetsError::InvalidFormat("n_shards must be >= 1".into()));
434 }
435 if labels.is_empty() {
436 return Err(DatasetsError::InvalidFormat(
437 "labels must not be empty".into(),
438 ));
439 }
440
441 let shards = shard_stratified(labels, config.n_shards, config.shuffle, config.seed);
442 let total_size = labels.len();
443
444 Ok(Self {
445 shards,
446 total_size,
447 config,
448 })
449 }
450
451 pub fn get_shard(&self, shard_id: usize) -> Option<&DataShard> {
453 self.shards.get(shard_id)
454 }
455
456 pub fn train_shards(&self, val_fraction: f64) -> (Vec<usize>, Vec<usize>) {
461 let n = self.shards.len();
462 if n == 0 {
463 return (Vec::new(), Vec::new());
464 }
465 let n_val = ((n as f64 * val_fraction).ceil() as usize).min(n);
466 let n_train = n - n_val;
467 let train_ids: Vec<usize> = (0..n_train).collect();
468 let val_ids: Vec<usize> = (n_train..n).collect();
469 (train_ids, val_ids)
470 }
471
472 pub fn shard_iter(&self, shard_id: usize) -> impl Iterator<Item = usize> + '_ {
476 let slice: &[usize] = match self.shards.get(shard_id) {
477 Some(shard) => &shard.indices,
478 None => &[],
479 };
480 slice.iter().copied()
481 }
482
483 pub fn n_shards(&self) -> usize {
485 self.shards.len()
486 }
487
488 pub fn total_samples(&self) -> usize {
490 self.shards.iter().map(|s| s.indices.len()).sum()
491 }
492}
493
494#[derive(Debug, Clone)]
500pub struct DatasetShard {
501 pub shard_id: usize,
503 pub total_shards: usize,
505 pub indices: Vec<usize>,
507 pub data: Vec<Vec<f64>>,
509 pub labels: Vec<usize>,
511}
512
513impl DatasetShard {
514 pub fn len(&self) -> usize {
516 self.indices.len()
517 }
518
519 pub fn is_empty(&self) -> bool {
521 self.indices.is_empty()
522 }
523
524 pub fn apply_f64(&self, data: &[Vec<f64>]) -> Vec<Vec<f64>> {
529 self.indices
530 .iter()
531 .filter(|&&i| i < data.len())
532 .map(|&i| data[i].clone())
533 .collect()
534 }
535
536 pub fn apply_labels(&self, labels: &[usize]) -> Vec<usize> {
541 self.indices
542 .iter()
543 .filter(|&&i| i < labels.len())
544 .map(|&i| labels[i])
545 .collect()
546 }
547}
548
549#[derive(Debug, Clone)]
577pub struct ShardedLoader {
578 pub total_samples: usize,
580 pub n_shards: usize,
582 pub seed: u64,
584}
585
586impl ShardedLoader {
587 pub fn new(total_samples: usize, n_shards: usize, seed: u64) -> Self {
593 Self {
594 total_samples,
595 n_shards,
596 seed,
597 }
598 }
599
600 pub fn global_permutation(&self) -> Vec<usize> {
605 consistent_shuffle(self.total_samples, self.seed)
606 }
607
608 pub fn get_shard(&self, shard_id: usize) -> DatasetShard {
613 if self.n_shards == 0 || self.total_samples == 0 || shard_id >= self.n_shards {
614 return DatasetShard {
615 shard_id,
616 total_shards: self.n_shards,
617 indices: Vec::new(),
618 data: Vec::new(),
619 labels: Vec::new(),
620 };
621 }
622
623 let permuted = self.global_permutation();
624 let base = self.total_samples / self.n_shards;
625 let remainder = self.total_samples % self.n_shards;
626
627 let mut offset = 0usize;
629 for id in 0..shard_id {
630 let extra = if id < remainder { 1 } else { 0 };
631 offset += base + extra;
632 }
633 let extra = if shard_id < remainder { 1 } else { 0 };
634 let size = base + extra;
635
636 let indices = permuted[offset..offset + size].to_vec();
637
638 DatasetShard {
639 shard_id,
640 total_shards: self.n_shards,
641 indices,
642 data: Vec::new(),
643 labels: Vec::new(),
644 }
645 }
646
647 pub fn verify_coverage(&self) -> bool {
652 if self.n_shards == 0 || self.total_samples == 0 {
653 return self.total_samples == 0;
654 }
655
656 let mut seen = vec![false; self.total_samples];
657 for shard_id in 0..self.n_shards {
658 let shard = self.get_shard(shard_id);
659 for &idx in &shard.indices {
660 if idx >= self.total_samples || seen[idx] {
661 return false;
662 }
663 seen[idx] = true;
664 }
665 }
666 seen.iter().all(|&v| v)
667 }
668}
669
670pub fn shard_dataset(
679 data: &[Vec<f64>],
680 labels: &[usize],
681 n_shards: usize,
682 seed: u64,
683) -> Result<Vec<DatasetShard>> {
684 let n = data.len();
685 if n != labels.len() {
686 return Err(DatasetsError::InvalidFormat(format!(
687 "data length ({}) != labels length ({})",
688 n,
689 labels.len()
690 )));
691 }
692 if n_shards == 0 {
693 return Err(DatasetsError::InvalidFormat("n_shards must be >= 1".into()));
694 }
695 if n == 0 {
696 return Ok(Vec::new());
697 }
698
699 let index_shards = shard_by_index(n, n_shards, true, seed);
700 Ok(build_dataset_shards(data, labels, &index_shards))
701}
702
703pub fn stratified_shard(
710 data: &[Vec<f64>],
711 labels: &[usize],
712 n_shards: usize,
713) -> Result<Vec<DatasetShard>> {
714 let n = data.len();
715 if n != labels.len() {
716 return Err(DatasetsError::InvalidFormat(format!(
717 "data length ({}) != labels length ({})",
718 n,
719 labels.len()
720 )));
721 }
722 if n_shards == 0 {
723 return Err(DatasetsError::InvalidFormat("n_shards must be >= 1".into()));
724 }
725 if n == 0 {
726 return Ok(Vec::new());
727 }
728
729 let index_shards = shard_stratified(labels, n_shards, false, 0);
730 Ok(build_dataset_shards(data, labels, &index_shards))
731}
732
733pub fn shuffled_shard(
741 data: &[Vec<f64>],
742 labels: &[usize],
743 n_shards: usize,
744 seed: u64,
745) -> Result<Vec<DatasetShard>> {
746 shard_dataset(data, labels, n_shards, seed)
747}
748
749pub fn merge_shards(shards: &[DatasetShard]) -> (Vec<Vec<f64>>, Vec<usize>) {
754 if shards.is_empty() {
755 return (Vec::new(), Vec::new());
756 }
757
758 let mut entries: Vec<(usize, &Vec<f64>, usize)> = Vec::new();
760 for shard in shards {
761 for (pos, &idx) in shard.indices.iter().enumerate() {
762 entries.push((idx, &shard.data[pos], shard.labels[pos]));
763 }
764 }
765
766 entries.sort_by_key(|(idx, _, _)| *idx);
768
769 let data: Vec<Vec<f64>> = entries.iter().map(|(_, d, _)| (*d).clone()).collect();
770 let labels: Vec<usize> = entries.iter().map(|(_, _, l)| *l).collect();
771 (data, labels)
772}
773
774fn build_dataset_shards(
776 data: &[Vec<f64>],
777 labels: &[usize],
778 index_shards: &[DataShard],
779) -> Vec<DatasetShard> {
780 index_shards
781 .iter()
782 .map(|is| {
783 let shard_data: Vec<Vec<f64>> = is.indices.iter().map(|&i| data[i].clone()).collect();
784 let shard_labels: Vec<usize> = is.indices.iter().map(|&i| labels[i]).collect();
785 DatasetShard {
786 shard_id: is.shard_id,
787 total_shards: is.n_shards,
788 indices: is.indices.clone(),
789 data: shard_data,
790 labels: shard_labels,
791 }
792 })
793 .collect()
794}
795
796#[cfg(test)]
801mod tests {
802 use super::*;
803
804 #[test]
805 fn test_shard_by_index_no_shuffle() {
806 let shards = shard_by_index(100, 4, false, 0);
807 assert_eq!(shards.len(), 4);
808 for shard in &shards {
809 assert_eq!(shard.indices.len(), 25);
810 }
811 let mut seen = [false; 100];
813 for shard in &shards {
814 for &i in &shard.indices {
815 assert!(!seen[i], "index {i} appears twice");
816 seen[i] = true;
817 }
818 }
819 assert!(seen.iter().all(|&v| v));
820 }
821
822 #[test]
823 fn test_shard_by_index_shuffle() {
824 let shards = shard_by_index(100, 4, true, 42);
825 assert_eq!(shards.len(), 4);
826 let total: usize = shards.iter().map(|s| s.len()).sum();
827 assert_eq!(total, 100);
828 }
829
830 #[test]
831 fn test_consistent_shuffle_determinism() {
832 let a = consistent_shuffle(50, 12345);
833 let b = consistent_shuffle(50, 12345);
834 assert_eq!(a, b);
835 let c = consistent_shuffle(50, 99999);
837 assert_ne!(a, c);
838 }
839
840 #[test]
841 fn test_consistent_shuffle_permutation() {
842 let n = 200;
843 let shuffled = consistent_shuffle(n, 7);
844 assert_eq!(shuffled.len(), n);
845 let mut sorted = shuffled.clone();
846 sorted.sort_unstable();
847 assert_eq!(sorted, (0..n).collect::<Vec<_>>());
848 }
849
850 #[test]
851 fn test_shard_by_hash() {
852 let shards = shard_by_hash(100, 4);
853 assert_eq!(shards.len(), 4);
854 assert!(shards[0].indices.iter().all(|&i| i % 4 == 0));
856 let total: usize = shards.iter().map(|s| s.len()).sum();
857 assert_eq!(total, 100);
858 }
859
860 #[test]
861 fn test_stratified_class_proportions() {
862 let mut labels = vec![0usize; 30];
864 labels.extend(vec![1usize; 20]);
865 let shards = shard_stratified(&labels, 5, false, 0);
866 assert_eq!(shards.len(), 5);
867 for shard in &shards {
869 assert_eq!(shard.indices.len(), 10);
870 }
871 }
872
873 #[test]
874 fn test_sharded_dataset_new() {
875 let config = ShardingConfig {
876 n_shards: 4,
877 strategy: ShardStrategy::Index,
878 shuffle: false,
879 seed: 0,
880 };
881 let ds = ShardedDataset::new(100, config).expect("should succeed");
882 assert_eq!(ds.n_shards(), 4);
883 assert_eq!(ds.total_samples(), 100);
884 }
885
886 #[test]
887 fn test_train_shards_split() {
888 let config = ShardingConfig {
889 n_shards: 8,
890 strategy: ShardStrategy::Index,
891 shuffle: false,
892 seed: 0,
893 };
894 let ds = ShardedDataset::new(80, config).expect("should succeed");
895 let (train, val) = ds.train_shards(0.25);
896 assert_eq!(train.len() + val.len(), 8);
897 assert_eq!(val.len(), 2); }
899
900 #[test]
901 fn test_shard_iter() {
902 let config = ShardingConfig {
903 n_shards: 4,
904 strategy: ShardStrategy::Index,
905 shuffle: false,
906 seed: 0,
907 };
908 let ds = ShardedDataset::new(40, config).expect("should succeed");
909 let collected: Vec<usize> = ds.shard_iter(0).collect();
910 assert_eq!(collected.len(), 10);
911 assert_eq!(collected, (0..10).collect::<Vec<_>>());
913 }
914
915 #[test]
916 fn test_shard_iter_out_of_bounds() {
917 let config = ShardingConfig::default();
918 let ds = ShardedDataset::new(10, config).expect("should succeed");
919 let empty: Vec<usize> = ds.shard_iter(999).collect();
920 assert!(empty.is_empty());
921 }
922
923 #[test]
924 fn test_sharded_dataset_invalid_config() {
925 let bad_config = ShardingConfig {
926 n_shards: 0,
927 ..Default::default()
928 };
929 assert!(ShardedDataset::new(100, bad_config).is_err());
930 }
931
932 #[test]
933 fn test_shard_id_assignment() {
934 let shards = shard_by_index(100, 4, false, 0);
935 for (expected_id, shard) in shards.iter().enumerate() {
936 assert_eq!(shard.shard_id, expected_id);
937 assert_eq!(shard.n_shards, 4);
938 }
939 }
940
941 #[test]
942 fn test_stratified_new_stratified() {
943 let labels: Vec<usize> = (0..60).map(|i| i % 3).collect();
944 let config = ShardingConfig {
945 n_shards: 3,
946 strategy: ShardStrategy::Stratified {
947 label_column: "class".into(),
948 },
949 shuffle: false,
950 seed: 0,
951 };
952 let ds = ShardedDataset::new_stratified(&labels, config).expect("ok");
953 assert_eq!(ds.n_shards(), 3);
954 assert_eq!(ds.total_samples(), 60);
955 }
956
957 fn make_test_data(n: usize) -> (Vec<Vec<f64>>, Vec<usize>) {
960 let data: Vec<Vec<f64>> = (0..n).map(|i| vec![i as f64, (i * 2) as f64]).collect();
961 let labels: Vec<usize> = (0..n).map(|i| i % 3).collect();
962 (data, labels)
963 }
964
965 #[test]
966 fn test_shard_dataset_total_samples() {
967 let (data, labels) = make_test_data(100);
968 let shards = shard_dataset(&data, &labels, 4, 42).expect("ok");
969 assert_eq!(shards.len(), 4);
970 let total: usize = shards.iter().map(|s| s.len()).sum();
971 assert_eq!(total, 100);
972 }
973
974 #[test]
975 fn test_stratified_shard_label_proportions() {
976 let mut labels = vec![0usize; 60];
978 labels.extend(vec![1usize; 40]);
979 let data: Vec<Vec<f64>> = (0..100).map(|i| vec![i as f64]).collect();
980 let shards = stratified_shard(&data, &labels, 5).expect("ok");
981 assert_eq!(shards.len(), 5);
982 for shard in &shards {
983 let c0 = shard.labels.iter().filter(|&&l| l == 0).count();
984 let c1 = shard.labels.iter().filter(|&&l| l == 1).count();
985 assert_eq!(c0, 12, "Expected 12 class-0 per shard, got {c0}");
987 assert_eq!(c1, 8, "Expected 8 class-1 per shard, got {c1}");
988 }
989 }
990
991 #[test]
992 fn test_merge_shards_recovers_data() {
993 let (data, labels) = make_test_data(50);
994 let shards = shard_dataset(&data, &labels, 5, 99).expect("ok");
995 let (merged_data, merged_labels) = merge_shards(&shards);
996 assert_eq!(merged_data.len(), 50);
997 assert_eq!(merged_labels.len(), 50);
998 for i in 0..50 {
1000 assert_eq!(merged_data[i], data[i], "Data mismatch at index {i}");
1001 assert_eq!(merged_labels[i], labels[i], "Label mismatch at index {i}");
1002 }
1003 }
1004
1005 #[test]
1006 fn test_shuffled_shard_determinism() {
1007 let (data, labels) = make_test_data(30);
1008 let s1 = shuffled_shard(&data, &labels, 3, 42).expect("ok");
1009 let s2 = shuffled_shard(&data, &labels, 3, 42).expect("ok");
1010 for (a, b) in s1.iter().zip(s2.iter()) {
1011 assert_eq!(a.indices, b.indices);
1012 }
1013 }
1014
1015 #[test]
1016 fn test_shard_dataset_error_on_mismatch() {
1017 let data = vec![vec![1.0]; 10];
1018 let labels = vec![0; 5];
1019 assert!(shard_dataset(&data, &labels, 2, 0).is_err());
1020 }
1021
1022 #[test]
1023 fn test_merge_empty_shards() {
1024 let (data, labels) = merge_shards(&[]);
1025 assert!(data.is_empty());
1026 assert!(labels.is_empty());
1027 }
1028
1029 #[test]
1033 fn test_sharded_loader_verify_coverage() {
1034 let loader = ShardedLoader::new(100, 4, 42);
1035 assert!(
1036 loader.verify_coverage(),
1037 "all 100 samples should be covered"
1038 );
1039 }
1040
1041 #[test]
1043 fn test_sharded_loader_balanced_sizes() {
1044 let loader = ShardedLoader::new(101, 4, 7); let sizes: Vec<usize> = (0..4).map(|id| loader.get_shard(id).len()).collect();
1046 let min_size = *sizes.iter().min().expect("non-empty");
1047 let max_size = *sizes.iter().max().expect("non-empty");
1048 assert!(
1049 max_size - min_size <= 1,
1050 "shard sizes differ by more than 1: {sizes:?}"
1051 );
1052 let total: usize = sizes.iter().sum();
1053 assert_eq!(total, 101, "total should equal n_samples");
1054 }
1055
1056 #[test]
1058 fn test_sharded_loader_disjoint_shards() {
1059 let loader = ShardedLoader::new(100, 4, 99);
1060 let shard0 = loader.get_shard(0);
1061 let shard1 = loader.get_shard(1);
1062 for &i in &shard0.indices {
1063 assert!(
1064 !shard1.indices.contains(&i),
1065 "index {i} appears in both shard 0 and shard 1"
1066 );
1067 }
1068 }
1069
1070 #[test]
1072 fn test_sharded_loader_same_seed_same_permutation() {
1073 let loader = ShardedLoader::new(100, 4, 12345);
1074 let p1 = loader.global_permutation();
1075 let p2 = loader.global_permutation();
1076 assert_eq!(p1, p2, "same seed should give same permutation");
1077
1078 let loader2 = ShardedLoader::new(100, 4, 12345);
1079 let p3 = loader2.global_permutation();
1080 assert_eq!(p1, p3, "independent loader with same seed should match");
1081 }
1082
1083 #[test]
1085 fn test_dataset_shard_apply_f64() {
1086 let data: Vec<Vec<f64>> = (0..100).map(|i| vec![i as f64, (i * 2) as f64]).collect();
1087 let loader = ShardedLoader::new(100, 4, 42);
1088 let shard = loader.get_shard(0);
1089 let subset = shard.apply_f64(&data);
1090 assert_eq!(
1091 subset.len(),
1092 shard.len(),
1093 "apply_f64 should return exactly shard.len() rows"
1094 );
1095 for row in &subset {
1097 assert_eq!(row.len(), 2, "each row should have 2 features");
1098 }
1099 }
1100
1101 #[test]
1103 fn test_dataset_shard_apply_labels() {
1104 let labels: Vec<usize> = (0..100).map(|i| i % 3).collect();
1105 let loader = ShardedLoader::new(100, 4, 42);
1106 let shard = loader.get_shard(2);
1107 let subset = shard.apply_labels(&labels);
1108 assert_eq!(
1109 subset.len(),
1110 shard.len(),
1111 "apply_labels should return exactly shard.len() labels"
1112 );
1113 }
1114
1115 #[test]
1117 fn test_sharded_loader_single_shard_coverage() {
1118 let loader = ShardedLoader::new(50, 1, 0);
1119 assert!(loader.verify_coverage());
1120 let shard = loader.get_shard(0);
1121 assert_eq!(shard.len(), 50);
1122 }
1123
1124 #[test]
1126 fn test_sharded_loader_out_of_range_shard() {
1127 let loader = ShardedLoader::new(100, 4, 42);
1128 let empty_shard = loader.get_shard(99);
1129 assert!(
1130 empty_shard.is_empty(),
1131 "out-of-range shard_id should give empty shard"
1132 );
1133 }
1134}