1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
20use sklears_core::{
21 error::{Result, SklearsError},
22 traits::{Fit, Predict, PredictProba},
23 types::FloatBounds,
24};
25
26#[cfg(feature = "serde")]
27use serde::{Deserialize, Serialize};
28
29#[derive(Debug, Clone, Copy, PartialEq)]
31#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
32pub enum OptimizationMetric {
33 F1,
35 FBeta(f64),
37 Precision,
39 Recall,
41 BalancedAccuracy,
43 Cost {
45 fp_cost: f64,
47 fn_cost: f64,
49 },
50 Jaccard,
52 Matthews,
54}
55
56#[derive(Debug, Clone)]
60pub struct FixedThresholdClassifier<E> {
61 estimator: E,
63 threshold: f64,
65 pos_label_idx: usize,
67}
68
69impl<E> FixedThresholdClassifier<E> {
70 pub fn new(estimator: E, threshold: f64) -> Self {
72 Self {
73 estimator,
74 threshold,
75 pos_label_idx: 1,
76 }
77 }
78
79 pub fn threshold(mut self, threshold: f64) -> Self {
81 if threshold < 0.0 || threshold > 1.0 {
82 panic!("Threshold must be between 0.0 and 1.0");
83 }
84 self.threshold = threshold;
85 self
86 }
87
88 pub fn pos_label_idx(mut self, idx: usize) -> Self {
90 self.pos_label_idx = idx;
91 self
92 }
93
94 pub fn get_threshold(&self) -> f64 {
96 self.threshold
97 }
98
99 pub fn estimator(&self) -> &E {
101 &self.estimator
102 }
103}
104
105impl<'a, E, F: FloatBounds> Fit<ArrayView2<'a, F>, ArrayView1<'a, usize>>
106 for FixedThresholdClassifier<E>
107where
108 E: Fit<ArrayView2<'a, F>, ArrayView1<'a, usize>>,
109{
110 type Fitted = FixedThresholdClassifier<E::Fitted>;
111
112 fn fit(self, x: &ArrayView2<'a, F>, y: &ArrayView1<'a, usize>) -> Result<Self::Fitted> {
113 let trained_estimator = self.estimator.fit(x, y)?;
114 Ok(FixedThresholdClassifier {
115 estimator: trained_estimator,
116 threshold: self.threshold,
117 pos_label_idx: self.pos_label_idx,
118 })
119 }
120}
121
122impl<'a, E, F: FloatBounds> Predict<ArrayView2<'a, F>, Array1<usize>>
123 for FixedThresholdClassifier<E>
124where
125 E: PredictProba<ArrayView2<'a, F>, Array2<F>>,
126{
127 fn predict(&self, x: &ArrayView2<'a, F>) -> Result<Array1<usize>> {
128 let probas = self.estimator.predict_proba(x)?;
129
130 let predictions = probas.map_axis(Axis(1), |row| {
132 if row.len() <= self.pos_label_idx {
133 return 0;
134 }
135 if row[self.pos_label_idx].to_f64().unwrap_or(0.0) >= self.threshold {
136 1
137 } else {
138 0
139 }
140 });
141
142 Ok(predictions)
143 }
144}
145
146impl<'a, E, F: FloatBounds> PredictProba<ArrayView2<'a, F>, Array2<F>>
147 for FixedThresholdClassifier<E>
148where
149 E: PredictProba<ArrayView2<'a, F>, Array2<F>>,
150{
151 fn predict_proba(&self, x: &ArrayView2<'a, F>) -> Result<Array2<F>> {
152 self.estimator.predict_proba(x)
154 }
155}
156
157use scirs2_core::ndarray::Axis;
158
159#[derive(Debug)]
164pub struct TunedThresholdClassifierCV<E, C> {
165 estimator: E,
167 cv: C,
169 scoring: OptimizationMetric,
171 n_thresholds: usize,
173 min_threshold: f64,
175 max_threshold: f64,
177 pos_label_idx: usize,
179}
180
181impl<E, C> TunedThresholdClassifierCV<E, C> {
182 pub fn new(estimator: E, cv: C) -> Self {
184 Self {
185 estimator,
186 cv,
187 scoring: OptimizationMetric::F1,
188 n_thresholds: 100,
189 min_threshold: 0.0,
190 max_threshold: 1.0,
191 pos_label_idx: 1,
192 }
193 }
194
195 pub fn scoring(mut self, metric: OptimizationMetric) -> Self {
197 self.scoring = metric;
198 self
199 }
200
201 pub fn n_thresholds(mut self, n: usize) -> Self {
203 self.n_thresholds = n;
204 self
205 }
206
207 pub fn threshold_range(mut self, min: f64, max: f64) -> Self {
209 self.min_threshold = min;
210 self.max_threshold = max;
211 self
212 }
213}
214
215#[derive(Debug)]
217pub struct TunedThresholdClassifierCVTrained<E> {
218 estimator: E,
220 best_threshold_: f64,
222 best_score_: f64,
224 thresholds_: Vec<f64>,
226 scores_: Vec<f64>,
228 pos_label_idx: usize,
230}
231
232impl<E> TunedThresholdClassifierCVTrained<E> {
233 pub fn best_threshold(&self) -> f64 {
235 self.best_threshold_
236 }
237
238 pub fn best_score(&self) -> f64 {
240 self.best_score_
241 }
242
243 pub fn thresholds(&self) -> &[f64] {
245 &self.thresholds_
246 }
247
248 pub fn scores(&self) -> &[f64] {
250 &self.scores_
251 }
252}
253
254impl<'a, E, F: FloatBounds> Predict<ArrayView2<'a, F>, Array1<usize>>
255 for TunedThresholdClassifierCVTrained<E>
256where
257 E: PredictProba<ArrayView2<'a, F>, Array2<F>>,
258{
259 fn predict(&self, x: &ArrayView2<'a, F>) -> Result<Array1<usize>> {
260 let probas = self.estimator.predict_proba(x)?;
261
262 let predictions = probas.map_axis(Axis(1), |row| {
263 if row.len() <= self.pos_label_idx {
264 return 0;
265 }
266 if row[self.pos_label_idx].to_f64().unwrap_or(0.0) >= self.best_threshold_ {
267 1
268 } else {
269 0
270 }
271 });
272
273 Ok(predictions)
274 }
275}
276
277impl<'a, E, F: FloatBounds> PredictProba<ArrayView2<'a, F>, Array2<F>>
278 for TunedThresholdClassifierCVTrained<E>
279where
280 E: PredictProba<ArrayView2<'a, F>, Array2<F>>,
281{
282 fn predict_proba(&self, x: &ArrayView2<'a, F>) -> Result<Array2<F>> {
283 self.estimator.predict_proba(x)
284 }
285}
286
287impl OptimizationMetric {
289 pub fn compute(&self, y_true: &[usize], y_pred: &[usize]) -> f64 {
291 match self {
292 OptimizationMetric::F1 => compute_f1(y_true, y_pred),
293 OptimizationMetric::FBeta(beta) => compute_fbeta(y_true, y_pred, *beta),
294 OptimizationMetric::Precision => compute_precision(y_true, y_pred),
295 OptimizationMetric::Recall => compute_recall(y_true, y_pred),
296 OptimizationMetric::BalancedAccuracy => compute_balanced_accuracy(y_true, y_pred),
297 OptimizationMetric::Cost { fp_cost, fn_cost } => {
298 -compute_cost(y_true, y_pred, *fp_cost, *fn_cost)
299 }
300 OptimizationMetric::Jaccard => compute_jaccard(y_true, y_pred),
301 OptimizationMetric::Matthews => compute_matthews(y_true, y_pred),
302 }
303 }
304}
305
306fn confusion_matrix_binary(y_true: &[usize], y_pred: &[usize]) -> (usize, usize, usize, usize) {
308 let mut tp = 0;
309 let mut tn = 0;
310 let mut fp = 0;
311 let mut fn_count = 0;
312
313 for (&true_label, &pred_label) in y_true.iter().zip(y_pred.iter()) {
314 match (true_label, pred_label) {
315 (1, 1) => tp += 1,
316 (0, 0) => tn += 1,
317 (0, 1) => fp += 1,
318 (1, 0) => fn_count += 1,
319 _ => {}
320 }
321 }
322
323 (tp, tn, fp, fn_count)
324}
325
326fn compute_precision(y_true: &[usize], y_pred: &[usize]) -> f64 {
327 let (tp, _, fp, _) = confusion_matrix_binary(y_true, y_pred);
328 if tp + fp > 0 {
329 tp as f64 / (tp + fp) as f64
330 } else {
331 0.0
332 }
333}
334
335fn compute_recall(y_true: &[usize], y_pred: &[usize]) -> f64 {
336 let (tp, _, _, fn_count) = confusion_matrix_binary(y_true, y_pred);
337 if tp + fn_count > 0 {
338 tp as f64 / (tp + fn_count) as f64
339 } else {
340 0.0
341 }
342}
343
344fn compute_f1(y_true: &[usize], y_pred: &[usize]) -> f64 {
345 let precision = compute_precision(y_true, y_pred);
346 let recall = compute_recall(y_true, y_pred);
347
348 if precision + recall > 0.0 {
349 2.0 * (precision * recall) / (precision + recall)
350 } else {
351 0.0
352 }
353}
354
355fn compute_fbeta(y_true: &[usize], y_pred: &[usize], beta: f64) -> f64 {
356 let precision = compute_precision(y_true, y_pred);
357 let recall = compute_recall(y_true, y_pred);
358 let beta_sq = beta * beta;
359
360 if precision + recall > 0.0 {
361 (1.0 + beta_sq) * (precision * recall) / (beta_sq * precision + recall)
362 } else {
363 0.0
364 }
365}
366
367fn compute_balanced_accuracy(y_true: &[usize], y_pred: &[usize]) -> f64 {
368 let (tp, tn, fp, fn_count) = confusion_matrix_binary(y_true, y_pred);
369
370 let sensitivity = if tp + fn_count > 0 {
371 tp as f64 / (tp + fn_count) as f64
372 } else {
373 0.0
374 };
375
376 let specificity = if tn + fp > 0 {
377 tn as f64 / (tn + fp) as f64
378 } else {
379 0.0
380 };
381
382 (sensitivity + specificity) / 2.0
383}
384
385fn compute_cost(y_true: &[usize], y_pred: &[usize], fp_cost: f64, fn_cost: f64) -> f64 {
386 let (_, _, fp, fn_count) = confusion_matrix_binary(y_true, y_pred);
387 (fp as f64 * fp_cost) + (fn_count as f64 * fn_cost)
388}
389
390fn compute_jaccard(y_true: &[usize], y_pred: &[usize]) -> f64 {
391 let (tp, _, fp, fn_count) = confusion_matrix_binary(y_true, y_pred);
392 if tp + fp + fn_count > 0 {
393 tp as f64 / (tp + fp + fn_count) as f64
394 } else {
395 0.0
396 }
397}
398
399fn compute_matthews(y_true: &[usize], y_pred: &[usize]) -> f64 {
400 let (tp, tn, fp, fn_count) = confusion_matrix_binary(y_true, y_pred);
401
402 let numerator = (tp * tn) as f64 - (fp * fn_count) as f64;
403 let denominator = ((tp + fp) * (tp + fn_count) * (tn + fp) * (tn + fn_count)) as f64;
404
405 if denominator > 0.0 {
406 numerator / denominator.sqrt()
407 } else {
408 0.0
409 }
410}
411
412#[derive(Debug, Clone)]
414#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
415pub struct ThresholdOptimizationResult {
416 pub best_threshold: f64,
418 pub best_score: f64,
420 pub thresholds: Vec<f64>,
422 pub scores: Vec<f64>,
424}
425
426pub fn optimize_threshold<F: FloatBounds>(
428 y_true: &[usize],
429 y_proba: &Array2<F>,
430 metric: OptimizationMetric,
431 n_thresholds: usize,
432 pos_label_idx: usize,
433) -> Result<ThresholdOptimizationResult> {
434 if y_true.len() != y_proba.nrows() {
435 return Err(SklearsError::InvalidInput(
436 "y_true and y_proba must have same length".to_string(),
437 ));
438 }
439
440 let mut best_threshold = 0.5;
441 let mut best_score = f64::NEG_INFINITY;
442 let mut thresholds = Vec::with_capacity(n_thresholds);
443 let mut scores = Vec::with_capacity(n_thresholds);
444
445 for i in 0..n_thresholds {
447 let threshold = i as f64 / (n_thresholds - 1) as f64;
448 thresholds.push(threshold);
449
450 let y_pred: Vec<usize> = y_proba
452 .outer_iter()
453 .map(|row| {
454 if row.len() <= pos_label_idx {
455 return 0;
456 }
457 if row[pos_label_idx].to_f64().unwrap_or(0.0) >= threshold {
458 1
459 } else {
460 0
461 }
462 })
463 .collect();
464
465 let score = metric.compute(y_true, &y_pred);
467 scores.push(score);
468
469 if score > best_score {
470 best_score = score;
471 best_threshold = threshold;
472 }
473 }
474
475 Ok(ThresholdOptimizationResult {
476 best_threshold,
477 best_score,
478 thresholds,
479 scores,
480 })
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486 use scirs2_core::ndarray::array;
487
488 #[derive(Debug, Clone)]
490 struct MockClassifier;
491
492 impl<'a> Fit<ArrayView2<'a, f64>, ArrayView1<'a, usize>> for MockClassifier {
493 type Fitted = MockClassifierTrained;
494 fn fit(self, _x: &ArrayView2<'a, f64>, _y: &ArrayView1<'a, usize>) -> Result<Self::Fitted> {
495 Ok(MockClassifierTrained {
496 probas: array![[0.2, 0.8], [0.7, 0.3], [0.4, 0.6], [0.9, 0.1]],
497 })
498 }
499 }
500
501 #[derive(Debug, Clone)]
502 struct MockClassifierTrained {
503 probas: Array2<f64>,
504 }
505
506 impl<'a> PredictProba<ArrayView2<'a, f64>, Array2<f64>> for MockClassifierTrained {
507 fn predict_proba(&self, _x: &ArrayView2<'a, f64>) -> Result<Array2<f64>> {
508 Ok(self.probas.clone())
509 }
510 }
511
512 #[test]
513 fn test_fixed_threshold_classifier() {
514 let mock = MockClassifier;
515 let fixed = FixedThresholdClassifier::new(mock, 0.5);
516
517 assert_eq!(fixed.get_threshold(), 0.5);
518 }
519
520 #[test]
521 fn test_fixed_threshold_custom() {
522 let mock = MockClassifier;
523 let fixed = FixedThresholdClassifier::new(mock, 0.7).threshold(0.3);
524
525 assert_eq!(fixed.get_threshold(), 0.3);
526 }
527
528 #[test]
529 fn test_fixed_threshold_prediction() {
530 let mock = MockClassifier;
531 let fixed = FixedThresholdClassifier::new(mock, 0.5);
532
533 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
534 let y = array![1, 0, 1, 0];
535
536 let trained = fixed.fit(&x.view(), &y.view()).unwrap();
537 let predictions = trained.predict(&x.view()).unwrap();
538
539 assert_eq!(predictions[0], 1);
545 assert_eq!(predictions[1], 0);
546 assert_eq!(predictions[2], 1);
547 assert_eq!(predictions[3], 0);
548 }
549
550 #[test]
551 fn test_fixed_threshold_high() {
552 let mock = MockClassifier;
553 let fixed = FixedThresholdClassifier::new(mock, 0.7);
554
555 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
556 let y = array![1, 0, 1, 0];
557
558 let trained = fixed.fit(&x.view(), &y.view()).unwrap();
559 let predictions = trained.predict(&x.view()).unwrap();
560
561 assert_eq!(predictions[0], 1);
567 assert_eq!(predictions[1], 0);
568 assert_eq!(predictions[2], 0);
569 assert_eq!(predictions[3], 0);
570 }
571
572 #[test]
573 fn test_confusion_matrix() {
574 let y_true = vec![1, 1, 0, 0, 1, 0, 1, 0];
575 let y_pred = vec![1, 0, 0, 1, 1, 0, 0, 1];
576
577 let (tp, tn, fp, fn_count) = confusion_matrix_binary(&y_true, &y_pred);
578
579 assert_eq!(tp, 2); assert_eq!(tn, 2); assert_eq!(fp, 2); assert_eq!(fn_count, 2); }
584
585 #[test]
586 fn test_precision_recall() {
587 let y_true = vec![1, 1, 0, 0, 1, 0];
588 let y_pred = vec![1, 0, 0, 1, 1, 0];
589
590 let precision = compute_precision(&y_true, &y_pred);
591 let recall = compute_recall(&y_true, &y_pred);
592
593 assert!((precision - 0.666).abs() < 0.01); assert!((recall - 0.666).abs() < 0.01); }
597
598 #[test]
599 fn test_f1_score() {
600 let y_true = vec![1, 1, 0, 0, 1, 0];
601 let y_pred = vec![1, 0, 0, 1, 1, 0];
602
603 let f1 = compute_f1(&y_true, &y_pred);
604 assert!((f1 - 0.666).abs() < 0.01);
605 }
606
607 #[test]
608 fn test_balanced_accuracy() {
609 let y_true = vec![1, 1, 1, 0, 0, 0];
610 let y_pred = vec![1, 1, 0, 0, 0, 1];
611
612 let balanced_acc = compute_balanced_accuracy(&y_true, &y_pred);
613 assert!((balanced_acc - 0.666).abs() < 0.01);
616 }
617
618 #[test]
619 fn test_cost_computation() {
620 let y_true = vec![1, 1, 0, 0];
621 let y_pred = vec![1, 0, 1, 0];
622 let cost = compute_cost(&y_true, &y_pred, 10.0, 5.0);
625 assert_eq!(cost, 15.0); }
627
628 #[test]
629 fn test_jaccard_score() {
630 let y_true = vec![1, 1, 0, 0, 1];
631 let y_pred = vec![1, 0, 0, 1, 1];
632 let jaccard = compute_jaccard(&y_true, &y_pred);
636 assert_eq!(jaccard, 0.5);
637 }
638
639 #[test]
640 fn test_matthews_correlation() {
641 let y_true = vec![1, 1, 0, 0];
642 let y_pred = vec![1, 0, 0, 1];
643 let mcc = compute_matthews(&y_true, &y_pred);
646 assert_eq!(mcc, 0.0); }
648
649 #[test]
650 fn test_optimize_threshold() {
651 let y_proba = array![
653 [0.3, 0.7],
654 [0.8, 0.2],
655 [0.4, 0.6],
656 [0.9, 0.1],
657 [0.2, 0.8],
658 [0.6, 0.4],
659 ];
660 let y_true = vec![1, 0, 1, 0, 1, 0];
661
662 let result = optimize_threshold(&y_true, &y_proba, OptimizationMetric::F1, 20, 1).unwrap();
663
664 assert!(result.best_threshold >= 0.0 && result.best_threshold <= 1.0);
665 assert!(result.best_score >= 0.0 && result.best_score <= 1.0);
666 assert_eq!(result.thresholds.len(), 20);
667 assert_eq!(result.scores.len(), 20);
668 }
669
670 #[test]
671 fn test_optimize_threshold_precision() {
672 let y_proba = array![
676 [0.1, 0.9], [0.4, 0.6], [0.3, 0.7], [0.6, 0.4], [0.5, 0.5], [0.7, 0.3], ];
683 let y_true = vec![1, 0, 1, 0, 1, 0];
684
685 let result =
686 optimize_threshold(&y_true, &y_proba, OptimizationMetric::Precision, 50, 1).unwrap();
687
688 assert!(
692 result.best_threshold >= 0.6,
693 "Expected threshold >= 0.6 for precision, got {}",
694 result.best_threshold
695 );
696 }
697
698 #[test]
699 fn test_optimize_threshold_recall() {
700 let y_proba = array![[0.3, 0.7], [0.8, 0.2], [0.4, 0.6], [0.1, 0.9],];
701 let y_true = vec![1, 0, 1, 1];
702
703 let result =
704 optimize_threshold(&y_true, &y_proba, OptimizationMetric::Recall, 50, 1).unwrap();
705
706 assert!(result.best_threshold <= 0.5);
708 }
709
710 #[test]
711 fn test_fbeta_optimization() {
712 let y_proba = array![[0.2, 0.8], [0.7, 0.3], [0.5, 0.5], [0.3, 0.7]];
713 let y_true = vec![1, 0, 1, 1];
714
715 let result =
717 optimize_threshold(&y_true, &y_proba, OptimizationMetric::FBeta(2.0), 50, 1).unwrap();
718
719 assert!(result.best_score >= 0.0);
720 assert!(result.best_score <= 1.0);
721 }
722
723 #[test]
724 fn test_cost_sensitive_optimization() {
725 let y_proba = array![
729 [0.1, 0.9], [0.4, 0.6], [0.3, 0.7], [0.6, 0.4], [0.2, 0.8], ];
735 let y_true = vec![1, 0, 1, 0, 1];
736
737 let result = optimize_threshold(
739 &y_true,
740 &y_proba,
741 OptimizationMetric::Cost {
742 fp_cost: 10.0,
743 fn_cost: 1.0,
744 },
745 50,
746 1,
747 )
748 .unwrap();
749
750 assert!(
755 result.best_threshold >= 0.6,
756 "Expected threshold >= 0.6, got {}",
757 result.best_threshold
758 );
759 assert!(
760 result.best_score >= -0.1,
761 "Expected near-zero cost (score >= -0.1), got {}",
762 result.best_score
763 );
764 }
765}