sklears_model_selection/
scoring.rs

1//! Enhanced scoring utilities for model selection
2
3use scirs2_core::ndarray::Array1;
4use scirs2_core::random::rngs::StdRng;
5use scirs2_core::random::Rng;
6use scirs2_core::random::SeedableRng;
7use sklears_core::{
8    error::{Result, SklearsError},
9    // traits::Score,
10    types::Float,
11};
12use sklears_metrics::{
13    classification::{accuracy_score, f1_score, precision_score, recall_score},
14    regression::{explained_variance_score, mean_absolute_error, mean_squared_error, r2_score},
15};
16use std::collections::HashMap;
17use std::sync::Arc;
18
19/// Custom scoring function trait
20pub trait CustomScorer: Send + Sync + std::fmt::Debug {
21    /// Compute score given true and predicted values
22    fn score(&self, y_true: &Array1<Float>, y_pred: &Array1<Float>) -> Result<f64>;
23    /// Get the name of this custom scorer
24    fn name(&self) -> &str;
25    /// Whether higher scores are better (true) or lower scores are better (false)
26    fn higher_is_better(&self) -> bool;
27}
28
29/// Custom scoring function wrapper for closures
30pub struct ClosureScorer {
31    name: String,
32    scorer_fn: Arc<dyn Fn(&Array1<Float>, &Array1<Float>) -> Result<f64> + Send + Sync>,
33    higher_is_better: bool,
34}
35
36impl ClosureScorer {
37    /// Create a new custom scorer from a closure
38    pub fn new<F>(name: String, scorer_fn: F, higher_is_better: bool) -> Self
39    where
40        F: Fn(&Array1<Float>, &Array1<Float>) -> Result<f64> + Send + Sync + 'static,
41    {
42        Self {
43            name,
44            scorer_fn: Arc::new(scorer_fn),
45            higher_is_better,
46        }
47    }
48}
49
50impl std::fmt::Debug for ClosureScorer {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        f.debug_struct("ClosureScorer")
53            .field("name", &self.name)
54            .field("higher_is_better", &self.higher_is_better)
55            .finish()
56    }
57}
58
59impl CustomScorer for ClosureScorer {
60    fn score(&self, y_true: &Array1<Float>, y_pred: &Array1<Float>) -> Result<f64> {
61        (self.scorer_fn)(y_true, y_pred)
62    }
63
64    fn name(&self) -> &str {
65        &self.name
66    }
67
68    fn higher_is_better(&self) -> bool {
69        self.higher_is_better
70    }
71}
72
73/// Scorer registry for built-in and custom scorers
74#[derive(Debug, Clone)]
75pub struct ScorerRegistry {
76    custom_scorers: HashMap<String, Arc<dyn CustomScorer>>,
77}
78
79impl Default for ScorerRegistry {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl ScorerRegistry {
86    /// Create a new scorer registry
87    pub fn new() -> Self {
88        Self {
89            custom_scorers: HashMap::new(),
90        }
91    }
92
93    /// Register a custom scorer
94    pub fn register_scorer(&mut self, scorer: Arc<dyn CustomScorer>) {
95        self.custom_scorers
96            .insert(scorer.name().to_string(), scorer);
97    }
98
99    /// Register a custom scorer from a closure
100    pub fn register_closure_scorer<F>(&mut self, name: String, scorer_fn: F, higher_is_better: bool)
101    where
102        F: Fn(&Array1<Float>, &Array1<Float>) -> Result<f64> + Send + Sync + 'static,
103    {
104        let scorer = Arc::new(ClosureScorer::new(name, scorer_fn, higher_is_better));
105        self.register_scorer(scorer);
106    }
107
108    /// Get a custom scorer by name
109    pub fn get_scorer(&self, name: &str) -> Option<&Arc<dyn CustomScorer>> {
110        self.custom_scorers.get(name)
111    }
112
113    /// List all registered custom scorers
114    pub fn list_scorers(&self) -> Vec<&str> {
115        self.custom_scorers.keys().map(|s| s.as_str()).collect()
116    }
117
118    /// Check if a scorer is registered
119    pub fn has_scorer(&self, name: &str) -> bool {
120        self.custom_scorers.contains_key(name)
121    }
122}
123
124/// Enhanced scoring configuration
125#[derive(Debug, Clone)]
126pub struct ScoringConfig {
127    /// Primary scoring metric
128    pub primary: String,
129    /// Additional metrics to compute
130    pub additional: Vec<String>,
131    /// Whether to compute confidence intervals
132    pub confidence_intervals: bool,
133    /// Confidence level for intervals (0.95 = 95%)
134    pub confidence_level: f64,
135    /// Number of bootstrap samples for confidence intervals
136    pub n_bootstrap: usize,
137    /// Random state for bootstrap sampling
138    pub random_state: Option<u64>,
139    /// Custom scorer registry
140    pub scorer_registry: ScorerRegistry,
141}
142
143impl Default for ScoringConfig {
144    fn default() -> Self {
145        Self {
146            primary: "accuracy".to_string(),
147            additional: vec![],
148            confidence_intervals: false,
149            confidence_level: 0.95,
150            n_bootstrap: 1000,
151            random_state: None,
152            scorer_registry: ScorerRegistry::new(),
153        }
154    }
155}
156
157impl ScoringConfig {
158    /// Create a new scoring configuration with primary metric
159    pub fn new(primary: &str) -> Self {
160        Self {
161            primary: primary.to_string(),
162            ..Default::default()
163        }
164    }
165
166    /// Add additional metrics
167    pub fn with_additional_metrics(mut self, metrics: Vec<String>) -> Self {
168        self.additional = metrics;
169        self
170    }
171
172    /// Enable confidence intervals
173    pub fn with_confidence_intervals(mut self, level: f64, n_bootstrap: usize) -> Self {
174        self.confidence_intervals = true;
175        self.confidence_level = level;
176        self.n_bootstrap = n_bootstrap;
177        self
178    }
179
180    /// Set random state for reproducibility
181    pub fn with_random_state(mut self, random_state: u64) -> Self {
182        self.random_state = Some(random_state);
183        self
184    }
185
186    /// Register a custom scorer
187    pub fn with_custom_scorer(mut self, scorer: Arc<dyn CustomScorer>) -> Self {
188        self.scorer_registry.register_scorer(scorer);
189        self
190    }
191
192    /// Register a custom scorer from a closure
193    pub fn with_closure_scorer<F>(
194        mut self,
195        name: String,
196        scorer_fn: F,
197        higher_is_better: bool,
198    ) -> Self
199    where
200        F: Fn(&Array1<Float>, &Array1<Float>) -> Result<f64> + Send + Sync + 'static,
201    {
202        self.scorer_registry
203            .register_closure_scorer(name, scorer_fn, higher_is_better);
204        self
205    }
206
207    /// Get a mutable reference to the scorer registry
208    pub fn scorer_registry_mut(&mut self) -> &mut ScorerRegistry {
209        &mut self.scorer_registry
210    }
211
212    /// Get a reference to the scorer registry
213    pub fn scorer_registry(&self) -> &ScorerRegistry {
214        &self.scorer_registry
215    }
216}
217
218/// Scoring result with confidence intervals and multiple metrics
219#[derive(Debug, Clone)]
220pub struct ScoringResult {
221    /// Primary metric scores
222    pub primary_scores: Array1<f64>,
223    /// Additional metric scores
224    pub additional_scores: HashMap<String, Array1<f64>>,
225    /// Confidence intervals for primary metric
226    pub confidence_interval: Option<(f64, f64)>,
227    /// Confidence intervals for additional metrics
228    pub additional_confidence_intervals: HashMap<String, (f64, f64)>,
229    /// Mean scores
230    pub mean_scores: HashMap<String, f64>,
231    /// Standard deviations
232    pub std_scores: HashMap<String, f64>,
233}
234
235impl ScoringResult {
236    /// Get the primary metric mean score
237    pub fn primary_mean(&self) -> f64 {
238        self.mean_scores.get("primary").copied().unwrap_or(0.0)
239    }
240
241    /// Get mean score for a specific metric
242    pub fn mean_score(&self, metric: &str) -> Option<f64> {
243        self.mean_scores.get(metric).copied()
244    }
245
246    /// Get all mean scores
247    pub fn all_mean_scores(&self) -> &HashMap<String, f64> {
248        &self.mean_scores
249    }
250}
251
252/// Enhanced scorer that supports multiple metrics and confidence intervals
253pub struct EnhancedScorer {
254    config: ScoringConfig,
255}
256
257impl EnhancedScorer {
258    /// Create a new enhanced scorer
259    pub fn new(config: ScoringConfig) -> Self {
260        Self { config }
261    }
262
263    /// Score predictions with multiple metrics and confidence intervals
264    pub fn score_predictions(
265        &self,
266        y_true_splits: &[Array1<Float>],
267        y_pred_splits: &[Array1<Float>],
268        task_type: TaskType,
269    ) -> Result<ScoringResult> {
270        if y_true_splits.len() != y_pred_splits.len() {
271            return Err(SklearsError::InvalidInput(
272                "Number of true and predicted splits must match".to_string(),
273            ));
274        }
275
276        let n_splits = y_true_splits.len();
277        let mut primary_scores = Vec::with_capacity(n_splits);
278        let mut additional_scores: HashMap<String, Vec<f64>> = HashMap::new();
279
280        // Initialize additional scores storage
281        for metric in &self.config.additional {
282            additional_scores.insert(metric.clone(), Vec::with_capacity(n_splits));
283        }
284
285        // Compute scores for each split
286        for (y_true, y_pred) in y_true_splits.iter().zip(y_pred_splits.iter()) {
287            // Primary metric
288            let primary_score =
289                self.compute_metric_score(&self.config.primary, y_true, y_pred, task_type)?;
290            primary_scores.push(primary_score);
291
292            // Additional metrics
293            for metric in &self.config.additional {
294                let score = self.compute_metric_score(metric, y_true, y_pred, task_type)?;
295                additional_scores.get_mut(metric).unwrap().push(score);
296            }
297        }
298
299        // Convert to arrays
300        let primary_scores_array = Array1::from_vec(primary_scores.clone());
301        let mut additional_scores_arrays = HashMap::new();
302        for (metric, scores) in additional_scores.iter() {
303            additional_scores_arrays.insert(metric.clone(), Array1::from_vec(scores.clone()));
304        }
305
306        // Compute confidence intervals if requested
307        let confidence_interval = if self.config.confidence_intervals {
308            Some(self.bootstrap_confidence_interval(&primary_scores)?)
309        } else {
310            None
311        };
312
313        let mut additional_confidence_intervals = HashMap::new();
314        if self.config.confidence_intervals {
315            for (metric, scores) in &additional_scores {
316                let ci = self.bootstrap_confidence_interval(scores)?;
317                additional_confidence_intervals.insert(metric.clone(), ci);
318            }
319        }
320
321        // Compute mean and std
322        let mut mean_scores = HashMap::new();
323        let mut std_scores = HashMap::new();
324
325        mean_scores.insert("primary".to_string(), primary_scores_array.mean().unwrap());
326        std_scores.insert("primary".to_string(), primary_scores_array.std(1.0));
327
328        for (metric, scores) in &additional_scores_arrays {
329            mean_scores.insert(metric.clone(), scores.mean().unwrap());
330            std_scores.insert(metric.clone(), scores.std(1.0));
331        }
332
333        Ok(ScoringResult {
334            primary_scores: primary_scores_array,
335            additional_scores: additional_scores_arrays,
336            confidence_interval,
337            additional_confidence_intervals,
338            mean_scores,
339            std_scores,
340        })
341    }
342
343    /// Compute score for a specific metric
344    fn compute_metric_score(
345        &self,
346        metric: &str,
347        y_true: &Array1<Float>,
348        y_pred: &Array1<Float>,
349        task_type: TaskType,
350    ) -> Result<f64> {
351        // First check if it's a custom scorer
352        if let Some(custom_scorer) = self.config.scorer_registry.get_scorer(metric) {
353            return custom_scorer.score(y_true, y_pred);
354        }
355
356        // Otherwise use built-in scorers
357        match task_type {
358            TaskType::Classification => self.compute_classification_score(metric, y_true, y_pred),
359            TaskType::Regression => self.compute_regression_score(metric, y_true, y_pred),
360        }
361    }
362
363    fn compute_classification_score(
364        &self,
365        metric: &str,
366        y_true: &Array1<Float>,
367        y_pred: &Array1<Float>,
368    ) -> Result<f64> {
369        // Convert float arrays to integer arrays for classification metrics
370        let y_true_int: Array1<i32> = y_true.mapv(|x| x as i32);
371        let y_pred_int: Array1<i32> = y_pred.mapv(|x| x as i32);
372
373        let score = match metric {
374            "accuracy" => accuracy_score(&y_true_int, &y_pred_int)
375                .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
376            "precision" => precision_score(&y_true_int, &y_pred_int, None)
377                .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
378            "recall" => recall_score(&y_true_int, &y_pred_int, None)
379                .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
380            "f1" => f1_score(&y_true_int, &y_pred_int, None)
381                .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
382            _ => {
383                return Err(SklearsError::InvalidInput(format!(
384                    "Unknown classification metric: {}",
385                    metric
386                )))
387            }
388        };
389
390        Ok(score)
391    }
392
393    fn compute_regression_score(
394        &self,
395        metric: &str,
396        y_true: &Array1<Float>,
397        y_pred: &Array1<Float>,
398    ) -> Result<f64> {
399        let score = match metric {
400            "r2" | "r2_score" => {
401                r2_score(y_true, y_pred).map_err(|e| SklearsError::InvalidInput(e.to_string()))?
402            }
403            "neg_mean_squared_error" => -mean_squared_error(y_true, y_pred)
404                .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
405            "neg_mean_absolute_error" => -mean_absolute_error(y_true, y_pred)
406                .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
407            "explained_variance" => explained_variance_score(y_true, y_pred)
408                .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
409            _ => {
410                return Err(SklearsError::InvalidInput(format!(
411                    "Unknown regression metric: {}",
412                    metric
413                )))
414            }
415        };
416
417        Ok(score)
418    }
419
420    /// Compute bootstrap confidence interval
421    fn bootstrap_confidence_interval(&self, scores: &[f64]) -> Result<(f64, f64)> {
422        let mut rng = match self.config.random_state {
423            Some(seed) => StdRng::seed_from_u64(seed),
424            None => StdRng::seed_from_u64(42),
425        };
426
427        let n_scores = scores.len();
428        let mut bootstrap_means = Vec::with_capacity(self.config.n_bootstrap);
429
430        for _ in 0..self.config.n_bootstrap {
431            let mut bootstrap_sample = Vec::with_capacity(n_scores);
432            for _ in 0..n_scores {
433                let idx = rng.gen_range(0..n_scores);
434                bootstrap_sample.push(scores[idx]);
435            }
436
437            let mean = bootstrap_sample.iter().sum::<f64>() / n_scores as f64;
438            bootstrap_means.push(mean);
439        }
440
441        bootstrap_means.sort_by(|a, b| a.partial_cmp(b).unwrap());
442
443        let alpha = 1.0 - self.config.confidence_level;
444        let lower_idx = ((alpha / 2.0) * self.config.n_bootstrap as f64) as usize;
445        let upper_idx = ((1.0 - alpha / 2.0) * self.config.n_bootstrap as f64) as usize;
446
447        let lower = bootstrap_means[lower_idx.min(self.config.n_bootstrap - 1)];
448        let upper = bootstrap_means[upper_idx.min(self.config.n_bootstrap - 1)];
449
450        Ok((lower, upper))
451    }
452}
453
454/// Task type for scoring
455#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
456pub enum TaskType {
457    /// Classification
458    Classification,
459    /// Regression
460    Regression,
461}
462
463/// Statistical significance test result
464#[derive(Debug, Clone)]
465pub struct SignificanceTestResult {
466    /// Test statistic
467    pub statistic: f64,
468    /// P-value
469    pub p_value: f64,
470    /// Whether the result is significant at alpha level
471    pub is_significant: bool,
472    /// Alpha level used
473    pub alpha: f64,
474    /// Test name
475    pub test_name: String,
476}
477
478/// Perform paired t-test for comparing two sets of CV scores
479pub fn paired_ttest(
480    scores1: &Array1<f64>,
481    scores2: &Array1<f64>,
482    alpha: f64,
483) -> Result<SignificanceTestResult> {
484    if scores1.len() != scores2.len() {
485        return Err(SklearsError::InvalidInput(
486            "Score arrays must have the same length".to_string(),
487        ));
488    }
489
490    let n = scores1.len() as f64;
491    if n < 2.0 {
492        return Err(SklearsError::InvalidInput(
493            "Need at least 2 samples for t-test".to_string(),
494        ));
495    }
496
497    // Compute differences
498    let differences: Array1<f64> = scores1 - scores2;
499    let mean_diff = differences.mean().unwrap();
500    let std_diff = differences.std(1.0);
501
502    if std_diff == 0.0 {
503        return Err(SklearsError::InvalidInput(
504            "Standard deviation of differences is zero".to_string(),
505        ));
506    }
507
508    // Compute t-statistic
509    let t_stat = mean_diff * (n.sqrt()) / std_diff;
510
511    // Compute p-value (two-tailed test)
512    // Using approximation for t-distribution
513    let df = n - 1.0;
514    let p_value = 2.0 * (1.0 - student_t_cdf(t_stat.abs(), df));
515
516    Ok(SignificanceTestResult {
517        statistic: t_stat,
518        p_value,
519        is_significant: p_value < alpha,
520        alpha,
521        test_name: "Paired t-test".to_string(),
522    })
523}
524
525/// Approximate CDF of Student's t-distribution
526fn student_t_cdf(t: f64, df: f64) -> f64 {
527    // Simple approximation using normal distribution for large df
528    if df > 30.0 {
529        return standard_normal_cdf(t);
530    }
531
532    // Basic approximation for small df
533    let x = t / (df + t * t).sqrt();
534    0.5 + 0.5 * x * (1.0 - x * x / 3.0)
535}
536
537/// Standard normal CDF approximation
538fn standard_normal_cdf(x: f64) -> f64 {
539    0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
540}
541
542/// Error function approximation
543fn erf(x: f64) -> f64 {
544    // Abramowitz and Stegun approximation
545    let a1 = 0.254829592;
546    let a2 = -0.284496736;
547    let a3 = 1.421413741;
548    let a4 = -1.453152027;
549    let a5 = 1.061405429;
550    let p = 0.3275911;
551
552    let sign = if x < 0.0 { -1.0 } else { 1.0 };
553    let x = x.abs();
554
555    let t = 1.0 / (1.0 + p * x);
556    let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
557
558    sign * y
559}
560
561/// Perform Wilcoxon signed-rank test (non-parametric alternative to t-test)
562pub fn wilcoxon_signed_rank_test(
563    scores1: &Array1<f64>,
564    scores2: &Array1<f64>,
565    alpha: f64,
566) -> Result<SignificanceTestResult> {
567    if scores1.len() != scores2.len() {
568        return Err(SklearsError::InvalidInput(
569            "Score arrays must have the same length".to_string(),
570        ));
571    }
572
573    let differences: Vec<f64> = scores1
574        .iter()
575        .zip(scores2.iter())
576        .map(|(a, b)| a - b)
577        .filter(|&d| d != 0.0) // Remove zero differences
578        .collect();
579
580    let n = differences.len();
581    if n < 5 {
582        return Err(SklearsError::InvalidInput(
583            "Need at least 5 non-zero differences for Wilcoxon test".to_string(),
584        ));
585    }
586
587    // Rank absolute differences
588    let mut abs_diffs_with_indices: Vec<(f64, usize, f64)> = differences
589        .iter()
590        .enumerate()
591        .map(|(i, &d)| (d.abs(), i, d))
592        .collect();
593
594    abs_diffs_with_indices.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
595
596    let mut ranks = vec![0.0; n];
597    let mut i = 0;
598    while i < n {
599        let mut j = i;
600        while j < n && abs_diffs_with_indices[j].0 == abs_diffs_with_indices[i].0 {
601            j += 1;
602        }
603
604        let rank = (i + j + 1) as f64 / 2.0; // Average rank for ties
605        for k in i..j {
606            ranks[abs_diffs_with_indices[k].1] = rank;
607        }
608        i = j;
609    }
610
611    // Sum of positive ranks
612    let w_plus: f64 = differences
613        .iter()
614        .zip(&ranks)
615        .filter(|(&d, _)| d > 0.0)
616        .map(|(_, &rank)| rank)
617        .sum();
618
619    // Expected value and variance under null hypothesis
620    let expected = n as f64 * (n + 1) as f64 / 4.0;
621    let variance = n as f64 * (n + 1) as f64 * (2 * n + 1) as f64 / 24.0;
622
623    // Z-statistic with continuity correction
624    let z = if w_plus > expected {
625        (w_plus - 0.5 - expected) / variance.sqrt()
626    } else {
627        (w_plus + 0.5 - expected) / variance.sqrt()
628    };
629
630    let p_value = 2.0 * (1.0 - standard_normal_cdf(z.abs()));
631
632    Ok(SignificanceTestResult {
633        statistic: w_plus,
634        p_value,
635        is_significant: p_value < alpha,
636        alpha,
637        test_name: "Wilcoxon signed-rank test".to_string(),
638    })
639}