scirs2_neural/utils/evaluation/
roc_curve.rs

1//! ROC curve for binary classification problems
2//!
3//! This module provides tools for computing, analyzing, and visualizing
4//! Receiver Operating Characteristic (ROC) curves for binary classifiers.
5
6use crate::error::{NeuralError, Result};
7use crate::utils::colors::{
8    colored_metric_cell, colorize, stylize, Color, ColorOptions, Style, RESET,
9};
10use crate::utils::evaluation::helpers::draw_line_with_coords;
11use scirs2_core::ndarray::{Array1, ArrayView1};
12use scirs2_core::numeric::Float;
13use std::fmt::{Debug, Display};
14
15/// ROC curve data structure for binary classification evaluation
16///
17/// This struct represents a Receiver Operating Characteristic (ROC) curve,
18/// which plots the True Positive Rate (TPR) against the False Positive Rate (FPR)
19/// at various classification thresholds. It also calculates the Area Under the
20/// Curve (AUC), a common metric for binary classification performance.
21pub struct ROCCurve<F: Float + Debug + Display> {
22    /// False positive rates at different thresholds
23    pub fpr: Array1<F>,
24    /// True positive rates at different thresholds
25    pub tpr: Array1<F>,
26    /// Classification thresholds
27    pub thresholds: Array1<F>,
28    /// Area Under the ROC Curve (AUC)
29    pub auc: F,
30}
31
32impl<F: Float + Debug + Display> ROCCurve<F> {
33    /// Compute ROC curve and AUC from binary classification scores
34    ///
35    /// # Arguments
36    /// * `y_true` - True binary labels (0 or 1)
37    /// * `y_score` - Predicted probabilities or decision function
38    ///
39    /// # Returns
40    /// * `Result<ROCCurve<F>>` - ROC curve data
41    ///
42    /// # Example
43    /// ```
44    /// use scirs2_core::ndarray::{Array1, ArrayView1};
45    /// use scirs2_neural::utils::evaluation::ROCCurve;
46    ///
47    /// // Create some example data
48    /// let y_true = Array1::from_vec(vec![0, 1, 1, 0, 1, 0, 1, 0, 1, 0]);
49    /// let y_score = Array1::from_vec(vec![0.1, 0.9, 0.8, 0.3, 0.7, 0.2, 0.6, 0.4, 0.8, 0.3]);
50    ///
51    /// // Compute ROC curve
52    /// let roc = ROCCurve::<f64>::new(&y_true.view(), &y_score.view()).unwrap();
53    ///
54    /// // AUC should be > 0.5 for a model better than random guessing
55    /// assert!(roc.auc > 0.5);
56    /// ```
57    pub fn new(y_true: &ArrayView1<usize>, yscore: &ArrayView1<F>) -> Result<Self> {
58        if y_true.len() != yscore.len() {
59            return Err(NeuralError::ValidationError(
60                "Labels and scores must have the same length".to_string(),
61            ));
62        }
63
64        // Check if y_true contains only binary values (0 or 1)
65        for &label in y_true.iter() {
66            if label != 0 && label != 1 {
67                return Err(NeuralError::ValidationError(
68                    "Labels must be binary (0 or 1)".to_string(),
69                ));
70            }
71        }
72
73        // Sort scores and corresponding labels in descending order
74        let mut score_label_pairs: Vec<(F, usize)> = yscore
75            .iter()
76            .zip(y_true.iter())
77            .map(|(&_score, &label)| (_score, label))
78            .collect();
79        score_label_pairs
80            .sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
81
82        // Count positives and negatives
83        let n_pos = y_true.iter().filter(|&&label| label == 1).count();
84        let n_neg = y_true.len() - n_pos;
85        if n_pos == 0 || n_neg == 0 {
86            return Err(NeuralError::ValidationError(
87                "Both positive and negative samples are required".to_string(),
88            ));
89        }
90
91        // Initialize arrays for ROC curve
92        let n_thresholds = score_label_pairs.len() + 1;
93        let mut fpr = Array1::zeros(n_thresholds);
94        let mut tpr = Array1::zeros(n_thresholds);
95        let mut thresholds = Array1::zeros(n_thresholds);
96
97        // Set the first point (0,0) and last threshold to infinity
98        thresholds[0] = F::infinity();
99
100        // Compute ROC curve
101        let mut tp = 0;
102        let mut fp = 0;
103        for i in 0..score_label_pairs.len() {
104            let (score, label) = score_label_pairs[i];
105            // Update counts
106            if label == 1 {
107                tp += 1;
108            } else {
109                fp += 1;
110            }
111
112            // Set threshold for this point
113            thresholds[i + 1] = score;
114            // Compute rates
115            tpr[i + 1] = F::from(tp).unwrap() / F::from(n_pos).unwrap();
116            fpr[i + 1] = F::from(fp).unwrap() / F::from(n_neg).unwrap();
117        }
118
119        // Compute AUC using trapezoidal rule
120        let mut auc = F::zero();
121        for i in 0..fpr.len() - 1 {
122            auc = auc + (fpr[i + 1] - fpr[i]) * (tpr[i] + tpr[i + 1]) * F::from(0.5).unwrap();
123        }
124
125        Ok(ROCCurve {
126            fpr,
127            tpr,
128            thresholds,
129            auc,
130        })
131    }
132
133    /// Create an ASCII line plot of the ROC curve
134    ///
135    /// # Arguments
136    /// * `title` - Optional title for the plot
137    /// * `width` - Width of the plot
138    /// * `height` - Height of the plot
139    ///
140    /// # Returns
141    /// * `String` - ASCII line plot
142    pub fn to_ascii(&self, title: Option<&str>, width: usize, height: usize) -> String {
143        self.to_ascii_with_options(title, width, height, &ColorOptions::default())
144    }
145
146    /// Create an ASCII line plot of the ROC curve with color options
147    /// This method provides a customizable visualization of the ROC curve
148    /// with controls for colors and styling.
149    ///
150    /// # Arguments
151    /// * `title` - Optional title for the plot
152    /// * `width` - Width of the plot
153    /// * `height` - Height of the plot
154    /// * `color_options` - Color options for visualization
155    ///
156    /// # Returns
157    /// * `String` - ASCII line plot with colors
158    ///
159    /// # Example
160    /// ```
161    /// use scirs2_neural::utils::colors::ColorOptions;
162    /// use scirs2_neural::utils::ROCCurve;
163    /// use scirs2_core::ndarray::Array1;
164    ///
165    /// // Create test data
166    /// let y_true = Array1::from_vec(vec![0, 0, 1, 1]);
167    /// let y_scores = Array1::from_vec(vec![0.1, 0.4, 0.35, 0.8]);
168    /// let roc = ROCCurve::new(&y_true.view(), &y_scores.view()).unwrap();
169    ///
170    /// // Create ROC curve visualization
171    /// let options = ColorOptions::default();
172    /// let plot = roc.to_ascii_with_options(Some("Model Performance"), 50, 20, &options);
173    ///
174    /// // Visualization will show the curve with the AUC value
175    /// assert!(plot.contains("AUC ="));
176    /// ```
177    pub fn to_ascii_with_options(
178        &self,
179        title: Option<&str>,
180        width: usize,
181        height: usize,
182        color_options: &ColorOptions,
183    ) -> String {
184        // Pre-allocate result string with estimated capacity
185        let mut result = String::with_capacity(width * height * 2);
186
187        // Add title and AUC with coloring if enabled
188        if let Some(titletext) = title {
189            if color_options.enabled {
190                let styled_title = stylize(titletext, Style::Bold);
191                let auc_value = self.auc.to_f64().unwrap_or(0.0);
192                let colored_auc =
193                    colored_metric_cell(format!("{:.3}", self.auc), auc_value, color_options);
194                result.push_str(&format!("{styled_title} (AUC = {colored_auc})\n\n"));
195            } else {
196                result.push_str(&format!("{} (AUC = {:.3})\n\n", titletext, self.auc));
197            }
198        } else if color_options.enabled {
199            let styled_title = stylize("ROC Curve", Style::Bold);
200            let auc_value = self.auc.to_f64().unwrap_or(0.0);
201            let colored_auc =
202                colored_metric_cell(format!("{:.3}", self.auc), auc_value, color_options);
203            result.push_str(&format!("{styled_title} (AUC = {colored_auc})\n\n"));
204        } else {
205            result.push_str(&format!("ROC Curve (AUC = {:.3})\n\n", self.auc));
206        }
207
208        // Create a 2D grid for the plot
209        let mut grid = vec![vec![' '; width]; height];
210
211        // Draw the diagonal (random classifier line)
212        for i in 0..std::cmp::min(width, height) {
213            let x = i;
214            let y = height - 1 - i * (height - 1) / (width - 1);
215            if x < width && y < height {
216                grid[y][x] = '.';
217            }
218        }
219
220        // Convert ROC curve points to line segments
221        let mut prev_x = 0;
222        let mut prev_y = height - 1; // Start at (0,0) in ROC space
223        for i in 1..self.fpr.len() {
224            let x = (self.fpr[i].to_f64().unwrap() * (width - 1) as f64).round() as usize;
225            let y =
226                height - 1 - (self.tpr[i].to_f64().unwrap() * (height - 1) as f64).round() as usize;
227
228            // Draw line segments between points
229            if x != prev_x || y != prev_y {
230                for (line_x, line_y) in
231                    draw_line_with_coords(prev_x, prev_y, x, y, Some(width), Some(height))
232                {
233                    grid[line_y][line_x] = '●';
234                }
235                prev_x = x;
236                prev_y = y;
237            }
238        }
239
240        // Draw the grid
241        for (y, row) in grid.iter().enumerate() {
242            // Y-axis labels with styling
243            if y == height - 1 {
244                if color_options.enabled {
245                    let fg_code = Color::BrightCyan.fg_code();
246                    result.push_str(&format!("{fg_code}0.0{RESET} |"));
247                } else {
248                    result.push_str("0.0 |");
249                }
250            } else if y == 0 {
251                if color_options.enabled {
252                    let fg_code = Color::BrightCyan.fg_code();
253                    result.push_str(&format!("{fg_code}1.0{RESET} |"));
254                } else {
255                    result.push_str("1.0 |");
256                }
257            } else if y == height / 2 {
258                if color_options.enabled {
259                    let fg_code = Color::BrightCyan.fg_code();
260                    result.push_str(&format!("{fg_code}0.5{RESET} |"));
261                } else {
262                    result.push_str("0.5 |");
263                }
264            } else {
265                result.push_str("    |");
266            }
267
268            // Grid content with coloring
269            for char in row.iter().take(width) {
270                if color_options.enabled {
271                    match char {
272                        '●' => {
273                            // Color the ROC curve points
274                            result.push_str(&colorize("●", Color::BrightGreen));
275                        }
276                        '.' => {
277                            // Color the diagonal line
278                            result.push_str(&colorize(".", Color::BrightBlack));
279                        }
280                        _ => result.push(*char),
281                    }
282                } else {
283                    result.push(*char);
284                }
285            }
286            result.push('\n');
287        }
288
289        // X-axis
290        result.push_str("    +");
291        result.push_str(&"-".repeat(width));
292        result.push('\n');
293
294        // X-axis labels with styling
295        result.push_str("     ");
296        if color_options.enabled {
297            result.push_str(&colorize("0.0", Color::BrightCyan));
298            result.push_str(&" ".repeat(width - 6));
299            result.push_str(&colorize("1.0", Color::BrightCyan));
300        } else {
301            result.push_str("0.0");
302            result.push_str(&" ".repeat(width - 6));
303            result.push_str("1.0");
304        }
305        result.push('\n');
306
307        // Axis labels with styling
308        if color_options.enabled {
309            result.push_str(&format!(
310                "     {}\n",
311                stylize("False Positive Rate (FPR)", Style::Bold)
312            ));
313        } else {
314            result.push_str("     False Positive Rate (FPR)\n");
315        }
316
317        // Add legend if colors are enabled
318        if color_options.enabled {
319            result.push_str(&format!(
320                "     {} ROC curve     {} Random classifier\n",
321                colorize("●", Color::BrightGreen),
322                colorize(".", Color::BrightBlack)
323            ));
324        }
325
326        result
327    }
328}