Skip to main content

sklears_model_selection/
scoring.rs

1//! Enhanced scoring utilities for model selection
2
3use scirs2_core::ndarray::Array1;
4use scirs2_core::random::rngs::StdRng;
5use scirs2_core::random::SeedableRng;
6use scirs2_core::RngExt;
7use sklears_core::{
8    error::{Result, SklearsError},
9    // traits::Score,
10    types::Float,
11};
12use sklears_metrics::{
13    classification::{accuracy_score, f1_score, precision_score, recall_score},
14    regression::{explained_variance_score, mean_absolute_error, mean_squared_error, r2_score},
15};
16use std::collections::HashMap;
17use std::sync::Arc;
18
19/// Custom scoring function trait
20pub trait CustomScorer: Send + Sync + std::fmt::Debug {
21    /// Compute score given true and predicted values
22    fn score(&self, y_true: &Array1<Float>, y_pred: &Array1<Float>) -> Result<f64>;
23    /// Get the name of this custom scorer
24    fn name(&self) -> &str;
25    /// Whether higher scores are better (true) or lower scores are better (false)
26    fn higher_is_better(&self) -> bool;
27}
28
29/// Custom scoring function wrapper for closures
30pub struct ClosureScorer {
31    name: String,
32    scorer_fn: Arc<dyn Fn(&Array1<Float>, &Array1<Float>) -> Result<f64> + Send + Sync>,
33    higher_is_better: bool,
34}
35
36impl ClosureScorer {
37    /// Create a new custom scorer from a closure
38    pub fn new<F>(name: String, scorer_fn: F, higher_is_better: bool) -> Self
39    where
40        F: Fn(&Array1<Float>, &Array1<Float>) -> Result<f64> + Send + Sync + 'static,
41    {
42        Self {
43            name,
44            scorer_fn: Arc::new(scorer_fn),
45            higher_is_better,
46        }
47    }
48}
49
50impl std::fmt::Debug for ClosureScorer {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        f.debug_struct("ClosureScorer")
53            .field("name", &self.name)
54            .field("higher_is_better", &self.higher_is_better)
55            .finish()
56    }
57}
58
59impl CustomScorer for ClosureScorer {
60    fn score(&self, y_true: &Array1<Float>, y_pred: &Array1<Float>) -> Result<f64> {
61        (self.scorer_fn)(y_true, y_pred)
62    }
63
64    fn name(&self) -> &str {
65        &self.name
66    }
67
68    fn higher_is_better(&self) -> bool {
69        self.higher_is_better
70    }
71}
72
73/// Scorer registry for built-in and custom scorers
74#[derive(Debug, Clone)]
75pub struct ScorerRegistry {
76    custom_scorers: HashMap<String, Arc<dyn CustomScorer>>,
77}
78
79impl Default for ScorerRegistry {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl ScorerRegistry {
86    /// Create a new scorer registry
87    pub fn new() -> Self {
88        Self {
89            custom_scorers: HashMap::new(),
90        }
91    }
92
93    /// Register a custom scorer
94    pub fn register_scorer(&mut self, scorer: Arc<dyn CustomScorer>) {
95        self.custom_scorers
96            .insert(scorer.name().to_string(), scorer);
97    }
98
99    /// Register a custom scorer from a closure
100    pub fn register_closure_scorer<F>(&mut self, name: String, scorer_fn: F, higher_is_better: bool)
101    where
102        F: Fn(&Array1<Float>, &Array1<Float>) -> Result<f64> + Send + Sync + 'static,
103    {
104        let scorer = Arc::new(ClosureScorer::new(name, scorer_fn, higher_is_better));
105        self.register_scorer(scorer);
106    }
107
108    /// Get a custom scorer by name
109    pub fn get_scorer(&self, name: &str) -> Option<&Arc<dyn CustomScorer>> {
110        self.custom_scorers.get(name)
111    }
112
113    /// List all registered custom scorers
114    pub fn list_scorers(&self) -> Vec<&str> {
115        self.custom_scorers.keys().map(|s| s.as_str()).collect()
116    }
117
118    /// Check if a scorer is registered
119    pub fn has_scorer(&self, name: &str) -> bool {
120        self.custom_scorers.contains_key(name)
121    }
122}
123
124/// Enhanced scoring configuration
125#[derive(Debug, Clone)]
126pub struct ScoringConfig {
127    /// Primary scoring metric
128    pub primary: String,
129    /// Additional metrics to compute
130    pub additional: Vec<String>,
131    /// Whether to compute confidence intervals
132    pub confidence_intervals: bool,
133    /// Confidence level for intervals (0.95 = 95%)
134    pub confidence_level: f64,
135    /// Number of bootstrap samples for confidence intervals
136    pub n_bootstrap: usize,
137    /// Random state for bootstrap sampling
138    pub random_state: Option<u64>,
139    /// Custom scorer registry
140    pub scorer_registry: ScorerRegistry,
141}
142
143impl Default for ScoringConfig {
144    fn default() -> Self {
145        Self {
146            primary: "accuracy".to_string(),
147            additional: vec![],
148            confidence_intervals: false,
149            confidence_level: 0.95,
150            n_bootstrap: 1000,
151            random_state: None,
152            scorer_registry: ScorerRegistry::new(),
153        }
154    }
155}
156
157impl ScoringConfig {
158    /// Create a new scoring configuration with primary metric
159    pub fn new(primary: &str) -> Self {
160        Self {
161            primary: primary.to_string(),
162            ..Default::default()
163        }
164    }
165
166    /// Add additional metrics
167    pub fn with_additional_metrics(mut self, metrics: Vec<String>) -> Self {
168        self.additional = metrics;
169        self
170    }
171
172    /// Enable confidence intervals
173    pub fn with_confidence_intervals(mut self, level: f64, n_bootstrap: usize) -> Self {
174        self.confidence_intervals = true;
175        self.confidence_level = level;
176        self.n_bootstrap = n_bootstrap;
177        self
178    }
179
180    /// Set random state for reproducibility
181    pub fn with_random_state(mut self, random_state: u64) -> Self {
182        self.random_state = Some(random_state);
183        self
184    }
185
186    /// Register a custom scorer
187    pub fn with_custom_scorer(mut self, scorer: Arc<dyn CustomScorer>) -> Self {
188        self.scorer_registry.register_scorer(scorer);
189        self
190    }
191
192    /// Register a custom scorer from a closure
193    pub fn with_closure_scorer<F>(
194        mut self,
195        name: String,
196        scorer_fn: F,
197        higher_is_better: bool,
198    ) -> Self
199    where
200        F: Fn(&Array1<Float>, &Array1<Float>) -> Result<f64> + Send + Sync + 'static,
201    {
202        self.scorer_registry
203            .register_closure_scorer(name, scorer_fn, higher_is_better);
204        self
205    }
206
207    /// Get a mutable reference to the scorer registry
208    pub fn scorer_registry_mut(&mut self) -> &mut ScorerRegistry {
209        &mut self.scorer_registry
210    }
211
212    /// Get a reference to the scorer registry
213    pub fn scorer_registry(&self) -> &ScorerRegistry {
214        &self.scorer_registry
215    }
216}
217
218/// Scoring result with confidence intervals and multiple metrics
219#[derive(Debug, Clone)]
220pub struct ScoringResult {
221    /// Primary metric scores
222    pub primary_scores: Array1<f64>,
223    /// Additional metric scores
224    pub additional_scores: HashMap<String, Array1<f64>>,
225    /// Confidence intervals for primary metric
226    pub confidence_interval: Option<(f64, f64)>,
227    /// Confidence intervals for additional metrics
228    pub additional_confidence_intervals: HashMap<String, (f64, f64)>,
229    /// Mean scores
230    pub mean_scores: HashMap<String, f64>,
231    /// Standard deviations
232    pub std_scores: HashMap<String, f64>,
233}
234
235impl ScoringResult {
236    /// Get the primary metric mean score
237    pub fn primary_mean(&self) -> f64 {
238        self.mean_scores.get("primary").copied().unwrap_or(0.0)
239    }
240
241    /// Get mean score for a specific metric
242    pub fn mean_score(&self, metric: &str) -> Option<f64> {
243        self.mean_scores.get(metric).copied()
244    }
245
246    /// Get all mean scores
247    pub fn all_mean_scores(&self) -> &HashMap<String, f64> {
248        &self.mean_scores
249    }
250}
251
252/// Enhanced scorer that supports multiple metrics and confidence intervals
253pub struct EnhancedScorer {
254    config: ScoringConfig,
255}
256
257impl EnhancedScorer {
258    /// Create a new enhanced scorer
259    pub fn new(config: ScoringConfig) -> Self {
260        Self { config }
261    }
262
263    /// Score predictions with multiple metrics and confidence intervals
264    pub fn score_predictions(
265        &self,
266        y_true_splits: &[Array1<Float>],
267        y_pred_splits: &[Array1<Float>],
268        task_type: TaskType,
269    ) -> Result<ScoringResult> {
270        if y_true_splits.len() != y_pred_splits.len() {
271            return Err(SklearsError::InvalidInput(
272                "Number of true and predicted splits must match".to_string(),
273            ));
274        }
275
276        let n_splits = y_true_splits.len();
277        let mut primary_scores = Vec::with_capacity(n_splits);
278        let mut additional_scores: HashMap<String, Vec<f64>> = HashMap::new();
279
280        // Initialize additional scores storage
281        for metric in &self.config.additional {
282            additional_scores.insert(metric.clone(), Vec::with_capacity(n_splits));
283        }
284
285        // Compute scores for each split
286        for (y_true, y_pred) in y_true_splits.iter().zip(y_pred_splits.iter()) {
287            // Primary metric
288            let primary_score =
289                self.compute_metric_score(&self.config.primary, y_true, y_pred, task_type)?;
290            primary_scores.push(primary_score);
291
292            // Additional metrics
293            for metric in &self.config.additional {
294                let score = self.compute_metric_score(metric, y_true, y_pred, task_type)?;
295                additional_scores
296                    .get_mut(metric)
297                    .expect("operation should succeed")
298                    .push(score);
299            }
300        }
301
302        // Convert to arrays
303        let primary_scores_array = Array1::from_vec(primary_scores.clone());
304        let mut additional_scores_arrays = HashMap::new();
305        for (metric, scores) in additional_scores.iter() {
306            additional_scores_arrays.insert(metric.clone(), Array1::from_vec(scores.clone()));
307        }
308
309        // Compute confidence intervals if requested
310        let confidence_interval = if self.config.confidence_intervals {
311            Some(self.bootstrap_confidence_interval(&primary_scores)?)
312        } else {
313            None
314        };
315
316        let mut additional_confidence_intervals = HashMap::new();
317        if self.config.confidence_intervals {
318            for (metric, scores) in &additional_scores {
319                let ci = self.bootstrap_confidence_interval(scores)?;
320                additional_confidence_intervals.insert(metric.clone(), ci);
321            }
322        }
323
324        // Compute mean and std
325        let mut mean_scores = HashMap::new();
326        let mut std_scores = HashMap::new();
327
328        mean_scores.insert(
329            "primary".to_string(),
330            primary_scores_array
331                .mean()
332                .expect("operation should succeed"),
333        );
334        std_scores.insert("primary".to_string(), primary_scores_array.std(1.0));
335
336        for (metric, scores) in &additional_scores_arrays {
337            mean_scores.insert(
338                metric.clone(),
339                scores.mean().expect("operation should succeed"),
340            );
341            std_scores.insert(metric.clone(), scores.std(1.0));
342        }
343
344        Ok(ScoringResult {
345            primary_scores: primary_scores_array,
346            additional_scores: additional_scores_arrays,
347            confidence_interval,
348            additional_confidence_intervals,
349            mean_scores,
350            std_scores,
351        })
352    }
353
354    /// Compute score for a specific metric
355    fn compute_metric_score(
356        &self,
357        metric: &str,
358        y_true: &Array1<Float>,
359        y_pred: &Array1<Float>,
360        task_type: TaskType,
361    ) -> Result<f64> {
362        // First check if it's a custom scorer
363        if let Some(custom_scorer) = self.config.scorer_registry.get_scorer(metric) {
364            return custom_scorer.score(y_true, y_pred);
365        }
366
367        // Otherwise use built-in scorers
368        match task_type {
369            TaskType::Classification => self.compute_classification_score(metric, y_true, y_pred),
370            TaskType::Regression => self.compute_regression_score(metric, y_true, y_pred),
371        }
372    }
373
374    fn compute_classification_score(
375        &self,
376        metric: &str,
377        y_true: &Array1<Float>,
378        y_pred: &Array1<Float>,
379    ) -> Result<f64> {
380        // Convert float arrays to integer arrays for classification metrics
381        let y_true_int: Array1<i32> = y_true.mapv(|x| x as i32);
382        let y_pred_int: Array1<i32> = y_pred.mapv(|x| x as i32);
383
384        let score = match metric {
385            "accuracy" => accuracy_score(&y_true_int, &y_pred_int)
386                .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
387            "precision" => precision_score(&y_true_int, &y_pred_int, None)
388                .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
389            "recall" => recall_score(&y_true_int, &y_pred_int, None)
390                .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
391            "f1" => f1_score(&y_true_int, &y_pred_int, None)
392                .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
393            _ => {
394                return Err(SklearsError::InvalidInput(format!(
395                    "Unknown classification metric: {}",
396                    metric
397                )))
398            }
399        };
400
401        Ok(score)
402    }
403
404    fn compute_regression_score(
405        &self,
406        metric: &str,
407        y_true: &Array1<Float>,
408        y_pred: &Array1<Float>,
409    ) -> Result<f64> {
410        let score = match metric {
411            "r2" | "r2_score" => {
412                r2_score(y_true, y_pred).map_err(|e| SklearsError::InvalidInput(e.to_string()))?
413            }
414            "neg_mean_squared_error" => -mean_squared_error(y_true, y_pred)
415                .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
416            "neg_mean_absolute_error" => -mean_absolute_error(y_true, y_pred)
417                .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
418            "explained_variance" => explained_variance_score(y_true, y_pred)
419                .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
420            _ => {
421                return Err(SklearsError::InvalidInput(format!(
422                    "Unknown regression metric: {}",
423                    metric
424                )))
425            }
426        };
427
428        Ok(score)
429    }
430
431    /// Compute bootstrap confidence interval
432    fn bootstrap_confidence_interval(&self, scores: &[f64]) -> Result<(f64, f64)> {
433        let mut rng = match self.config.random_state {
434            Some(seed) => StdRng::seed_from_u64(seed),
435            None => StdRng::seed_from_u64(42),
436        };
437
438        let n_scores = scores.len();
439        let mut bootstrap_means = Vec::with_capacity(self.config.n_bootstrap);
440
441        for _ in 0..self.config.n_bootstrap {
442            let mut bootstrap_sample = Vec::with_capacity(n_scores);
443            for _ in 0..n_scores {
444                let idx = rng.random_range(0..n_scores);
445                bootstrap_sample.push(scores[idx]);
446            }
447
448            let mean = bootstrap_sample.iter().sum::<f64>() / n_scores as f64;
449            bootstrap_means.push(mean);
450        }
451
452        bootstrap_means.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
453
454        let alpha = 1.0 - self.config.confidence_level;
455        let lower_idx = ((alpha / 2.0) * self.config.n_bootstrap as f64) as usize;
456        let upper_idx = ((1.0 - alpha / 2.0) * self.config.n_bootstrap as f64) as usize;
457
458        let lower = bootstrap_means[lower_idx.min(self.config.n_bootstrap - 1)];
459        let upper = bootstrap_means[upper_idx.min(self.config.n_bootstrap - 1)];
460
461        Ok((lower, upper))
462    }
463}
464
465/// Task type for scoring
466#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
467pub enum TaskType {
468    /// Classification
469    Classification,
470    /// Regression
471    Regression,
472}
473
474/// Statistical significance test result
475#[derive(Debug, Clone)]
476pub struct SignificanceTestResult {
477    /// Test statistic
478    pub statistic: f64,
479    /// P-value
480    pub p_value: f64,
481    /// Whether the result is significant at alpha level
482    pub is_significant: bool,
483    /// Alpha level used
484    pub alpha: f64,
485    /// Test name
486    pub test_name: String,
487}
488
489/// Perform paired t-test for comparing two sets of CV scores
490pub fn paired_ttest(
491    scores1: &Array1<f64>,
492    scores2: &Array1<f64>,
493    alpha: f64,
494) -> Result<SignificanceTestResult> {
495    if scores1.len() != scores2.len() {
496        return Err(SklearsError::InvalidInput(
497            "Score arrays must have the same length".to_string(),
498        ));
499    }
500
501    let n = scores1.len() as f64;
502    if n < 2.0 {
503        return Err(SklearsError::InvalidInput(
504            "Need at least 2 samples for t-test".to_string(),
505        ));
506    }
507
508    // Compute differences
509    let differences: Array1<f64> = scores1 - scores2;
510    let mean_diff = differences.mean().expect("operation should succeed");
511    let std_diff = differences.std(1.0);
512
513    if std_diff == 0.0 {
514        return Err(SklearsError::InvalidInput(
515            "Standard deviation of differences is zero".to_string(),
516        ));
517    }
518
519    // Compute t-statistic
520    let t_stat = mean_diff * (n.sqrt()) / std_diff;
521
522    // Compute p-value (two-tailed test)
523    // Using approximation for t-distribution
524    let df = n - 1.0;
525    let p_value = 2.0 * (1.0 - student_t_cdf(t_stat.abs(), df));
526
527    Ok(SignificanceTestResult {
528        statistic: t_stat,
529        p_value,
530        is_significant: p_value < alpha,
531        alpha,
532        test_name: "Paired t-test".to_string(),
533    })
534}
535
536/// Approximate CDF of Student's t-distribution
537fn student_t_cdf(t: f64, df: f64) -> f64 {
538    // Simple approximation using normal distribution for large df
539    if df > 30.0 {
540        return standard_normal_cdf(t);
541    }
542
543    // Basic approximation for small df
544    let x = t / (df + t * t).sqrt();
545    0.5 + 0.5 * x * (1.0 - x * x / 3.0)
546}
547
548/// Standard normal CDF approximation
549fn standard_normal_cdf(x: f64) -> f64 {
550    0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
551}
552
553/// Error function approximation
554fn erf(x: f64) -> f64 {
555    // Abramowitz and Stegun approximation
556    let a1 = 0.254829592;
557    let a2 = -0.284496736;
558    let a3 = 1.421413741;
559    let a4 = -1.453152027;
560    let a5 = 1.061405429;
561    let p = 0.3275911;
562
563    let sign = if x < 0.0 { -1.0 } else { 1.0 };
564    let x = x.abs();
565
566    let t = 1.0 / (1.0 + p * x);
567    let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
568
569    sign * y
570}
571
572/// Perform Wilcoxon signed-rank test (non-parametric alternative to t-test)
573pub fn wilcoxon_signed_rank_test(
574    scores1: &Array1<f64>,
575    scores2: &Array1<f64>,
576    alpha: f64,
577) -> Result<SignificanceTestResult> {
578    if scores1.len() != scores2.len() {
579        return Err(SklearsError::InvalidInput(
580            "Score arrays must have the same length".to_string(),
581        ));
582    }
583
584    let differences: Vec<f64> = scores1
585        .iter()
586        .zip(scores2.iter())
587        .map(|(a, b)| a - b)
588        .filter(|&d| d != 0.0) // Remove zero differences
589        .collect();
590
591    let n = differences.len();
592    if n < 5 {
593        return Err(SklearsError::InvalidInput(
594            "Need at least 5 non-zero differences for Wilcoxon test".to_string(),
595        ));
596    }
597
598    // Rank absolute differences
599    let mut abs_diffs_with_indices: Vec<(f64, usize, f64)> = differences
600        .iter()
601        .enumerate()
602        .map(|(i, &d)| (d.abs(), i, d))
603        .collect();
604
605    abs_diffs_with_indices.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("operation should succeed"));
606
607    let mut ranks = vec![0.0; n];
608    let mut i = 0;
609    while i < n {
610        let mut j = i;
611        while j < n && abs_diffs_with_indices[j].0 == abs_diffs_with_indices[i].0 {
612            j += 1;
613        }
614
615        let rank = (i + j + 1) as f64 / 2.0; // Average rank for ties
616        for k in i..j {
617            ranks[abs_diffs_with_indices[k].1] = rank;
618        }
619        i = j;
620    }
621
622    // Sum of positive ranks
623    let w_plus: f64 = differences
624        .iter()
625        .zip(&ranks)
626        .filter(|(&d, _)| d > 0.0)
627        .map(|(_, &rank)| rank)
628        .sum();
629
630    // Expected value and variance under null hypothesis
631    let expected = n as f64 * (n + 1) as f64 / 4.0;
632    let variance = n as f64 * (n + 1) as f64 * (2 * n + 1) as f64 / 24.0;
633
634    // Z-statistic with continuity correction
635    let z = if w_plus > expected {
636        (w_plus - 0.5 - expected) / variance.sqrt()
637    } else {
638        (w_plus + 0.5 - expected) / variance.sqrt()
639    };
640
641    let p_value = 2.0 * (1.0 - standard_normal_cdf(z.abs()));
642
643    Ok(SignificanceTestResult {
644        statistic: w_plus,
645        p_value,
646        is_significant: p_value < alpha,
647        alpha,
648        test_name: "Wilcoxon signed-rank test".to_string(),
649    })
650}