1use scirs2_core::ndarray::Array1;
4use scirs2_core::random::rngs::StdRng;
5use scirs2_core::random::SeedableRng;
6use scirs2_core::RngExt;
7use sklears_core::{
8 error::{Result, SklearsError},
9 types::Float,
11};
12use sklears_metrics::{
13 classification::{accuracy_score, f1_score, precision_score, recall_score},
14 regression::{explained_variance_score, mean_absolute_error, mean_squared_error, r2_score},
15};
16use std::collections::HashMap;
17use std::sync::Arc;
18
19pub trait CustomScorer: Send + Sync + std::fmt::Debug {
21 fn score(&self, y_true: &Array1<Float>, y_pred: &Array1<Float>) -> Result<f64>;
23 fn name(&self) -> &str;
25 fn higher_is_better(&self) -> bool;
27}
28
29pub struct ClosureScorer {
31 name: String,
32 scorer_fn: Arc<dyn Fn(&Array1<Float>, &Array1<Float>) -> Result<f64> + Send + Sync>,
33 higher_is_better: bool,
34}
35
36impl ClosureScorer {
37 pub fn new<F>(name: String, scorer_fn: F, higher_is_better: bool) -> Self
39 where
40 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<f64> + Send + Sync + 'static,
41 {
42 Self {
43 name,
44 scorer_fn: Arc::new(scorer_fn),
45 higher_is_better,
46 }
47 }
48}
49
50impl std::fmt::Debug for ClosureScorer {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 f.debug_struct("ClosureScorer")
53 .field("name", &self.name)
54 .field("higher_is_better", &self.higher_is_better)
55 .finish()
56 }
57}
58
59impl CustomScorer for ClosureScorer {
60 fn score(&self, y_true: &Array1<Float>, y_pred: &Array1<Float>) -> Result<f64> {
61 (self.scorer_fn)(y_true, y_pred)
62 }
63
64 fn name(&self) -> &str {
65 &self.name
66 }
67
68 fn higher_is_better(&self) -> bool {
69 self.higher_is_better
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct ScorerRegistry {
76 custom_scorers: HashMap<String, Arc<dyn CustomScorer>>,
77}
78
79impl Default for ScorerRegistry {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl ScorerRegistry {
86 pub fn new() -> Self {
88 Self {
89 custom_scorers: HashMap::new(),
90 }
91 }
92
93 pub fn register_scorer(&mut self, scorer: Arc<dyn CustomScorer>) {
95 self.custom_scorers
96 .insert(scorer.name().to_string(), scorer);
97 }
98
99 pub fn register_closure_scorer<F>(&mut self, name: String, scorer_fn: F, higher_is_better: bool)
101 where
102 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<f64> + Send + Sync + 'static,
103 {
104 let scorer = Arc::new(ClosureScorer::new(name, scorer_fn, higher_is_better));
105 self.register_scorer(scorer);
106 }
107
108 pub fn get_scorer(&self, name: &str) -> Option<&Arc<dyn CustomScorer>> {
110 self.custom_scorers.get(name)
111 }
112
113 pub fn list_scorers(&self) -> Vec<&str> {
115 self.custom_scorers.keys().map(|s| s.as_str()).collect()
116 }
117
118 pub fn has_scorer(&self, name: &str) -> bool {
120 self.custom_scorers.contains_key(name)
121 }
122}
123
124#[derive(Debug, Clone)]
126pub struct ScoringConfig {
127 pub primary: String,
129 pub additional: Vec<String>,
131 pub confidence_intervals: bool,
133 pub confidence_level: f64,
135 pub n_bootstrap: usize,
137 pub random_state: Option<u64>,
139 pub scorer_registry: ScorerRegistry,
141}
142
143impl Default for ScoringConfig {
144 fn default() -> Self {
145 Self {
146 primary: "accuracy".to_string(),
147 additional: vec![],
148 confidence_intervals: false,
149 confidence_level: 0.95,
150 n_bootstrap: 1000,
151 random_state: None,
152 scorer_registry: ScorerRegistry::new(),
153 }
154 }
155}
156
157impl ScoringConfig {
158 pub fn new(primary: &str) -> Self {
160 Self {
161 primary: primary.to_string(),
162 ..Default::default()
163 }
164 }
165
166 pub fn with_additional_metrics(mut self, metrics: Vec<String>) -> Self {
168 self.additional = metrics;
169 self
170 }
171
172 pub fn with_confidence_intervals(mut self, level: f64, n_bootstrap: usize) -> Self {
174 self.confidence_intervals = true;
175 self.confidence_level = level;
176 self.n_bootstrap = n_bootstrap;
177 self
178 }
179
180 pub fn with_random_state(mut self, random_state: u64) -> Self {
182 self.random_state = Some(random_state);
183 self
184 }
185
186 pub fn with_custom_scorer(mut self, scorer: Arc<dyn CustomScorer>) -> Self {
188 self.scorer_registry.register_scorer(scorer);
189 self
190 }
191
192 pub fn with_closure_scorer<F>(
194 mut self,
195 name: String,
196 scorer_fn: F,
197 higher_is_better: bool,
198 ) -> Self
199 where
200 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<f64> + Send + Sync + 'static,
201 {
202 self.scorer_registry
203 .register_closure_scorer(name, scorer_fn, higher_is_better);
204 self
205 }
206
207 pub fn scorer_registry_mut(&mut self) -> &mut ScorerRegistry {
209 &mut self.scorer_registry
210 }
211
212 pub fn scorer_registry(&self) -> &ScorerRegistry {
214 &self.scorer_registry
215 }
216}
217
218#[derive(Debug, Clone)]
220pub struct ScoringResult {
221 pub primary_scores: Array1<f64>,
223 pub additional_scores: HashMap<String, Array1<f64>>,
225 pub confidence_interval: Option<(f64, f64)>,
227 pub additional_confidence_intervals: HashMap<String, (f64, f64)>,
229 pub mean_scores: HashMap<String, f64>,
231 pub std_scores: HashMap<String, f64>,
233}
234
235impl ScoringResult {
236 pub fn primary_mean(&self) -> f64 {
238 self.mean_scores.get("primary").copied().unwrap_or(0.0)
239 }
240
241 pub fn mean_score(&self, metric: &str) -> Option<f64> {
243 self.mean_scores.get(metric).copied()
244 }
245
246 pub fn all_mean_scores(&self) -> &HashMap<String, f64> {
248 &self.mean_scores
249 }
250}
251
252pub struct EnhancedScorer {
254 config: ScoringConfig,
255}
256
257impl EnhancedScorer {
258 pub fn new(config: ScoringConfig) -> Self {
260 Self { config }
261 }
262
263 pub fn score_predictions(
265 &self,
266 y_true_splits: &[Array1<Float>],
267 y_pred_splits: &[Array1<Float>],
268 task_type: TaskType,
269 ) -> Result<ScoringResult> {
270 if y_true_splits.len() != y_pred_splits.len() {
271 return Err(SklearsError::InvalidInput(
272 "Number of true and predicted splits must match".to_string(),
273 ));
274 }
275
276 let n_splits = y_true_splits.len();
277 let mut primary_scores = Vec::with_capacity(n_splits);
278 let mut additional_scores: HashMap<String, Vec<f64>> = HashMap::new();
279
280 for metric in &self.config.additional {
282 additional_scores.insert(metric.clone(), Vec::with_capacity(n_splits));
283 }
284
285 for (y_true, y_pred) in y_true_splits.iter().zip(y_pred_splits.iter()) {
287 let primary_score =
289 self.compute_metric_score(&self.config.primary, y_true, y_pred, task_type)?;
290 primary_scores.push(primary_score);
291
292 for metric in &self.config.additional {
294 let score = self.compute_metric_score(metric, y_true, y_pred, task_type)?;
295 additional_scores
296 .get_mut(metric)
297 .expect("operation should succeed")
298 .push(score);
299 }
300 }
301
302 let primary_scores_array = Array1::from_vec(primary_scores.clone());
304 let mut additional_scores_arrays = HashMap::new();
305 for (metric, scores) in additional_scores.iter() {
306 additional_scores_arrays.insert(metric.clone(), Array1::from_vec(scores.clone()));
307 }
308
309 let confidence_interval = if self.config.confidence_intervals {
311 Some(self.bootstrap_confidence_interval(&primary_scores)?)
312 } else {
313 None
314 };
315
316 let mut additional_confidence_intervals = HashMap::new();
317 if self.config.confidence_intervals {
318 for (metric, scores) in &additional_scores {
319 let ci = self.bootstrap_confidence_interval(scores)?;
320 additional_confidence_intervals.insert(metric.clone(), ci);
321 }
322 }
323
324 let mut mean_scores = HashMap::new();
326 let mut std_scores = HashMap::new();
327
328 mean_scores.insert(
329 "primary".to_string(),
330 primary_scores_array
331 .mean()
332 .expect("operation should succeed"),
333 );
334 std_scores.insert("primary".to_string(), primary_scores_array.std(1.0));
335
336 for (metric, scores) in &additional_scores_arrays {
337 mean_scores.insert(
338 metric.clone(),
339 scores.mean().expect("operation should succeed"),
340 );
341 std_scores.insert(metric.clone(), scores.std(1.0));
342 }
343
344 Ok(ScoringResult {
345 primary_scores: primary_scores_array,
346 additional_scores: additional_scores_arrays,
347 confidence_interval,
348 additional_confidence_intervals,
349 mean_scores,
350 std_scores,
351 })
352 }
353
354 fn compute_metric_score(
356 &self,
357 metric: &str,
358 y_true: &Array1<Float>,
359 y_pred: &Array1<Float>,
360 task_type: TaskType,
361 ) -> Result<f64> {
362 if let Some(custom_scorer) = self.config.scorer_registry.get_scorer(metric) {
364 return custom_scorer.score(y_true, y_pred);
365 }
366
367 match task_type {
369 TaskType::Classification => self.compute_classification_score(metric, y_true, y_pred),
370 TaskType::Regression => self.compute_regression_score(metric, y_true, y_pred),
371 }
372 }
373
374 fn compute_classification_score(
375 &self,
376 metric: &str,
377 y_true: &Array1<Float>,
378 y_pred: &Array1<Float>,
379 ) -> Result<f64> {
380 let y_true_int: Array1<i32> = y_true.mapv(|x| x as i32);
382 let y_pred_int: Array1<i32> = y_pred.mapv(|x| x as i32);
383
384 let score = match metric {
385 "accuracy" => accuracy_score(&y_true_int, &y_pred_int)
386 .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
387 "precision" => precision_score(&y_true_int, &y_pred_int, None)
388 .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
389 "recall" => recall_score(&y_true_int, &y_pred_int, None)
390 .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
391 "f1" => f1_score(&y_true_int, &y_pred_int, None)
392 .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
393 _ => {
394 return Err(SklearsError::InvalidInput(format!(
395 "Unknown classification metric: {}",
396 metric
397 )))
398 }
399 };
400
401 Ok(score)
402 }
403
404 fn compute_regression_score(
405 &self,
406 metric: &str,
407 y_true: &Array1<Float>,
408 y_pred: &Array1<Float>,
409 ) -> Result<f64> {
410 let score = match metric {
411 "r2" | "r2_score" => {
412 r2_score(y_true, y_pred).map_err(|e| SklearsError::InvalidInput(e.to_string()))?
413 }
414 "neg_mean_squared_error" => -mean_squared_error(y_true, y_pred)
415 .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
416 "neg_mean_absolute_error" => -mean_absolute_error(y_true, y_pred)
417 .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
418 "explained_variance" => explained_variance_score(y_true, y_pred)
419 .map_err(|e| SklearsError::InvalidInput(e.to_string()))?,
420 _ => {
421 return Err(SklearsError::InvalidInput(format!(
422 "Unknown regression metric: {}",
423 metric
424 )))
425 }
426 };
427
428 Ok(score)
429 }
430
431 fn bootstrap_confidence_interval(&self, scores: &[f64]) -> Result<(f64, f64)> {
433 let mut rng = match self.config.random_state {
434 Some(seed) => StdRng::seed_from_u64(seed),
435 None => StdRng::seed_from_u64(42),
436 };
437
438 let n_scores = scores.len();
439 let mut bootstrap_means = Vec::with_capacity(self.config.n_bootstrap);
440
441 for _ in 0..self.config.n_bootstrap {
442 let mut bootstrap_sample = Vec::with_capacity(n_scores);
443 for _ in 0..n_scores {
444 let idx = rng.random_range(0..n_scores);
445 bootstrap_sample.push(scores[idx]);
446 }
447
448 let mean = bootstrap_sample.iter().sum::<f64>() / n_scores as f64;
449 bootstrap_means.push(mean);
450 }
451
452 bootstrap_means.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
453
454 let alpha = 1.0 - self.config.confidence_level;
455 let lower_idx = ((alpha / 2.0) * self.config.n_bootstrap as f64) as usize;
456 let upper_idx = ((1.0 - alpha / 2.0) * self.config.n_bootstrap as f64) as usize;
457
458 let lower = bootstrap_means[lower_idx.min(self.config.n_bootstrap - 1)];
459 let upper = bootstrap_means[upper_idx.min(self.config.n_bootstrap - 1)];
460
461 Ok((lower, upper))
462 }
463}
464
465#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
467pub enum TaskType {
468 Classification,
470 Regression,
472}
473
474#[derive(Debug, Clone)]
476pub struct SignificanceTestResult {
477 pub statistic: f64,
479 pub p_value: f64,
481 pub is_significant: bool,
483 pub alpha: f64,
485 pub test_name: String,
487}
488
489pub fn paired_ttest(
491 scores1: &Array1<f64>,
492 scores2: &Array1<f64>,
493 alpha: f64,
494) -> Result<SignificanceTestResult> {
495 if scores1.len() != scores2.len() {
496 return Err(SklearsError::InvalidInput(
497 "Score arrays must have the same length".to_string(),
498 ));
499 }
500
501 let n = scores1.len() as f64;
502 if n < 2.0 {
503 return Err(SklearsError::InvalidInput(
504 "Need at least 2 samples for t-test".to_string(),
505 ));
506 }
507
508 let differences: Array1<f64> = scores1 - scores2;
510 let mean_diff = differences.mean().expect("operation should succeed");
511 let std_diff = differences.std(1.0);
512
513 if std_diff == 0.0 {
514 return Err(SklearsError::InvalidInput(
515 "Standard deviation of differences is zero".to_string(),
516 ));
517 }
518
519 let t_stat = mean_diff * (n.sqrt()) / std_diff;
521
522 let df = n - 1.0;
525 let p_value = 2.0 * (1.0 - student_t_cdf(t_stat.abs(), df));
526
527 Ok(SignificanceTestResult {
528 statistic: t_stat,
529 p_value,
530 is_significant: p_value < alpha,
531 alpha,
532 test_name: "Paired t-test".to_string(),
533 })
534}
535
536fn student_t_cdf(t: f64, df: f64) -> f64 {
538 if df > 30.0 {
540 return standard_normal_cdf(t);
541 }
542
543 let x = t / (df + t * t).sqrt();
545 0.5 + 0.5 * x * (1.0 - x * x / 3.0)
546}
547
548fn standard_normal_cdf(x: f64) -> f64 {
550 0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
551}
552
553fn erf(x: f64) -> f64 {
555 let a1 = 0.254829592;
557 let a2 = -0.284496736;
558 let a3 = 1.421413741;
559 let a4 = -1.453152027;
560 let a5 = 1.061405429;
561 let p = 0.3275911;
562
563 let sign = if x < 0.0 { -1.0 } else { 1.0 };
564 let x = x.abs();
565
566 let t = 1.0 / (1.0 + p * x);
567 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
568
569 sign * y
570}
571
572pub fn wilcoxon_signed_rank_test(
574 scores1: &Array1<f64>,
575 scores2: &Array1<f64>,
576 alpha: f64,
577) -> Result<SignificanceTestResult> {
578 if scores1.len() != scores2.len() {
579 return Err(SklearsError::InvalidInput(
580 "Score arrays must have the same length".to_string(),
581 ));
582 }
583
584 let differences: Vec<f64> = scores1
585 .iter()
586 .zip(scores2.iter())
587 .map(|(a, b)| a - b)
588 .filter(|&d| d != 0.0) .collect();
590
591 let n = differences.len();
592 if n < 5 {
593 return Err(SklearsError::InvalidInput(
594 "Need at least 5 non-zero differences for Wilcoxon test".to_string(),
595 ));
596 }
597
598 let mut abs_diffs_with_indices: Vec<(f64, usize, f64)> = differences
600 .iter()
601 .enumerate()
602 .map(|(i, &d)| (d.abs(), i, d))
603 .collect();
604
605 abs_diffs_with_indices.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("operation should succeed"));
606
607 let mut ranks = vec![0.0; n];
608 let mut i = 0;
609 while i < n {
610 let mut j = i;
611 while j < n && abs_diffs_with_indices[j].0 == abs_diffs_with_indices[i].0 {
612 j += 1;
613 }
614
615 let rank = (i + j + 1) as f64 / 2.0; for k in i..j {
617 ranks[abs_diffs_with_indices[k].1] = rank;
618 }
619 i = j;
620 }
621
622 let w_plus: f64 = differences
624 .iter()
625 .zip(&ranks)
626 .filter(|(&d, _)| d > 0.0)
627 .map(|(_, &rank)| rank)
628 .sum();
629
630 let expected = n as f64 * (n + 1) as f64 / 4.0;
632 let variance = n as f64 * (n + 1) as f64 * (2 * n + 1) as f64 / 24.0;
633
634 let z = if w_plus > expected {
636 (w_plus - 0.5 - expected) / variance.sqrt()
637 } else {
638 (w_plus + 0.5 - expected) / variance.sqrt()
639 };
640
641 let p_value = 2.0 * (1.0 - standard_normal_cdf(z.abs()));
642
643 Ok(SignificanceTestResult {
644 statistic: w_plus,
645 p_value,
646 is_significant: p_value < alpha,
647 alpha,
648 test_name: "Wilcoxon signed-rank test".to_string(),
649 })
650}