scirs2_metrics/visualization/interactive/
roc_curve.rs

1//! Interactive ROC curve visualization
2//!
3//! This module provides tools for creating interactive ROC curve visualizations with
4//! threshold adjustment and performance metrics display.
5
6use scirs2_core::ndarray::{ArrayBase, Data, Ix1};
7use std::collections::HashMap;
8use std::error::Error;
9
10use crate::classification::curves::roc_curve;
11use crate::error::{MetricsError, Result};
12use crate::visualization::{
13    MetricVisualizer, PlotType, VisualizationData, VisualizationMetadata, VisualizationOptions,
14};
15
16/// Type alias for ROC curve computation result
17pub(crate) type ROCComputeResult = (Vec<f64>, Vec<f64>, Vec<f64>, Option<f64>);
18
19/// Type alias for confusion matrix values at a threshold
20pub(crate) type ConfusionMatrixValues = (usize, usize, usize, usize); // (TP, FP, TN, FN)
21
22/// Interactive ROC curve visualizer
23///
24/// This struct provides methods for visualizing interactive ROC curves with
25/// threshold adjustment and performance metrics display.
26#[derive(Debug, Clone)]
27pub struct InteractiveROCVisualizer<'a, T, S>
28where
29    T: Clone + PartialOrd,
30    S: Data<Elem = T>,
31{
32    /// True positive rates
33    tpr: Option<Vec<f64>>,
34    /// False positive rates
35    fpr: Option<Vec<f64>>,
36    /// Thresholds
37    thresholds: Option<Vec<f64>>,
38    /// AUC value
39    auc: Option<f64>,
40    /// Title for the plot
41    title: String,
42    /// Whether to display the AUC in the legend
43    show_auc: bool,
44    /// Whether to display the baseline
45    show_baseline: bool,
46    /// Original y_true data
47    y_true: Option<&'a ArrayBase<S, Ix1>>,
48    /// Original y_score data
49    y_score: Option<&'a ArrayBase<S, Ix1>>,
50    /// Class label for multi-class ROC curve
51    pos_label: Option<T>,
52    /// Current selected threshold index
53    current_threshold_idx: Option<usize>,
54    /// Display metrics on the plot
55    show_metrics: bool,
56    /// Layout options for the interactive plot
57    interactive_options: InteractiveOptions,
58}
59
60/// Options for interactive visualization
61#[derive(Debug, Clone)]
62pub struct InteractiveOptions {
63    /// Width of the plot
64    pub width: usize,
65    /// Height of the plot
66    pub height: usize,
67    /// Whether to show the threshold slider
68    pub show_threshold_slider: bool,
69    /// Whether to show metric values at the selected threshold
70    pub show_metric_values: bool,
71    /// Whether to show confusion matrix at the selected threshold
72    pub show_confusion_matrix: bool,
73    /// Custom layout options
74    pub custom_layout: HashMap<String, String>,
75}
76
77impl Default for InteractiveOptions {
78    fn default() -> Self {
79        Self {
80            width: 800,
81            height: 600,
82            show_threshold_slider: true,
83            show_metric_values: true,
84            show_confusion_matrix: true,
85            custom_layout: HashMap::new(),
86        }
87    }
88}
89
90impl<'a, T, S> InteractiveROCVisualizer<'a, T, S>
91where
92    T: Clone + PartialOrd + 'static,
93    S: Data<Elem = T>,
94    f64: From<T>,
95{
96    /// Create a new InteractiveROCVisualizer from pre-computed ROC curve data
97    ///
98    /// # Arguments
99    ///
100    /// * `fpr` - False positive rates
101    /// * `tpr` - True positive rates
102    /// * `thresholds` - Optional thresholds
103    /// * `auc` - Optional AUC value
104    ///
105    /// # Returns
106    ///
107    /// * A new InteractiveROCVisualizer
108    pub fn new(
109        fpr: Vec<f64>,
110        tpr: Vec<f64>,
111        thresholds: Option<Vec<f64>>,
112        auc: Option<f64>,
113    ) -> Self {
114        InteractiveROCVisualizer {
115            tpr: Some(tpr),
116            fpr: Some(fpr),
117            thresholds,
118            auc,
119            title: "Interactive ROC Curve".to_string(),
120            show_auc: true,
121            show_baseline: true,
122            y_true: None,
123            y_score: None,
124            pos_label: None,
125            current_threshold_idx: None,
126            show_metrics: true,
127            interactive_options: InteractiveOptions::default(),
128        }
129    }
130
131    /// Create a ROCCurveVisualizer from true labels and scores
132    ///
133    /// # Arguments
134    ///
135    /// * `y_true` - True binary labels
136    /// * `y_score` - Target scores (probabilities or decision function output)
137    /// * `pos_label` - Label of the positive class
138    ///
139    /// # Returns
140    ///
141    /// * A new InteractiveROCVisualizer
142    pub fn from_labels(
143        y_true: &'a ArrayBase<S, Ix1>,
144        y_score: &'a ArrayBase<S, Ix1>,
145        pos_label: Option<T>,
146    ) -> Self {
147        InteractiveROCVisualizer {
148            tpr: None,
149            fpr: None,
150            thresholds: None,
151            auc: None,
152            title: "Interactive ROC Curve".to_string(),
153            show_auc: true,
154            show_baseline: true,
155            y_true: Some(y_true),
156            y_score: Some(y_score),
157            pos_label,
158            current_threshold_idx: None,
159            show_metrics: true,
160            interactive_options: InteractiveOptions::default(),
161        }
162    }
163
164    /// Set the title for the plot
165    ///
166    /// # Arguments
167    ///
168    /// * `title` - Title for the plot
169    ///
170    /// # Returns
171    ///
172    /// * Self for method chaining
173    pub fn with_title(mut self, title: String) -> Self {
174        self.title = title;
175        self
176    }
177
178    /// Set whether to display the AUC in the legend
179    ///
180    /// # Arguments
181    ///
182    /// * `show_auc` - Whether to display the AUC
183    ///
184    /// # Returns
185    ///
186    /// * Self for method chaining
187    pub fn with_show_auc(mut self, showauc: bool) -> Self {
188        self.show_auc = showauc;
189        self
190    }
191
192    /// Set whether to display the baseline
193    ///
194    /// # Arguments
195    ///
196    /// * `show_baseline` - Whether to display the baseline
197    ///
198    /// # Returns
199    ///
200    /// * Self for method chaining
201    pub fn with_show_baseline(mut self, showbaseline: bool) -> Self {
202        self.show_baseline = showbaseline;
203        self
204    }
205
206    /// Set the AUC value
207    ///
208    /// # Arguments
209    ///
210    /// * `auc` - AUC value
211    ///
212    /// # Returns
213    ///
214    /// * Self for method chaining
215    pub fn with_auc(mut self, auc: f64) -> Self {
216        self.auc = Some(auc);
217        self
218    }
219
220    /// Set whether to display metrics on the plot
221    ///
222    /// # Arguments
223    ///
224    /// * `show_metrics` - Whether to display metrics
225    ///
226    /// # Returns
227    ///
228    /// * Self for method chaining
229    pub fn with_show_metrics(mut self, showmetrics: bool) -> Self {
230        self.show_metrics = showmetrics;
231        self
232    }
233
234    /// Set interactive options
235    ///
236    /// # Arguments
237    ///
238    /// * `options` - Interactive options
239    ///
240    /// # Returns
241    ///
242    /// * Self for method chaining
243    pub fn with_interactive_options(mut self, options: InteractiveOptions) -> Self {
244        self.interactive_options = options;
245        self
246    }
247
248    /// Set current threshold index
249    ///
250    /// # Arguments
251    ///
252    /// * `idx` - Threshold index
253    ///
254    /// # Returns
255    ///
256    /// * Self for method chaining
257    pub fn with_threshold_index(mut self, idx: usize) -> Self {
258        self.current_threshold_idx = Some(idx);
259        self
260    }
261
262    /// Set current threshold value
263    ///
264    /// # Arguments
265    ///
266    /// * `threshold` - Threshold value
267    ///
268    /// # Returns
269    ///
270    /// * Result containing self for method chaining
271    pub fn with_threshold_value(mut self, threshold: f64) -> Result<Self> {
272        // Ensure thresholds are computed
273        let (_, _, thresholds_, _) = self.compute_roc()?;
274
275        if thresholds_.is_empty() {
276            return Err(MetricsError::InvalidInput(
277                "No thresholds available".to_string(),
278            ));
279        }
280
281        // Find the closest threshold index
282        let mut closest_idx = 0;
283        let mut min_diff = f64::INFINITY;
284
285        for (i, &t) in thresholds_.iter().enumerate() {
286            let diff = (t - threshold).abs();
287            if diff < min_diff {
288                min_diff = diff;
289                closest_idx = i;
290            }
291        }
292
293        self.current_threshold_idx = Some(closest_idx);
294        Ok(self)
295    }
296
297    /// Compute the ROC curve if not already computed
298    ///
299    /// # Returns
300    ///
301    /// * Result containing (fpr, tpr, thresholds, auc)
302    fn compute_roc(&self) -> Result<ROCComputeResult> {
303        if self.fpr.is_some() && self.tpr.is_some() {
304            // Return pre-computed values
305            return Ok((
306                self.fpr.clone().unwrap(),
307                self.tpr.clone().unwrap(),
308                self.thresholds.clone().unwrap_or_default(),
309                self.auc,
310            ));
311        }
312
313        if self.y_true.is_none() || self.y_score.is_none() {
314            return Err(MetricsError::InvalidInput(
315                "No data provided for ROC curve computation".to_string(),
316            ));
317        }
318
319        let y_true = self.y_true.unwrap();
320        let y_score = self.y_score.unwrap();
321
322        // Compute ROC curve
323        let (fpr, tpr, thresholds) = roc_curve(y_true, y_score)?;
324
325        // Compute AUC if not already provided
326        let auc = if self.auc.is_none() {
327            // AUC is the area under the ROC curve, which we can approximate
328            // using the trapezoidal rule
329            let n = fpr.len();
330
331            let mut area = 0.0;
332            for i in 1..n {
333                // Trapezoidal area: (b - a) * (f(a) + f(b)) / 2
334                area += (fpr[i] - fpr[i - 1]) * (tpr[i] + tpr[i - 1]) / 2.0;
335            }
336
337            Some(area)
338        } else {
339            self.auc
340        };
341
342        Ok((fpr.to_vec(), tpr.to_vec(), thresholds.to_vec(), auc))
343    }
344
345    /// Calculate confusion matrix values at a given threshold index
346    ///
347    /// # Arguments
348    ///
349    /// * `threshold_idx` - Index of the threshold
350    ///
351    /// # Returns
352    ///
353    /// * Result containing confusion matrix values (TP, FP, TN, FN)
354    pub fn calculate_confusion_matrix(
355        &self,
356        threshold_idx: usize,
357    ) -> Result<ConfusionMatrixValues> {
358        if self.y_true.is_none() || self.y_score.is_none() {
359            return Err(MetricsError::InvalidInput(
360                "Original data required for confusion matrix calculation".to_string(),
361            ));
362        }
363
364        let (_, _, thresholds_, _) = self.compute_roc()?;
365
366        if threshold_idx >= thresholds_.len() {
367            return Err(MetricsError::InvalidArgument(
368                "Threshold index out of range".to_string(),
369            ));
370        }
371
372        let threshold = thresholds_[threshold_idx];
373        let y_true = self.y_true.unwrap();
374        let y_score = self.y_score.unwrap();
375
376        let mut tp = 0;
377        let mut fp = 0;
378        let mut tn = 0;
379        let mut fn_ = 0;
380
381        // Convert positive label to f64 for comparison
382        let pos_label_f64 = match &self.pos_label {
383            Some(label) => f64::from(label.clone()),
384            None => 1.0, // Default positive label is 1.0
385        };
386
387        for i in 0..y_true.len() {
388            let true_val = f64::from(y_true[i].clone());
389            let score = f64::from(y_score[i].clone());
390
391            let pred = if score >= threshold {
392                pos_label_f64
393            } else {
394                0.0
395            };
396
397            if pred == pos_label_f64 && true_val == pos_label_f64 {
398                tp += 1;
399            } else if pred == pos_label_f64 && true_val != pos_label_f64 {
400                fp += 1;
401            } else if pred != pos_label_f64 && true_val != pos_label_f64 {
402                tn += 1;
403            } else {
404                fn_ += 1;
405            }
406        }
407
408        Ok((tp, fp, tn, fn_))
409    }
410
411    /// Calculate metrics at a given threshold index
412    ///
413    /// # Arguments
414    ///
415    /// * `threshold_idx` - Index of the threshold
416    ///
417    /// # Returns
418    ///
419    /// * Result containing a HashMap with metric names and values
420    pub fn calculate_metrics(&self, thresholdidx: usize) -> Result<HashMap<String, f64>> {
421        let (tp, fp, tn, fn_) = self.calculate_confusion_matrix(thresholdidx)?;
422
423        let mut metrics = HashMap::new();
424
425        // Accuracy
426        let accuracy = (tp + tn) as f64 / (tp + fp + tn + fn_) as f64;
427        metrics.insert("accuracy".to_string(), accuracy);
428
429        // Precision
430        let precision = if tp + fp > 0 {
431            tp as f64 / (tp + fp) as f64
432        } else {
433            0.0
434        };
435        metrics.insert("precision".to_string(), precision);
436
437        // Recall (Sensitivity, True Positive Rate)
438        let recall = if tp + fn_ > 0 {
439            tp as f64 / (tp + fn_) as f64
440        } else {
441            0.0
442        };
443        metrics.insert("recall".to_string(), recall);
444
445        // Specificity (True Negative Rate)
446        let specificity = if tn + fp > 0 {
447            tn as f64 / (tn + fp) as f64
448        } else {
449            0.0
450        };
451        metrics.insert("specificity".to_string(), specificity);
452
453        // F1 Score
454        let f1 = if precision + recall > 0.0 {
455            2.0 * precision * recall / (precision + recall)
456        } else {
457            0.0
458        };
459        metrics.insert("f1_score".to_string(), f1);
460
461        // Add threshold value
462        let (_, _, thresholds_, _) = self.compute_roc()?;
463        metrics.insert("threshold".to_string(), thresholds_[thresholdidx]);
464
465        Ok(metrics)
466    }
467
468    /// Get the current threshold index or a default
469    ///
470    /// # Returns
471    ///
472    /// * The current threshold index or the middle index if not set
473    pub fn get_current_threshold_idx(&self) -> Result<usize> {
474        let (_, _, thresholds_, _) = self.compute_roc()?;
475
476        if thresholds_.is_empty() {
477            return Err(MetricsError::InvalidInput(
478                "No thresholds available".to_string(),
479            ));
480        }
481
482        match self.current_threshold_idx {
483            Some(idx) if idx < thresholds_.len() => Ok(idx),
484            _ => Ok(thresholds_.len() / 2), // Default to middle threshold
485        }
486    }
487}
488
489impl<T, S> MetricVisualizer for InteractiveROCVisualizer<'_, T, S>
490where
491    T: Clone + PartialOrd + 'static,
492    S: Data<Elem = T>,
493    f64: From<T>,
494{
495    fn prepare_data(&self) -> std::result::Result<VisualizationData, Box<dyn Error>> {
496        let (fpr, tpr, thresholds, auc) = self
497            .compute_roc()
498            .map_err(|e| Box::new(e) as Box<dyn Error>)?;
499
500        // Prepare data for visualization
501        let mut data = VisualizationData::new();
502
503        // ROC curve points
504        data.x = fpr.clone();
505        data.y = tpr.clone();
506
507        // Store thresholds in auxiliary data
508        data.add_auxiliary_data("thresholds".to_string(), thresholds.clone());
509
510        // Add AUC if available
511        if let Some(auc_val) = auc {
512            data.add_auxiliary_metadata("auc".to_string(), auc_val.to_string());
513        }
514
515        // Add current threshold point if available
516        if let Ok(threshold_idx) = self.get_current_threshold_idx() {
517            let current_point_x = vec![fpr[threshold_idx]];
518            let current_point_y = vec![tpr[threshold_idx]];
519
520            data.add_auxiliary_data("current_point_x".to_string(), current_point_x);
521            data.add_auxiliary_data("current_point_y".to_string(), current_point_y);
522            data.add_auxiliary_metadata(
523                "current_threshold".to_string(),
524                thresholds[threshold_idx].to_string(),
525            );
526
527            // Add metrics at current threshold if requested
528            if self.show_metrics {
529                if let Ok(metrics) = self.calculate_metrics(threshold_idx) {
530                    for (name, value) in metrics {
531                        data.add_auxiliary_metadata(format!("metric_{name}"), value.to_string());
532                    }
533                }
534            }
535        }
536
537        // Add interactive options
538        data.add_auxiliary_metadata(
539            "interactive_width".to_string(),
540            self.interactive_options.width.to_string(),
541        );
542        data.add_auxiliary_metadata(
543            "interactive_height".to_string(),
544            self.interactive_options.height.to_string(),
545        );
546        data.add_auxiliary_metadata(
547            "show_threshold_slider".to_string(),
548            self.interactive_options.show_threshold_slider.to_string(),
549        );
550        data.add_auxiliary_metadata(
551            "show_metric_values".to_string(),
552            self.interactive_options.show_metric_values.to_string(),
553        );
554        data.add_auxiliary_metadata(
555            "show_confusion_matrix".to_string(),
556            self.interactive_options.show_confusion_matrix.to_string(),
557        );
558
559        // Add custom layout options
560        for (key, value) in &self.interactive_options.custom_layout {
561            data.add_auxiliary_metadata(format!("layout_{key}"), value.clone());
562        }
563
564        // Add baseline if requested
565        if self.show_baseline {
566            data.add_auxiliary_data("baseline_x".to_string(), vec![0.0, 1.0]);
567            data.add_auxiliary_data("baseline_y".to_string(), vec![0.0, 1.0]);
568        }
569
570        // Prepare series names
571        let mut series_names = Vec::new();
572
573        if self.show_auc && auc.is_some() {
574            series_names.push(format!("ROC curve (AUC = {:.3})", auc.unwrap()));
575        } else {
576            series_names.push("ROC curve".to_string());
577        }
578
579        if self.show_baseline {
580            series_names.push("Random classifier".to_string());
581        }
582
583        // Point at current threshold
584        series_names.push("Current threshold".to_string());
585
586        data.add_series_names(series_names);
587
588        Ok(data)
589    }
590
591    fn get_metadata(&self) -> VisualizationMetadata {
592        let mut metadata = VisualizationMetadata::new(self.title.clone());
593        metadata.set_plot_type(PlotType::Line);
594        metadata.set_x_label("False Positive Rate".to_string());
595        metadata.set_y_label("True Positive Rate".to_string());
596        metadata.set_description("Interactive ROC curve showing the trade-off between true positive rate and false positive rate. Adjust the threshold to see performance metrics.".to_string());
597
598        metadata
599    }
600}
601
602/// Create an interactive ROC curve visualization from pre-computed ROC curve data
603///
604/// # Arguments
605///
606/// * `fpr` - False positive rates
607/// * `tpr` - True positive rates
608/// * `thresholds` - Optional thresholds
609/// * `auc` - Optional AUC value
610///
611/// # Returns
612///
613/// * An InteractiveROCVisualizer
614#[allow(dead_code)]
615pub fn interactive_roc_curve_visualization(
616    fpr: Vec<f64>,
617    tpr: Vec<f64>,
618    thresholds: Option<Vec<f64>>,
619    auc: Option<f64>,
620) -> InteractiveROCVisualizer<'static, f64, scirs2_core::ndarray::OwnedRepr<f64>> {
621    InteractiveROCVisualizer::new(fpr, tpr, thresholds, auc)
622}
623
624/// Create an interactive ROC curve visualization from true labels and scores
625///
626/// # Arguments
627///
628/// * `y_true` - True binary labels
629/// * `y_score` - Target scores (probabilities or decision function output)
630/// * `pos_label` - Optional label of the positive class
631///
632/// # Returns
633///
634/// * An InteractiveROCVisualizer
635#[allow(dead_code)]
636pub fn interactive_roc_curve_from_labels<'a, T, S>(
637    y_true: &'a ArrayBase<S, Ix1>,
638    y_score: &'a ArrayBase<S, Ix1>,
639    pos_label: Option<T>,
640) -> InteractiveROCVisualizer<'a, T, S>
641where
642    T: Clone + PartialOrd + 'static,
643    S: Data<Elem = T>,
644    f64: From<T>,
645{
646    InteractiveROCVisualizer::from_labels(y_true, y_score, pos_label)
647}