Skip to main content

scirs2_datasets/generators/
classification.rs

1//! Advanced classification dataset generators
2//!
3//! Provides sklearn-style synthetic classification generators including
4//! multi-label classification, Hastie et al. binary classification,
5//! and enhanced n-class classification with informative, redundant,
6//! and noise features.
7
8use crate::error::{DatasetsError, Result};
9use crate::utils::Dataset;
10use scirs2_core::ndarray::{Array1, Array2};
11use scirs2_core::random::prelude::*;
12use scirs2_core::random::rand_distributions::Distribution;
13
14/// Helper to create an RNG from an optional seed
15fn create_rng(randomseed: Option<u64>) -> StdRng {
16    match randomseed {
17        Some(seed) => StdRng::seed_from_u64(seed),
18        None => {
19            let mut r = thread_rng();
20            StdRng::seed_from_u64(r.next_u64())
21        }
22    }
23}
24
25/// Configuration for the enhanced classification generator
26#[derive(Debug, Clone)]
27pub struct ClassificationConfig {
28    /// Number of samples
29    pub n_samples: usize,
30    /// Total number of features
31    pub n_features: usize,
32    /// Number of informative features
33    pub n_informative: usize,
34    /// Number of redundant features (linear combinations of informative)
35    pub n_redundant: usize,
36    /// Number of repeated features (duplicates of informative + redundant)
37    pub n_repeated: usize,
38    /// Number of classes
39    pub n_classes: usize,
40    /// Number of clusters per class
41    pub n_clusters_per_class: usize,
42    /// Fraction of labels to flip (noise in labels)
43    pub flip_y: f64,
44    /// Scale of the hypercube containing the clusters
45    pub class_sep: f64,
46    /// Whether to shuffle the samples and features
47    pub shuffle: bool,
48    /// Optional random seed for reproducibility
49    pub random_state: Option<u64>,
50}
51
52impl Default for ClassificationConfig {
53    fn default() -> Self {
54        Self {
55            n_samples: 100,
56            n_features: 20,
57            n_informative: 2,
58            n_redundant: 2,
59            n_repeated: 0,
60            n_classes: 2,
61            n_clusters_per_class: 2,
62            flip_y: 0.01,
63            class_sep: 1.0,
64            shuffle: true,
65            random_state: None,
66        }
67    }
68}
69
70/// Generate an enhanced random n-class classification problem
71///
72/// This is a more featureful version of `make_classification` in `basic.rs`,
73/// following the sklearn interface more closely. It creates a dataset with
74/// informative features, redundant features (linear combinations of informative
75/// features), repeated features (duplicates), and pure noise features.
76///
77/// # Arguments
78///
79/// * `config` - Classification configuration specifying all parameters
80///
81/// # Returns
82///
83/// A `Dataset` with n_samples rows and n_features columns, plus target labels
84///
85/// # Examples
86///
87/// ```rust
88/// use scirs2_datasets::generators::classification::{make_classification_enhanced, ClassificationConfig};
89///
90/// let config = ClassificationConfig {
91///     n_samples: 200,
92///     n_features: 20,
93///     n_informative: 5,
94///     n_redundant: 3,
95///     n_repeated: 2,
96///     n_classes: 3,
97///     random_state: Some(42),
98///     ..Default::default()
99/// };
100/// let ds = make_classification_enhanced(config).expect("should succeed");
101/// assert_eq!(ds.n_samples(), 200);
102/// assert_eq!(ds.n_features(), 20);
103/// ```
104pub fn make_classification_enhanced(config: ClassificationConfig) -> Result<Dataset> {
105    // Validate parameters
106    if config.n_samples == 0 {
107        return Err(DatasetsError::InvalidFormat(
108            "n_samples must be > 0".to_string(),
109        ));
110    }
111    if config.n_features == 0 {
112        return Err(DatasetsError::InvalidFormat(
113            "n_features must be > 0".to_string(),
114        ));
115    }
116    if config.n_informative == 0 {
117        return Err(DatasetsError::InvalidFormat(
118            "n_informative must be > 0".to_string(),
119        ));
120    }
121    if config.n_classes < 2 {
122        return Err(DatasetsError::InvalidFormat(
123            "n_classes must be >= 2".to_string(),
124        ));
125    }
126    if config.n_clusters_per_class == 0 {
127        return Err(DatasetsError::InvalidFormat(
128            "n_clusters_per_class must be > 0".to_string(),
129        ));
130    }
131    let total_useful = config.n_informative + config.n_redundant + config.n_repeated;
132    if total_useful > config.n_features {
133        return Err(DatasetsError::InvalidFormat(format!(
134            "n_informative ({}) + n_redundant ({}) + n_repeated ({}) = {} must be <= n_features ({})",
135            config.n_informative,
136            config.n_redundant,
137            config.n_repeated,
138            total_useful,
139            config.n_features
140        )));
141    }
142    if config.n_informative < config.n_classes {
143        return Err(DatasetsError::InvalidFormat(format!(
144            "n_informative ({}) must be >= n_classes ({})",
145            config.n_informative, config.n_classes
146        )));
147    }
148    if config.flip_y < 0.0 || config.flip_y > 1.0 {
149        return Err(DatasetsError::InvalidFormat(
150            "flip_y must be in [0, 1]".to_string(),
151        ));
152    }
153
154    let mut rng = create_rng(config.random_state);
155
156    let normal = scirs2_core::random::Normal::new(0.0, 1.0).map_err(|e| {
157        DatasetsError::ComputationError(format!("Failed to create normal dist: {e}"))
158    })?;
159
160    let n_noise = config.n_features - config.n_informative - config.n_redundant - config.n_repeated;
161
162    // Step 1: Generate informative features by creating centroids per cluster
163    let n_centroids = config.n_classes * config.n_clusters_per_class;
164    let mut centroids = Array2::zeros((n_centroids, config.n_informative));
165
166    for i in 0..n_centroids {
167        for j in 0..config.n_informative {
168            centroids[[i, j]] = config.class_sep * (2.0 * rng.random::<f64>() - 1.0);
169        }
170    }
171
172    // Step 2: Generate samples around centroids
173    let mut informative = Array2::zeros((config.n_samples, config.n_informative));
174    let mut target = Array1::zeros(config.n_samples);
175
176    let samples_per_class = config.n_samples / config.n_classes;
177    let remainder = config.n_samples % config.n_classes;
178    let mut idx = 0;
179
180    for class_idx in 0..config.n_classes {
181        let n_samples_class = if class_idx < remainder {
182            samples_per_class + 1
183        } else {
184            samples_per_class
185        };
186        let spc = n_samples_class / config.n_clusters_per_class;
187        let spc_rem = n_samples_class % config.n_clusters_per_class;
188
189        for cluster_idx in 0..config.n_clusters_per_class {
190            let n_cluster = if cluster_idx < spc_rem { spc + 1 } else { spc };
191            let centroid_idx = class_idx * config.n_clusters_per_class + cluster_idx;
192
193            for _ in 0..n_cluster {
194                for j in 0..config.n_informative {
195                    informative[[idx, j]] =
196                        centroids[[centroid_idx, j]] + 0.5 * normal.sample(&mut rng);
197                }
198                target[idx] = class_idx as f64;
199                idx += 1;
200            }
201        }
202    }
203
204    // Step 3: Generate redundant features as linear combinations of informative features
205    let mut redundant = Array2::zeros((config.n_samples, config.n_redundant));
206    if config.n_redundant > 0 {
207        // Create a random mixing matrix
208        let mut mixing = Array2::zeros((config.n_informative, config.n_redundant));
209        for i in 0..config.n_informative {
210            for j in 0..config.n_redundant {
211                mixing[[i, j]] = normal.sample(&mut rng);
212            }
213        }
214        // redundant = informative @ mixing
215        for i in 0..config.n_samples {
216            for j in 0..config.n_redundant {
217                let mut val = 0.0;
218                for k in 0..config.n_informative {
219                    val += informative[[i, k]] * mixing[[k, j]];
220                }
221                redundant[[i, j]] = val;
222            }
223        }
224    }
225
226    // Step 4: Generate repeated features (copies of informative + redundant)
227    let mut repeated = Array2::zeros((config.n_samples, config.n_repeated));
228    if config.n_repeated > 0 {
229        let source_cols = config.n_informative + config.n_redundant;
230        for j in 0..config.n_repeated {
231            let src_j = j % source_cols;
232            for i in 0..config.n_samples {
233                if src_j < config.n_informative {
234                    repeated[[i, j]] = informative[[i, src_j]];
235                } else {
236                    repeated[[i, j]] = redundant[[i, src_j - config.n_informative]];
237                }
238            }
239        }
240    }
241
242    // Step 5: Generate noise features
243    let mut noise_features = Array2::zeros((config.n_samples, n_noise));
244    for i in 0..config.n_samples {
245        for j in 0..n_noise {
246            noise_features[[i, j]] = normal.sample(&mut rng);
247        }
248    }
249
250    // Step 6: Assemble the full feature matrix
251    let mut data = Array2::zeros((config.n_samples, config.n_features));
252    for i in 0..config.n_samples {
253        let mut col = 0;
254        for j in 0..config.n_informative {
255            data[[i, col]] = informative[[i, j]];
256            col += 1;
257        }
258        for j in 0..config.n_redundant {
259            data[[i, col]] = redundant[[i, j]];
260            col += 1;
261        }
262        for j in 0..config.n_repeated {
263            data[[i, col]] = repeated[[i, j]];
264            col += 1;
265        }
266        for j in 0..n_noise {
267            data[[i, col]] = noise_features[[i, j]];
268            col += 1;
269        }
270    }
271
272    // Step 7: Flip labels with probability flip_y
273    if config.flip_y > 0.0 {
274        let uniform = scirs2_core::random::Uniform::new(0.0, 1.0).map_err(|e| {
275            DatasetsError::ComputationError(format!("Failed to create uniform dist: {e}"))
276        })?;
277        for i in 0..config.n_samples {
278            if uniform.sample(&mut rng) < config.flip_y {
279                // Assign a random different class
280                let current = target[i] as usize;
281                let mut new_class = rng.random_range(0..config.n_classes);
282                while new_class == current && config.n_classes > 1 {
283                    new_class = rng.random_range(0..config.n_classes);
284                }
285                target[i] = new_class as f64;
286            }
287        }
288    }
289
290    // Step 8: Shuffle if requested
291    if config.shuffle {
292        let n = config.n_samples;
293        // Fisher-Yates shuffle
294        for i in (1..n).rev() {
295            let j = rng.random_range(0..=i);
296            if i != j {
297                // Swap rows in data
298                for col in 0..config.n_features {
299                    let tmp = data[[i, col]];
300                    data[[i, col]] = data[[j, col]];
301                    data[[j, col]] = tmp;
302                }
303                // Swap targets
304                let tmp = target[i];
305                target[i] = target[j];
306                target[j] = tmp;
307            }
308        }
309    }
310
311    // Build feature names
312    let mut feature_names = Vec::with_capacity(config.n_features);
313    for j in 0..config.n_informative {
314        feature_names.push(format!("informative_{j}"));
315    }
316    for j in 0..config.n_redundant {
317        feature_names.push(format!("redundant_{j}"));
318    }
319    for j in 0..config.n_repeated {
320        feature_names.push(format!("repeated_{j}"));
321    }
322    for j in 0..n_noise {
323        feature_names.push(format!("noise_{j}"));
324    }
325
326    let class_names: Vec<String> = (0..config.n_classes)
327        .map(|i| format!("class_{i}"))
328        .collect();
329
330    let dataset = Dataset::new(data, Some(target))
331        .with_featurenames(feature_names)
332        .with_targetnames(class_names)
333        .with_description(format!(
334            "Enhanced classification dataset: {} samples, {} features ({} informative, {} redundant, {} repeated, {} noise), {} classes",
335            config.n_samples, config.n_features, config.n_informative,
336            config.n_redundant, config.n_repeated, n_noise, config.n_classes
337        ))
338        .with_metadata("n_informative", &config.n_informative.to_string())
339        .with_metadata("n_redundant", &config.n_redundant.to_string())
340        .with_metadata("n_repeated", &config.n_repeated.to_string())
341        .with_metadata("n_noise", &n_noise.to_string())
342        .with_metadata("class_sep", &config.class_sep.to_string())
343        .with_metadata("flip_y", &config.flip_y.to_string());
344
345    Ok(dataset)
346}
347
348/// Configuration for multi-label classification generator
349#[derive(Debug, Clone)]
350pub struct MultilabelConfig {
351    /// Number of samples
352    pub n_samples: usize,
353    /// Number of features
354    pub n_features: usize,
355    /// Number of classes (labels)
356    pub n_classes: usize,
357    /// Number of labels per sample on average
358    pub n_labels: usize,
359    /// Whether to allow return_indicator format (target as matrix)
360    pub allow_unlabeled: bool,
361    /// Optional random seed
362    pub random_state: Option<u64>,
363}
364
365impl Default for MultilabelConfig {
366    fn default() -> Self {
367        Self {
368            n_samples: 100,
369            n_features: 20,
370            n_classes: 5,
371            n_labels: 2,
372            allow_unlabeled: true,
373            random_state: None,
374        }
375    }
376}
377
378/// Result type for multi-label classification datasets
379///
380/// Multi-label datasets have a target matrix instead of a target vector,
381/// where each column represents a binary label.
382#[derive(Debug, Clone)]
383pub struct MultilabelDataset {
384    /// Feature matrix (n_samples x n_features)
385    pub data: Array2<f64>,
386    /// Target indicator matrix (n_samples x n_classes), binary entries
387    pub target: Array2<f64>,
388    /// Feature names
389    pub feature_names: Vec<String>,
390    /// Class/label names
391    pub class_names: Vec<String>,
392    /// Description
393    pub description: String,
394}
395
396/// Generate a random multi-label classification problem
397///
398/// Each sample can belong to multiple classes simultaneously. The target is
399/// an indicator matrix where `target[i,j] = 1` if sample i has label j.
400///
401/// The generation process:
402/// 1. Create class centers in feature space
403/// 2. For each sample, generate features near one or more class centers
404/// 3. Assign labels based on proximity to class centers
405///
406/// # Arguments
407///
408/// * `config` - Multi-label configuration
409///
410/// # Returns
411///
412/// A `MultilabelDataset` with feature matrix and binary indicator target matrix
413///
414/// # Examples
415///
416/// ```rust
417/// use scirs2_datasets::generators::classification::{make_multilabel_classification, MultilabelConfig};
418///
419/// let config = MultilabelConfig {
420///     n_samples: 100,
421///     n_features: 10,
422///     n_classes: 4,
423///     n_labels: 2,
424///     random_state: Some(42),
425///     ..Default::default()
426/// };
427/// let ds = make_multilabel_classification(config).expect("should succeed");
428/// assert_eq!(ds.data.nrows(), 100);
429/// assert_eq!(ds.data.ncols(), 10);
430/// assert_eq!(ds.target.ncols(), 4);
431/// ```
432pub fn make_multilabel_classification(config: MultilabelConfig) -> Result<MultilabelDataset> {
433    if config.n_samples == 0 {
434        return Err(DatasetsError::InvalidFormat(
435            "n_samples must be > 0".to_string(),
436        ));
437    }
438    if config.n_features == 0 {
439        return Err(DatasetsError::InvalidFormat(
440            "n_features must be > 0".to_string(),
441        ));
442    }
443    if config.n_classes == 0 {
444        return Err(DatasetsError::InvalidFormat(
445            "n_classes must be > 0".to_string(),
446        ));
447    }
448    if config.n_labels == 0 || config.n_labels > config.n_classes {
449        return Err(DatasetsError::InvalidFormat(format!(
450            "n_labels ({}) must be in [1, n_classes ({})]",
451            config.n_labels, config.n_classes
452        )));
453    }
454
455    let mut rng = create_rng(config.random_state);
456
457    let normal = scirs2_core::random::Normal::new(0.0, 1.0).map_err(|e| {
458        DatasetsError::ComputationError(format!("Failed to create normal dist: {e}"))
459    })?;
460
461    // Generate class centers
462    let mut centers = Array2::zeros((config.n_classes, config.n_features));
463    for i in 0..config.n_classes {
464        for j in 0..config.n_features {
465            centers[[i, j]] = 3.0 * normal.sample(&mut rng);
466        }
467    }
468
469    // Generate samples and assign multiple labels
470    let mut data = Array2::zeros((config.n_samples, config.n_features));
471    let mut target_matrix = Array2::zeros((config.n_samples, config.n_classes));
472
473    for i in 0..config.n_samples {
474        // Select n_labels random classes for this sample
475        let mut labels: Vec<usize> = Vec::with_capacity(config.n_labels);
476        while labels.len() < config.n_labels {
477            let candidate = rng.random_range(0..config.n_classes);
478            if !labels.contains(&candidate) {
479                labels.push(candidate);
480            }
481        }
482
483        // If !allow_unlabeled, ensure at least one label
484        if !config.allow_unlabeled && labels.is_empty() {
485            labels.push(rng.random_range(0..config.n_classes));
486        }
487
488        // Generate features as a mixture of the selected class centers
489        for j in 0..config.n_features {
490            let mut val = 0.0;
491            for &label in &labels {
492                val += centers[[label, j]];
493            }
494            val /= labels.len() as f64;
495            val += normal.sample(&mut rng); // Add noise
496            data[[i, j]] = val;
497        }
498
499        // Set target indicators
500        for &label in &labels {
501            target_matrix[[i, label]] = 1.0;
502        }
503    }
504
505    let feature_names: Vec<String> = (0..config.n_features)
506        .map(|j| format!("feature_{j}"))
507        .collect();
508    let class_names: Vec<String> = (0..config.n_classes)
509        .map(|j| format!("label_{j}"))
510        .collect();
511
512    Ok(MultilabelDataset {
513        data,
514        target: target_matrix,
515        feature_names,
516        class_names,
517        description: format!(
518            "Multi-label classification dataset: {} samples, {} features, {} classes, ~{} labels per sample",
519            config.n_samples, config.n_features, config.n_classes, config.n_labels
520        ),
521    })
522}
523
524/// Generate the Hastie et al. 10-dimensional binary classification dataset
525///
526/// Generates data from the 10-dimensional standard normal distribution.
527/// The target is defined as:
528///   y = 1 if sum(x_i^2) > chi-squared median (9.34), else -1
529///
530/// This is the dataset used in:
531/// Hastie, T., Tibshirani, R., Friedman, J. (2009).
532/// The Elements of Statistical Learning, 2nd Edition, Example 10.2.
533///
534/// # Arguments
535///
536/// * `n_samples` - Number of samples (default 12000 in sklearn, split 2000/10000 train/test)
537/// * `random_state` - Optional random seed
538///
539/// # Returns
540///
541/// A `Dataset` with 10 features and binary target {-1, 1}
542///
543/// # Examples
544///
545/// ```rust
546/// use scirs2_datasets::generators::classification::make_hastie_10_2;
547///
548/// let ds = make_hastie_10_2(12000, Some(42)).expect("should succeed");
549/// assert_eq!(ds.n_samples(), 12000);
550/// assert_eq!(ds.n_features(), 10);
551/// ```
552pub fn make_hastie_10_2(n_samples: usize, random_state: Option<u64>) -> Result<Dataset> {
553    if n_samples == 0 {
554        return Err(DatasetsError::InvalidFormat(
555            "n_samples must be > 0".to_string(),
556        ));
557    }
558
559    let mut rng = create_rng(random_state);
560
561    let normal = scirs2_core::random::Normal::new(0.0, 1.0).map_err(|e| {
562        DatasetsError::ComputationError(format!("Failed to create normal dist: {e}"))
563    })?;
564
565    let n_features = 10;
566    // Chi-squared(10) median is approximately 9.3418
567    let chi2_median = 9.3418;
568
569    let mut data = Array2::zeros((n_samples, n_features));
570    let mut target = Array1::zeros(n_samples);
571
572    for i in 0..n_samples {
573        let mut sum_sq = 0.0;
574        for j in 0..n_features {
575            let val = normal.sample(&mut rng);
576            data[[i, j]] = val;
577            sum_sq += val * val;
578        }
579
580        target[i] = if sum_sq > chi2_median { 1.0 } else { -1.0 };
581    }
582
583    let feature_names: Vec<String> = (0..n_features).map(|j| format!("x_{j}")).collect();
584
585    let dataset = Dataset::new(data, Some(target))
586        .with_featurenames(feature_names)
587        .with_targetnames(vec!["-1".to_string(), "1".to_string()])
588        .with_description(
589            "Hastie et al. 10.2 binary classification dataset. \
590             Features are standard normal; y=1 if sum(x_i^2) > 9.34 (chi2(10) median), else y=-1. \
591             Reference: Hastie, Tibshirani, Friedman (2009) The Elements of Statistical Learning."
592                .to_string(),
593        )
594        .with_metadata("chi2_median_threshold", &chi2_median.to_string())
595        .with_metadata("n_features", &n_features.to_string());
596
597    Ok(dataset)
598}
599
600#[cfg(test)]
601mod tests {
602    use super::*;
603
604    // =========================================================================
605    // make_classification_enhanced tests
606    // =========================================================================
607
608    #[test]
609    fn test_classification_enhanced_basic() {
610        let config = ClassificationConfig {
611            n_samples: 200,
612            n_features: 20,
613            n_informative: 5,
614            n_redundant: 3,
615            n_repeated: 2,
616            n_classes: 3,
617            random_state: Some(42),
618            ..Default::default()
619        };
620        let ds = make_classification_enhanced(config).expect("should succeed");
621        assert_eq!(ds.n_samples(), 200);
622        assert_eq!(ds.n_features(), 20);
623        assert!(ds.target.is_some());
624        let target = ds.target.as_ref().expect("target present");
625        assert_eq!(target.len(), 200);
626        // All labels should be in [0, 3)
627        for &val in target.iter() {
628            assert!((0.0..3.0).contains(&val), "Invalid class label: {val}");
629        }
630    }
631
632    #[test]
633    fn test_classification_enhanced_feature_names() {
634        let config = ClassificationConfig {
635            n_samples: 50,
636            n_features: 10,
637            n_informative: 3,
638            n_redundant: 2,
639            n_repeated: 1,
640            n_classes: 2,
641            random_state: Some(42),
642            ..Default::default()
643        };
644        let ds = make_classification_enhanced(config).expect("should succeed");
645        let names = ds.featurenames.as_ref().expect("names present");
646        assert_eq!(names.len(), 10);
647        assert!(names[0].starts_with("informative_"));
648        assert!(names[3].starts_with("redundant_"));
649        assert!(names[5].starts_with("repeated_"));
650        assert!(names[6].starts_with("noise_"));
651    }
652
653    #[test]
654    fn test_classification_enhanced_reproducibility() {
655        let make = || {
656            let config = ClassificationConfig {
657                n_samples: 50,
658                n_features: 10,
659                n_informative: 3,
660                n_redundant: 2,
661                n_repeated: 0,
662                n_classes: 2,
663                flip_y: 0.0,
664                shuffle: false,
665                random_state: Some(123),
666                ..Default::default()
667            };
668            make_classification_enhanced(config).expect("should succeed")
669        };
670        let ds1 = make();
671        let ds2 = make();
672        for i in 0..50 {
673            for j in 0..10 {
674                assert!(
675                    (ds1.data[[i, j]] - ds2.data[[i, j]]).abs() < 1e-15,
676                    "Reproducibility failed at ({i},{j})"
677                );
678            }
679        }
680    }
681
682    #[test]
683    fn test_classification_enhanced_validation() {
684        // n_samples = 0
685        let cfg = ClassificationConfig {
686            n_samples: 0,
687            ..Default::default()
688        };
689        assert!(make_classification_enhanced(cfg).is_err());
690
691        // n_informative > n_features
692        let cfg = ClassificationConfig {
693            n_features: 5,
694            n_informative: 3,
695            n_redundant: 2,
696            n_repeated: 2,
697            ..Default::default()
698        };
699        assert!(make_classification_enhanced(cfg).is_err());
700
701        // n_classes > n_informative
702        let cfg = ClassificationConfig {
703            n_informative: 2,
704            n_classes: 5,
705            ..Default::default()
706        };
707        assert!(make_classification_enhanced(cfg).is_err());
708    }
709
710    #[test]
711    fn test_classification_enhanced_redundant_correlation() {
712        // Redundant features should be correlated with informative
713        let config = ClassificationConfig {
714            n_samples: 500,
715            n_features: 10,
716            n_informative: 5,
717            n_redundant: 3,
718            n_repeated: 0,
719            n_classes: 2,
720            flip_y: 0.0,
721            shuffle: false,
722            random_state: Some(42),
723            ..Default::default()
724        };
725        let ds = make_classification_enhanced(config).expect("should succeed");
726
727        // Compute variance of redundant feature (col 5)
728        let col5: Vec<f64> = (0..500).map(|i| ds.data[[i, 5]]).collect();
729        let mean5: f64 = col5.iter().sum::<f64>() / 500.0;
730        let var5: f64 = col5.iter().map(|x| (x - mean5).powi(2)).sum::<f64>() / 499.0;
731        // Redundant features should have non-trivial variance (not just noise)
732        assert!(var5 > 0.01, "Redundant feature variance too low: {var5}");
733    }
734
735    #[test]
736    fn test_classification_enhanced_flip_y() {
737        // With flip_y = 1.0, all labels should be flipped randomly
738        let config = ClassificationConfig {
739            n_samples: 1000,
740            n_features: 5,
741            n_informative: 3,
742            n_redundant: 0,
743            n_repeated: 0,
744            n_classes: 2,
745            flip_y: 0.0,
746            shuffle: false,
747            random_state: Some(42),
748            ..Default::default()
749        };
750        let ds_no_flip = make_classification_enhanced(config).expect("should succeed");
751
752        let config_flip = ClassificationConfig {
753            n_samples: 1000,
754            n_features: 5,
755            n_informative: 3,
756            n_redundant: 0,
757            n_repeated: 0,
758            n_classes: 2,
759            flip_y: 0.5,
760            shuffle: false,
761            random_state: Some(42),
762            ..Default::default()
763        };
764        let ds_flip = make_classification_enhanced(config_flip).expect("should succeed");
765
766        // With 50% flip rate, some labels should differ
767        let n_different = (0..1000)
768            .filter(|&i| {
769                let t1 = ds_no_flip.target.as_ref().expect("target")[i];
770                let t2 = ds_flip.target.as_ref().expect("target")[i];
771                (t1 - t2).abs() > 0.5
772            })
773            .count();
774        // The no-flip targets are the SAME RNG state initially, but flip_y draws random
775        // numbers differently, so we just check the flipped version differs
776        // from a version with no flipping
777        assert!(
778            n_different > 0,
779            "Expected some labels to differ with flip_y=0.5"
780        );
781    }
782
783    // =========================================================================
784    // make_multilabel_classification tests
785    // =========================================================================
786
787    #[test]
788    fn test_multilabel_basic() {
789        let config = MultilabelConfig {
790            n_samples: 100,
791            n_features: 10,
792            n_classes: 5,
793            n_labels: 2,
794            random_state: Some(42),
795            ..Default::default()
796        };
797        let ds = make_multilabel_classification(config).expect("should succeed");
798        assert_eq!(ds.data.nrows(), 100);
799        assert_eq!(ds.data.ncols(), 10);
800        assert_eq!(ds.target.nrows(), 100);
801        assert_eq!(ds.target.ncols(), 5);
802    }
803
804    #[test]
805    fn test_multilabel_binary_targets() {
806        let config = MultilabelConfig {
807            n_samples: 50,
808            n_features: 5,
809            n_classes: 3,
810            n_labels: 2,
811            random_state: Some(42),
812            ..Default::default()
813        };
814        let ds = make_multilabel_classification(config).expect("should succeed");
815        // All target entries should be 0 or 1
816        for i in 0..50 {
817            for j in 0..3 {
818                let val = ds.target[[i, j]];
819                assert!(
820                    val == 0.0 || val == 1.0,
821                    "Target entry at ({i},{j}) should be binary, got {val}"
822                );
823            }
824        }
825    }
826
827    #[test]
828    fn test_multilabel_labels_per_sample() {
829        let config = MultilabelConfig {
830            n_samples: 200,
831            n_features: 5,
832            n_classes: 6,
833            n_labels: 3,
834            random_state: Some(42),
835            ..Default::default()
836        };
837        let ds = make_multilabel_classification(config).expect("should succeed");
838        // Each sample should have exactly n_labels = 3 labels
839        for i in 0..200 {
840            let label_count: f64 = (0..6).map(|j| ds.target[[i, j]]).sum();
841            assert_eq!(
842                label_count, 3.0,
843                "Sample {i} should have 3 labels, got {label_count}"
844            );
845        }
846    }
847
848    #[test]
849    fn test_multilabel_reproducibility() {
850        let make = || {
851            let config = MultilabelConfig {
852                n_samples: 30,
853                n_features: 5,
854                n_classes: 3,
855                n_labels: 1,
856                random_state: Some(77),
857                ..Default::default()
858            };
859            make_multilabel_classification(config).expect("should succeed")
860        };
861        let ds1 = make();
862        let ds2 = make();
863        for i in 0..30 {
864            for j in 0..5 {
865                assert!(
866                    (ds1.data[[i, j]] - ds2.data[[i, j]]).abs() < 1e-15,
867                    "Reproducibility failed at ({i},{j})"
868                );
869            }
870        }
871    }
872
873    #[test]
874    fn test_multilabel_validation() {
875        let cfg = MultilabelConfig {
876            n_samples: 0,
877            ..Default::default()
878        };
879        assert!(make_multilabel_classification(cfg).is_err());
880
881        let cfg = MultilabelConfig {
882            n_labels: 0,
883            ..Default::default()
884        };
885        assert!(make_multilabel_classification(cfg).is_err());
886
887        let cfg = MultilabelConfig {
888            n_labels: 10,
889            n_classes: 3,
890            ..Default::default()
891        };
892        assert!(make_multilabel_classification(cfg).is_err());
893    }
894
895    // =========================================================================
896    // make_hastie_10_2 tests
897    // =========================================================================
898
899    #[test]
900    fn test_hastie_basic() {
901        let ds = make_hastie_10_2(1000, Some(42)).expect("should succeed");
902        assert_eq!(ds.n_samples(), 1000);
903        assert_eq!(ds.n_features(), 10);
904        assert!(ds.target.is_some());
905    }
906
907    #[test]
908    fn test_hastie_binary_labels() {
909        let ds = make_hastie_10_2(500, Some(42)).expect("should succeed");
910        let target = ds.target.as_ref().expect("target present");
911        for &val in target.iter() {
912            assert!(
913                val == -1.0 || val == 1.0,
914                "Hastie labels should be -1 or 1, got {val}"
915            );
916        }
917    }
918
919    #[test]
920    fn test_hastie_balanced_classes() {
921        // With enough samples, classes should be roughly balanced
922        let ds = make_hastie_10_2(10000, Some(42)).expect("should succeed");
923        let target = ds.target.as_ref().expect("target present");
924        let n_positive = target.iter().filter(|&&v| v > 0.0).count();
925        let n_negative = target.len() - n_positive;
926        // Chi-squared(10) median divides the distribution roughly in half
927        let ratio = n_positive as f64 / n_negative as f64;
928        assert!(
929            ratio > 0.7 && ratio < 1.4,
930            "Classes should be roughly balanced, got ratio {ratio} (pos={n_positive}, neg={n_negative})"
931        );
932    }
933
934    #[test]
935    fn test_hastie_feature_stats() {
936        // Features should be standard normal
937        let ds = make_hastie_10_2(5000, Some(42)).expect("should succeed");
938        for j in 0..10 {
939            let col: Vec<f64> = (0..5000).map(|i| ds.data[[i, j]]).collect();
940            let mean: f64 = col.iter().sum::<f64>() / 5000.0;
941            let var: f64 = col.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / 4999.0;
942            assert!(
943                mean.abs() < 0.1,
944                "Feature {j} mean should be ~0, got {mean}"
945            );
946            assert!(
947                (var - 1.0).abs() < 0.15,
948                "Feature {j} variance should be ~1, got {var}"
949            );
950        }
951    }
952
953    #[test]
954    fn test_hastie_reproducibility() {
955        let ds1 = make_hastie_10_2(100, Some(99)).expect("should succeed");
956        let ds2 = make_hastie_10_2(100, Some(99)).expect("should succeed");
957        for i in 0..100 {
958            for j in 0..10 {
959                assert!(
960                    (ds1.data[[i, j]] - ds2.data[[i, j]]).abs() < 1e-15,
961                    "Reproducibility failed at ({i},{j})"
962                );
963            }
964        }
965    }
966
967    #[test]
968    fn test_hastie_validation() {
969        assert!(make_hastie_10_2(0, None).is_err());
970    }
971
972    #[test]
973    fn test_hastie_description() {
974        let ds = make_hastie_10_2(100, Some(42)).expect("should succeed");
975        assert!(ds.description.is_some());
976        let desc = ds.description.as_ref().expect("desc present");
977        assert!(desc.contains("Hastie"));
978    }
979}