scirs2_metrics/classification/
threshold_analyzer.rs

1//! Threshold analysis for binary classification
2//!
3//! This module provides tools for analyzing binary classification performance
4//! across different thresholds and finding optimal thresholds based on various
5//! metrics and strategies.
6
7use scirs2_core::ndarray::{ArrayBase, Data, Dimension, Ix1};
8use std::collections::HashMap;
9use std::hash::{Hash, Hasher};
10
11use crate::classification::curves::roc_curve;
12use crate::error::{MetricsError, Result};
13
14/// Metrics calculated at a specific threshold
15#[derive(Debug, Clone)]
16pub struct ThresholdMetrics {
17    /// Classification threshold value
18    pub threshold: f64,
19    /// True positive rate (sensitivity, recall)
20    pub tpr: f64,
21    /// False positive rate (1 - specificity)
22    pub fpr: f64,
23    /// Precision (positive predictive value)
24    pub precision: f64,
25    /// F1 score (harmonic mean of precision and recall)
26    pub f1_score: f64,
27    /// Accuracy (correct predictions / total predictions)
28    pub accuracy: f64,
29    /// Specificity (true negative rate)
30    pub specificity: f64,
31    /// Negative predictive value
32    pub npv: f64,
33    /// Matthews correlation coefficient
34    pub mcc: f64,
35    /// Cohen's kappa coefficient
36    pub kappa: f64,
37    /// Youden's J statistic (sensitivity + specificity - 1)
38    pub youdens_j: f64,
39    /// Balanced accuracy
40    pub balanced_accuracy: f64,
41    /// Count of true positives
42    pub tp: usize,
43    /// Count of false positives
44    pub fp: usize,
45    /// Count of true negatives
46    pub tn: usize,
47    /// Count of false negatives
48    pub fn_: usize,
49}
50
51/// Strategies for finding optimal threshold
52#[derive(Debug, Clone, Copy, PartialEq)]
53pub enum OptimalThresholdStrategy {
54    /// Maximize F1 score
55    MaxF1,
56    /// Maximize Youden's J statistic (sensitivity + specificity - 1)
57    YoudensJ,
58    /// Maximize accuracy
59    MaxAccuracy,
60    /// Maximize Matthews correlation coefficient
61    MaxMCC,
62    /// Maximize Cohen's kappa
63    MaxKappa,
64    /// Balance sensitivity and specificity
65    BalancedSensSpec,
66    /// Balance precision and recall
67    BalancedPrecRecall,
68    /// Minimize distance to perfect classifier (0,1) in ROC space
69    MinDistanceToOptimal,
70    /// Use a specific threshold value
71    Manual(f64),
72}
73
74impl Hash for OptimalThresholdStrategy {
75    fn hash<H: Hasher>(&self, state: &mut H) {
76        match self {
77            OptimalThresholdStrategy::MaxF1 => 0.hash(state),
78            OptimalThresholdStrategy::YoudensJ => 1.hash(state),
79            OptimalThresholdStrategy::MaxAccuracy => 2.hash(state),
80            OptimalThresholdStrategy::MaxMCC => 3.hash(state),
81            OptimalThresholdStrategy::MaxKappa => 4.hash(state),
82            OptimalThresholdStrategy::BalancedSensSpec => 5.hash(state),
83            OptimalThresholdStrategy::BalancedPrecRecall => 6.hash(state),
84            OptimalThresholdStrategy::MinDistanceToOptimal => 7.hash(state),
85            OptimalThresholdStrategy::Manual(val) => {
86                8.hash(state);
87                val.to_bits().hash(state);
88            }
89        }
90    }
91}
92
93impl Eq for OptimalThresholdStrategy {}
94
95/// Analyzer for binary classification thresholds
96///
97/// This struct provides tools for analyzing binary classification performance
98/// across different thresholds and finding optimal thresholds based on various
99/// metrics and strategies.
100#[derive(Debug)]
101pub struct ThresholdAnalyzer {
102    /// True positive rates
103    tpr: Vec<f64>,
104    /// False positive rates
105    fpr: Vec<f64>,
106    /// Thresholds
107    thresholds: Vec<f64>,
108    /// Raw true labels
109    y_true: Vec<f64>,
110    /// Raw score predictions
111    y_score: Vec<f64>,
112    /// Metrics at each threshold
113    metrics: Option<Vec<ThresholdMetrics>>,
114    /// Optimal thresholds for each strategy
115    optimal_thresholds: HashMap<OptimalThresholdStrategy, usize>,
116}
117
118impl ThresholdAnalyzer {
119    /// Create a new threshold analyzer from true labels and scores
120    ///
121    /// # Arguments
122    ///
123    /// * `y_true` - Binary true labels (0 or 1)
124    /// * `y_score` - Predicted scores (probabilities)
125    ///
126    /// # Returns
127    ///
128    /// * `Result<ThresholdAnalyzer>` - The analyzer or error
129    pub fn new<D1, D2, S1, S2>(
130        y_true: &ArrayBase<S1, D1>,
131        y_score: &ArrayBase<S2, D2>,
132    ) -> Result<Self>
133    where
134        S1: Data,
135        S2: Data,
136        D1: Dimension,
137        D2: Dimension,
138        S1::Elem: Clone + Into<f64> + PartialEq,
139        S2::Elem: Clone + Into<f64> + PartialOrd,
140    {
141        // Compute ROC curve
142        let (fpr, tpr, thresholds) = roc_curve(y_true, y_score)?;
143
144        // Convert arrays to vectors for storage
145        let fpr = fpr.to_vec();
146        let tpr = tpr.to_vec();
147        let thresholds = thresholds.to_vec();
148
149        // Store original data
150        let y_true = y_true
151            .iter()
152            .map(|x| x.clone().into())
153            .collect::<Vec<f64>>();
154        let y_score = y_score
155            .iter()
156            .map(|x| x.clone().into())
157            .collect::<Vec<f64>>();
158
159        Ok(Self {
160            tpr,
161            fpr,
162            thresholds,
163            y_true,
164            y_score,
165            metrics: None,
166            optimal_thresholds: HashMap::new(),
167        })
168    }
169
170    /// Create a new threshold analyzer from pre-computed ROC curve
171    ///
172    /// # Arguments
173    ///
174    /// * `fpr` - False positive rates
175    /// * `tpr` - True positive rates
176    /// * `thresholds` - Thresholds
177    /// * `y_true` - Binary true labels (0 or 1)
178    /// * `y_score` - Predicted scores (probabilities)
179    ///
180    /// # Returns
181    ///
182    /// * `Result<ThresholdAnalyzer>` - The analyzer or error
183    pub fn from_roc_curve<D1, D2, S1, S2, S3, S4, S5, D3, D4, D5>(
184        fpr: &ArrayBase<S1, D1>,
185        tpr: &ArrayBase<S2, D2>,
186        thresholds: &ArrayBase<S3, D3>,
187        y_true: &ArrayBase<S4, D4>,
188        y_score: &ArrayBase<S5, D5>,
189    ) -> Result<Self>
190    where
191        S1: Data<Elem = f64>,
192        S2: Data<Elem = f64>,
193        S3: Data<Elem = f64>,
194        S4: Data,
195        S5: Data,
196        D1: Dimension,
197        D2: Dimension,
198        D3: Dimension,
199        D4: Dimension,
200        D5: Dimension,
201        S4::Elem: Clone + Into<f64>,
202        S5::Elem: Clone + Into<f64>,
203    {
204        // Convert arrays to vectors for storage
205        let fpr = fpr.iter().cloned().collect::<Vec<f64>>();
206        let tpr = tpr.iter().cloned().collect::<Vec<f64>>();
207        let thresholds = thresholds.iter().cloned().collect::<Vec<f64>>();
208
209        // Store original data
210        let y_true = y_true
211            .iter()
212            .map(|x| x.clone().into())
213            .collect::<Vec<f64>>();
214        let y_score = y_score
215            .iter()
216            .map(|x| x.clone().into())
217            .collect::<Vec<f64>>();
218
219        // Ensure proper shape
220        if fpr.len() != tpr.len() || fpr.len() != thresholds.len() {
221            return Err(MetricsError::ShapeMismatch {
222                shape1: format!("fpr: {}", fpr.len()),
223                shape2: format!("tpr: {}, thresholds: {}", tpr.len(), thresholds.len()),
224            });
225        }
226
227        Ok(Self {
228            tpr,
229            fpr,
230            thresholds,
231            y_true,
232            y_score,
233            metrics: None,
234            optimal_thresholds: HashMap::new(),
235        })
236    }
237
238    /// Calculate metrics at all thresholds
239    ///
240    /// # Returns
241    ///
242    /// * `Result<&[ThresholdMetrics]>` - Metrics at all thresholds
243    pub fn calculate_metrics(&mut self) -> Result<&[ThresholdMetrics]> {
244        // If metrics are already calculated, return them
245        if let Some(ref metrics) = self.metrics {
246            return Ok(metrics);
247        }
248
249        // Calculate metrics for each threshold
250        let mut metrics = Vec::with_capacity(self.thresholds.len());
251
252        for &threshold in self.thresholds.iter() {
253            // Count TP, FP, TN, FN
254            let mut tp = 0;
255            let mut fp = 0;
256            let mut tn = 0;
257            let mut fn_ = 0;
258
259            for (&true_val, &score) in self.y_true.iter().zip(&self.y_score) {
260                let pred = if score >= threshold { 1.0 } else { 0.0 };
261
262                match (true_val, pred) {
263                    (1.0, 1.0) => tp += 1,
264                    (0.0, 1.0) => fp += 1,
265                    (0.0, 0.0) => tn += 1,
266                    (1.0, 0.0) => fn_ += 1,
267                    _ => {
268                        return Err(MetricsError::InvalidArgument(format!(
269                            "Invalid true value: {true_val}"
270                        )));
271                    }
272                }
273            }
274
275            // Calculate metrics
276            let tpr = if tp + fn_ > 0 {
277                tp as f64 / (tp + fn_) as f64
278            } else {
279                0.0
280            };
281            let fpr = if fp + tn > 0 {
282                fp as f64 / (fp + tn) as f64
283            } else {
284                0.0
285            };
286            let precision = if tp + fp > 0 {
287                tp as f64 / (tp + fp) as f64
288            } else {
289                0.0
290            };
291            let f1_score = if precision + tpr > 0.0 {
292                2.0 * precision * tpr / (precision + tpr)
293            } else {
294                0.0
295            };
296            let accuracy = (tp + tn) as f64 / (tp + fp + tn + fn_) as f64;
297            let specificity = if tn + fp > 0 {
298                tn as f64 / (tn + fp) as f64
299            } else {
300                0.0
301            };
302            let npv = if tn + fn_ > 0 {
303                tn as f64 / (tn + fn_) as f64
304            } else {
305                0.0
306            };
307            let youdens_j = tpr + specificity - 1.0;
308            let balanced_accuracy = (tpr + specificity) / 2.0;
309
310            // Matthews correlation coefficient
311            let mcc_numerator = (tp * tn) as f64 - (fp * fn_) as f64;
312            let mcc_denominator = ((tp + fp) * (tp + fn_) * (tn + fp) * (tn + fn_)) as f64;
313            let mcc = if mcc_denominator > 0.0 {
314                mcc_numerator / mcc_denominator.sqrt()
315            } else {
316                0.0
317            };
318
319            // Cohen's kappa
320            let p_o = accuracy;
321            let p_e = (((tp + fp) as f64 / (tp + fp + tn + fn_) as f64)
322                * ((tp + fn_) as f64 / (tp + fp + tn + fn_) as f64))
323                + (((tn + fn_) as f64 / (tp + fp + tn + fn_) as f64)
324                    * ((tn + fp) as f64 / (tp + fp + tn + fn_) as f64));
325            let kappa = if p_e < 1.0 {
326                (p_o - p_e) / (1.0 - p_e)
327            } else {
328                0.0
329            };
330
331            metrics.push(ThresholdMetrics {
332                threshold,
333                tpr,
334                fpr,
335                precision,
336                f1_score,
337                accuracy,
338                specificity,
339                npv,
340                mcc,
341                kappa,
342                youdens_j,
343                balanced_accuracy,
344                tp,
345                fp,
346                tn,
347                fn_,
348            });
349        }
350
351        self.metrics = Some(metrics);
352        Ok(self.metrics.as_ref().unwrap())
353    }
354
355    /// Find optimal threshold based on a given strategy
356    ///
357    /// # Arguments
358    ///
359    /// * `strategy` - Strategy for finding optimal threshold
360    ///
361    /// # Returns
362    ///
363    /// * `Result<(f64, ThresholdMetrics)>` - Optimal threshold and its metrics
364    pub fn find_optimal_threshold(
365        &mut self,
366        strategy: OptimalThresholdStrategy,
367    ) -> Result<(f64, ThresholdMetrics)> {
368        // Check if optimal threshold for this strategy is already calculated first
369        if let Some(&idx) = self.optimal_thresholds.get(&strategy) {
370            self.calculate_metrics()?;
371            let threshold = self.thresholds[idx];
372            let metrics = self.metrics.as_ref().unwrap();
373            return Ok((threshold, metrics[idx].clone()));
374        }
375
376        // Calculate metrics for finding optimal
377        self.calculate_metrics()?;
378        let metrics = self.metrics.as_ref().unwrap();
379
380        // Find optimal threshold based on strategy
381        let optimal_idx = match strategy {
382            OptimalThresholdStrategy::MaxF1 => metrics
383                .iter()
384                .enumerate()
385                .max_by(|(_, a), (_, b)| a.f1_score.partial_cmp(&b.f1_score).unwrap())
386                .map(|(idx, _)| idx)
387                .unwrap_or(0),
388            OptimalThresholdStrategy::YoudensJ => metrics
389                .iter()
390                .enumerate()
391                .max_by(|(_, a), (_, b)| a.youdens_j.partial_cmp(&b.youdens_j).unwrap())
392                .map(|(idx, _)| idx)
393                .unwrap_or(0),
394            OptimalThresholdStrategy::MaxAccuracy => metrics
395                .iter()
396                .enumerate()
397                .max_by(|(_, a), (_, b)| a.accuracy.partial_cmp(&b.accuracy).unwrap())
398                .map(|(idx, _)| idx)
399                .unwrap_or(0),
400            OptimalThresholdStrategy::MaxMCC => metrics
401                .iter()
402                .enumerate()
403                .max_by(|(_, a), (_, b)| a.mcc.partial_cmp(&b.mcc).unwrap())
404                .map(|(idx, _)| idx)
405                .unwrap_or(0),
406            OptimalThresholdStrategy::MaxKappa => metrics
407                .iter()
408                .enumerate()
409                .max_by(|(_, a), (_, b)| a.kappa.partial_cmp(&b.kappa).unwrap())
410                .map(|(idx, _)| idx)
411                .unwrap_or(0),
412            OptimalThresholdStrategy::BalancedSensSpec => metrics
413                .iter()
414                .enumerate()
415                .min_by(|(_, a), (_, b)| {
416                    let a_diff = (a.tpr - a.specificity).abs();
417                    let b_diff = (b.tpr - b.specificity).abs();
418                    a_diff.partial_cmp(&b_diff).unwrap()
419                })
420                .map(|(idx, _)| idx)
421                .unwrap_or(0),
422            OptimalThresholdStrategy::BalancedPrecRecall => metrics
423                .iter()
424                .enumerate()
425                .min_by(|(_, a), (_, b)| {
426                    let a_diff = (a.precision - a.tpr).abs();
427                    let b_diff = (b.precision - b.tpr).abs();
428                    a_diff.partial_cmp(&b_diff).unwrap()
429                })
430                .map(|(idx, _)| idx)
431                .unwrap_or(0),
432            OptimalThresholdStrategy::MinDistanceToOptimal => metrics
433                .iter()
434                .enumerate()
435                .min_by(|(_, a), (_, b)| {
436                    let a_dist = (a.fpr.powi(2) + (1.0 - a.tpr).powi(2)).sqrt();
437                    let b_dist = (b.fpr.powi(2) + (1.0 - b.tpr).powi(2)).sqrt();
438                    a_dist.partial_cmp(&b_dist).unwrap()
439                })
440                .map(|(idx, _)| idx)
441                .unwrap_or(0),
442            OptimalThresholdStrategy::Manual(threshold) => {
443                // Find the closest threshold
444                metrics
445                    .iter()
446                    .enumerate()
447                    .min_by(|(_, a), (_, b)| {
448                        let a_diff = (a.threshold - threshold).abs();
449                        let b_diff = (b.threshold - threshold).abs();
450                        a_diff.partial_cmp(&b_diff).unwrap()
451                    })
452                    .map(|(idx, _)| idx)
453                    .unwrap_or(0)
454            }
455        };
456
457        // Store the threshold and metrics before borrowing conflicts
458        let threshold = self.thresholds[optimal_idx];
459        let metric = metrics[optimal_idx].clone();
460
461        // Store optimal threshold for this strategy
462        self.optimal_thresholds.insert(strategy, optimal_idx);
463
464        Ok((threshold, metric))
465    }
466
467    /// Get metrics at a specific threshold
468    ///
469    /// # Arguments
470    ///
471    /// * `threshold` - Threshold value
472    ///
473    /// # Returns
474    ///
475    /// * `Result<ThresholdMetrics>` - Metrics at the threshold
476    pub fn get_metrics_at_threshold(&mut self, threshold: f64) -> Result<ThresholdMetrics> {
477        // Calculate metrics if not already calculated
478        self.calculate_metrics()?;
479
480        // Find the closest threshold index
481        let idx = self
482            .thresholds
483            .iter()
484            .enumerate()
485            .min_by(|(_, &a), (_, &b)| {
486                let a_diff = (a - threshold).abs();
487                let b_diff = (b - threshold).abs();
488                a_diff.partial_cmp(&b_diff).unwrap()
489            })
490            .map(|(idx, _)| idx)
491            .unwrap_or(0);
492
493        // Get metrics - we know it's calculated at this point
494        let metrics = self.metrics.as_ref().unwrap();
495        Ok(metrics[idx].clone())
496    }
497
498    /// Get all threshold metrics
499    ///
500    /// # Returns
501    ///
502    /// * `Result<&[ThresholdMetrics]>` - All threshold metrics
503    pub fn get_all_metrics(&mut self) -> Result<&[ThresholdMetrics]> {
504        self.calculate_metrics()
505    }
506
507    /// Get thresholds
508    ///
509    /// # Returns
510    ///
511    /// * `&[f64]` - Thresholds
512    pub fn get_thresholds(&self) -> &[f64] {
513        &self.thresholds
514    }
515
516    /// Get false positive rates
517    ///
518    /// # Returns
519    ///
520    /// * `&[f64]` - False positive rates
521    pub fn get_fpr(&self) -> &[f64] {
522        &self.fpr
523    }
524
525    /// Get true positive rates
526    ///
527    /// # Returns
528    ///
529    /// * `&[f64]` - True positive rates
530    pub fn get_tpr(&self) -> &[f64] {
531        &self.tpr
532    }
533
534    /// Convert metrics to a specific column data structure
535    ///
536    /// # Arguments
537    ///
538    /// * `metric_name` - Name of the metric to extract
539    ///
540    /// # Returns
541    ///
542    /// * `Result<Vec<f64>>` - Values of the specified metric
543    pub fn get_metric_values(&mut self, metricname: &str) -> Result<Vec<f64>> {
544        let metrics = self.calculate_metrics()?;
545
546        let values = match metricname {
547            "threshold" => metrics.iter().map(|m| m.threshold).collect(),
548            "tpr" | "recall" | "sensitivity" => metrics.iter().map(|m| m.tpr).collect(),
549            "fpr" => metrics.iter().map(|m| m.fpr).collect(),
550            "precision" => metrics.iter().map(|m| m.precision).collect(),
551            "f1_score" | "f1" => metrics.iter().map(|m| m.f1_score).collect(),
552            "accuracy" => metrics.iter().map(|m| m.accuracy).collect(),
553            "specificity" => metrics.iter().map(|m| m.specificity).collect(),
554            "npv" => metrics.iter().map(|m| m.npv).collect(),
555            "mcc" => metrics.iter().map(|m| m.mcc).collect(),
556            "kappa" => metrics.iter().map(|m| m.kappa).collect(),
557            "youdens_j" | "j" => metrics.iter().map(|m| m.youdens_j).collect(),
558            "balanced_accuracy" => metrics.iter().map(|m| m.balanced_accuracy).collect(),
559            _ => {
560                return Err(MetricsError::InvalidArgument(format!(
561                    "Unknown metric: {metricname}"
562                )))
563            }
564        };
565
566        Ok(values)
567    }
568
569    /// Get metric names
570    ///
571    /// # Returns
572    ///
573    /// * `Vec<String>` - Names of available metrics
574    pub fn get_metric_names() -> Vec<String> {
575        vec![
576            "threshold".to_string(),
577            "tpr".to_string(),
578            "fpr".to_string(),
579            "precision".to_string(),
580            "f1_score".to_string(),
581            "accuracy".to_string(),
582            "specificity".to_string(),
583            "npv".to_string(),
584            "mcc".to_string(),
585            "kappa".to_string(),
586            "youdens_j".to_string(),
587            "balanced_accuracy".to_string(),
588        ]
589    }
590}
591
592/// Find the optimal threshold for binary classification
593///
594/// # Arguments
595///
596/// * `y_true` - Binary true labels (0 or 1)
597/// * `y_score` - Predicted scores (probabilities)
598/// * `strategy` - Strategy for finding optimal threshold
599///
600/// # Returns
601///
602/// * `Result<(f64, ThresholdMetrics)>` - Optimal threshold and its metrics
603#[allow(dead_code)]
604pub fn find_optimal_threshold<S1, S2>(
605    y_true: &ArrayBase<S1, Ix1>,
606    y_score: &ArrayBase<S2, Ix1>,
607    strategy: OptimalThresholdStrategy,
608) -> Result<(f64, ThresholdMetrics)>
609where
610    S1: Data,
611    S2: Data,
612    S1::Elem: Clone + Into<f64> + PartialEq,
613    S2::Elem: Clone + Into<f64> + PartialOrd,
614{
615    let mut analyzer = ThresholdAnalyzer::new(y_true, y_score)?;
616    let (threshold, metrics) = analyzer.find_optimal_threshold(strategy)?;
617    Ok((threshold, metrics.clone()))
618}
619
620/// Get metrics at a specific threshold
621///
622/// # Arguments
623///
624/// * `y_true` - Binary true labels (0 or 1)
625/// * `y_score` - Predicted scores (probabilities)
626/// * `threshold` - Specific threshold to evaluate
627///
628/// # Returns
629///
630/// * `Result<ThresholdMetrics>` - Metrics at the specified threshold
631#[allow(dead_code)]
632pub fn threshold_metrics<S1, S2>(
633    y_true: &ArrayBase<S1, Ix1>,
634    y_score: &ArrayBase<S2, Ix1>,
635    threshold: f64,
636) -> Result<ThresholdMetrics>
637where
638    S1: Data,
639    S2: Data,
640    S1::Elem: Clone + Into<f64> + PartialEq,
641    S2::Elem: Clone + Into<f64> + PartialOrd,
642{
643    let mut analyzer = ThresholdAnalyzer::new(y_true, y_score)?;
644    let metrics = analyzer.get_metrics_at_threshold(threshold)?;
645    Ok(metrics.clone())
646}
647
648/// Calculate metrics at all possible thresholds
649///
650/// # Arguments
651///
652/// * `y_true` - Binary true labels (0 or 1)
653/// * `y_score` - Predicted scores (probabilities)
654///
655/// # Returns
656///
657/// * `Result<Vec<ThresholdMetrics>>` - Metrics at all thresholds
658#[allow(dead_code)]
659pub fn all_threshold_metrics<S1, S2>(
660    y_true: &ArrayBase<S1, Ix1>,
661    y_score: &ArrayBase<S2, Ix1>,
662) -> Result<Vec<ThresholdMetrics>>
663where
664    S1: Data,
665    S2: Data,
666    S1::Elem: Clone + Into<f64> + PartialEq,
667    S2::Elem: Clone + Into<f64> + PartialOrd,
668{
669    let mut analyzer = ThresholdAnalyzer::new(y_true, y_score)?;
670    let metrics = analyzer.get_all_metrics()?;
671    Ok(metrics.to_vec())
672}