1use scirs2_core::ndarray::Array1;
4use scirs2_core::random::rngs::StdRng;
5use scirs2_core::random::Rng;
6use scirs2_core::random::SeedableRng;
7use sklears_core::{
8 error::{Result, SklearsError},
9 types::Float,
11};
12use sklears_metrics::{
13 classification::{accuracy_score, f1_score, precision_score, recall_score},
14 regression::{explained_variance_score, mean_absolute_error, mean_squared_error, r2_score},
15};
16use std::collections::HashMap;
17use std::sync::Arc;
18
19pub trait CustomScorer: Send + Sync + std::fmt::Debug {
21 fn score(&self, y_true: &Array1<Float>, y_pred: &Array1<Float>) -> Result<f64>;
23 fn name(&self) -> &str;
25 fn higher_is_better(&self) -> bool;
27}
28
29pub struct ClosureScorer {
31 name: String,
32 scorer_fn: Arc<dyn Fn(&Array1<Float>, &Array1<Float>) -> Result<f64> + Send + Sync>,
33 higher_is_better: bool,
34}
35
36impl ClosureScorer {
37 pub fn new<F>(name: String, scorer_fn: F, higher_is_better: bool) -> Self
39 where
40 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<f64> + Send + Sync + 'static,
41 {
42 Self {
43 name,
44 scorer_fn: Arc::new(scorer_fn),
45 higher_is_better,
46 }
47 }
48}
49
50impl std::fmt::Debug for ClosureScorer {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 f.debug_struct("ClosureScorer")
53 .field("name", &self.name)
54 .field("higher_is_better", &self.higher_is_better)
55 .finish()
56 }
57}
58
59impl CustomScorer for ClosureScorer {
60 fn score(&self, y_true: &Array1<Float>, y_pred: &Array1<Float>) -> Result<f64> {
61 (self.scorer_fn)(y_true, y_pred)
62 }
63
64 fn name(&self) -> &str {
65 &self.name
66 }
67
68 fn higher_is_better(&self) -> bool {
69 self.higher_is_better
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct ScorerRegistry {
76 custom_scorers: HashMap<String, Arc<dyn CustomScorer>>,
77}
78
79impl Default for ScorerRegistry {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl ScorerRegistry {
86 pub fn new() -> Self {
88 Self {
89 custom_scorers: HashMap::new(),
90 }
91 }
92
93 pub fn register_scorer(&mut self, scorer: Arc<dyn CustomScorer>) {
95 self.custom_scorers
96 .insert(scorer.name().to_string(), scorer);
97 }
98
99 pub fn register_closure_scorer<F>(&mut self, name: String, scorer_fn: F, higher_is_better: bool)
101 where
102 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<f64> + Send + Sync + 'static,
103 {
104 let scorer = Arc::new(ClosureScorer::new(name, scorer_fn, higher_is_better));
105 self.register_scorer(scorer);
106 }
107
108 pub fn get_scorer(&self, name: &str) -> Option<&Arc<dyn CustomScorer>> {
110 self.custom_scorers.get(name)
111 }
112
113 pub fn list_scorers(&self) -> Vec<&str> {
115 self.custom_scorers.keys().map(|s| s.as_str()).collect()
116 }
117
118 pub fn has_scorer(&self, name: &str) -> bool {
120 self.custom_scorers.contains_key(name)
121 }
122}
123
124#[derive(Debug, Clone)]
126pub struct ScoringConfig {
127 pub primary: String,
129 pub additional: Vec<String>,
131 pub confidence_intervals: bool,
133 pub confidence_level: f64,
135 pub n_bootstrap: usize,
137 pub random_state: Option<u64>,
139 pub scorer_registry: ScorerRegistry,
141}
142
143impl Default for ScoringConfig {
144 fn default() -> Self {
145 Self {
146 primary: "accuracy".to_string(),
147 additional: vec![],
148 confidence_intervals: false,
149 confidence_level: 0.95,
150 n_bootstrap: 1000,
151 random_state: None,
152 scorer_registry: ScorerRegistry::new(),
153 }
154 }
155}
156
157impl ScoringConfig {
158 pub fn new(primary: &str) -> Self {
160 Self {
161 primary: primary.to_string(),
162 ..Default::default()
163 }
164 }
165
166 pub fn with_additional_metrics(mut self, metrics: Vec<String>) -> Self {
168 self.additional = metrics;
169 self
170 }
171
172 pub fn with_confidence_intervals(mut self, level: f64, n_bootstrap: usize) -> Self {
174 self.confidence_intervals = true;
175 self.confidence_level = level;
176 self.n_bootstrap = n_bootstrap;
177 self
178 }
179
180 pub fn with_random_state(mut self, random_state: u64) -> Self {
182 self.random_state = Some(random_state);
183 self
184 }
185
186 pub fn with_custom_scorer(mut self, scorer: Arc<dyn CustomScorer>) -> Self {
188 self.scorer_registry.register_scorer(scorer);
189 self
190 }
191
192 pub fn with_closure_scorer<F>(
194 mut self,
195 name: String,
196 scorer_fn: F,
197 higher_is_better: bool,
198 ) -> Self
199 where
200 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<f64> + Send + Sync + 'static,
201 {
202 self.scorer_registry
203 .register_closure_scorer(name, scorer_fn, higher_is_better);
204 self
205 }
206
207 pub fn scorer_registry_mut(&mut self) -> &mut ScorerRegistry {
209 &mut self.scorer_registry
210 }
211
212 pub fn scorer_registry(&self) -> &ScorerRegistry {
214 &self.scorer_registry
215 }
216}
217
218#[derive(Debug, Clone)]
220pub struct ScoringResult {
221 pub primary_scores: Array1<f64>,
223 pub additional_scores: HashMap<String, Array1<f64>>,
225 pub confidence_interval: Option<(f64, f64)>,
227 pub additional_confidence_intervals: HashMap<String, (f64, f64)>,
229 pub mean_scores: HashMap<String, f64>,
231 pub std_scores: HashMap<String, f64>,
233}
234
235impl ScoringResult {
236 pub fn primary_mean(&self) -> f64 {
238 self.mean_scores.get("primary").copied().unwrap_or(0.0)
239 }
240
241 pub fn mean_score(&self, metric: &str) -> Option<f64> {
243 self.mean_scores.get(metric).copied()
244 }
245
246 pub fn all_mean_scores(&self) -> &HashMap<String, f64> {
248 &self.mean_scores
249 }
250}
251
252pub struct EnhancedScorer {
254 config: ScoringConfig,
255}
256
257impl EnhancedScorer {
258 pub fn new(config: ScoringConfig) -> Self {
260 Self { config }
261 }
262
263 pub fn score_predictions(
265 &self,
266 y_true_splits: &[Array1<Float>],
267 y_pred_splits: &[Array1<Float>],
268 task_type: TaskType,
269 ) -> Result<ScoringResult> {
270 if y_true_splits.len() != y_pred_splits.len() {
271 return Err(SklearsError::InvalidInput(
272 "Number of true and predicted splits must match".to_string(),
273 ));
274 }
275
276 let n_splits = y_true_splits.len();
277 let mut primary_scores = Vec::with_capacity(n_splits);
278 let mut additional_scores: HashMap<String, Vec<f64>> = HashMap::new();
279
280 for metric in &self.config.additional {
282 additional_scores.insert(metric.clone(), Vec::with_capacity(n_splits));
283 }
284
285 for (y_true, y_pred) in y_true_splits.iter().zip(y_pred_splits.iter()) {
287 let primary_score =
289 self.compute_metric_score(&self.config.primary, y_true, y_pred, task_type)?;
290 primary_scores.push(primary_score);
291
292 for metric in &self.config.additional {
294 let score = self.compute_metric_score(metric, y_true, y_pred, task_type)?;
295 additional_scores.get_mut(metric).unwrap().push(score);
296 }
297 }
298
299 let primary_scores_array = Array1::from_vec(primary_scores.clone());
301 let mut additional_scores_arrays = HashMap::new();
302 for (metric, scores) in additional_scores.iter() {
303 additional_scores_arrays.insert(metric.clone(), Array1::from_vec(scores.clone()));
304 }
305
306 let confidence_interval = if self.config.confidence_intervals {
308 Some(self.bootstrap_confidence_interval(&primary_scores)?)
309 } else {
310 None
311 };
312
313 let mut additional_confidence_intervals = HashMap::new();
314 if self.config.confidence_intervals {
315 for (metric, scores) in &additional_scores {
316 let ci = self.bootstrap_confidence_interval(scores)?;
317 additional_confidence_intervals.insert(metric.clone(), ci);
318 }
319 }
320
321 let mut mean_scores = HashMap::new();
323 let mut std_scores = HashMap::new();
324
325 mean_scores.insert("primary".to_string(), primary_scores_array.mean().unwrap());
326 std_scores.insert("primary".to_string(), primary_scores_array.std(1.0));
327
328 for (metric, scores) in &additional_scores_arrays {
329 mean_scores.insert(metric.clone(), scores.mean().unwrap());
330 std_scores.insert(metric.clone(), scores.std(1.0));
331 }
332
333 Ok(ScoringResult {
334 primary_scores: primary_scores_array,
335 additional_scores: additional_scores_arrays,
336 confidence_interval,
337 additional_confidence_intervals,
338 mean_scores,
339 std_scores,
340 })
341 }
342
343 fn compute_metric_score(
345 &self,
346 metric: &str,
347 y_true: &Array1<Float>,
348 y_pred: &Array1<Float>,
349 task_type: TaskType,
350 ) -> Result<f64> {
351 if let Some(custom_scorer) = self.config.scorer_registry.get_scorer(metric) {
353 return custom_scorer.score(y_true, y_pred);
354 }
355
356 match task_type {
358 TaskType::Classification => self.compute_classification_score(metric, y_true, y_pred),
359 TaskType::Regression => self.compute_regression_score(metric, y_true, y_pred),
360 }
361 }
362
363 fn compute_classification_score(
364 &self,
365 metric: &str,
366 y_true: &Array1<Float>,
367 y_pred: &Array1<Float>,
368 ) -> Result<f64> {
369 let y_true_int: Array1<i32> = y_true.mapv(|x| x as i32);
371 let y_pred_int: Array1<i32> = y_pred.mapv(|x| x as i32);
372
373 let score = match metric {
374 "accuracy" => accuracy_score(&y_true_int, &y_pred_int)
375 .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
376 "precision" => precision_score(&y_true_int, &y_pred_int, None)
377 .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
378 "recall" => recall_score(&y_true_int, &y_pred_int, None)
379 .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
380 "f1" => f1_score(&y_true_int, &y_pred_int, None)
381 .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
382 _ => {
383 return Err(SklearsError::InvalidInput(format!(
384 "Unknown classification metric: {}",
385 metric
386 )))
387 }
388 };
389
390 Ok(score)
391 }
392
393 fn compute_regression_score(
394 &self,
395 metric: &str,
396 y_true: &Array1<Float>,
397 y_pred: &Array1<Float>,
398 ) -> Result<f64> {
399 let score = match metric {
400 "r2" | "r2_score" => {
401 r2_score(y_true, y_pred).map_err(|e| SklearsError::InvalidInput(e.to_string()))?
402 }
403 "neg_mean_squared_error" => -mean_squared_error(y_true, y_pred)
404 .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
405 "neg_mean_absolute_error" => -mean_absolute_error(y_true, y_pred)
406 .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
407 "explained_variance" => explained_variance_score(y_true, y_pred)
408 .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
409 _ => {
410 return Err(SklearsError::InvalidInput(format!(
411 "Unknown regression metric: {}",
412 metric
413 )))
414 }
415 };
416
417 Ok(score)
418 }
419
420 fn bootstrap_confidence_interval(&self, scores: &[f64]) -> Result<(f64, f64)> {
422 let mut rng = match self.config.random_state {
423 Some(seed) => StdRng::seed_from_u64(seed),
424 None => StdRng::seed_from_u64(42),
425 };
426
427 let n_scores = scores.len();
428 let mut bootstrap_means = Vec::with_capacity(self.config.n_bootstrap);
429
430 for _ in 0..self.config.n_bootstrap {
431 let mut bootstrap_sample = Vec::with_capacity(n_scores);
432 for _ in 0..n_scores {
433 let idx = rng.gen_range(0..n_scores);
434 bootstrap_sample.push(scores[idx]);
435 }
436
437 let mean = bootstrap_sample.iter().sum::<f64>() / n_scores as f64;
438 bootstrap_means.push(mean);
439 }
440
441 bootstrap_means.sort_by(|a, b| a.partial_cmp(b).unwrap());
442
443 let alpha = 1.0 - self.config.confidence_level;
444 let lower_idx = ((alpha / 2.0) * self.config.n_bootstrap as f64) as usize;
445 let upper_idx = ((1.0 - alpha / 2.0) * self.config.n_bootstrap as f64) as usize;
446
447 let lower = bootstrap_means[lower_idx.min(self.config.n_bootstrap - 1)];
448 let upper = bootstrap_means[upper_idx.min(self.config.n_bootstrap - 1)];
449
450 Ok((lower, upper))
451 }
452}
453
454#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
456pub enum TaskType {
457 Classification,
459 Regression,
461}
462
463#[derive(Debug, Clone)]
465pub struct SignificanceTestResult {
466 pub statistic: f64,
468 pub p_value: f64,
470 pub is_significant: bool,
472 pub alpha: f64,
474 pub test_name: String,
476}
477
478pub fn paired_ttest(
480 scores1: &Array1<f64>,
481 scores2: &Array1<f64>,
482 alpha: f64,
483) -> Result<SignificanceTestResult> {
484 if scores1.len() != scores2.len() {
485 return Err(SklearsError::InvalidInput(
486 "Score arrays must have the same length".to_string(),
487 ));
488 }
489
490 let n = scores1.len() as f64;
491 if n < 2.0 {
492 return Err(SklearsError::InvalidInput(
493 "Need at least 2 samples for t-test".to_string(),
494 ));
495 }
496
497 let differences: Array1<f64> = scores1 - scores2;
499 let mean_diff = differences.mean().unwrap();
500 let std_diff = differences.std(1.0);
501
502 if std_diff == 0.0 {
503 return Err(SklearsError::InvalidInput(
504 "Standard deviation of differences is zero".to_string(),
505 ));
506 }
507
508 let t_stat = mean_diff * (n.sqrt()) / std_diff;
510
511 let df = n - 1.0;
514 let p_value = 2.0 * (1.0 - student_t_cdf(t_stat.abs(), df));
515
516 Ok(SignificanceTestResult {
517 statistic: t_stat,
518 p_value,
519 is_significant: p_value < alpha,
520 alpha,
521 test_name: "Paired t-test".to_string(),
522 })
523}
524
525fn student_t_cdf(t: f64, df: f64) -> f64 {
527 if df > 30.0 {
529 return standard_normal_cdf(t);
530 }
531
532 let x = t / (df + t * t).sqrt();
534 0.5 + 0.5 * x * (1.0 - x * x / 3.0)
535}
536
537fn standard_normal_cdf(x: f64) -> f64 {
539 0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
540}
541
542fn erf(x: f64) -> f64 {
544 let a1 = 0.254829592;
546 let a2 = -0.284496736;
547 let a3 = 1.421413741;
548 let a4 = -1.453152027;
549 let a5 = 1.061405429;
550 let p = 0.3275911;
551
552 let sign = if x < 0.0 { -1.0 } else { 1.0 };
553 let x = x.abs();
554
555 let t = 1.0 / (1.0 + p * x);
556 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
557
558 sign * y
559}
560
561pub fn wilcoxon_signed_rank_test(
563 scores1: &Array1<f64>,
564 scores2: &Array1<f64>,
565 alpha: f64,
566) -> Result<SignificanceTestResult> {
567 if scores1.len() != scores2.len() {
568 return Err(SklearsError::InvalidInput(
569 "Score arrays must have the same length".to_string(),
570 ));
571 }
572
573 let differences: Vec<f64> = scores1
574 .iter()
575 .zip(scores2.iter())
576 .map(|(a, b)| a - b)
577 .filter(|&d| d != 0.0) .collect();
579
580 let n = differences.len();
581 if n < 5 {
582 return Err(SklearsError::InvalidInput(
583 "Need at least 5 non-zero differences for Wilcoxon test".to_string(),
584 ));
585 }
586
587 let mut abs_diffs_with_indices: Vec<(f64, usize, f64)> = differences
589 .iter()
590 .enumerate()
591 .map(|(i, &d)| (d.abs(), i, d))
592 .collect();
593
594 abs_diffs_with_indices.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
595
596 let mut ranks = vec![0.0; n];
597 let mut i = 0;
598 while i < n {
599 let mut j = i;
600 while j < n && abs_diffs_with_indices[j].0 == abs_diffs_with_indices[i].0 {
601 j += 1;
602 }
603
604 let rank = (i + j + 1) as f64 / 2.0; for k in i..j {
606 ranks[abs_diffs_with_indices[k].1] = rank;
607 }
608 i = j;
609 }
610
611 let w_plus: f64 = differences
613 .iter()
614 .zip(&ranks)
615 .filter(|(&d, _)| d > 0.0)
616 .map(|(_, &rank)| rank)
617 .sum();
618
619 let expected = n as f64 * (n + 1) as f64 / 4.0;
621 let variance = n as f64 * (n + 1) as f64 * (2 * n + 1) as f64 / 24.0;
622
623 let z = if w_plus > expected {
625 (w_plus - 0.5 - expected) / variance.sqrt()
626 } else {
627 (w_plus + 0.5 - expected) / variance.sqrt()
628 };
629
630 let p_value = 2.0 * (1.0 - standard_normal_cdf(z.abs()));
631
632 Ok(SignificanceTestResult {
633 statistic: w_plus,
634 p_value,
635 is_significant: p_value < alpha,
636 alpha,
637 test_name: "Wilcoxon signed-rank test".to_string(),
638 })
639}