1use scirs2_core::ndarray::{ArrayBase, Data, Ix1};
7use std::collections::HashMap;
8use std::error::Error;
9
10use crate::classification::curves::roc_curve;
11use crate::error::{MetricsError, Result};
12use crate::visualization::{
13 MetricVisualizer, PlotType, VisualizationData, VisualizationMetadata, VisualizationOptions,
14};
15
16pub(crate) type ROCComputeResult = (Vec<f64>, Vec<f64>, Vec<f64>, Option<f64>);
18
19pub(crate) type ConfusionMatrixValues = (usize, usize, usize, usize); #[derive(Debug, Clone)]
27pub struct InteractiveROCVisualizer<'a, T, S>
28where
29 T: Clone + PartialOrd,
30 S: Data<Elem = T>,
31{
32 tpr: Option<Vec<f64>>,
34 fpr: Option<Vec<f64>>,
36 thresholds: Option<Vec<f64>>,
38 auc: Option<f64>,
40 title: String,
42 show_auc: bool,
44 show_baseline: bool,
46 y_true: Option<&'a ArrayBase<S, Ix1>>,
48 y_score: Option<&'a ArrayBase<S, Ix1>>,
50 pos_label: Option<T>,
52 current_threshold_idx: Option<usize>,
54 show_metrics: bool,
56 interactive_options: InteractiveOptions,
58}
59
60#[derive(Debug, Clone)]
62pub struct InteractiveOptions {
63 pub width: usize,
65 pub height: usize,
67 pub show_threshold_slider: bool,
69 pub show_metric_values: bool,
71 pub show_confusion_matrix: bool,
73 pub custom_layout: HashMap<String, String>,
75}
76
77impl Default for InteractiveOptions {
78 fn default() -> Self {
79 Self {
80 width: 800,
81 height: 600,
82 show_threshold_slider: true,
83 show_metric_values: true,
84 show_confusion_matrix: true,
85 custom_layout: HashMap::new(),
86 }
87 }
88}
89
90impl<'a, T, S> InteractiveROCVisualizer<'a, T, S>
91where
92 T: Clone + PartialOrd + 'static,
93 S: Data<Elem = T>,
94 f64: From<T>,
95{
96 pub fn new(
109 fpr: Vec<f64>,
110 tpr: Vec<f64>,
111 thresholds: Option<Vec<f64>>,
112 auc: Option<f64>,
113 ) -> Self {
114 InteractiveROCVisualizer {
115 tpr: Some(tpr),
116 fpr: Some(fpr),
117 thresholds,
118 auc,
119 title: "Interactive ROC Curve".to_string(),
120 show_auc: true,
121 show_baseline: true,
122 y_true: None,
123 y_score: None,
124 pos_label: None,
125 current_threshold_idx: None,
126 show_metrics: true,
127 interactive_options: InteractiveOptions::default(),
128 }
129 }
130
131 pub fn from_labels(
143 y_true: &'a ArrayBase<S, Ix1>,
144 y_score: &'a ArrayBase<S, Ix1>,
145 pos_label: Option<T>,
146 ) -> Self {
147 InteractiveROCVisualizer {
148 tpr: None,
149 fpr: None,
150 thresholds: None,
151 auc: None,
152 title: "Interactive ROC Curve".to_string(),
153 show_auc: true,
154 show_baseline: true,
155 y_true: Some(y_true),
156 y_score: Some(y_score),
157 pos_label,
158 current_threshold_idx: None,
159 show_metrics: true,
160 interactive_options: InteractiveOptions::default(),
161 }
162 }
163
164 pub fn with_title(mut self, title: String) -> Self {
174 self.title = title;
175 self
176 }
177
178 pub fn with_show_auc(mut self, showauc: bool) -> Self {
188 self.show_auc = showauc;
189 self
190 }
191
192 pub fn with_show_baseline(mut self, showbaseline: bool) -> Self {
202 self.show_baseline = showbaseline;
203 self
204 }
205
206 pub fn with_auc(mut self, auc: f64) -> Self {
216 self.auc = Some(auc);
217 self
218 }
219
220 pub fn with_show_metrics(mut self, showmetrics: bool) -> Self {
230 self.show_metrics = showmetrics;
231 self
232 }
233
234 pub fn with_interactive_options(mut self, options: InteractiveOptions) -> Self {
244 self.interactive_options = options;
245 self
246 }
247
248 pub fn with_threshold_index(mut self, idx: usize) -> Self {
258 self.current_threshold_idx = Some(idx);
259 self
260 }
261
262 pub fn with_threshold_value(mut self, threshold: f64) -> Result<Self> {
272 let (_, _, thresholds_, _) = self.compute_roc()?;
274
275 if thresholds_.is_empty() {
276 return Err(MetricsError::InvalidInput(
277 "No thresholds available".to_string(),
278 ));
279 }
280
281 let mut closest_idx = 0;
283 let mut min_diff = f64::INFINITY;
284
285 for (i, &t) in thresholds_.iter().enumerate() {
286 let diff = (t - threshold).abs();
287 if diff < min_diff {
288 min_diff = diff;
289 closest_idx = i;
290 }
291 }
292
293 self.current_threshold_idx = Some(closest_idx);
294 Ok(self)
295 }
296
297 fn compute_roc(&self) -> Result<ROCComputeResult> {
303 if self.fpr.is_some() && self.tpr.is_some() {
304 return Ok((
306 self.fpr.clone().unwrap(),
307 self.tpr.clone().unwrap(),
308 self.thresholds.clone().unwrap_or_default(),
309 self.auc,
310 ));
311 }
312
313 if self.y_true.is_none() || self.y_score.is_none() {
314 return Err(MetricsError::InvalidInput(
315 "No data provided for ROC curve computation".to_string(),
316 ));
317 }
318
319 let y_true = self.y_true.unwrap();
320 let y_score = self.y_score.unwrap();
321
322 let (fpr, tpr, thresholds) = roc_curve(y_true, y_score)?;
324
325 let auc = if self.auc.is_none() {
327 let n = fpr.len();
330
331 let mut area = 0.0;
332 for i in 1..n {
333 area += (fpr[i] - fpr[i - 1]) * (tpr[i] + tpr[i - 1]) / 2.0;
335 }
336
337 Some(area)
338 } else {
339 self.auc
340 };
341
342 Ok((fpr.to_vec(), tpr.to_vec(), thresholds.to_vec(), auc))
343 }
344
345 pub fn calculate_confusion_matrix(
355 &self,
356 threshold_idx: usize,
357 ) -> Result<ConfusionMatrixValues> {
358 if self.y_true.is_none() || self.y_score.is_none() {
359 return Err(MetricsError::InvalidInput(
360 "Original data required for confusion matrix calculation".to_string(),
361 ));
362 }
363
364 let (_, _, thresholds_, _) = self.compute_roc()?;
365
366 if threshold_idx >= thresholds_.len() {
367 return Err(MetricsError::InvalidArgument(
368 "Threshold index out of range".to_string(),
369 ));
370 }
371
372 let threshold = thresholds_[threshold_idx];
373 let y_true = self.y_true.unwrap();
374 let y_score = self.y_score.unwrap();
375
376 let mut tp = 0;
377 let mut fp = 0;
378 let mut tn = 0;
379 let mut fn_ = 0;
380
381 let pos_label_f64 = match &self.pos_label {
383 Some(label) => f64::from(label.clone()),
384 None => 1.0, };
386
387 for i in 0..y_true.len() {
388 let true_val = f64::from(y_true[i].clone());
389 let score = f64::from(y_score[i].clone());
390
391 let pred = if score >= threshold {
392 pos_label_f64
393 } else {
394 0.0
395 };
396
397 if pred == pos_label_f64 && true_val == pos_label_f64 {
398 tp += 1;
399 } else if pred == pos_label_f64 && true_val != pos_label_f64 {
400 fp += 1;
401 } else if pred != pos_label_f64 && true_val != pos_label_f64 {
402 tn += 1;
403 } else {
404 fn_ += 1;
405 }
406 }
407
408 Ok((tp, fp, tn, fn_))
409 }
410
411 pub fn calculate_metrics(&self, thresholdidx: usize) -> Result<HashMap<String, f64>> {
421 let (tp, fp, tn, fn_) = self.calculate_confusion_matrix(thresholdidx)?;
422
423 let mut metrics = HashMap::new();
424
425 let accuracy = (tp + tn) as f64 / (tp + fp + tn + fn_) as f64;
427 metrics.insert("accuracy".to_string(), accuracy);
428
429 let precision = if tp + fp > 0 {
431 tp as f64 / (tp + fp) as f64
432 } else {
433 0.0
434 };
435 metrics.insert("precision".to_string(), precision);
436
437 let recall = if tp + fn_ > 0 {
439 tp as f64 / (tp + fn_) as f64
440 } else {
441 0.0
442 };
443 metrics.insert("recall".to_string(), recall);
444
445 let specificity = if tn + fp > 0 {
447 tn as f64 / (tn + fp) as f64
448 } else {
449 0.0
450 };
451 metrics.insert("specificity".to_string(), specificity);
452
453 let f1 = if precision + recall > 0.0 {
455 2.0 * precision * recall / (precision + recall)
456 } else {
457 0.0
458 };
459 metrics.insert("f1_score".to_string(), f1);
460
461 let (_, _, thresholds_, _) = self.compute_roc()?;
463 metrics.insert("threshold".to_string(), thresholds_[thresholdidx]);
464
465 Ok(metrics)
466 }
467
468 pub fn get_current_threshold_idx(&self) -> Result<usize> {
474 let (_, _, thresholds_, _) = self.compute_roc()?;
475
476 if thresholds_.is_empty() {
477 return Err(MetricsError::InvalidInput(
478 "No thresholds available".to_string(),
479 ));
480 }
481
482 match self.current_threshold_idx {
483 Some(idx) if idx < thresholds_.len() => Ok(idx),
484 _ => Ok(thresholds_.len() / 2), }
486 }
487}
488
489impl<T, S> MetricVisualizer for InteractiveROCVisualizer<'_, T, S>
490where
491 T: Clone + PartialOrd + 'static,
492 S: Data<Elem = T>,
493 f64: From<T>,
494{
495 fn prepare_data(&self) -> std::result::Result<VisualizationData, Box<dyn Error>> {
496 let (fpr, tpr, thresholds, auc) = self
497 .compute_roc()
498 .map_err(|e| Box::new(e) as Box<dyn Error>)?;
499
500 let mut data = VisualizationData::new();
502
503 data.x = fpr.clone();
505 data.y = tpr.clone();
506
507 data.add_auxiliary_data("thresholds".to_string(), thresholds.clone());
509
510 if let Some(auc_val) = auc {
512 data.add_auxiliary_metadata("auc".to_string(), auc_val.to_string());
513 }
514
515 if let Ok(threshold_idx) = self.get_current_threshold_idx() {
517 let current_point_x = vec![fpr[threshold_idx]];
518 let current_point_y = vec![tpr[threshold_idx]];
519
520 data.add_auxiliary_data("current_point_x".to_string(), current_point_x);
521 data.add_auxiliary_data("current_point_y".to_string(), current_point_y);
522 data.add_auxiliary_metadata(
523 "current_threshold".to_string(),
524 thresholds[threshold_idx].to_string(),
525 );
526
527 if self.show_metrics {
529 if let Ok(metrics) = self.calculate_metrics(threshold_idx) {
530 for (name, value) in metrics {
531 data.add_auxiliary_metadata(format!("metric_{name}"), value.to_string());
532 }
533 }
534 }
535 }
536
537 data.add_auxiliary_metadata(
539 "interactive_width".to_string(),
540 self.interactive_options.width.to_string(),
541 );
542 data.add_auxiliary_metadata(
543 "interactive_height".to_string(),
544 self.interactive_options.height.to_string(),
545 );
546 data.add_auxiliary_metadata(
547 "show_threshold_slider".to_string(),
548 self.interactive_options.show_threshold_slider.to_string(),
549 );
550 data.add_auxiliary_metadata(
551 "show_metric_values".to_string(),
552 self.interactive_options.show_metric_values.to_string(),
553 );
554 data.add_auxiliary_metadata(
555 "show_confusion_matrix".to_string(),
556 self.interactive_options.show_confusion_matrix.to_string(),
557 );
558
559 for (key, value) in &self.interactive_options.custom_layout {
561 data.add_auxiliary_metadata(format!("layout_{key}"), value.clone());
562 }
563
564 if self.show_baseline {
566 data.add_auxiliary_data("baseline_x".to_string(), vec![0.0, 1.0]);
567 data.add_auxiliary_data("baseline_y".to_string(), vec![0.0, 1.0]);
568 }
569
570 let mut series_names = Vec::new();
572
573 if self.show_auc && auc.is_some() {
574 series_names.push(format!("ROC curve (AUC = {:.3})", auc.unwrap()));
575 } else {
576 series_names.push("ROC curve".to_string());
577 }
578
579 if self.show_baseline {
580 series_names.push("Random classifier".to_string());
581 }
582
583 series_names.push("Current threshold".to_string());
585
586 data.add_series_names(series_names);
587
588 Ok(data)
589 }
590
591 fn get_metadata(&self) -> VisualizationMetadata {
592 let mut metadata = VisualizationMetadata::new(self.title.clone());
593 metadata.set_plot_type(PlotType::Line);
594 metadata.set_x_label("False Positive Rate".to_string());
595 metadata.set_y_label("True Positive Rate".to_string());
596 metadata.set_description("Interactive ROC curve showing the trade-off between true positive rate and false positive rate. Adjust the threshold to see performance metrics.".to_string());
597
598 metadata
599 }
600}
601
602#[allow(dead_code)]
615pub fn interactive_roc_curve_visualization(
616 fpr: Vec<f64>,
617 tpr: Vec<f64>,
618 thresholds: Option<Vec<f64>>,
619 auc: Option<f64>,
620) -> InteractiveROCVisualizer<'static, f64, scirs2_core::ndarray::OwnedRepr<f64>> {
621 InteractiveROCVisualizer::new(fpr, tpr, thresholds, auc)
622}
623
624#[allow(dead_code)]
636pub fn interactive_roc_curve_from_labels<'a, T, S>(
637 y_true: &'a ArrayBase<S, Ix1>,
638 y_score: &'a ArrayBase<S, Ix1>,
639 pos_label: Option<T>,
640) -> InteractiveROCVisualizer<'a, T, S>
641where
642 T: Clone + PartialOrd + 'static,
643 S: Data<Elem = T>,
644 f64: From<T>,
645{
646 InteractiveROCVisualizer::from_labels(y_true, y_score, pos_label)
647}