sklears_dummy/validation/
mod.rs1pub mod bootstrap_validation;
7pub mod cross_validation;
8pub mod data_splitting;
9pub mod statistical_analysis;
10pub mod strategy_comparison;
11pub mod tests;
12pub mod validation_core;
13pub mod validation_metrics;
14pub mod validation_utils;
15
16pub use bootstrap_validation::*;
18pub use cross_validation::*;
19pub use data_splitting::*;
20pub use statistical_analysis::*;
21pub use strategy_comparison::*;
22pub use validation_core::*;
23pub use validation_metrics::*;
24pub use validation_utils::*;
25
26use crate::dummy_classifier::Strategy as ClassifierStrategy;
28use crate::dummy_regressor::Strategy as RegressorStrategy;
29use scirs2_core::ndarray::{Array1, Array2};
30
31#[derive(Debug, Clone)]
33pub struct DummyValidationResult {
34 pub strategy: String,
36 pub cv_scores: Vec<f64>,
38 pub mean_score: f64,
40 pub std_score: f64,
42 pub confidence_interval: (f64, f64),
44}
45
46impl DummyValidationResult {
47 pub fn new(mean_score: f64, std_score: f64, cv_scores: Vec<f64>, strategy: String) -> Self {
49 let n = cv_scores.len() as f64;
51 let margin = 1.96 * (std_score / n.sqrt());
52 let confidence_interval = (mean_score - margin, mean_score + margin);
53
54 Self {
55 strategy,
56 cv_scores,
57 mean_score,
58 std_score,
59 confidence_interval,
60 }
61 }
62}
63
64#[derive(Debug, Clone)]
66pub struct StatisticalValidationResult {
67 pub test_statistic: f64,
69 pub p_value: f64,
71 pub critical_value: f64,
73 pub is_significant: bool,
75}
76
77#[derive(Debug, Clone)]
79pub struct StrategyRanking {
80 pub strategy: String,
82 pub rank: usize,
84 pub score: f64,
86 pub tier: String,
88}
89
90#[derive(Debug, Clone)]
92pub struct StrategyRecommendation {
93 pub recommended_strategy: String,
95 pub confidence: f64,
97 pub reasoning: String,
99 pub alternatives: Vec<String>,
101}
102
103#[derive(Debug, Clone)]
105pub struct DatasetCharacteristics {
106 pub n_samples: usize,
108 pub n_features: usize,
110 pub target_type: String,
112 pub class_balance: Option<f64>,
114 pub missing_values: f64,
116}
117
118#[derive(Debug, Clone)]
120pub struct ClassDistribution {
121 pub classes: Vec<i32>,
123 pub counts: Vec<usize>,
125 pub proportions: Vec<f64>,
127}
128
129#[derive(Debug, Clone)]
131pub struct TargetDistribution {
132 pub mean: f64,
134 pub std: f64,
136 pub min: f64,
138 pub max: f64,
140 pub percentiles: Vec<f64>,
142}
143
144#[derive(Debug, Clone)]
146pub enum DataType {
147 Classification,
149 Regression,
151 Multiclass,
153 Multilabel,
155}
156
157#[derive(Debug, Clone)]
159pub struct PermutationTestResult {
160 pub observed_score: f64,
162 pub null_distribution: Vec<f64>,
164 pub p_value: f64,
166 pub significance_level: f64,
168 pub is_significant: bool,
170}
171
172#[derive(Debug, Clone)]
174pub struct ValidationSummary {
175 pub best_strategy: String,
177 pub all_results: Vec<DummyValidationResult>,
179 pub statistical_significance: StatisticalValidationResult,
181 pub recommendation: StrategyRecommendation,
183}
184
185pub fn analyze_classification_dataset(x: &Array2<f64>, y: &Array1<i32>) -> DatasetCharacteristics {
190 use std::collections::HashMap;
191
192 let n_samples = x.nrows();
193 let n_features = x.ncols();
194
195 let mut class_counts: HashMap<i32, usize> = HashMap::new();
197 for &class in y.iter() {
198 *class_counts.entry(class).or_insert(0) += 1;
199 }
200
201 let class_balance = if class_counts.len() > 1 {
204 let counts: Vec<usize> = class_counts.values().copied().collect();
205 let min_count = *counts.iter().min().unwrap() as f64;
206 let max_count = *counts.iter().max().unwrap() as f64;
207 Some(min_count / max_count)
208 } else {
209 None };
211
212 DatasetCharacteristics {
214 n_samples,
215 n_features,
216 target_type: "classification".to_string(),
217 class_balance,
218 missing_values: 0.0,
219 }
220}
221
222pub fn analyze_regression_dataset(_x: &Array2<f64>, _y: &Array1<f64>) -> DatasetCharacteristics {
224 DatasetCharacteristics {
227 n_samples: 100,
228 n_features: 5,
229 target_type: "regression".to_string(),
230 class_balance: None,
231 missing_values: 0.0,
232 }
233}
234
235pub fn get_adaptive_classification_strategy(
237 characteristics: &DatasetCharacteristics,
238) -> ClassifierStrategy {
239 if let Some(class_balance) = characteristics.class_balance {
241 if class_balance > 0.7 {
244 ClassifierStrategy::Stratified
245 } else {
246 ClassifierStrategy::MostFrequent
248 }
249 } else {
250 ClassifierStrategy::MostFrequent
252 }
253}
254
255pub fn get_adaptive_regression_strategy(
257 _characteristics: &DatasetCharacteristics,
258) -> RegressorStrategy {
259 RegressorStrategy::Mean
261}
262
263pub fn cross_validate_dummy(
265 _estimator: &str,
266 _x: &Array2<f64>,
267 _y: &Array1<f64>,
268 _cv: usize,
269) -> DummyValidationResult {
270 DummyValidationResult {
273 strategy: "placeholder".to_string(),
274 cv_scores: vec![0.5, 0.6, 0.7],
275 mean_score: 0.6,
276 std_score: 0.1,
277 confidence_interval: (0.5, 0.7),
278 }
279}
280
281pub fn comprehensive_validation_classifier(
283 _classifier: &str,
284 _x: &Array2<f64>,
285 _y: &Array1<i32>,
286) -> ValidationSummary {
287 ValidationSummary {
290 best_strategy: "most_frequent".to_string(),
291 all_results: vec![],
292 statistical_significance: StatisticalValidationResult {
293 test_statistic: 1.5,
294 p_value: 0.05,
295 critical_value: 1.96,
296 is_significant: false,
297 },
298 recommendation: StrategyRecommendation {
299 recommended_strategy: "most_frequent".to_string(),
300 confidence: 0.8,
301 reasoning: "Balanced dataset".to_string(),
302 alternatives: vec!["stratified".to_string()],
303 },
304 }
305}
306
307pub fn get_best_strategy(_results: &[DummyValidationResult]) -> String {
309 "best_strategy".to_string()
310}
311
312pub fn get_ranking_summary(_results: &[DummyValidationResult]) -> Vec<StrategyRanking> {
314 vec![]
315}
316
317pub fn get_strategies_in_tier(_results: &[DummyValidationResult], _tier: &str) -> Vec<String> {
319 vec![]
320}
321
322pub fn permutation_test_classifier(
324 _classifier: &str,
325 _x: &Array2<f64>,
326 _y: &Array1<i32>,
327) -> PermutationTestResult {
328 PermutationTestResult {
331 observed_score: 0.8,
332 null_distribution: vec![0.5; 100],
333 p_value: 0.01,
334 significance_level: 0.05,
335 is_significant: true,
336 }
337}
338
339pub fn permutation_test_vs_random_classifier(
341 _classifier: &str,
342 _x: &Array2<f64>,
343 _y: &Array1<i32>,
344) -> PermutationTestResult {
345 PermutationTestResult {
348 observed_score: 0.8,
349 null_distribution: vec![0.5; 100],
350 p_value: 0.01,
351 significance_level: 0.05,
352 is_significant: true,
353 }
354}
355
356pub fn rank_dummy_strategies_classifier(
358 _strategies: &[ClassifierStrategy],
359 _x: &Array2<f64>,
360 _y: &Array1<i32>,
361) -> Vec<StrategyRanking> {
362 vec![]
363}
364
365pub fn rank_dummy_strategies_regressor(
367 _strategies: &[RegressorStrategy],
368 _x: &Array2<f64>,
369 _y: &Array1<f64>,
370) -> Vec<StrategyRanking> {
371 vec![]
372}
373
374pub fn recommend_classification_strategy(
376 _x: &Array2<f64>,
377 _y: &Array1<i32>,
378) -> StrategyRecommendation {
379 StrategyRecommendation {
381 recommended_strategy: "most_frequent".to_string(),
382 confidence: 0.8,
383 reasoning: "Default recommendation".to_string(),
384 alternatives: vec!["stratified".to_string()],
385 }
386}
387
388pub fn recommend_regression_strategy(_x: &Array2<f64>, _y: &Array1<f64>) -> StrategyRecommendation {
390 StrategyRecommendation {
392 recommended_strategy: "mean".to_string(),
393 confidence: 0.8,
394 reasoning: "Default recommendation".to_string(),
395 alternatives: vec!["median".to_string()],
396 }
397}
398
399pub fn validate_reproducibility(
401 _estimator: &str,
402 _x: &Array2<f64>,
403 _y: &Array1<f64>,
404 _random_state: u64,
405) -> bool {
406 true }