1use scirs2_core::ndarray::{ArrayBase, Data, Dimension, Ix1};
8use std::collections::HashMap;
9use std::hash::{Hash, Hasher};
10
11use crate::classification::curves::roc_curve;
12use crate::error::{MetricsError, Result};
13
14#[derive(Debug, Clone)]
16pub struct ThresholdMetrics {
17 pub threshold: f64,
19 pub tpr: f64,
21 pub fpr: f64,
23 pub precision: f64,
25 pub f1_score: f64,
27 pub accuracy: f64,
29 pub specificity: f64,
31 pub npv: f64,
33 pub mcc: f64,
35 pub kappa: f64,
37 pub youdens_j: f64,
39 pub balanced_accuracy: f64,
41 pub tp: usize,
43 pub fp: usize,
45 pub tn: usize,
47 pub fn_: usize,
49}
50
51#[derive(Debug, Clone, Copy, PartialEq)]
53pub enum OptimalThresholdStrategy {
54 MaxF1,
56 YoudensJ,
58 MaxAccuracy,
60 MaxMCC,
62 MaxKappa,
64 BalancedSensSpec,
66 BalancedPrecRecall,
68 MinDistanceToOptimal,
70 Manual(f64),
72}
73
74impl Hash for OptimalThresholdStrategy {
75 fn hash<H: Hasher>(&self, state: &mut H) {
76 match self {
77 OptimalThresholdStrategy::MaxF1 => 0.hash(state),
78 OptimalThresholdStrategy::YoudensJ => 1.hash(state),
79 OptimalThresholdStrategy::MaxAccuracy => 2.hash(state),
80 OptimalThresholdStrategy::MaxMCC => 3.hash(state),
81 OptimalThresholdStrategy::MaxKappa => 4.hash(state),
82 OptimalThresholdStrategy::BalancedSensSpec => 5.hash(state),
83 OptimalThresholdStrategy::BalancedPrecRecall => 6.hash(state),
84 OptimalThresholdStrategy::MinDistanceToOptimal => 7.hash(state),
85 OptimalThresholdStrategy::Manual(val) => {
86 8.hash(state);
87 val.to_bits().hash(state);
88 }
89 }
90 }
91}
92
93impl Eq for OptimalThresholdStrategy {}
94
95#[derive(Debug)]
101pub struct ThresholdAnalyzer {
102 tpr: Vec<f64>,
104 fpr: Vec<f64>,
106 thresholds: Vec<f64>,
108 y_true: Vec<f64>,
110 y_score: Vec<f64>,
112 metrics: Option<Vec<ThresholdMetrics>>,
114 optimal_thresholds: HashMap<OptimalThresholdStrategy, usize>,
116}
117
118impl ThresholdAnalyzer {
119 pub fn new<D1, D2, S1, S2>(
130 y_true: &ArrayBase<S1, D1>,
131 y_score: &ArrayBase<S2, D2>,
132 ) -> Result<Self>
133 where
134 S1: Data,
135 S2: Data,
136 D1: Dimension,
137 D2: Dimension,
138 S1::Elem: Clone + Into<f64> + PartialEq,
139 S2::Elem: Clone + Into<f64> + PartialOrd,
140 {
141 let (fpr, tpr, thresholds) = roc_curve(y_true, y_score)?;
143
144 let fpr = fpr.to_vec();
146 let tpr = tpr.to_vec();
147 let thresholds = thresholds.to_vec();
148
149 let y_true = y_true
151 .iter()
152 .map(|x| x.clone().into())
153 .collect::<Vec<f64>>();
154 let y_score = y_score
155 .iter()
156 .map(|x| x.clone().into())
157 .collect::<Vec<f64>>();
158
159 Ok(Self {
160 tpr,
161 fpr,
162 thresholds,
163 y_true,
164 y_score,
165 metrics: None,
166 optimal_thresholds: HashMap::new(),
167 })
168 }
169
170 pub fn from_roc_curve<D1, D2, S1, S2, S3, S4, S5, D3, D4, D5>(
184 fpr: &ArrayBase<S1, D1>,
185 tpr: &ArrayBase<S2, D2>,
186 thresholds: &ArrayBase<S3, D3>,
187 y_true: &ArrayBase<S4, D4>,
188 y_score: &ArrayBase<S5, D5>,
189 ) -> Result<Self>
190 where
191 S1: Data<Elem = f64>,
192 S2: Data<Elem = f64>,
193 S3: Data<Elem = f64>,
194 S4: Data,
195 S5: Data,
196 D1: Dimension,
197 D2: Dimension,
198 D3: Dimension,
199 D4: Dimension,
200 D5: Dimension,
201 S4::Elem: Clone + Into<f64>,
202 S5::Elem: Clone + Into<f64>,
203 {
204 let fpr = fpr.iter().cloned().collect::<Vec<f64>>();
206 let tpr = tpr.iter().cloned().collect::<Vec<f64>>();
207 let thresholds = thresholds.iter().cloned().collect::<Vec<f64>>();
208
209 let y_true = y_true
211 .iter()
212 .map(|x| x.clone().into())
213 .collect::<Vec<f64>>();
214 let y_score = y_score
215 .iter()
216 .map(|x| x.clone().into())
217 .collect::<Vec<f64>>();
218
219 if fpr.len() != tpr.len() || fpr.len() != thresholds.len() {
221 return Err(MetricsError::ShapeMismatch {
222 shape1: format!("fpr: {}", fpr.len()),
223 shape2: format!("tpr: {}, thresholds: {}", tpr.len(), thresholds.len()),
224 });
225 }
226
227 Ok(Self {
228 tpr,
229 fpr,
230 thresholds,
231 y_true,
232 y_score,
233 metrics: None,
234 optimal_thresholds: HashMap::new(),
235 })
236 }
237
238 pub fn calculate_metrics(&mut self) -> Result<&[ThresholdMetrics]> {
244 if let Some(ref metrics) = self.metrics {
246 return Ok(metrics);
247 }
248
249 let mut metrics = Vec::with_capacity(self.thresholds.len());
251
252 for &threshold in self.thresholds.iter() {
253 let mut tp = 0;
255 let mut fp = 0;
256 let mut tn = 0;
257 let mut fn_ = 0;
258
259 for (&true_val, &score) in self.y_true.iter().zip(&self.y_score) {
260 let pred = if score >= threshold { 1.0 } else { 0.0 };
261
262 match (true_val, pred) {
263 (1.0, 1.0) => tp += 1,
264 (0.0, 1.0) => fp += 1,
265 (0.0, 0.0) => tn += 1,
266 (1.0, 0.0) => fn_ += 1,
267 _ => {
268 return Err(MetricsError::InvalidArgument(format!(
269 "Invalid true value: {true_val}"
270 )));
271 }
272 }
273 }
274
275 let tpr = if tp + fn_ > 0 {
277 tp as f64 / (tp + fn_) as f64
278 } else {
279 0.0
280 };
281 let fpr = if fp + tn > 0 {
282 fp as f64 / (fp + tn) as f64
283 } else {
284 0.0
285 };
286 let precision = if tp + fp > 0 {
287 tp as f64 / (tp + fp) as f64
288 } else {
289 0.0
290 };
291 let f1_score = if precision + tpr > 0.0 {
292 2.0 * precision * tpr / (precision + tpr)
293 } else {
294 0.0
295 };
296 let accuracy = (tp + tn) as f64 / (tp + fp + tn + fn_) as f64;
297 let specificity = if tn + fp > 0 {
298 tn as f64 / (tn + fp) as f64
299 } else {
300 0.0
301 };
302 let npv = if tn + fn_ > 0 {
303 tn as f64 / (tn + fn_) as f64
304 } else {
305 0.0
306 };
307 let youdens_j = tpr + specificity - 1.0;
308 let balanced_accuracy = (tpr + specificity) / 2.0;
309
310 let mcc_numerator = (tp * tn) as f64 - (fp * fn_) as f64;
312 let mcc_denominator = ((tp + fp) * (tp + fn_) * (tn + fp) * (tn + fn_)) as f64;
313 let mcc = if mcc_denominator > 0.0 {
314 mcc_numerator / mcc_denominator.sqrt()
315 } else {
316 0.0
317 };
318
319 let p_o = accuracy;
321 let p_e = (((tp + fp) as f64 / (tp + fp + tn + fn_) as f64)
322 * ((tp + fn_) as f64 / (tp + fp + tn + fn_) as f64))
323 + (((tn + fn_) as f64 / (tp + fp + tn + fn_) as f64)
324 * ((tn + fp) as f64 / (tp + fp + tn + fn_) as f64));
325 let kappa = if p_e < 1.0 {
326 (p_o - p_e) / (1.0 - p_e)
327 } else {
328 0.0
329 };
330
331 metrics.push(ThresholdMetrics {
332 threshold,
333 tpr,
334 fpr,
335 precision,
336 f1_score,
337 accuracy,
338 specificity,
339 npv,
340 mcc,
341 kappa,
342 youdens_j,
343 balanced_accuracy,
344 tp,
345 fp,
346 tn,
347 fn_,
348 });
349 }
350
351 self.metrics = Some(metrics);
352 Ok(self.metrics.as_ref().unwrap())
353 }
354
355 pub fn find_optimal_threshold(
365 &mut self,
366 strategy: OptimalThresholdStrategy,
367 ) -> Result<(f64, ThresholdMetrics)> {
368 if let Some(&idx) = self.optimal_thresholds.get(&strategy) {
370 self.calculate_metrics()?;
371 let threshold = self.thresholds[idx];
372 let metrics = self.metrics.as_ref().unwrap();
373 return Ok((threshold, metrics[idx].clone()));
374 }
375
376 self.calculate_metrics()?;
378 let metrics = self.metrics.as_ref().unwrap();
379
380 let optimal_idx = match strategy {
382 OptimalThresholdStrategy::MaxF1 => metrics
383 .iter()
384 .enumerate()
385 .max_by(|(_, a), (_, b)| a.f1_score.partial_cmp(&b.f1_score).unwrap())
386 .map(|(idx, _)| idx)
387 .unwrap_or(0),
388 OptimalThresholdStrategy::YoudensJ => metrics
389 .iter()
390 .enumerate()
391 .max_by(|(_, a), (_, b)| a.youdens_j.partial_cmp(&b.youdens_j).unwrap())
392 .map(|(idx, _)| idx)
393 .unwrap_or(0),
394 OptimalThresholdStrategy::MaxAccuracy => metrics
395 .iter()
396 .enumerate()
397 .max_by(|(_, a), (_, b)| a.accuracy.partial_cmp(&b.accuracy).unwrap())
398 .map(|(idx, _)| idx)
399 .unwrap_or(0),
400 OptimalThresholdStrategy::MaxMCC => metrics
401 .iter()
402 .enumerate()
403 .max_by(|(_, a), (_, b)| a.mcc.partial_cmp(&b.mcc).unwrap())
404 .map(|(idx, _)| idx)
405 .unwrap_or(0),
406 OptimalThresholdStrategy::MaxKappa => metrics
407 .iter()
408 .enumerate()
409 .max_by(|(_, a), (_, b)| a.kappa.partial_cmp(&b.kappa).unwrap())
410 .map(|(idx, _)| idx)
411 .unwrap_or(0),
412 OptimalThresholdStrategy::BalancedSensSpec => metrics
413 .iter()
414 .enumerate()
415 .min_by(|(_, a), (_, b)| {
416 let a_diff = (a.tpr - a.specificity).abs();
417 let b_diff = (b.tpr - b.specificity).abs();
418 a_diff.partial_cmp(&b_diff).unwrap()
419 })
420 .map(|(idx, _)| idx)
421 .unwrap_or(0),
422 OptimalThresholdStrategy::BalancedPrecRecall => metrics
423 .iter()
424 .enumerate()
425 .min_by(|(_, a), (_, b)| {
426 let a_diff = (a.precision - a.tpr).abs();
427 let b_diff = (b.precision - b.tpr).abs();
428 a_diff.partial_cmp(&b_diff).unwrap()
429 })
430 .map(|(idx, _)| idx)
431 .unwrap_or(0),
432 OptimalThresholdStrategy::MinDistanceToOptimal => metrics
433 .iter()
434 .enumerate()
435 .min_by(|(_, a), (_, b)| {
436 let a_dist = (a.fpr.powi(2) + (1.0 - a.tpr).powi(2)).sqrt();
437 let b_dist = (b.fpr.powi(2) + (1.0 - b.tpr).powi(2)).sqrt();
438 a_dist.partial_cmp(&b_dist).unwrap()
439 })
440 .map(|(idx, _)| idx)
441 .unwrap_or(0),
442 OptimalThresholdStrategy::Manual(threshold) => {
443 metrics
445 .iter()
446 .enumerate()
447 .min_by(|(_, a), (_, b)| {
448 let a_diff = (a.threshold - threshold).abs();
449 let b_diff = (b.threshold - threshold).abs();
450 a_diff.partial_cmp(&b_diff).unwrap()
451 })
452 .map(|(idx, _)| idx)
453 .unwrap_or(0)
454 }
455 };
456
457 let threshold = self.thresholds[optimal_idx];
459 let metric = metrics[optimal_idx].clone();
460
461 self.optimal_thresholds.insert(strategy, optimal_idx);
463
464 Ok((threshold, metric))
465 }
466
467 pub fn get_metrics_at_threshold(&mut self, threshold: f64) -> Result<ThresholdMetrics> {
477 self.calculate_metrics()?;
479
480 let idx = self
482 .thresholds
483 .iter()
484 .enumerate()
485 .min_by(|(_, &a), (_, &b)| {
486 let a_diff = (a - threshold).abs();
487 let b_diff = (b - threshold).abs();
488 a_diff.partial_cmp(&b_diff).unwrap()
489 })
490 .map(|(idx, _)| idx)
491 .unwrap_or(0);
492
493 let metrics = self.metrics.as_ref().unwrap();
495 Ok(metrics[idx].clone())
496 }
497
498 pub fn get_all_metrics(&mut self) -> Result<&[ThresholdMetrics]> {
504 self.calculate_metrics()
505 }
506
507 pub fn get_thresholds(&self) -> &[f64] {
513 &self.thresholds
514 }
515
516 pub fn get_fpr(&self) -> &[f64] {
522 &self.fpr
523 }
524
525 pub fn get_tpr(&self) -> &[f64] {
531 &self.tpr
532 }
533
534 pub fn get_metric_values(&mut self, metricname: &str) -> Result<Vec<f64>> {
544 let metrics = self.calculate_metrics()?;
545
546 let values = match metricname {
547 "threshold" => metrics.iter().map(|m| m.threshold).collect(),
548 "tpr" | "recall" | "sensitivity" => metrics.iter().map(|m| m.tpr).collect(),
549 "fpr" => metrics.iter().map(|m| m.fpr).collect(),
550 "precision" => metrics.iter().map(|m| m.precision).collect(),
551 "f1_score" | "f1" => metrics.iter().map(|m| m.f1_score).collect(),
552 "accuracy" => metrics.iter().map(|m| m.accuracy).collect(),
553 "specificity" => metrics.iter().map(|m| m.specificity).collect(),
554 "npv" => metrics.iter().map(|m| m.npv).collect(),
555 "mcc" => metrics.iter().map(|m| m.mcc).collect(),
556 "kappa" => metrics.iter().map(|m| m.kappa).collect(),
557 "youdens_j" | "j" => metrics.iter().map(|m| m.youdens_j).collect(),
558 "balanced_accuracy" => metrics.iter().map(|m| m.balanced_accuracy).collect(),
559 _ => {
560 return Err(MetricsError::InvalidArgument(format!(
561 "Unknown metric: {metricname}"
562 )))
563 }
564 };
565
566 Ok(values)
567 }
568
569 pub fn get_metric_names() -> Vec<String> {
575 vec![
576 "threshold".to_string(),
577 "tpr".to_string(),
578 "fpr".to_string(),
579 "precision".to_string(),
580 "f1_score".to_string(),
581 "accuracy".to_string(),
582 "specificity".to_string(),
583 "npv".to_string(),
584 "mcc".to_string(),
585 "kappa".to_string(),
586 "youdens_j".to_string(),
587 "balanced_accuracy".to_string(),
588 ]
589 }
590}
591
592#[allow(dead_code)]
604pub fn find_optimal_threshold<S1, S2>(
605 y_true: &ArrayBase<S1, Ix1>,
606 y_score: &ArrayBase<S2, Ix1>,
607 strategy: OptimalThresholdStrategy,
608) -> Result<(f64, ThresholdMetrics)>
609where
610 S1: Data,
611 S2: Data,
612 S1::Elem: Clone + Into<f64> + PartialEq,
613 S2::Elem: Clone + Into<f64> + PartialOrd,
614{
615 let mut analyzer = ThresholdAnalyzer::new(y_true, y_score)?;
616 let (threshold, metrics) = analyzer.find_optimal_threshold(strategy)?;
617 Ok((threshold, metrics.clone()))
618}
619
620#[allow(dead_code)]
632pub fn threshold_metrics<S1, S2>(
633 y_true: &ArrayBase<S1, Ix1>,
634 y_score: &ArrayBase<S2, Ix1>,
635 threshold: f64,
636) -> Result<ThresholdMetrics>
637where
638 S1: Data,
639 S2: Data,
640 S1::Elem: Clone + Into<f64> + PartialEq,
641 S2::Elem: Clone + Into<f64> + PartialOrd,
642{
643 let mut analyzer = ThresholdAnalyzer::new(y_true, y_score)?;
644 let metrics = analyzer.get_metrics_at_threshold(threshold)?;
645 Ok(metrics.clone())
646}
647
648#[allow(dead_code)]
659pub fn all_threshold_metrics<S1, S2>(
660 y_true: &ArrayBase<S1, Ix1>,
661 y_score: &ArrayBase<S2, Ix1>,
662) -> Result<Vec<ThresholdMetrics>>
663where
664 S1: Data,
665 S2: Data,
666 S1::Elem: Clone + Into<f64> + PartialEq,
667 S2::Elem: Clone + Into<f64> + PartialOrd,
668{
669 let mut analyzer = ThresholdAnalyzer::new(y_true, y_score)?;
670 let metrics = analyzer.get_all_metrics()?;
671 Ok(metrics.to_vec())
672}