scirs2_datasets/utils/
mod.rs

1//! Utility functions and data structures for datasets
2//!
3//! This module provides a comprehensive collection of utilities for dataset manipulation,
4//! including data serialization, dataset structures, splitting, sampling, balancing,
5//! scaling, feature engineering, and trait extensions.
6
7// Import all submodules
8pub mod advanced_analytics;
9pub mod balancing;
10pub mod dataset;
11pub mod enhanced_analytics;
12pub mod extensions;
13pub mod feature_engineering;
14pub mod sampling;
15pub mod scaling;
16pub mod serialization;
17pub mod splitting;
18
19// Re-export main types and functions for backward compatibility
20
21// Dataset and serialization
22pub use dataset::Dataset;
23pub use serialization::*;
24
25// Data splitting
26pub use splitting::{
27    k_fold_split, stratified_k_fold_split, time_series_split, train_test_split,
28    CrossValidationFolds,
29};
30
31// Data sampling
32pub use sampling::{
33    bootstrap_sample, importance_sample, multiple_bootstrap_samples, random_sample,
34    stratified_sample,
35};
36
37// Data balancing
38pub use balancing::{
39    create_balanced_dataset, generate_synthetic_samples, random_oversample, random_undersample,
40    BalancingStrategy,
41};
42
43// Data scaling and normalization
44pub use scaling::{min_max_scale, normalize, robust_scale, StatsExt};
45
46// Feature engineering
47pub use feature_engineering::{
48    create_binned_features, polynomial_features, statistical_features, BinningStrategy,
49};
50
51// Advanced analytics
52pub use advanced_analytics::{
53    analyze_dataset_advanced, quick_quality_assessment, AdvancedDatasetAnalyzer,
54    AdvancedQualityMetrics, CorrelationInsights, NormalityAssessment,
55};
56
57// Trait extensions
58// pub use extensions::*;
59
60// Type aliases for convenience
61
62/// Convenience alias for ndarray 1D array
63pub type Array1<T> = scirs2_core::ndarray::Array1<T>;
64/// Convenience alias for ndarray 2D array
65pub type Array2<T> = scirs2_core::ndarray::Array2<T>;
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70    use scirs2_core::ndarray::Array2;
71
72    #[test]
73    fn test_module_integration() {
74        // Test that all major functionality is accessible through the module
75        let data = Array2::from_shape_vec(
76            (6, 2),
77            vec![
78                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
79            ],
80        )
81        .unwrap();
82        let target = scirs2_core::ndarray::Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
83
84        // Test dataset creation
85        let dataset = Dataset::new(data.clone(), Some(target.clone()));
86        assert_eq!(dataset.n_samples(), 6);
87        assert_eq!(dataset.n_features(), 2);
88
89        // Test data splitting
90        let (train, test) = train_test_split(&dataset, 0.3, Some(42)).unwrap();
91        assert_eq!(train.n_samples() + test.n_samples(), 6);
92
93        // Test sampling
94        let indices = random_sample(6, 3, false, Some(42)).unwrap();
95        assert_eq!(indices.len(), 3);
96
97        // Test balancing
98        let (balanced_data, _balanced_targets) =
99            random_oversample(&data, &target, Some(42)).unwrap();
100        assert!(balanced_data.nrows() > data.nrows()); // Should have more samples after oversampling
101
102        // Test scaling
103        let mut scaled_data = data.clone();
104        min_max_scale(&mut scaled_data, (0.0, 1.0));
105        assert!(scaled_data.iter().all(|&x| (0.0..=1.0).contains(&x)));
106
107        // Test feature engineering
108        let poly_features = polynomial_features(&data, 2, true).unwrap();
109        assert!(poly_features.ncols() > data.ncols()); // Should have more features
110    }
111
112    #[test]
113    fn test_backward_compatibility() {
114        // Test that the old API still works after refactoring
115        use crate::utils::*;
116
117        let data =
118            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
119        let targets = scirs2_core::ndarray::Array1::from(vec![0.0, 0.0, 1.0, 1.0]);
120
121        // These should all work exactly as they did before refactoring
122        let dataset = Dataset::new(data.clone(), Some(targets.clone()));
123        let folds = k_fold_split(4, 2, false, Some(42)).unwrap();
124        let sample_indices = stratified_sample(&targets, 2, Some(42)).unwrap();
125        let (bal_data, _bal_targets) = create_balanced_dataset(
126            &data,
127            &targets,
128            BalancingStrategy::RandomOversample,
129            Some(42),
130        )
131        .unwrap();
132
133        assert_eq!(dataset.n_samples(), 4);
134        assert_eq!(folds.len(), 2);
135        assert_eq!(sample_indices.len(), 2);
136        assert!(bal_data.nrows() >= data.nrows());
137    }
138
139    #[test]
140    fn test_cross_validation_compatibility() {
141        // Test cross-validation functionality that spans multiple modules
142        let data = Array2::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect()).unwrap();
143        let targets =
144            scirs2_core::ndarray::Array1::from((0..10).map(|x| (x % 3) as f64).collect::<Vec<_>>());
145
146        let dataset = Dataset::new(data, Some(targets.clone()));
147
148        // Test k-fold splitting
149        let folds = k_fold_split(dataset.n_samples(), 5, true, Some(42)).unwrap();
150        assert_eq!(folds.len(), 5);
151
152        // Test stratified splitting
153        let stratified_folds = stratified_k_fold_split(&targets, 3, true, Some(42)).unwrap();
154        assert_eq!(stratified_folds.len(), 3);
155
156        // Test time series splitting
157        let ts_folds = time_series_split(dataset.n_samples(), 3, 2, 1).unwrap();
158        assert_eq!(ts_folds.len(), 3);
159    }
160}