1use crate::error::{DatasetsError, Result};
9use crate::utils::Dataset;
10use scirs2_core::ndarray::{Array1, Array2};
11use scirs2_core::random::prelude::*;
12use scirs2_core::random::rand_distributions::Distribution;
13
14fn create_rng(randomseed: Option<u64>) -> StdRng {
16 match randomseed {
17 Some(seed) => StdRng::seed_from_u64(seed),
18 None => {
19 let mut r = thread_rng();
20 StdRng::seed_from_u64(r.next_u64())
21 }
22 }
23}
24
25#[derive(Debug, Clone)]
27pub struct ClassificationConfig {
28 pub n_samples: usize,
30 pub n_features: usize,
32 pub n_informative: usize,
34 pub n_redundant: usize,
36 pub n_repeated: usize,
38 pub n_classes: usize,
40 pub n_clusters_per_class: usize,
42 pub flip_y: f64,
44 pub class_sep: f64,
46 pub shuffle: bool,
48 pub random_state: Option<u64>,
50}
51
52impl Default for ClassificationConfig {
53 fn default() -> Self {
54 Self {
55 n_samples: 100,
56 n_features: 20,
57 n_informative: 2,
58 n_redundant: 2,
59 n_repeated: 0,
60 n_classes: 2,
61 n_clusters_per_class: 2,
62 flip_y: 0.01,
63 class_sep: 1.0,
64 shuffle: true,
65 random_state: None,
66 }
67 }
68}
69
70pub fn make_classification_enhanced(config: ClassificationConfig) -> Result<Dataset> {
105 if config.n_samples == 0 {
107 return Err(DatasetsError::InvalidFormat(
108 "n_samples must be > 0".to_string(),
109 ));
110 }
111 if config.n_features == 0 {
112 return Err(DatasetsError::InvalidFormat(
113 "n_features must be > 0".to_string(),
114 ));
115 }
116 if config.n_informative == 0 {
117 return Err(DatasetsError::InvalidFormat(
118 "n_informative must be > 0".to_string(),
119 ));
120 }
121 if config.n_classes < 2 {
122 return Err(DatasetsError::InvalidFormat(
123 "n_classes must be >= 2".to_string(),
124 ));
125 }
126 if config.n_clusters_per_class == 0 {
127 return Err(DatasetsError::InvalidFormat(
128 "n_clusters_per_class must be > 0".to_string(),
129 ));
130 }
131 let total_useful = config.n_informative + config.n_redundant + config.n_repeated;
132 if total_useful > config.n_features {
133 return Err(DatasetsError::InvalidFormat(format!(
134 "n_informative ({}) + n_redundant ({}) + n_repeated ({}) = {} must be <= n_features ({})",
135 config.n_informative,
136 config.n_redundant,
137 config.n_repeated,
138 total_useful,
139 config.n_features
140 )));
141 }
142 if config.n_informative < config.n_classes {
143 return Err(DatasetsError::InvalidFormat(format!(
144 "n_informative ({}) must be >= n_classes ({})",
145 config.n_informative, config.n_classes
146 )));
147 }
148 if config.flip_y < 0.0 || config.flip_y > 1.0 {
149 return Err(DatasetsError::InvalidFormat(
150 "flip_y must be in [0, 1]".to_string(),
151 ));
152 }
153
154 let mut rng = create_rng(config.random_state);
155
156 let normal = scirs2_core::random::Normal::new(0.0, 1.0).map_err(|e| {
157 DatasetsError::ComputationError(format!("Failed to create normal dist: {e}"))
158 })?;
159
160 let n_noise = config.n_features - config.n_informative - config.n_redundant - config.n_repeated;
161
162 let n_centroids = config.n_classes * config.n_clusters_per_class;
164 let mut centroids = Array2::zeros((n_centroids, config.n_informative));
165
166 for i in 0..n_centroids {
167 for j in 0..config.n_informative {
168 centroids[[i, j]] = config.class_sep * (2.0 * rng.random::<f64>() - 1.0);
169 }
170 }
171
172 let mut informative = Array2::zeros((config.n_samples, config.n_informative));
174 let mut target = Array1::zeros(config.n_samples);
175
176 let samples_per_class = config.n_samples / config.n_classes;
177 let remainder = config.n_samples % config.n_classes;
178 let mut idx = 0;
179
180 for class_idx in 0..config.n_classes {
181 let n_samples_class = if class_idx < remainder {
182 samples_per_class + 1
183 } else {
184 samples_per_class
185 };
186 let spc = n_samples_class / config.n_clusters_per_class;
187 let spc_rem = n_samples_class % config.n_clusters_per_class;
188
189 for cluster_idx in 0..config.n_clusters_per_class {
190 let n_cluster = if cluster_idx < spc_rem { spc + 1 } else { spc };
191 let centroid_idx = class_idx * config.n_clusters_per_class + cluster_idx;
192
193 for _ in 0..n_cluster {
194 for j in 0..config.n_informative {
195 informative[[idx, j]] =
196 centroids[[centroid_idx, j]] + 0.5 * normal.sample(&mut rng);
197 }
198 target[idx] = class_idx as f64;
199 idx += 1;
200 }
201 }
202 }
203
204 let mut redundant = Array2::zeros((config.n_samples, config.n_redundant));
206 if config.n_redundant > 0 {
207 let mut mixing = Array2::zeros((config.n_informative, config.n_redundant));
209 for i in 0..config.n_informative {
210 for j in 0..config.n_redundant {
211 mixing[[i, j]] = normal.sample(&mut rng);
212 }
213 }
214 for i in 0..config.n_samples {
216 for j in 0..config.n_redundant {
217 let mut val = 0.0;
218 for k in 0..config.n_informative {
219 val += informative[[i, k]] * mixing[[k, j]];
220 }
221 redundant[[i, j]] = val;
222 }
223 }
224 }
225
226 let mut repeated = Array2::zeros((config.n_samples, config.n_repeated));
228 if config.n_repeated > 0 {
229 let source_cols = config.n_informative + config.n_redundant;
230 for j in 0..config.n_repeated {
231 let src_j = j % source_cols;
232 for i in 0..config.n_samples {
233 if src_j < config.n_informative {
234 repeated[[i, j]] = informative[[i, src_j]];
235 } else {
236 repeated[[i, j]] = redundant[[i, src_j - config.n_informative]];
237 }
238 }
239 }
240 }
241
242 let mut noise_features = Array2::zeros((config.n_samples, n_noise));
244 for i in 0..config.n_samples {
245 for j in 0..n_noise {
246 noise_features[[i, j]] = normal.sample(&mut rng);
247 }
248 }
249
250 let mut data = Array2::zeros((config.n_samples, config.n_features));
252 for i in 0..config.n_samples {
253 let mut col = 0;
254 for j in 0..config.n_informative {
255 data[[i, col]] = informative[[i, j]];
256 col += 1;
257 }
258 for j in 0..config.n_redundant {
259 data[[i, col]] = redundant[[i, j]];
260 col += 1;
261 }
262 for j in 0..config.n_repeated {
263 data[[i, col]] = repeated[[i, j]];
264 col += 1;
265 }
266 for j in 0..n_noise {
267 data[[i, col]] = noise_features[[i, j]];
268 col += 1;
269 }
270 }
271
272 if config.flip_y > 0.0 {
274 let uniform = scirs2_core::random::Uniform::new(0.0, 1.0).map_err(|e| {
275 DatasetsError::ComputationError(format!("Failed to create uniform dist: {e}"))
276 })?;
277 for i in 0..config.n_samples {
278 if uniform.sample(&mut rng) < config.flip_y {
279 let current = target[i] as usize;
281 let mut new_class = rng.random_range(0..config.n_classes);
282 while new_class == current && config.n_classes > 1 {
283 new_class = rng.random_range(0..config.n_classes);
284 }
285 target[i] = new_class as f64;
286 }
287 }
288 }
289
290 if config.shuffle {
292 let n = config.n_samples;
293 for i in (1..n).rev() {
295 let j = rng.random_range(0..=i);
296 if i != j {
297 for col in 0..config.n_features {
299 let tmp = data[[i, col]];
300 data[[i, col]] = data[[j, col]];
301 data[[j, col]] = tmp;
302 }
303 let tmp = target[i];
305 target[i] = target[j];
306 target[j] = tmp;
307 }
308 }
309 }
310
311 let mut feature_names = Vec::with_capacity(config.n_features);
313 for j in 0..config.n_informative {
314 feature_names.push(format!("informative_{j}"));
315 }
316 for j in 0..config.n_redundant {
317 feature_names.push(format!("redundant_{j}"));
318 }
319 for j in 0..config.n_repeated {
320 feature_names.push(format!("repeated_{j}"));
321 }
322 for j in 0..n_noise {
323 feature_names.push(format!("noise_{j}"));
324 }
325
326 let class_names: Vec<String> = (0..config.n_classes)
327 .map(|i| format!("class_{i}"))
328 .collect();
329
330 let dataset = Dataset::new(data, Some(target))
331 .with_featurenames(feature_names)
332 .with_targetnames(class_names)
333 .with_description(format!(
334 "Enhanced classification dataset: {} samples, {} features ({} informative, {} redundant, {} repeated, {} noise), {} classes",
335 config.n_samples, config.n_features, config.n_informative,
336 config.n_redundant, config.n_repeated, n_noise, config.n_classes
337 ))
338 .with_metadata("n_informative", &config.n_informative.to_string())
339 .with_metadata("n_redundant", &config.n_redundant.to_string())
340 .with_metadata("n_repeated", &config.n_repeated.to_string())
341 .with_metadata("n_noise", &n_noise.to_string())
342 .with_metadata("class_sep", &config.class_sep.to_string())
343 .with_metadata("flip_y", &config.flip_y.to_string());
344
345 Ok(dataset)
346}
347
348#[derive(Debug, Clone)]
350pub struct MultilabelConfig {
351 pub n_samples: usize,
353 pub n_features: usize,
355 pub n_classes: usize,
357 pub n_labels: usize,
359 pub allow_unlabeled: bool,
361 pub random_state: Option<u64>,
363}
364
365impl Default for MultilabelConfig {
366 fn default() -> Self {
367 Self {
368 n_samples: 100,
369 n_features: 20,
370 n_classes: 5,
371 n_labels: 2,
372 allow_unlabeled: true,
373 random_state: None,
374 }
375 }
376}
377
378#[derive(Debug, Clone)]
383pub struct MultilabelDataset {
384 pub data: Array2<f64>,
386 pub target: Array2<f64>,
388 pub feature_names: Vec<String>,
390 pub class_names: Vec<String>,
392 pub description: String,
394}
395
396pub fn make_multilabel_classification(config: MultilabelConfig) -> Result<MultilabelDataset> {
433 if config.n_samples == 0 {
434 return Err(DatasetsError::InvalidFormat(
435 "n_samples must be > 0".to_string(),
436 ));
437 }
438 if config.n_features == 0 {
439 return Err(DatasetsError::InvalidFormat(
440 "n_features must be > 0".to_string(),
441 ));
442 }
443 if config.n_classes == 0 {
444 return Err(DatasetsError::InvalidFormat(
445 "n_classes must be > 0".to_string(),
446 ));
447 }
448 if config.n_labels == 0 || config.n_labels > config.n_classes {
449 return Err(DatasetsError::InvalidFormat(format!(
450 "n_labels ({}) must be in [1, n_classes ({})]",
451 config.n_labels, config.n_classes
452 )));
453 }
454
455 let mut rng = create_rng(config.random_state);
456
457 let normal = scirs2_core::random::Normal::new(0.0, 1.0).map_err(|e| {
458 DatasetsError::ComputationError(format!("Failed to create normal dist: {e}"))
459 })?;
460
461 let mut centers = Array2::zeros((config.n_classes, config.n_features));
463 for i in 0..config.n_classes {
464 for j in 0..config.n_features {
465 centers[[i, j]] = 3.0 * normal.sample(&mut rng);
466 }
467 }
468
469 let mut data = Array2::zeros((config.n_samples, config.n_features));
471 let mut target_matrix = Array2::zeros((config.n_samples, config.n_classes));
472
473 for i in 0..config.n_samples {
474 let mut labels: Vec<usize> = Vec::with_capacity(config.n_labels);
476 while labels.len() < config.n_labels {
477 let candidate = rng.random_range(0..config.n_classes);
478 if !labels.contains(&candidate) {
479 labels.push(candidate);
480 }
481 }
482
483 if !config.allow_unlabeled && labels.is_empty() {
485 labels.push(rng.random_range(0..config.n_classes));
486 }
487
488 for j in 0..config.n_features {
490 let mut val = 0.0;
491 for &label in &labels {
492 val += centers[[label, j]];
493 }
494 val /= labels.len() as f64;
495 val += normal.sample(&mut rng); data[[i, j]] = val;
497 }
498
499 for &label in &labels {
501 target_matrix[[i, label]] = 1.0;
502 }
503 }
504
505 let feature_names: Vec<String> = (0..config.n_features)
506 .map(|j| format!("feature_{j}"))
507 .collect();
508 let class_names: Vec<String> = (0..config.n_classes)
509 .map(|j| format!("label_{j}"))
510 .collect();
511
512 Ok(MultilabelDataset {
513 data,
514 target: target_matrix,
515 feature_names,
516 class_names,
517 description: format!(
518 "Multi-label classification dataset: {} samples, {} features, {} classes, ~{} labels per sample",
519 config.n_samples, config.n_features, config.n_classes, config.n_labels
520 ),
521 })
522}
523
524pub fn make_hastie_10_2(n_samples: usize, random_state: Option<u64>) -> Result<Dataset> {
553 if n_samples == 0 {
554 return Err(DatasetsError::InvalidFormat(
555 "n_samples must be > 0".to_string(),
556 ));
557 }
558
559 let mut rng = create_rng(random_state);
560
561 let normal = scirs2_core::random::Normal::new(0.0, 1.0).map_err(|e| {
562 DatasetsError::ComputationError(format!("Failed to create normal dist: {e}"))
563 })?;
564
565 let n_features = 10;
566 let chi2_median = 9.3418;
568
569 let mut data = Array2::zeros((n_samples, n_features));
570 let mut target = Array1::zeros(n_samples);
571
572 for i in 0..n_samples {
573 let mut sum_sq = 0.0;
574 for j in 0..n_features {
575 let val = normal.sample(&mut rng);
576 data[[i, j]] = val;
577 sum_sq += val * val;
578 }
579
580 target[i] = if sum_sq > chi2_median { 1.0 } else { -1.0 };
581 }
582
583 let feature_names: Vec<String> = (0..n_features).map(|j| format!("x_{j}")).collect();
584
585 let dataset = Dataset::new(data, Some(target))
586 .with_featurenames(feature_names)
587 .with_targetnames(vec!["-1".to_string(), "1".to_string()])
588 .with_description(
589 "Hastie et al. 10.2 binary classification dataset. \
590 Features are standard normal; y=1 if sum(x_i^2) > 9.34 (chi2(10) median), else y=-1. \
591 Reference: Hastie, Tibshirani, Friedman (2009) The Elements of Statistical Learning."
592 .to_string(),
593 )
594 .with_metadata("chi2_median_threshold", &chi2_median.to_string())
595 .with_metadata("n_features", &n_features.to_string());
596
597 Ok(dataset)
598}
599
600#[cfg(test)]
601mod tests {
602 use super::*;
603
604 #[test]
609 fn test_classification_enhanced_basic() {
610 let config = ClassificationConfig {
611 n_samples: 200,
612 n_features: 20,
613 n_informative: 5,
614 n_redundant: 3,
615 n_repeated: 2,
616 n_classes: 3,
617 random_state: Some(42),
618 ..Default::default()
619 };
620 let ds = make_classification_enhanced(config).expect("should succeed");
621 assert_eq!(ds.n_samples(), 200);
622 assert_eq!(ds.n_features(), 20);
623 assert!(ds.target.is_some());
624 let target = ds.target.as_ref().expect("target present");
625 assert_eq!(target.len(), 200);
626 for &val in target.iter() {
628 assert!((0.0..3.0).contains(&val), "Invalid class label: {val}");
629 }
630 }
631
632 #[test]
633 fn test_classification_enhanced_feature_names() {
634 let config = ClassificationConfig {
635 n_samples: 50,
636 n_features: 10,
637 n_informative: 3,
638 n_redundant: 2,
639 n_repeated: 1,
640 n_classes: 2,
641 random_state: Some(42),
642 ..Default::default()
643 };
644 let ds = make_classification_enhanced(config).expect("should succeed");
645 let names = ds.featurenames.as_ref().expect("names present");
646 assert_eq!(names.len(), 10);
647 assert!(names[0].starts_with("informative_"));
648 assert!(names[3].starts_with("redundant_"));
649 assert!(names[5].starts_with("repeated_"));
650 assert!(names[6].starts_with("noise_"));
651 }
652
653 #[test]
654 fn test_classification_enhanced_reproducibility() {
655 let make = || {
656 let config = ClassificationConfig {
657 n_samples: 50,
658 n_features: 10,
659 n_informative: 3,
660 n_redundant: 2,
661 n_repeated: 0,
662 n_classes: 2,
663 flip_y: 0.0,
664 shuffle: false,
665 random_state: Some(123),
666 ..Default::default()
667 };
668 make_classification_enhanced(config).expect("should succeed")
669 };
670 let ds1 = make();
671 let ds2 = make();
672 for i in 0..50 {
673 for j in 0..10 {
674 assert!(
675 (ds1.data[[i, j]] - ds2.data[[i, j]]).abs() < 1e-15,
676 "Reproducibility failed at ({i},{j})"
677 );
678 }
679 }
680 }
681
682 #[test]
683 fn test_classification_enhanced_validation() {
684 let cfg = ClassificationConfig {
686 n_samples: 0,
687 ..Default::default()
688 };
689 assert!(make_classification_enhanced(cfg).is_err());
690
691 let cfg = ClassificationConfig {
693 n_features: 5,
694 n_informative: 3,
695 n_redundant: 2,
696 n_repeated: 2,
697 ..Default::default()
698 };
699 assert!(make_classification_enhanced(cfg).is_err());
700
701 let cfg = ClassificationConfig {
703 n_informative: 2,
704 n_classes: 5,
705 ..Default::default()
706 };
707 assert!(make_classification_enhanced(cfg).is_err());
708 }
709
710 #[test]
711 fn test_classification_enhanced_redundant_correlation() {
712 let config = ClassificationConfig {
714 n_samples: 500,
715 n_features: 10,
716 n_informative: 5,
717 n_redundant: 3,
718 n_repeated: 0,
719 n_classes: 2,
720 flip_y: 0.0,
721 shuffle: false,
722 random_state: Some(42),
723 ..Default::default()
724 };
725 let ds = make_classification_enhanced(config).expect("should succeed");
726
727 let col5: Vec<f64> = (0..500).map(|i| ds.data[[i, 5]]).collect();
729 let mean5: f64 = col5.iter().sum::<f64>() / 500.0;
730 let var5: f64 = col5.iter().map(|x| (x - mean5).powi(2)).sum::<f64>() / 499.0;
731 assert!(var5 > 0.01, "Redundant feature variance too low: {var5}");
733 }
734
735 #[test]
736 fn test_classification_enhanced_flip_y() {
737 let config = ClassificationConfig {
739 n_samples: 1000,
740 n_features: 5,
741 n_informative: 3,
742 n_redundant: 0,
743 n_repeated: 0,
744 n_classes: 2,
745 flip_y: 0.0,
746 shuffle: false,
747 random_state: Some(42),
748 ..Default::default()
749 };
750 let ds_no_flip = make_classification_enhanced(config).expect("should succeed");
751
752 let config_flip = ClassificationConfig {
753 n_samples: 1000,
754 n_features: 5,
755 n_informative: 3,
756 n_redundant: 0,
757 n_repeated: 0,
758 n_classes: 2,
759 flip_y: 0.5,
760 shuffle: false,
761 random_state: Some(42),
762 ..Default::default()
763 };
764 let ds_flip = make_classification_enhanced(config_flip).expect("should succeed");
765
766 let n_different = (0..1000)
768 .filter(|&i| {
769 let t1 = ds_no_flip.target.as_ref().expect("target")[i];
770 let t2 = ds_flip.target.as_ref().expect("target")[i];
771 (t1 - t2).abs() > 0.5
772 })
773 .count();
774 assert!(
778 n_different > 0,
779 "Expected some labels to differ with flip_y=0.5"
780 );
781 }
782
783 #[test]
788 fn test_multilabel_basic() {
789 let config = MultilabelConfig {
790 n_samples: 100,
791 n_features: 10,
792 n_classes: 5,
793 n_labels: 2,
794 random_state: Some(42),
795 ..Default::default()
796 };
797 let ds = make_multilabel_classification(config).expect("should succeed");
798 assert_eq!(ds.data.nrows(), 100);
799 assert_eq!(ds.data.ncols(), 10);
800 assert_eq!(ds.target.nrows(), 100);
801 assert_eq!(ds.target.ncols(), 5);
802 }
803
804 #[test]
805 fn test_multilabel_binary_targets() {
806 let config = MultilabelConfig {
807 n_samples: 50,
808 n_features: 5,
809 n_classes: 3,
810 n_labels: 2,
811 random_state: Some(42),
812 ..Default::default()
813 };
814 let ds = make_multilabel_classification(config).expect("should succeed");
815 for i in 0..50 {
817 for j in 0..3 {
818 let val = ds.target[[i, j]];
819 assert!(
820 val == 0.0 || val == 1.0,
821 "Target entry at ({i},{j}) should be binary, got {val}"
822 );
823 }
824 }
825 }
826
827 #[test]
828 fn test_multilabel_labels_per_sample() {
829 let config = MultilabelConfig {
830 n_samples: 200,
831 n_features: 5,
832 n_classes: 6,
833 n_labels: 3,
834 random_state: Some(42),
835 ..Default::default()
836 };
837 let ds = make_multilabel_classification(config).expect("should succeed");
838 for i in 0..200 {
840 let label_count: f64 = (0..6).map(|j| ds.target[[i, j]]).sum();
841 assert_eq!(
842 label_count, 3.0,
843 "Sample {i} should have 3 labels, got {label_count}"
844 );
845 }
846 }
847
848 #[test]
849 fn test_multilabel_reproducibility() {
850 let make = || {
851 let config = MultilabelConfig {
852 n_samples: 30,
853 n_features: 5,
854 n_classes: 3,
855 n_labels: 1,
856 random_state: Some(77),
857 ..Default::default()
858 };
859 make_multilabel_classification(config).expect("should succeed")
860 };
861 let ds1 = make();
862 let ds2 = make();
863 for i in 0..30 {
864 for j in 0..5 {
865 assert!(
866 (ds1.data[[i, j]] - ds2.data[[i, j]]).abs() < 1e-15,
867 "Reproducibility failed at ({i},{j})"
868 );
869 }
870 }
871 }
872
873 #[test]
874 fn test_multilabel_validation() {
875 let cfg = MultilabelConfig {
876 n_samples: 0,
877 ..Default::default()
878 };
879 assert!(make_multilabel_classification(cfg).is_err());
880
881 let cfg = MultilabelConfig {
882 n_labels: 0,
883 ..Default::default()
884 };
885 assert!(make_multilabel_classification(cfg).is_err());
886
887 let cfg = MultilabelConfig {
888 n_labels: 10,
889 n_classes: 3,
890 ..Default::default()
891 };
892 assert!(make_multilabel_classification(cfg).is_err());
893 }
894
895 #[test]
900 fn test_hastie_basic() {
901 let ds = make_hastie_10_2(1000, Some(42)).expect("should succeed");
902 assert_eq!(ds.n_samples(), 1000);
903 assert_eq!(ds.n_features(), 10);
904 assert!(ds.target.is_some());
905 }
906
907 #[test]
908 fn test_hastie_binary_labels() {
909 let ds = make_hastie_10_2(500, Some(42)).expect("should succeed");
910 let target = ds.target.as_ref().expect("target present");
911 for &val in target.iter() {
912 assert!(
913 val == -1.0 || val == 1.0,
914 "Hastie labels should be -1 or 1, got {val}"
915 );
916 }
917 }
918
919 #[test]
920 fn test_hastie_balanced_classes() {
921 let ds = make_hastie_10_2(10000, Some(42)).expect("should succeed");
923 let target = ds.target.as_ref().expect("target present");
924 let n_positive = target.iter().filter(|&&v| v > 0.0).count();
925 let n_negative = target.len() - n_positive;
926 let ratio = n_positive as f64 / n_negative as f64;
928 assert!(
929 ratio > 0.7 && ratio < 1.4,
930 "Classes should be roughly balanced, got ratio {ratio} (pos={n_positive}, neg={n_negative})"
931 );
932 }
933
934 #[test]
935 fn test_hastie_feature_stats() {
936 let ds = make_hastie_10_2(5000, Some(42)).expect("should succeed");
938 for j in 0..10 {
939 let col: Vec<f64> = (0..5000).map(|i| ds.data[[i, j]]).collect();
940 let mean: f64 = col.iter().sum::<f64>() / 5000.0;
941 let var: f64 = col.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / 4999.0;
942 assert!(
943 mean.abs() < 0.1,
944 "Feature {j} mean should be ~0, got {mean}"
945 );
946 assert!(
947 (var - 1.0).abs() < 0.15,
948 "Feature {j} variance should be ~1, got {var}"
949 );
950 }
951 }
952
953 #[test]
954 fn test_hastie_reproducibility() {
955 let ds1 = make_hastie_10_2(100, Some(99)).expect("should succeed");
956 let ds2 = make_hastie_10_2(100, Some(99)).expect("should succeed");
957 for i in 0..100 {
958 for j in 0..10 {
959 assert!(
960 (ds1.data[[i, j]] - ds2.data[[i, j]]).abs() < 1e-15,
961 "Reproducibility failed at ({i},{j})"
962 );
963 }
964 }
965 }
966
967 #[test]
968 fn test_hastie_validation() {
969 assert!(make_hastie_10_2(0, None).is_err());
970 }
971
972 #[test]
973 fn test_hastie_description() {
974 let ds = make_hastie_10_2(100, Some(42)).expect("should succeed");
975 assert!(ds.description.is_some());
976 let desc = ds.description.as_ref().expect("desc present");
977 assert!(desc.contains("Hastie"));
978 }
979}