sklears_dummy/validation/
validation_core.rs

1use scirs2_core::ndarray::Array1;
2use scirs2_core::random::{thread_rng, RngCore, SeedableRng};
3use sklears_core::error::{Result, SklearsError};
4use sklears_core::types::Float;
5use std::cmp::Ordering;
6
7/// Result of dummy estimator validation
8#[derive(Debug, Clone)]
9pub struct DummyValidationResult {
10    /// Mean score across folds
11    pub mean_score: Float,
12    /// Standard deviation of scores across folds
13    pub std_score: Float,
14    /// Individual fold scores
15    pub fold_scores: Vec<Float>,
16    /// Strategy that was evaluated
17    pub strategy: String,
18}
19
20impl DummyValidationResult {
21    pub fn new(
22        mean_score: Float,
23        std_score: Float,
24        fold_scores: Vec<Float>,
25        strategy: String,
26    ) -> Self {
27        Self {
28            mean_score,
29            std_score,
30            fold_scores,
31            strategy,
32        }
33    }
34
35    pub fn confidence_interval(&self, confidence_level: Float) -> (Float, Float) {
36        let n = self.fold_scores.len() as Float;
37        let sem = self.std_score / n.sqrt();
38
39        // Approximate t-value for common confidence levels
40        let t_value = match confidence_level {
41            0.90 => 1.645,
42            0.95 => 1.96,
43            0.99 => 2.576,
44            _ => 1.96, // Default to 95%
45        };
46
47        let margin = t_value * sem;
48        (self.mean_score - margin, self.mean_score + margin)
49    }
50
51    pub fn is_significantly_better_than(
52        &self,
53        other: &DummyValidationResult,
54        alpha: Float,
55    ) -> bool {
56        // Simple t-test approximation
57        let pooled_std = ((self.std_score.powi(2) + other.std_score.powi(2)) / 2.0).sqrt();
58        let n1 = self.fold_scores.len() as Float;
59        let n2 = other.fold_scores.len() as Float;
60        let se_diff = pooled_std * ((1.0 / n1) + (1.0 / n2)).sqrt();
61
62        if se_diff == 0.0 {
63            return self.mean_score > other.mean_score;
64        }
65
66        let t_stat = (self.mean_score - other.mean_score) / se_diff;
67        let t_critical = match alpha {
68            0.01 => 2.576,
69            0.05 => 1.96,
70            0.10 => 1.645,
71            _ => 1.96,
72        };
73
74        t_stat > t_critical
75    }
76}
77
78/// Configuration for validation procedures
79#[derive(Debug, Clone)]
80pub struct ValidationConfig {
81    /// cv_folds
82    pub cv_folds: usize,
83    /// random_state
84    pub random_state: Option<u64>,
85    /// shuffle
86    pub shuffle: bool,
87    /// stratify
88    pub stratify: bool,
89    /// scoring_metric
90    pub scoring_metric: String,
91    /// bootstrap_samples
92    pub bootstrap_samples: usize,
93    /// confidence_level
94    pub confidence_level: Float,
95}
96
97impl Default for ValidationConfig {
98    fn default() -> Self {
99        Self {
100            cv_folds: 5,
101            random_state: None,
102            shuffle: true,
103            stratify: false,
104            scoring_metric: "accuracy".to_string(),
105            bootstrap_samples: 1000,
106            confidence_level: 0.95,
107        }
108    }
109}
110
111impl ValidationConfig {
112    pub fn new() -> Self {
113        Self::default()
114    }
115
116    pub fn cv_folds(mut self, folds: usize) -> Self {
117        self.cv_folds = folds;
118        self
119    }
120
121    pub fn random_state(mut self, seed: u64) -> Self {
122        self.random_state = Some(seed);
123        self
124    }
125
126    pub fn shuffle(mut self, shuffle: bool) -> Self {
127        self.shuffle = shuffle;
128        self
129    }
130
131    pub fn stratify(mut self, stratify: bool) -> Self {
132        self.stratify = stratify;
133        self
134    }
135
136    pub fn scoring_metric(mut self, metric: String) -> Self {
137        self.scoring_metric = metric;
138        self
139    }
140
141    pub fn bootstrap_samples(mut self, samples: usize) -> Self {
142        self.bootstrap_samples = samples;
143        self
144    }
145
146    pub fn confidence_level(mut self, level: Float) -> Self {
147        self.confidence_level = level;
148        self
149    }
150}
151
152/// Comprehensive validation result with additional statistics
153#[derive(Debug, Clone)]
154pub struct ComprehensiveValidationResult {
155    /// validation_result
156    pub validation_result: DummyValidationResult,
157    /// fold_details
158    pub fold_details: Vec<FoldResult>,
159    /// statistical_summary
160    pub statistical_summary: StatisticalSummary,
161    /// config
162    pub config: ValidationConfig,
163}
164
165/// Result for individual fold
166#[derive(Debug, Clone)]
167pub struct FoldResult {
168    /// fold_index
169    pub fold_index: usize,
170    /// train_size
171    pub train_size: usize,
172    /// test_size
173    pub test_size: usize,
174    /// score
175    pub score: Float,
176    /// fit_time
177    pub fit_time: Float,
178    /// predict_time
179    pub predict_time: Float,
180}
181
182/// Statistical summary of validation results
183#[derive(Debug, Clone)]
184pub struct StatisticalSummary {
185    /// mean
186    pub mean: Float,
187    /// std
188    pub std: Float,
189    /// min
190    pub min: Float,
191    /// max
192    pub max: Float,
193    /// median
194    pub median: Float,
195    /// q25
196    pub q25: Float,
197    /// q75
198    pub q75: Float,
199    /// skewness
200    pub skewness: Float,
201    /// kurtosis
202    pub kurtosis: Float,
203}
204
205impl StatisticalSummary {
206    pub fn from_scores(scores: &[Float]) -> Self {
207        if scores.is_empty() {
208            return Self::default();
209        }
210
211        let mut sorted_scores = scores.to_vec();
212        sorted_scores.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
213
214        let n = scores.len() as Float;
215        let mean = scores.iter().sum::<Float>() / n;
216        let variance = scores.iter().map(|&x| (x - mean).powi(2)).sum::<Float>() / n;
217        let std = variance.sqrt();
218        let min = sorted_scores[0];
219        let max = sorted_scores[sorted_scores.len() - 1];
220
221        let median = if sorted_scores.len() % 2 == 0 {
222            let mid = sorted_scores.len() / 2;
223            (sorted_scores[mid - 1] + sorted_scores[mid]) / 2.0
224        } else {
225            sorted_scores[sorted_scores.len() / 2]
226        };
227
228        let q25_idx = (sorted_scores.len() as Float * 0.25) as usize;
229        let q75_idx = (sorted_scores.len() as Float * 0.75) as usize;
230        let q25 = sorted_scores[q25_idx.min(sorted_scores.len() - 1)];
231        let q75 = sorted_scores[q75_idx.min(sorted_scores.len() - 1)];
232
233        // Calculate skewness and kurtosis
234        let m3 = scores.iter().map(|&x| (x - mean).powi(3)).sum::<Float>() / n;
235        let m4 = scores.iter().map(|&x| (x - mean).powi(4)).sum::<Float>() / n;
236        let skewness = if std > 0.0 { m3 / std.powi(3) } else { 0.0 };
237        let kurtosis = if std > 0.0 {
238            m4 / std.powi(4) - 3.0
239        } else {
240            0.0
241        };
242
243        Self {
244            mean,
245            std,
246            min,
247            max,
248            median,
249            q25,
250            q75,
251            skewness,
252            kurtosis,
253        }
254    }
255}
256
257impl Default for StatisticalSummary {
258    fn default() -> Self {
259        Self {
260            mean: 0.0,
261            std: 0.0,
262            min: 0.0,
263            max: 0.0,
264            median: 0.0,
265            q25: 0.0,
266            q75: 0.0,
267            skewness: 0.0,
268            kurtosis: 0.0,
269        }
270    }
271}
272
273/// Validation error types specific to dummy estimators
274#[derive(Debug, Clone)]
275pub enum ValidationError {
276    /// InsufficientData
277    InsufficientData(String),
278    /// InvalidFolds
279    InvalidFolds(String),
280    /// StratificationError
281    StratificationError(String),
282    /// ScoringError
283    ScoringError(String),
284    /// ConfigurationError
285    ConfigurationError(String),
286}
287
288impl std::fmt::Display for ValidationError {
289    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
290        match self {
291            ValidationError::InsufficientData(msg) => write!(f, "Insufficient data: {}", msg),
292            ValidationError::InvalidFolds(msg) => write!(f, "Invalid fold configuration: {}", msg),
293            ValidationError::StratificationError(msg) => write!(f, "Stratification error: {}", msg),
294            ValidationError::ScoringError(msg) => write!(f, "Scoring error: {}", msg),
295            ValidationError::ConfigurationError(msg) => write!(f, "Configuration error: {}", msg),
296        }
297    }
298}
299
300impl std::error::Error for ValidationError {}
301
302/// Utility function to validate common validation parameters
303pub fn validate_cv_params(n_samples: usize, cv_folds: usize) -> Result<()> {
304    if cv_folds < 2 {
305        return Err(SklearsError::InvalidInput(
306            "Cross-validation folds must be at least 2".to_string(),
307        ));
308    }
309
310    if n_samples < cv_folds {
311        return Err(SklearsError::InvalidInput(
312            "Number of samples must be at least equal to cv folds".to_string(),
313        ));
314    }
315
316    Ok(())
317}
318
319/// Determine if the target variable represents a classification task
320pub fn is_classification_task(y: &Array1<Float>) -> bool {
321    if y.is_empty() {
322        return false;
323    }
324
325    // Check for NaN or infinite values
326    if y.iter().any(|&val| val.is_nan() || val.is_infinite()) {
327        return false;
328    }
329
330    // Check if all values are integers
331    let all_integers = y.iter().all(|&val| val.fract() == 0.0);
332    if !all_integers {
333        return false;
334    }
335
336    // Check the number of unique values
337    let mut unique_values: Vec<Float> = y.iter().copied().collect();
338    unique_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
339    unique_values.dedup();
340
341    // Typically classification tasks have fewer than 50 unique classes
342    // and more than 1 class
343    unique_values.len() > 1 && unique_values.len() < 50
344}
345
346/// Create a random number generator with optional seed
347pub fn create_rng(random_state: Option<u64>) -> Box<dyn RngCore> {
348    match random_state {
349        Some(seed) => Box::new(scirs2_core::random::rngs::StdRng::seed_from_u64(seed)),
350        None => Box::new(thread_rng()),
351    }
352}