1pub mod advanced_analytics;
9pub mod balancing;
10pub mod dataset;
11pub mod enhanced_analytics;
12pub mod extensions;
13pub mod feature_engineering;
14pub mod sampling;
15pub mod scaling;
16pub mod serialization;
17pub mod splitting;
18
19pub use dataset::Dataset;
23pub use serialization::*;
24
25pub use splitting::{
27 k_fold_split, stratified_k_fold_split, time_series_split, train_test_split,
28 CrossValidationFolds,
29};
30
31pub use sampling::{
33 bootstrap_sample, importance_sample, multiple_bootstrap_samples, random_sample,
34 stratified_sample,
35};
36
37pub use balancing::{
39 create_balanced_dataset, generate_synthetic_samples, random_oversample, random_undersample,
40 BalancingStrategy,
41};
42
43pub use scaling::{min_max_scale, normalize, robust_scale, StatsExt};
45
46pub use feature_engineering::{
48 create_binned_features, polynomial_features, statistical_features, BinningStrategy,
49};
50
51pub use advanced_analytics::{
53 analyze_dataset_advanced, quick_quality_assessment, AdvancedDatasetAnalyzer,
54 AdvancedQualityMetrics, CorrelationInsights, NormalityAssessment,
55};
56
57pub type Array1<T> = scirs2_core::ndarray::Array1<T>;
64pub type Array2<T> = scirs2_core::ndarray::Array2<T>;
66
67#[cfg(test)]
68mod tests {
69 use super::*;
70 use scirs2_core::ndarray::Array2;
71
72 #[test]
73 fn test_module_integration() {
74 let data = Array2::from_shape_vec(
76 (6, 2),
77 vec![
78 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
79 ],
80 )
81 .unwrap();
82 let target = scirs2_core::ndarray::Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
83
84 let dataset = Dataset::new(data.clone(), Some(target.clone()));
86 assert_eq!(dataset.n_samples(), 6);
87 assert_eq!(dataset.n_features(), 2);
88
89 let (train, test) = train_test_split(&dataset, 0.3, Some(42)).unwrap();
91 assert_eq!(train.n_samples() + test.n_samples(), 6);
92
93 let indices = random_sample(6, 3, false, Some(42)).unwrap();
95 assert_eq!(indices.len(), 3);
96
97 let (balanced_data, _balanced_targets) =
99 random_oversample(&data, &target, Some(42)).unwrap();
100 assert!(balanced_data.nrows() > data.nrows()); let mut scaled_data = data.clone();
104 min_max_scale(&mut scaled_data, (0.0, 1.0));
105 assert!(scaled_data.iter().all(|&x| (0.0..=1.0).contains(&x)));
106
107 let poly_features = polynomial_features(&data, 2, true).unwrap();
109 assert!(poly_features.ncols() > data.ncols()); }
111
112 #[test]
113 fn test_backward_compatibility() {
114 use crate::utils::*;
116
117 let data =
118 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
119 let targets = scirs2_core::ndarray::Array1::from(vec![0.0, 0.0, 1.0, 1.0]);
120
121 let dataset = Dataset::new(data.clone(), Some(targets.clone()));
123 let folds = k_fold_split(4, 2, false, Some(42)).unwrap();
124 let sample_indices = stratified_sample(&targets, 2, Some(42)).unwrap();
125 let (bal_data, _bal_targets) = create_balanced_dataset(
126 &data,
127 &targets,
128 BalancingStrategy::RandomOversample,
129 Some(42),
130 )
131 .unwrap();
132
133 assert_eq!(dataset.n_samples(), 4);
134 assert_eq!(folds.len(), 2);
135 assert_eq!(sample_indices.len(), 2);
136 assert!(bal_data.nrows() >= data.nrows());
137 }
138
139 #[test]
140 fn test_cross_validation_compatibility() {
141 let data = Array2::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect()).unwrap();
143 let targets =
144 scirs2_core::ndarray::Array1::from((0..10).map(|x| (x % 3) as f64).collect::<Vec<_>>());
145
146 let dataset = Dataset::new(data, Some(targets.clone()));
147
148 let folds = k_fold_split(dataset.n_samples(), 5, true, Some(42)).unwrap();
150 assert_eq!(folds.len(), 5);
151
152 let stratified_folds = stratified_k_fold_split(&targets, 3, true, Some(42)).unwrap();
154 assert_eq!(stratified_folds.len(), 3);
155
156 let ts_folds = time_series_split(dataset.n_samples(), 3, 2, 1).unwrap();
158 assert_eq!(ts_folds.len(), 3);
159 }
160}