Skip to main content

torsh_data/
builtin.rs

1//! Built-in datasets powered by SciRS2
2//!
3//! This module provides access to toy datasets, synthetic data generators,
4//! and other built-in data sources from the SciRS2 ecosystem.
5
6use crate::error::DataError;
7use scirs2_core::Distribution; // For sample() method on distributions
8use torsh_tensor::Tensor;
9
10// ✅ Direct SciRS2 datasets integration - using stable toy datasets API
11use scirs2_datasets::toy::{
12    load_boston as scirs2_load_boston, load_breast_cancer as scirs2_load_breast_cancer,
13    load_diabetes as scirs2_load_diabetes, load_digits as scirs2_load_digits,
14    load_iris as scirs2_load_iris,
15};
16
17/// Built-in dataset types
18#[derive(Debug, Clone)]
19pub enum BuiltinDataset {
20    Iris,
21    Boston,
22    Diabetes,
23    Wine,
24    BreastCancer,
25    Digits,
26}
27
28/// Synthetic data generation configuration
29#[derive(Debug, Clone)]
30pub struct SyntheticDataConfig {
31    /// Number of samples to generate
32    pub n_samples: usize,
33    /// Number of features
34    pub n_features: usize,
35    /// Number of classes (for classification)
36    pub n_classes: Option<usize>,
37    /// Random seed for reproducibility
38    pub seed: Option<u64>,
39    /// Whether to add noise
40    pub noise: Option<f64>,
41    /// Feature scaling method
42    pub scale: Option<ScalingMethod>,
43}
44
45/// Feature scaling methods
46#[derive(Debug, Clone)]
47pub enum ScalingMethod {
48    StandardScaler,
49    MinMaxScaler,
50    RobustScaler,
51    Normalizer,
52}
53
54/// Regression data generation parameters
55#[derive(Debug, Clone)]
56pub struct RegressionConfig {
57    pub n_samples: usize,
58    pub n_features: usize,
59    pub n_informative: Option<usize>,
60    pub noise: Option<f64>,
61    pub bias: Option<f64>,
62    pub random_state: Option<u64>,
63}
64
65/// Classification data generation parameters
66#[derive(Debug, Clone)]
67pub struct ClassificationConfig {
68    pub n_samples: usize,
69    pub n_features: usize,
70    pub n_classes: usize,
71    pub n_informative: Option<usize>,
72    pub n_redundant: Option<usize>,
73    pub n_clusters_per_class: Option<usize>,
74    pub class_sep: Option<f64>,
75    pub random_state: Option<u64>,
76}
77
78/// Clustering data generation parameters
79#[derive(Debug, Clone)]
80pub struct ClusteringConfig {
81    pub n_samples: usize,
82    pub centers: usize,
83    pub n_features: Option<usize>,
84    pub cluster_std: Option<f64>,
85    pub center_box: Option<(f64, f64)>,
86    pub random_state: Option<u64>,
87}
88
89/// Dataset result containing features and targets
90#[derive(Debug, Clone)]
91pub struct DatasetResult {
92    pub features: Tensor,
93    pub targets: Tensor,
94    pub feature_names: Option<Vec<String>>,
95    pub target_names: Option<Vec<String>>,
96    pub description: String,
97}
98
99impl Default for SyntheticDataConfig {
100    fn default() -> Self {
101        Self {
102            n_samples: 100,
103            n_features: 2,
104            n_classes: Some(2),
105            seed: None,
106            noise: Some(0.1),
107            scale: Some(ScalingMethod::StandardScaler),
108        }
109    }
110}
111
112/// Load a built-in dataset
113pub fn load_builtin_dataset(dataset: BuiltinDataset) -> Result<DatasetResult, DataError> {
114    match dataset {
115        BuiltinDataset::Iris => load_iris_dataset(),
116        BuiltinDataset::Boston => load_boston_dataset(),
117        BuiltinDataset::Diabetes => load_diabetes_dataset(),
118        BuiltinDataset::Wine => load_wine_dataset(),
119        BuiltinDataset::BreastCancer => load_breast_cancer_dataset(),
120        BuiltinDataset::Digits => load_digits_dataset(),
121    }
122}
123
124/// Generate synthetic regression data
125///
126/// Creates a regression problem with specified characteristics. The targets are generated as:
127/// y = X @ coef + bias + noise
128///
129/// where:
130/// - X contains n_informative features that actually contribute to y
131/// - The remaining (n_features - n_informative) features are random noise
132/// - coef are random coefficients for informative features
133/// - noise is Gaussian noise with standard deviation specified by `noise` parameter
134pub fn make_regression(config: RegressionConfig) -> Result<DatasetResult, DataError> {
135    use scirs2_core::random::{Normal, SeedableRng, StdRng};
136
137    let n_informative = config.n_informative.unwrap_or(config.n_features);
138    let noise_std = config.noise.unwrap_or(0.0);
139    let bias = config.bias.unwrap_or(0.0);
140
141    if n_informative > config.n_features {
142        return Err(DataError::dataset(
143            crate::error::DatasetErrorKind::CorruptedData,
144            format!(
145                "n_informative ({}) cannot exceed n_features ({})",
146                n_informative, config.n_features
147            ),
148        ));
149    }
150
151    // ✅ SciRS2 Policy Compliant - Using scirs2_core::random
152    let mut rng = if let Some(seed) = config.random_state {
153        StdRng::seed_from_u64(seed)
154    } else {
155        let mut thread_rng = scirs2_core::random::thread_rng();
156        StdRng::from_rng(&mut thread_rng)
157    };
158
159    // Generate features from standard normal distribution
160    let normal = Normal::new(0.0, 1.0).expect("valid Normal parameters");
161    let features_data: Vec<f32> = (0..config.n_samples * config.n_features)
162        .map(|_| normal.sample(&mut rng) as f32)
163        .collect();
164
165    let features = Tensor::from_vec(
166        features_data.clone(),
167        &[config.n_samples, config.n_features],
168    )?;
169
170    // Generate random coefficients for informative features
171    let coefficients: Vec<f32> = (0..n_informative)
172        .map(|_| rng.gen_range(-100.0..100.0))
173        .collect();
174
175    // Generate targets as linear combination of informative features
176    let noise_dist = Normal::new(0.0, noise_std).expect("valid Normal parameters");
177    let targets_data: Vec<f32> = (0..config.n_samples)
178        .map(|i| {
179            // Compute linear combination: sum(coef_j * x_ij) for informative features
180            let mut target = bias as f32;
181            for j in 0..n_informative {
182                let idx = i * config.n_features + j;
183                target += coefficients[j] * features_data[idx];
184            }
185
186            // Add Gaussian noise
187            if noise_std > 0.0 {
188                target += noise_dist.sample(&mut rng) as f32;
189            }
190
191            target
192        })
193        .collect();
194
195    let targets = Tensor::from_vec(targets_data, &[config.n_samples])?;
196
197    Ok(DatasetResult {
198        features,
199        targets,
200        feature_names: Some(
201            (0..config.n_features)
202                .map(|i| {
203                    if i < n_informative {
204                        format!("informative_{}", i)
205                    } else {
206                        format!("noise_{}", i - n_informative)
207                    }
208                })
209                .collect(),
210        ),
211        target_names: Some(vec!["target".to_string()]),
212        description: format!(
213            "Synthetic regression dataset: {} samples, {} features ({} informative), noise_std={:.2}, bias={:.2}",
214            config.n_samples, config.n_features, n_informative, noise_std, bias
215        ),
216    })
217}
218
219/// Generate synthetic classification data
220///
221/// Creates a classification problem with specified characteristics. Features are generated
222/// by creating Gaussian clusters for each class, with controllable separation.
223///
224/// - `n_informative`: Number of features that are informative for classification
225/// - `n_redundant`: Number of features that are linear combinations of informative features
226/// - `n_clusters_per_class`: Number of Gaussian clusters per class
227/// - `class_sep`: Multiplier for class separation (larger = more separated classes)
228pub fn make_classification(config: ClassificationConfig) -> Result<DatasetResult, DataError> {
229    use scirs2_core::random::{Normal, SeedableRng, StdRng};
230
231    let n_informative = config.n_informative.unwrap_or(config.n_features.min(2));
232    let n_redundant = config.n_redundant.unwrap_or(0);
233    let n_clusters_per_class = config.n_clusters_per_class.unwrap_or(1);
234    let class_sep = config.class_sep.unwrap_or(1.0);
235
236    if n_informative + n_redundant > config.n_features {
237        return Err(DataError::dataset(
238            crate::error::DatasetErrorKind::CorruptedData,
239            format!(
240                "n_informative ({}) + n_redundant ({}) cannot exceed n_features ({})",
241                n_informative, n_redundant, config.n_features
242            ),
243        ));
244    }
245
246    // ✅ SciRS2 Policy Compliant - Using scirs2_core::random
247    let mut rng = if let Some(seed) = config.random_state {
248        StdRng::seed_from_u64(seed)
249    } else {
250        let mut thread_rng = scirs2_core::random::thread_rng();
251        StdRng::from_rng(&mut thread_rng)
252    };
253
254    // Generate cluster centers for each class in the informative feature space
255    let total_clusters = config.n_classes * n_clusters_per_class;
256    let mut cluster_centers: Vec<Vec<f32>> = Vec::new();
257    let mut cluster_labels: Vec<usize> = Vec::new();
258
259    for class_id in 0..config.n_classes {
260        for _ in 0..n_clusters_per_class {
261            let center: Vec<f32> = (0..n_informative)
262                .map(|_| rng.gen_range(-class_sep as f32..class_sep as f32) * 10.0)
263                .collect();
264            cluster_centers.push(center);
265            cluster_labels.push(class_id);
266        }
267    }
268
269    // Distribute samples across clusters
270    let samples_per_cluster = config.n_samples / total_clusters;
271    let remainder = config.n_samples % total_clusters;
272
273    let mut features_data = Vec::new();
274    let mut targets_data = Vec::new();
275
276    let normal = Normal::new(0.0, 1.0).expect("valid Normal parameters");
277
278    for (cluster_idx, (center, &class_label)) in cluster_centers
279        .iter()
280        .zip(cluster_labels.iter())
281        .enumerate()
282    {
283        let n_samples_this_cluster =
284            samples_per_cluster + if cluster_idx < remainder { 1 } else { 0 };
285
286        for _ in 0..n_samples_this_cluster {
287            // Generate informative features from cluster center with Gaussian noise
288            for &center_val in center.iter() {
289                let noise = normal.sample(&mut rng) as f32;
290                features_data.push(center_val + noise);
291            }
292
293            // Generate redundant features as linear combinations of informative features
294            let start_idx = features_data.len() - n_informative;
295            for _ in 0..n_redundant {
296                let mut redundant = 0.0f32;
297                for j in 0..n_informative {
298                    let weight = rng.gen_range(-1.0..1.0);
299                    redundant += weight * features_data[start_idx + j];
300                }
301                features_data.push(redundant);
302            }
303
304            // Generate noise features (truly random)
305            let n_noise = config.n_features - n_informative - n_redundant;
306            for _ in 0..n_noise {
307                features_data.push(rng.gen_range(-10.0..10.0));
308            }
309
310            targets_data.push(class_label as f32);
311        }
312    }
313
314    let features = Tensor::from_vec(features_data, &[config.n_samples, config.n_features])?;
315    let targets = Tensor::from_vec(targets_data, &[config.n_samples])?;
316
317    Ok(DatasetResult {
318        features,
319        targets,
320        feature_names: Some(
321            (0..config.n_features)
322                .map(|i| {
323                    if i < n_informative {
324                        format!("informative_{}", i)
325                    } else if i < n_informative + n_redundant {
326                        format!("redundant_{}", i - n_informative)
327                    } else {
328                        format!("noise_{}", i - n_informative - n_redundant)
329                    }
330                })
331                .collect(),
332        ),
333        target_names: Some(
334            (0..config.n_classes)
335                .map(|i| format!("class_{}", i))
336                .collect(),
337        ),
338        description: format!(
339            "Synthetic classification dataset: {} samples, {} features ({} informative, {} redundant), {} classes, class_sep={:.2}",
340            config.n_samples, config.n_features, n_informative, n_redundant, config.n_classes, class_sep
341        ),
342    })
343}
344
345/// Generate synthetic clustering data (blobs)
346///
347/// Creates isotropic Gaussian blobs for clustering. Each blob is a Gaussian distribution
348/// centered at a random location within the bounding box.
349///
350/// - `centers`: Number of cluster centers to generate
351/// - `n_features`: Number of features (dimensions) for each sample
352/// - `cluster_std`: Standard deviation of the Gaussian noise around each cluster center
353/// - `center_box`: Bounding box (min, max) for cluster center locations
354pub fn make_blobs(config: ClusteringConfig) -> Result<DatasetResult, DataError> {
355    use scirs2_core::random::{Normal, SeedableRng, StdRng};
356
357    // ✅ SciRS2 Policy Compliant - Using scirs2_core::random
358    let mut rng = if let Some(seed) = config.random_state {
359        StdRng::seed_from_u64(seed)
360    } else {
361        let mut thread_rng = scirs2_core::random::thread_rng();
362        StdRng::from_rng(&mut thread_rng)
363    };
364
365    let n_features = config.n_features.unwrap_or(2);
366    let cluster_std = config.cluster_std.unwrap_or(1.0);
367    let (box_min, box_max) = config.center_box.unwrap_or((-10.0, 10.0));
368
369    if box_min >= box_max {
370        return Err(DataError::dataset(
371            crate::error::DatasetErrorKind::CorruptedData,
372            format!(
373                "center_box min ({}) must be less than max ({})",
374                box_min, box_max
375            ),
376        ));
377    }
378
379    // Generate cluster centers uniformly within the bounding box
380    let centers: Vec<Vec<f32>> = (0..config.centers)
381        .map(|_| {
382            (0..n_features)
383                .map(|_| rng.gen_range(box_min as f32..box_max as f32))
384                .collect()
385        })
386        .collect();
387
388    // Distribute samples across clusters
389    let samples_per_cluster = config.n_samples / config.centers;
390    let remainder = config.n_samples % config.centers;
391
392    let mut features_data = Vec::new();
393    let mut targets_data = Vec::new();
394
395    // Create Gaussian distribution for sampling around cluster centers
396    let normal = Normal::new(0.0, cluster_std).expect("valid Normal parameters");
397
398    for (cluster_id, center) in centers.iter().enumerate() {
399        let n_samples_this_cluster =
400            samples_per_cluster + if cluster_id < remainder { 1 } else { 0 };
401
402        for _ in 0..n_samples_this_cluster {
403            // Generate point around cluster center using Gaussian noise
404            for &center_coord in center {
405                let noise = normal.sample(&mut rng) as f32;
406                features_data.push(center_coord + noise);
407            }
408            targets_data.push(cluster_id as f32);
409        }
410    }
411
412    let features = Tensor::from_vec(features_data, &[config.n_samples, n_features])?;
413    let targets = Tensor::from_vec(targets_data, &[config.n_samples])?;
414
415    Ok(DatasetResult {
416        features,
417        targets,
418        feature_names: Some((0..n_features).map(|i| format!("feature_{}", i)).collect()),
419        target_names: Some(
420            (0..config.centers)
421                .map(|i| format!("cluster_{}", i))
422                .collect(),
423        ),
424        description: format!(
425            "Synthetic clustering dataset (blobs): {} samples, {} features, {} clusters, cluster_std={:.2}",
426            config.n_samples, n_features, config.centers, cluster_std
427        ),
428    })
429}
430
431/// Convert scirs2_datasets::Dataset to torsh-data's DatasetResult
432fn convert_scirs2_dataset(
433    scirs2_dataset: scirs2_datasets::utils::Dataset,
434) -> Result<DatasetResult, DataError> {
435    // Convert features: Array2<f64> -> Tensor
436    let shape = scirs2_dataset.data.shape();
437    let features_data: Vec<f32> = scirs2_dataset.data.iter().map(|&x| x as f32).collect();
438    let features = Tensor::from_vec(features_data, &[shape[0], shape[1]])?;
439
440    // Convert targets: Array1<f64> -> Tensor
441    let targets = if let Some(target_array) = scirs2_dataset.target {
442        let target_data: Vec<f32> = target_array.iter().map(|&x| x as f32).collect();
443        Tensor::from_vec(target_data, &[target_array.len()])?
444    } else {
445        // Create empty tensor if no targets
446        Tensor::from_vec(vec![], &[0])?
447    };
448
449    Ok(DatasetResult {
450        features,
451        targets,
452        feature_names: scirs2_dataset.featurenames,
453        target_names: scirs2_dataset.targetnames,
454        description: scirs2_dataset
455            .description
456            .unwrap_or_else(|| "Dataset loaded from scirs2".to_string()),
457    })
458}
459
460// Built-in dataset implementations using scirs2_datasets
461fn load_iris_dataset() -> Result<DatasetResult, DataError> {
462    // ✅ Using scirs2_datasets::load_iris() for authentic Iris dataset
463    let scirs2_dataset = scirs2_load_iris().map_err(|e| {
464        DataError::dataset(
465            crate::error::DatasetErrorKind::CorruptedData,
466            format!("Failed to load Iris dataset from scirs2_datasets: {}", e),
467        )
468    })?;
469
470    convert_scirs2_dataset(scirs2_dataset)
471}
472
473fn load_boston_dataset() -> Result<DatasetResult, DataError> {
474    // ✅ Using scirs2_datasets::load_boston() for authentic Boston Housing dataset
475    let scirs2_dataset = scirs2_load_boston().map_err(|e| {
476        DataError::dataset(
477            crate::error::DatasetErrorKind::CorruptedData,
478            format!("Failed to load Boston dataset from scirs2_datasets: {}", e),
479        )
480    })?;
481
482    convert_scirs2_dataset(scirs2_dataset)
483}
484
485fn load_diabetes_dataset() -> Result<DatasetResult, DataError> {
486    // ✅ Using scirs2_datasets::load_diabetes() for authentic Diabetes dataset
487    let scirs2_dataset = scirs2_load_diabetes().map_err(|e| {
488        DataError::dataset(
489            crate::error::DatasetErrorKind::CorruptedData,
490            format!(
491                "Failed to load Diabetes dataset from scirs2_datasets: {}",
492                e
493            ),
494        )
495    })?;
496
497    convert_scirs2_dataset(scirs2_dataset)
498}
499
500fn load_wine_dataset() -> Result<DatasetResult, DataError> {
501    make_classification(ClassificationConfig {
502        n_samples: 178,
503        n_features: 13,
504        n_classes: 3,
505        n_informative: Some(13),
506        random_state: Some(42),
507        ..Default::default()
508    })
509}
510
511fn load_breast_cancer_dataset() -> Result<DatasetResult, DataError> {
512    // ✅ Using scirs2_datasets::load_breast_cancer() for authentic Breast Cancer dataset
513    let scirs2_dataset = scirs2_load_breast_cancer().map_err(|e| {
514        DataError::dataset(
515            crate::error::DatasetErrorKind::CorruptedData,
516            format!(
517                "Failed to load Breast Cancer dataset from scirs2_datasets: {}",
518                e
519            ),
520        )
521    })?;
522
523    convert_scirs2_dataset(scirs2_dataset)
524}
525
526fn load_digits_dataset() -> Result<DatasetResult, DataError> {
527    // ✅ Using scirs2_datasets::load_digits() for authentic Digits dataset
528    let scirs2_dataset = scirs2_load_digits().map_err(|e| {
529        DataError::dataset(
530            crate::error::DatasetErrorKind::CorruptedData,
531            format!("Failed to load Digits dataset from scirs2_datasets: {}", e),
532        )
533    })?;
534
535    convert_scirs2_dataset(scirs2_dataset)
536}
537
538impl Default for RegressionConfig {
539    fn default() -> Self {
540        Self {
541            n_samples: 100,
542            n_features: 1,
543            n_informative: None,
544            noise: Some(0.1),
545            bias: Some(0.0),
546            random_state: None,
547        }
548    }
549}
550
551impl Default for ClassificationConfig {
552    fn default() -> Self {
553        Self {
554            n_samples: 100,
555            n_features: 2,
556            n_classes: 2,
557            n_informative: None,
558            n_redundant: None,
559            n_clusters_per_class: None,
560            class_sep: Some(1.0),
561            random_state: None,
562        }
563    }
564}
565
566impl Default for ClusteringConfig {
567    fn default() -> Self {
568        Self {
569            n_samples: 100,
570            centers: 3,
571            n_features: Some(2),
572            cluster_std: Some(1.0),
573            center_box: Some((-10.0, 10.0)),
574            random_state: None,
575        }
576    }
577}
578
579/// Dataset registry for managing available datasets
580#[derive(Debug, Default)]
581pub struct DatasetRegistry {
582    builtin_datasets: Vec<BuiltinDataset>,
583}
584
585impl DatasetRegistry {
586    /// Create a new dataset registry
587    pub fn new() -> Self {
588        Self {
589            builtin_datasets: vec![
590                BuiltinDataset::Iris,
591                BuiltinDataset::Boston,
592                BuiltinDataset::Diabetes,
593                BuiltinDataset::Wine,
594                BuiltinDataset::BreastCancer,
595                BuiltinDataset::Digits,
596            ],
597        }
598    }
599
600    /// List all available built-in datasets
601    pub fn list_builtin(&self) -> &[BuiltinDataset] {
602        &self.builtin_datasets
603    }
604
605    /// Load a dataset by name
606    pub fn load_by_name(&self, name: &str) -> Result<DatasetResult, DataError> {
607        let dataset = match name.to_lowercase().as_str() {
608            "iris" => BuiltinDataset::Iris,
609            "boston" => BuiltinDataset::Boston,
610            "diabetes" => BuiltinDataset::Diabetes,
611            "wine" => BuiltinDataset::Wine,
612            "breast_cancer" | "breastcancer" => BuiltinDataset::BreastCancer,
613            "digits" => BuiltinDataset::Digits,
614            _ => {
615                return Err(DataError::dataset(
616                    crate::error::DatasetErrorKind::UnsupportedFormat,
617                    format!("Unknown dataset: {}", name),
618                ))
619            }
620        };
621
622        load_builtin_dataset(dataset)
623    }
624}
625
626#[cfg(test)]
627mod tests {
628    use super::*;
629
630    #[test]
631    fn test_load_iris_dataset() {
632        let result = load_builtin_dataset(BuiltinDataset::Iris);
633        assert!(result.is_ok());
634        let dataset = result.unwrap();
635
636        // Iris has 150 samples, 4 features
637        assert_eq!(dataset.features.size(0).unwrap(), 150);
638        assert_eq!(dataset.features.size(1).unwrap(), 4);
639        assert_eq!(dataset.targets.size(0).unwrap(), 150);
640
641        // Check metadata
642        assert!(dataset.feature_names.is_some());
643        assert!(dataset.target_names.is_some());
644        assert!(!dataset.description.is_empty());
645
646        let feature_names = dataset.feature_names.unwrap();
647        assert_eq!(feature_names.len(), 4);
648        assert!(feature_names.contains(&"sepal_length".to_string()));
649
650        let target_names = dataset.target_names.unwrap();
651        assert_eq!(target_names.len(), 3);
652    }
653
654    #[test]
655    fn test_load_boston_dataset() {
656        let result = load_builtin_dataset(BuiltinDataset::Boston);
657        assert!(result.is_ok());
658        let dataset = result.unwrap();
659
660        // Boston has 30 samples, 5 features (simplified version from scirs2_datasets)
661        assert_eq!(dataset.features.size(0).unwrap(), 30);
662        assert_eq!(dataset.features.size(1).unwrap(), 5);
663        assert_eq!(dataset.targets.size(0).unwrap(), 30);
664
665        // Check metadata
666        assert!(dataset.feature_names.is_some());
667        assert!(!dataset.description.is_empty());
668    }
669
670    #[test]
671    fn test_load_diabetes_dataset() {
672        let result = load_builtin_dataset(BuiltinDataset::Diabetes);
673        assert!(result.is_ok());
674        let dataset = result.unwrap();
675
676        // Diabetes has 442 samples, 10 features (from scirs2_datasets)
677        assert_eq!(dataset.features.size(0).unwrap(), 442);
678        assert_eq!(dataset.features.size(1).unwrap(), 10);
679        assert_eq!(dataset.targets.size(0).unwrap(), 442);
680
681        // Check metadata
682        assert!(dataset.feature_names.is_some());
683        assert!(!dataset.description.is_empty());
684
685        let feature_names = dataset.feature_names.unwrap();
686        assert_eq!(feature_names.len(), 10);
687        // Verify expected feature names from scirs2 diabetes dataset
688        assert!(feature_names.contains(&"age".to_string()));
689        assert!(feature_names.contains(&"bmi".to_string()));
690    }
691
692    #[test]
693    fn test_load_breast_cancer_dataset() {
694        let result = load_builtin_dataset(BuiltinDataset::BreastCancer);
695        assert!(result.is_ok());
696        let dataset = result.unwrap();
697
698        // Breast cancer has 30 samples, 5 features (simplified version from scirs2_datasets)
699        assert_eq!(dataset.features.size(0).unwrap(), 30);
700        assert_eq!(dataset.features.size(1).unwrap(), 5);
701        assert_eq!(dataset.targets.size(0).unwrap(), 30);
702
703        // Check metadata
704        assert!(dataset.feature_names.is_some());
705        assert!(dataset.target_names.is_some());
706        assert!(!dataset.description.is_empty());
707
708        let target_names = dataset.target_names.unwrap();
709        assert_eq!(target_names.len(), 2); // Binary classification: malignant, benign
710        assert!(target_names.contains(&"malignant".to_string()));
711        assert!(target_names.contains(&"benign".to_string()));
712    }
713
714    #[test]
715    fn test_load_digits_dataset() {
716        let result = load_builtin_dataset(BuiltinDataset::Digits);
717        assert!(result.is_ok());
718        let dataset = result.unwrap();
719
720        // Digits has 50 samples, 16 features (4x4 images from scirs2_datasets)
721        assert_eq!(dataset.features.size(0).unwrap(), 50);
722        assert_eq!(dataset.features.size(1).unwrap(), 16);
723        assert_eq!(dataset.targets.size(0).unwrap(), 50);
724
725        // Check metadata
726        assert!(dataset.target_names.is_some());
727        assert!(!dataset.description.is_empty());
728
729        let target_names = dataset.target_names.unwrap();
730        assert_eq!(target_names.len(), 10); // 10 digits (0-9)
731    }
732
733    #[test]
734    fn test_load_wine_dataset() {
735        let result = load_builtin_dataset(BuiltinDataset::Wine);
736        assert!(result.is_ok());
737        let dataset = result.unwrap();
738
739        // Wine has 178 samples, 13 features
740        assert_eq!(dataset.features.size(0).unwrap(), 178);
741        assert_eq!(dataset.features.size(1).unwrap(), 13);
742        assert_eq!(dataset.targets.size(0).unwrap(), 178);
743
744        // Check metadata
745        assert!(!dataset.description.is_empty());
746    }
747
748    #[test]
749    fn test_dataset_registry() {
750        let registry = DatasetRegistry::new();
751        let builtin_datasets = registry.list_builtin();
752
753        // Check all datasets are registered
754        assert_eq!(builtin_datasets.len(), 6);
755    }
756
757    #[test]
758    fn test_load_by_name() {
759        let registry = DatasetRegistry::new();
760
761        // Test all dataset names (including aliases)
762        assert!(registry.load_by_name("iris").is_ok());
763        assert!(registry.load_by_name("boston").is_ok());
764        assert!(registry.load_by_name("diabetes").is_ok());
765        assert!(registry.load_by_name("wine").is_ok());
766        assert!(registry.load_by_name("breast_cancer").is_ok());
767        assert!(registry.load_by_name("breastcancer").is_ok()); // Alias
768        assert!(registry.load_by_name("digits").is_ok());
769
770        // Test case insensitivity
771        assert!(registry.load_by_name("IRIS").is_ok());
772        assert!(registry.load_by_name("Diabetes").is_ok());
773
774        // Test unknown dataset
775        assert!(registry.load_by_name("unknown").is_err());
776    }
777
778    #[test]
779    fn test_make_regression() {
780        let config = RegressionConfig {
781            n_samples: 100,
782            n_features: 5,
783            n_informative: Some(3),
784            noise: Some(0.1),
785            bias: Some(1.0),
786            random_state: Some(42),
787        };
788
789        let result = make_regression(config);
790        assert!(result.is_ok());
791        let dataset = result.unwrap();
792
793        assert_eq!(dataset.features.size(0).unwrap(), 100);
794        assert_eq!(dataset.features.size(1).unwrap(), 5);
795        assert_eq!(dataset.targets.size(0).unwrap(), 100);
796    }
797
798    #[test]
799    fn test_make_classification() {
800        let config = ClassificationConfig {
801            n_samples: 200,
802            n_features: 10,
803            n_classes: 3,
804            n_informative: Some(5),
805            random_state: Some(42),
806            ..Default::default()
807        };
808
809        let result = make_classification(config);
810        assert!(result.is_ok());
811        let dataset = result.unwrap();
812
813        assert_eq!(dataset.features.size(0).unwrap(), 200);
814        assert_eq!(dataset.features.size(1).unwrap(), 10);
815        assert_eq!(dataset.targets.size(0).unwrap(), 200);
816    }
817
818    #[test]
819    fn test_make_blobs() {
820        let config = ClusteringConfig {
821            n_samples: 150,
822            centers: 3,
823            n_features: Some(2),
824            cluster_std: Some(0.5),
825            random_state: Some(42),
826            ..Default::default()
827        };
828
829        let result = make_blobs(config);
830        assert!(result.is_ok());
831        let dataset = result.unwrap();
832
833        assert_eq!(dataset.features.size(0).unwrap(), 150);
834        assert_eq!(dataset.features.size(1).unwrap(), 2);
835        assert_eq!(dataset.targets.size(0).unwrap(), 150);
836    }
837
838    #[test]
839    fn test_regression_config_validation() {
840        // Test n_informative > n_features
841        let config = RegressionConfig {
842            n_samples: 100,
843            n_features: 5,
844            n_informative: Some(10), // More than n_features
845            noise: Some(0.1),
846            bias: Some(0.0),
847            random_state: Some(42),
848        };
849
850        let result = make_regression(config);
851        assert!(result.is_err());
852    }
853
854    #[test]
855    fn test_scirs2_integration_diabetes() {
856        // Test that diabetes dataset is authentic from scirs2, not synthetic
857        let result = load_builtin_dataset(BuiltinDataset::Diabetes);
858        assert!(result.is_ok());
859        let dataset = result.unwrap();
860
861        // Verify it has the correct scirs2 diabetes dataset characteristics
862        assert_eq!(dataset.features.size(0).unwrap(), 442);
863        assert_eq!(dataset.features.size(1).unwrap(), 10);
864
865        // Check that description mentions it's from scirs2 or is realistic
866        assert!(
867            dataset.description.contains("diabetes") || dataset.description.contains("Diabetes")
868        );
869    }
870
871    #[test]
872    fn test_scirs2_integration_breast_cancer() {
873        // Test that breast cancer dataset is authentic from scirs2
874        let result = load_builtin_dataset(BuiltinDataset::BreastCancer);
875        assert!(result.is_ok());
876        let dataset = result.unwrap();
877
878        // Verify it has the correct scirs2 breast cancer dataset characteristics
879        assert_eq!(dataset.features.size(0).unwrap(), 30);
880        assert_eq!(dataset.features.size(1).unwrap(), 5);
881
882        // Check metadata is properly populated
883        assert!(dataset.feature_names.is_some());
884        assert!(dataset.target_names.is_some());
885    }
886
887    #[test]
888    fn test_scirs2_integration_digits() {
889        // Test that digits dataset is authentic from scirs2
890        let result = load_builtin_dataset(BuiltinDataset::Digits);
891        assert!(result.is_ok());
892        let dataset = result.unwrap();
893
894        // Verify it has the correct scirs2 digits dataset characteristics
895        assert_eq!(dataset.features.size(0).unwrap(), 50);
896        assert_eq!(dataset.features.size(1).unwrap(), 16); // 4x4 pixels
897
898        // Check that we have 10 target classes (digits 0-9)
899        assert!(dataset.target_names.is_some());
900        let target_names = dataset.target_names.unwrap();
901        assert_eq!(target_names.len(), 10);
902    }
903}