sklears_model_selection/
threshold_tuning.rs

1//! Threshold tuning for probabilistic classifiers
2//!
3//! This module provides utilities for tuning decision thresholds in binary classification
4//! to optimize specific metrics or constraints beyond the default 0.5 threshold.
5//!
6//! # Key Components
7//!
8//! - **FixedThresholdClassifier**: Applies a fixed decision threshold
9//! - **TunedThresholdClassifierCV**: Automatically tunes threshold via cross-validation
10//! - **ThresholdOptimizer**: Optimizes threshold for specific metrics
11//!
12//! # Use Cases
13//!
14//! - Imbalanced classification (optimize F1, precision, recall)
15//! - Cost-sensitive learning (minimize false positives/negatives)
16//! - Meeting specific precision/recall requirements
17//! - ROC curve analysis and threshold selection
18
19use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
20use sklears_core::{
21    error::{Result, SklearsError},
22    traits::{Fit, Predict, PredictProba},
23    types::FloatBounds,
24};
25
26#[cfg(feature = "serde")]
27use serde::{Deserialize, Serialize};
28
29/// Metric to optimize when tuning threshold
30#[derive(Debug, Clone, Copy, PartialEq)]
31#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
32pub enum OptimizationMetric {
33    /// Maximize F1 score
34    F1,
35    /// Maximize F-beta score
36    FBeta(f64),
37    /// Maximize precision
38    Precision,
39    /// Maximize recall
40    Recall,
41    /// Maximize balanced accuracy
42    BalancedAccuracy,
43    /// Minimize cost (weighted FP/FN)
44    Cost {
45        /// Cost of false positive
46        fp_cost: f64,
47        /// Cost of false negative
48        fn_cost: f64,
49    },
50    /// Maximize Jaccard score
51    Jaccard,
52    /// Maximize Matthews correlation coefficient
53    Matthews,
54}
55
56/// Fixed threshold classifier wrapper
57///
58/// Applies a fixed decision threshold to a probabilistic classifier's predictions.
59#[derive(Debug, Clone)]
60pub struct FixedThresholdClassifier<E> {
61    /// Base estimator (must implement PredictProba)
62    estimator: E,
63    /// Decision threshold (default: 0.5)
64    threshold: f64,
65    /// Which probability column to threshold (for multiclass: use class 1)
66    pos_label_idx: usize,
67}
68
69impl<E> FixedThresholdClassifier<E> {
70    /// Create a new fixed threshold classifier
71    pub fn new(estimator: E, threshold: f64) -> Self {
72        Self {
73            estimator,
74            threshold,
75            pos_label_idx: 1,
76        }
77    }
78
79    /// Set the threshold value
80    pub fn threshold(mut self, threshold: f64) -> Self {
81        if threshold < 0.0 || threshold > 1.0 {
82            panic!("Threshold must be between 0.0 and 1.0");
83        }
84        self.threshold = threshold;
85        self
86    }
87
88    /// Set the positive label index for multiclass
89    pub fn pos_label_idx(mut self, idx: usize) -> Self {
90        self.pos_label_idx = idx;
91        self
92    }
93
94    /// Get the threshold value
95    pub fn get_threshold(&self) -> f64 {
96        self.threshold
97    }
98
99    /// Get the base estimator
100    pub fn estimator(&self) -> &E {
101        &self.estimator
102    }
103}
104
105impl<'a, E, F: FloatBounds> Fit<ArrayView2<'a, F>, ArrayView1<'a, usize>>
106    for FixedThresholdClassifier<E>
107where
108    E: Fit<ArrayView2<'a, F>, ArrayView1<'a, usize>>,
109{
110    type Fitted = FixedThresholdClassifier<E::Fitted>;
111
112    fn fit(self, x: &ArrayView2<'a, F>, y: &ArrayView1<'a, usize>) -> Result<Self::Fitted> {
113        let trained_estimator = self.estimator.fit(x, y)?;
114        Ok(FixedThresholdClassifier {
115            estimator: trained_estimator,
116            threshold: self.threshold,
117            pos_label_idx: self.pos_label_idx,
118        })
119    }
120}
121
122impl<'a, E, F: FloatBounds> Predict<ArrayView2<'a, F>, Array1<usize>>
123    for FixedThresholdClassifier<E>
124where
125    E: PredictProba<ArrayView2<'a, F>, Array2<F>>,
126{
127    fn predict(&self, x: &ArrayView2<'a, F>) -> Result<Array1<usize>> {
128        let probas = self.estimator.predict_proba(x)?;
129
130        // Apply threshold to positive class probability
131        let predictions = probas.map_axis(Axis(1), |row| {
132            if row.len() <= self.pos_label_idx {
133                return 0;
134            }
135            if row[self.pos_label_idx].to_f64().unwrap_or(0.0) >= self.threshold {
136                1
137            } else {
138                0
139            }
140        });
141
142        Ok(predictions)
143    }
144}
145
146impl<'a, E, F: FloatBounds> PredictProba<ArrayView2<'a, F>, Array2<F>>
147    for FixedThresholdClassifier<E>
148where
149    E: PredictProba<ArrayView2<'a, F>, Array2<F>>,
150{
151    fn predict_proba(&self, x: &ArrayView2<'a, F>) -> Result<Array2<F>> {
152        // Just pass through the base estimator's probabilities
153        self.estimator.predict_proba(x)
154    }
155}
156
157use scirs2_core::ndarray::Axis;
158
159/// Tuned threshold classifier with cross-validation
160///
161/// Automatically finds the optimal decision threshold by maximizing a specified
162/// metric through cross-validation on the training data.
163#[derive(Debug)]
164pub struct TunedThresholdClassifierCV<E, C> {
165    /// Base estimator
166    estimator: E,
167    /// Cross-validation splitter
168    cv: C,
169    /// Metric to optimize
170    scoring: OptimizationMetric,
171    /// Number of threshold values to try (linearly spaced from 0 to 1)
172    n_thresholds: usize,
173    /// Minimum threshold to consider
174    min_threshold: f64,
175    /// Maximum threshold to consider
176    max_threshold: f64,
177    /// Positive label index
178    pos_label_idx: usize,
179}
180
181impl<E, C> TunedThresholdClassifierCV<E, C> {
182    /// Create a new tuned threshold classifier
183    pub fn new(estimator: E, cv: C) -> Self {
184        Self {
185            estimator,
186            cv,
187            scoring: OptimizationMetric::F1,
188            n_thresholds: 100,
189            min_threshold: 0.0,
190            max_threshold: 1.0,
191            pos_label_idx: 1,
192        }
193    }
194
195    /// Set the metric to optimize
196    pub fn scoring(mut self, metric: OptimizationMetric) -> Self {
197        self.scoring = metric;
198        self
199    }
200
201    /// Set the number of thresholds to try
202    pub fn n_thresholds(mut self, n: usize) -> Self {
203        self.n_thresholds = n;
204        self
205    }
206
207    /// Set the threshold range
208    pub fn threshold_range(mut self, min: f64, max: f64) -> Self {
209        self.min_threshold = min;
210        self.max_threshold = max;
211        self
212    }
213}
214
215/// Trained tuned threshold classifier
216#[derive(Debug)]
217pub struct TunedThresholdClassifierCVTrained<E> {
218    /// Trained base estimator
219    estimator: E,
220    /// Optimal threshold found via CV
221    best_threshold_: f64,
222    /// Best score achieved
223    best_score_: f64,
224    /// All thresholds tried
225    thresholds_: Vec<f64>,
226    /// Scores for each threshold
227    scores_: Vec<f64>,
228    /// Positive label index
229    pos_label_idx: usize,
230}
231
232impl<E> TunedThresholdClassifierCVTrained<E> {
233    /// Get the optimal threshold
234    pub fn best_threshold(&self) -> f64 {
235        self.best_threshold_
236    }
237
238    /// Get the best score achieved
239    pub fn best_score(&self) -> f64 {
240        self.best_score_
241    }
242
243    /// Get all thresholds tried
244    pub fn thresholds(&self) -> &[f64] {
245        &self.thresholds_
246    }
247
248    /// Get scores for each threshold
249    pub fn scores(&self) -> &[f64] {
250        &self.scores_
251    }
252}
253
254impl<'a, E, F: FloatBounds> Predict<ArrayView2<'a, F>, Array1<usize>>
255    for TunedThresholdClassifierCVTrained<E>
256where
257    E: PredictProba<ArrayView2<'a, F>, Array2<F>>,
258{
259    fn predict(&self, x: &ArrayView2<'a, F>) -> Result<Array1<usize>> {
260        let probas = self.estimator.predict_proba(x)?;
261
262        let predictions = probas.map_axis(Axis(1), |row| {
263            if row.len() <= self.pos_label_idx {
264                return 0;
265            }
266            if row[self.pos_label_idx].to_f64().unwrap_or(0.0) >= self.best_threshold_ {
267                1
268            } else {
269                0
270            }
271        });
272
273        Ok(predictions)
274    }
275}
276
277impl<'a, E, F: FloatBounds> PredictProba<ArrayView2<'a, F>, Array2<F>>
278    for TunedThresholdClassifierCVTrained<E>
279where
280    E: PredictProba<ArrayView2<'a, F>, Array2<F>>,
281{
282    fn predict_proba(&self, x: &ArrayView2<'a, F>) -> Result<Array2<F>> {
283        self.estimator.predict_proba(x)
284    }
285}
286
287/// Helper functions for computing metrics
288impl OptimizationMetric {
289    /// Compute the metric value given predictions and true labels
290    pub fn compute(&self, y_true: &[usize], y_pred: &[usize]) -> f64 {
291        match self {
292            OptimizationMetric::F1 => compute_f1(y_true, y_pred),
293            OptimizationMetric::FBeta(beta) => compute_fbeta(y_true, y_pred, *beta),
294            OptimizationMetric::Precision => compute_precision(y_true, y_pred),
295            OptimizationMetric::Recall => compute_recall(y_true, y_pred),
296            OptimizationMetric::BalancedAccuracy => compute_balanced_accuracy(y_true, y_pred),
297            OptimizationMetric::Cost { fp_cost, fn_cost } => {
298                -compute_cost(y_true, y_pred, *fp_cost, *fn_cost)
299            }
300            OptimizationMetric::Jaccard => compute_jaccard(y_true, y_pred),
301            OptimizationMetric::Matthews => compute_matthews(y_true, y_pred),
302        }
303    }
304}
305
306/// Compute confusion matrix components
307fn confusion_matrix_binary(y_true: &[usize], y_pred: &[usize]) -> (usize, usize, usize, usize) {
308    let mut tp = 0;
309    let mut tn = 0;
310    let mut fp = 0;
311    let mut fn_count = 0;
312
313    for (&true_label, &pred_label) in y_true.iter().zip(y_pred.iter()) {
314        match (true_label, pred_label) {
315            (1, 1) => tp += 1,
316            (0, 0) => tn += 1,
317            (0, 1) => fp += 1,
318            (1, 0) => fn_count += 1,
319            _ => {}
320        }
321    }
322
323    (tp, tn, fp, fn_count)
324}
325
326fn compute_precision(y_true: &[usize], y_pred: &[usize]) -> f64 {
327    let (tp, _, fp, _) = confusion_matrix_binary(y_true, y_pred);
328    if tp + fp > 0 {
329        tp as f64 / (tp + fp) as f64
330    } else {
331        0.0
332    }
333}
334
335fn compute_recall(y_true: &[usize], y_pred: &[usize]) -> f64 {
336    let (tp, _, _, fn_count) = confusion_matrix_binary(y_true, y_pred);
337    if tp + fn_count > 0 {
338        tp as f64 / (tp + fn_count) as f64
339    } else {
340        0.0
341    }
342}
343
344fn compute_f1(y_true: &[usize], y_pred: &[usize]) -> f64 {
345    let precision = compute_precision(y_true, y_pred);
346    let recall = compute_recall(y_true, y_pred);
347
348    if precision + recall > 0.0 {
349        2.0 * (precision * recall) / (precision + recall)
350    } else {
351        0.0
352    }
353}
354
355fn compute_fbeta(y_true: &[usize], y_pred: &[usize], beta: f64) -> f64 {
356    let precision = compute_precision(y_true, y_pred);
357    let recall = compute_recall(y_true, y_pred);
358    let beta_sq = beta * beta;
359
360    if precision + recall > 0.0 {
361        (1.0 + beta_sq) * (precision * recall) / (beta_sq * precision + recall)
362    } else {
363        0.0
364    }
365}
366
367fn compute_balanced_accuracy(y_true: &[usize], y_pred: &[usize]) -> f64 {
368    let (tp, tn, fp, fn_count) = confusion_matrix_binary(y_true, y_pred);
369
370    let sensitivity = if tp + fn_count > 0 {
371        tp as f64 / (tp + fn_count) as f64
372    } else {
373        0.0
374    };
375
376    let specificity = if tn + fp > 0 {
377        tn as f64 / (tn + fp) as f64
378    } else {
379        0.0
380    };
381
382    (sensitivity + specificity) / 2.0
383}
384
385fn compute_cost(y_true: &[usize], y_pred: &[usize], fp_cost: f64, fn_cost: f64) -> f64 {
386    let (_, _, fp, fn_count) = confusion_matrix_binary(y_true, y_pred);
387    (fp as f64 * fp_cost) + (fn_count as f64 * fn_cost)
388}
389
390fn compute_jaccard(y_true: &[usize], y_pred: &[usize]) -> f64 {
391    let (tp, _, fp, fn_count) = confusion_matrix_binary(y_true, y_pred);
392    if tp + fp + fn_count > 0 {
393        tp as f64 / (tp + fp + fn_count) as f64
394    } else {
395        0.0
396    }
397}
398
399fn compute_matthews(y_true: &[usize], y_pred: &[usize]) -> f64 {
400    let (tp, tn, fp, fn_count) = confusion_matrix_binary(y_true, y_pred);
401
402    let numerator = (tp * tn) as f64 - (fp * fn_count) as f64;
403    let denominator = ((tp + fp) * (tp + fn_count) * (tn + fp) * (tn + fn_count)) as f64;
404
405    if denominator > 0.0 {
406        numerator / denominator.sqrt()
407    } else {
408        0.0
409    }
410}
411
412/// Threshold optimization results
413#[derive(Debug, Clone)]
414#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
415pub struct ThresholdOptimizationResult {
416    /// Optimal threshold found
417    pub best_threshold: f64,
418    /// Best metric value achieved
419    pub best_score: f64,
420    /// All thresholds evaluated
421    pub thresholds: Vec<f64>,
422    /// Scores for each threshold
423    pub scores: Vec<f64>,
424}
425
426/// Optimize threshold for a given metric
427pub fn optimize_threshold<F: FloatBounds>(
428    y_true: &[usize],
429    y_proba: &Array2<F>,
430    metric: OptimizationMetric,
431    n_thresholds: usize,
432    pos_label_idx: usize,
433) -> Result<ThresholdOptimizationResult> {
434    if y_true.len() != y_proba.nrows() {
435        return Err(SklearsError::InvalidInput(
436            "y_true and y_proba must have same length".to_string(),
437        ));
438    }
439
440    let mut best_threshold = 0.5;
441    let mut best_score = f64::NEG_INFINITY;
442    let mut thresholds = Vec::with_capacity(n_thresholds);
443    let mut scores = Vec::with_capacity(n_thresholds);
444
445    // Try different thresholds
446    for i in 0..n_thresholds {
447        let threshold = i as f64 / (n_thresholds - 1) as f64;
448        thresholds.push(threshold);
449
450        // Apply threshold to get predictions
451        let y_pred: Vec<usize> = y_proba
452            .outer_iter()
453            .map(|row| {
454                if row.len() <= pos_label_idx {
455                    return 0;
456                }
457                if row[pos_label_idx].to_f64().unwrap_or(0.0) >= threshold {
458                    1
459                } else {
460                    0
461                }
462            })
463            .collect();
464
465        // Compute metric
466        let score = metric.compute(y_true, &y_pred);
467        scores.push(score);
468
469        if score > best_score {
470            best_score = score;
471            best_threshold = threshold;
472        }
473    }
474
475    Ok(ThresholdOptimizationResult {
476        best_threshold,
477        best_score,
478        thresholds,
479        scores,
480    })
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486    use scirs2_core::ndarray::array;
487
488    // Simple mock classifier for testing
489    #[derive(Debug, Clone)]
490    struct MockClassifier;
491
492    impl<'a> Fit<ArrayView2<'a, f64>, ArrayView1<'a, usize>> for MockClassifier {
493        type Fitted = MockClassifierTrained;
494        fn fit(self, _x: &ArrayView2<'a, f64>, _y: &ArrayView1<'a, usize>) -> Result<Self::Fitted> {
495            Ok(MockClassifierTrained {
496                probas: array![[0.2, 0.8], [0.7, 0.3], [0.4, 0.6], [0.9, 0.1]],
497            })
498        }
499    }
500
501    #[derive(Debug, Clone)]
502    struct MockClassifierTrained {
503        probas: Array2<f64>,
504    }
505
506    impl<'a> PredictProba<ArrayView2<'a, f64>, Array2<f64>> for MockClassifierTrained {
507        fn predict_proba(&self, _x: &ArrayView2<'a, f64>) -> Result<Array2<f64>> {
508            Ok(self.probas.clone())
509        }
510    }
511
512    #[test]
513    fn test_fixed_threshold_classifier() {
514        let mock = MockClassifier;
515        let fixed = FixedThresholdClassifier::new(mock, 0.5);
516
517        assert_eq!(fixed.get_threshold(), 0.5);
518    }
519
520    #[test]
521    fn test_fixed_threshold_custom() {
522        let mock = MockClassifier;
523        let fixed = FixedThresholdClassifier::new(mock, 0.7).threshold(0.3);
524
525        assert_eq!(fixed.get_threshold(), 0.3);
526    }
527
528    #[test]
529    fn test_fixed_threshold_prediction() {
530        let mock = MockClassifier;
531        let fixed = FixedThresholdClassifier::new(mock, 0.5);
532
533        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
534        let y = array![1, 0, 1, 0];
535
536        let trained = fixed.fit(&x.view(), &y.view()).unwrap();
537        let predictions = trained.predict(&x.view()).unwrap();
538
539        // With threshold 0.5:
540        // Row 0: [0.2, 0.8] -> 0.8 >= 0.5 -> 1
541        // Row 1: [0.7, 0.3] -> 0.3 < 0.5 -> 0
542        // Row 2: [0.4, 0.6] -> 0.6 >= 0.5 -> 1
543        // Row 3: [0.9, 0.1] -> 0.1 < 0.5 -> 0
544        assert_eq!(predictions[0], 1);
545        assert_eq!(predictions[1], 0);
546        assert_eq!(predictions[2], 1);
547        assert_eq!(predictions[3], 0);
548    }
549
550    #[test]
551    fn test_fixed_threshold_high() {
552        let mock = MockClassifier;
553        let fixed = FixedThresholdClassifier::new(mock, 0.7);
554
555        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
556        let y = array![1, 0, 1, 0];
557
558        let trained = fixed.fit(&x.view(), &y.view()).unwrap();
559        let predictions = trained.predict(&x.view()).unwrap();
560
561        // With threshold 0.7:
562        // Row 0: [0.2, 0.8] -> 0.8 >= 0.7 -> 1
563        // Row 1: [0.7, 0.3] -> 0.3 < 0.7 -> 0
564        // Row 2: [0.4, 0.6] -> 0.6 < 0.7 -> 0
565        // Row 3: [0.9, 0.1] -> 0.1 < 0.7 -> 0
566        assert_eq!(predictions[0], 1);
567        assert_eq!(predictions[1], 0);
568        assert_eq!(predictions[2], 0);
569        assert_eq!(predictions[3], 0);
570    }
571
572    #[test]
573    fn test_confusion_matrix() {
574        let y_true = vec![1, 1, 0, 0, 1, 0, 1, 0];
575        let y_pred = vec![1, 0, 0, 1, 1, 0, 0, 1];
576
577        let (tp, tn, fp, fn_count) = confusion_matrix_binary(&y_true, &y_pred);
578
579        assert_eq!(tp, 2); // Correctly predicted positive
580        assert_eq!(tn, 2); // Correctly predicted negative
581        assert_eq!(fp, 2); // False positives
582        assert_eq!(fn_count, 2); // False negatives
583    }
584
585    #[test]
586    fn test_precision_recall() {
587        let y_true = vec![1, 1, 0, 0, 1, 0];
588        let y_pred = vec![1, 0, 0, 1, 1, 0];
589
590        let precision = compute_precision(&y_true, &y_pred);
591        let recall = compute_recall(&y_true, &y_pred);
592
593        // TP=2, FP=1, FN=1
594        assert!((precision - 0.666).abs() < 0.01); // 2/3
595        assert!((recall - 0.666).abs() < 0.01); // 2/3
596    }
597
598    #[test]
599    fn test_f1_score() {
600        let y_true = vec![1, 1, 0, 0, 1, 0];
601        let y_pred = vec![1, 0, 0, 1, 1, 0];
602
603        let f1 = compute_f1(&y_true, &y_pred);
604        assert!((f1 - 0.666).abs() < 0.01);
605    }
606
607    #[test]
608    fn test_balanced_accuracy() {
609        let y_true = vec![1, 1, 1, 0, 0, 0];
610        let y_pred = vec![1, 1, 0, 0, 0, 1];
611
612        let balanced_acc = compute_balanced_accuracy(&y_true, &y_pred);
613        // Sensitivity (TPR) = 2/3, Specificity (TNR) = 2/3
614        // Balanced = (2/3 + 2/3) / 2 = 2/3
615        assert!((balanced_acc - 0.666).abs() < 0.01);
616    }
617
618    #[test]
619    fn test_cost_computation() {
620        let y_true = vec![1, 1, 0, 0];
621        let y_pred = vec![1, 0, 1, 0];
622        // FP=1, FN=1
623
624        let cost = compute_cost(&y_true, &y_pred, 10.0, 5.0);
625        assert_eq!(cost, 15.0); // 1*10 + 1*5
626    }
627
628    #[test]
629    fn test_jaccard_score() {
630        let y_true = vec![1, 1, 0, 0, 1];
631        let y_pred = vec![1, 0, 0, 1, 1];
632        // TP=2, FP=1, FN=1
633        // Jaccard = TP / (TP + FP + FN) = 2 / 4 = 0.5
634
635        let jaccard = compute_jaccard(&y_true, &y_pred);
636        assert_eq!(jaccard, 0.5);
637    }
638
639    #[test]
640    fn test_matthews_correlation() {
641        let y_true = vec![1, 1, 0, 0];
642        let y_pred = vec![1, 0, 0, 1];
643        // TP=1, TN=1, FP=1, FN=1
644
645        let mcc = compute_matthews(&y_true, &y_pred);
646        assert_eq!(mcc, 0.0); // Perfect disagreement
647    }
648
649    #[test]
650    fn test_optimize_threshold() {
651        // Create probability predictions
652        let y_proba = array![
653            [0.3, 0.7],
654            [0.8, 0.2],
655            [0.4, 0.6],
656            [0.9, 0.1],
657            [0.2, 0.8],
658            [0.6, 0.4],
659        ];
660        let y_true = vec![1, 0, 1, 0, 1, 0];
661
662        let result = optimize_threshold(&y_true, &y_proba, OptimizationMetric::F1, 20, 1).unwrap();
663
664        assert!(result.best_threshold >= 0.0 && result.best_threshold <= 1.0);
665        assert!(result.best_score >= 0.0 && result.best_score <= 1.0);
666        assert_eq!(result.thresholds.len(), 20);
667        assert_eq!(result.scores.len(), 20);
668    }
669
670    #[test]
671    fn test_optimize_threshold_precision() {
672        // More realistic data with clear optimal threshold for precision
673        // Class 1 probabilities: [0.9, 0.6, 0.7, 0.4, 0.5, 0.3]
674        // True labels:           [1,   0,   1,   0,   1,   0]
675        let y_proba = array![
676            [0.1, 0.9], // true=1, proba=0.9
677            [0.4, 0.6], // true=0, proba=0.6 (FP if threshold < 0.6)
678            [0.3, 0.7], // true=1, proba=0.7
679            [0.6, 0.4], // true=0, proba=0.4
680            [0.5, 0.5], // true=1, proba=0.5
681            [0.7, 0.3], // true=0, proba=0.3
682        ];
683        let y_true = vec![1, 0, 1, 0, 1, 0];
684
685        let result =
686            optimize_threshold(&y_true, &y_proba, OptimizationMetric::Precision, 50, 1).unwrap();
687
688        // Threshold >= 0.6 gives precision=1.0 (no FP)
689        // Threshold < 0.6 includes sample 1 as FP, reducing precision
690        // So optimal should be >= 0.6
691        assert!(
692            result.best_threshold >= 0.6,
693            "Expected threshold >= 0.6 for precision, got {}",
694            result.best_threshold
695        );
696    }
697
698    #[test]
699    fn test_optimize_threshold_recall() {
700        let y_proba = array![[0.3, 0.7], [0.8, 0.2], [0.4, 0.6], [0.1, 0.9],];
701        let y_true = vec![1, 0, 1, 1];
702
703        let result =
704            optimize_threshold(&y_true, &y_proba, OptimizationMetric::Recall, 50, 1).unwrap();
705
706        // Low threshold should favor recall
707        assert!(result.best_threshold <= 0.5);
708    }
709
710    #[test]
711    fn test_fbeta_optimization() {
712        let y_proba = array![[0.2, 0.8], [0.7, 0.3], [0.5, 0.5], [0.3, 0.7]];
713        let y_true = vec![1, 0, 1, 1];
714
715        // F2 score (beta=2, favors recall)
716        let result =
717            optimize_threshold(&y_true, &y_proba, OptimizationMetric::FBeta(2.0), 50, 1).unwrap();
718
719        assert!(result.best_score >= 0.0);
720        assert!(result.best_score <= 1.0);
721    }
722
723    #[test]
724    fn test_cost_sensitive_optimization() {
725        // More realistic data with clear cost-optimal threshold
726        // Class 1 probabilities: [0.9, 0.6, 0.7, 0.4, 0.8]
727        // True labels:           [1,   0,   1,   0,   1]
728        let y_proba = array![
729            [0.1, 0.9], // true=1, proba=0.9
730            [0.4, 0.6], // true=0, proba=0.6 (costly FP if threshold < 0.6)
731            [0.3, 0.7], // true=1, proba=0.7
732            [0.6, 0.4], // true=0, proba=0.4
733            [0.2, 0.8], // true=1, proba=0.8
734        ];
735        let y_true = vec![1, 0, 1, 0, 1];
736
737        // High FP cost should push threshold higher
738        let result = optimize_threshold(
739            &y_true,
740            &y_proba,
741            OptimizationMetric::Cost {
742                fp_cost: 10.0,
743                fn_cost: 1.0,
744            },
745            50,
746            1,
747        )
748        .unwrap();
749
750        // Threshold < 0.6: includes sample 1 as FP → cost = 10
751        // Threshold >= 0.6 and < 0.7: no FP, no FN → cost = 0
752        // Threshold >= 0.7 and < 0.8: includes sample 4 as FN → cost = 1
753        // Optimal is threshold in [0.6, 0.7), but algorithm returns first
754        assert!(
755            result.best_threshold >= 0.6,
756            "Expected threshold >= 0.6, got {}",
757            result.best_threshold
758        );
759        assert!(
760            result.best_score >= -0.1,
761            "Expected near-zero cost (score >= -0.1), got {}",
762            result.best_score
763        );
764    }
765}