scirs2_neural/utils/evaluation/roc_curve.rs
1//! ROC curve for binary classification problems
2//!
3//! This module provides tools for computing, analyzing, and visualizing
4//! Receiver Operating Characteristic (ROC) curves for binary classifiers.
5
6use crate::error::{NeuralError, Result};
7use crate::utils::colors::{
8 colored_metric_cell, colorize, stylize, Color, ColorOptions, Style, RESET,
9};
10use crate::utils::evaluation::helpers::draw_line_with_coords;
11use scirs2_core::ndarray::{Array1, ArrayView1};
12use scirs2_core::numeric::Float;
13use std::fmt::{Debug, Display};
14
15/// ROC curve data structure for binary classification evaluation
16///
17/// This struct represents a Receiver Operating Characteristic (ROC) curve,
18/// which plots the True Positive Rate (TPR) against the False Positive Rate (FPR)
19/// at various classification thresholds. It also calculates the Area Under the
20/// Curve (AUC), a common metric for binary classification performance.
21pub struct ROCCurve<F: Float + Debug + Display> {
22 /// False positive rates at different thresholds
23 pub fpr: Array1<F>,
24 /// True positive rates at different thresholds
25 pub tpr: Array1<F>,
26 /// Classification thresholds
27 pub thresholds: Array1<F>,
28 /// Area Under the ROC Curve (AUC)
29 pub auc: F,
30}
31
32impl<F: Float + Debug + Display> ROCCurve<F> {
33 /// Compute ROC curve and AUC from binary classification scores
34 ///
35 /// # Arguments
36 /// * `y_true` - True binary labels (0 or 1)
37 /// * `y_score` - Predicted probabilities or decision function
38 ///
39 /// # Returns
40 /// * `Result<ROCCurve<F>>` - ROC curve data
41 ///
42 /// # Example
43 /// ```
44 /// use scirs2_core::ndarray::{Array1, ArrayView1};
45 /// use scirs2_neural::utils::evaluation::ROCCurve;
46 ///
47 /// // Create some example data
48 /// let y_true = Array1::from_vec(vec![0, 1, 1, 0, 1, 0, 1, 0, 1, 0]);
49 /// let y_score = Array1::from_vec(vec![0.1, 0.9, 0.8, 0.3, 0.7, 0.2, 0.6, 0.4, 0.8, 0.3]);
50 ///
51 /// // Compute ROC curve
52 /// let roc = ROCCurve::<f64>::new(&y_true.view(), &y_score.view()).unwrap();
53 ///
54 /// // AUC should be > 0.5 for a model better than random guessing
55 /// assert!(roc.auc > 0.5);
56 /// ```
57 pub fn new(y_true: &ArrayView1<usize>, yscore: &ArrayView1<F>) -> Result<Self> {
58 if y_true.len() != yscore.len() {
59 return Err(NeuralError::ValidationError(
60 "Labels and scores must have the same length".to_string(),
61 ));
62 }
63
64 // Check if y_true contains only binary values (0 or 1)
65 for &label in y_true.iter() {
66 if label != 0 && label != 1 {
67 return Err(NeuralError::ValidationError(
68 "Labels must be binary (0 or 1)".to_string(),
69 ));
70 }
71 }
72
73 // Sort scores and corresponding labels in descending order
74 let mut score_label_pairs: Vec<(F, usize)> = yscore
75 .iter()
76 .zip(y_true.iter())
77 .map(|(&_score, &label)| (_score, label))
78 .collect();
79 score_label_pairs
80 .sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
81
82 // Count positives and negatives
83 let n_pos = y_true.iter().filter(|&&label| label == 1).count();
84 let n_neg = y_true.len() - n_pos;
85 if n_pos == 0 || n_neg == 0 {
86 return Err(NeuralError::ValidationError(
87 "Both positive and negative samples are required".to_string(),
88 ));
89 }
90
91 // Initialize arrays for ROC curve
92 let n_thresholds = score_label_pairs.len() + 1;
93 let mut fpr = Array1::zeros(n_thresholds);
94 let mut tpr = Array1::zeros(n_thresholds);
95 let mut thresholds = Array1::zeros(n_thresholds);
96
97 // Set the first point (0,0) and last threshold to infinity
98 thresholds[0] = F::infinity();
99
100 // Compute ROC curve
101 let mut tp = 0;
102 let mut fp = 0;
103 for i in 0..score_label_pairs.len() {
104 let (score, label) = score_label_pairs[i];
105 // Update counts
106 if label == 1 {
107 tp += 1;
108 } else {
109 fp += 1;
110 }
111
112 // Set threshold for this point
113 thresholds[i + 1] = score;
114 // Compute rates
115 tpr[i + 1] = F::from(tp).unwrap() / F::from(n_pos).unwrap();
116 fpr[i + 1] = F::from(fp).unwrap() / F::from(n_neg).unwrap();
117 }
118
119 // Compute AUC using trapezoidal rule
120 let mut auc = F::zero();
121 for i in 0..fpr.len() - 1 {
122 auc = auc + (fpr[i + 1] - fpr[i]) * (tpr[i] + tpr[i + 1]) * F::from(0.5).unwrap();
123 }
124
125 Ok(ROCCurve {
126 fpr,
127 tpr,
128 thresholds,
129 auc,
130 })
131 }
132
133 /// Create an ASCII line plot of the ROC curve
134 ///
135 /// # Arguments
136 /// * `title` - Optional title for the plot
137 /// * `width` - Width of the plot
138 /// * `height` - Height of the plot
139 ///
140 /// # Returns
141 /// * `String` - ASCII line plot
142 pub fn to_ascii(&self, title: Option<&str>, width: usize, height: usize) -> String {
143 self.to_ascii_with_options(title, width, height, &ColorOptions::default())
144 }
145
146 /// Create an ASCII line plot of the ROC curve with color options
147 /// This method provides a customizable visualization of the ROC curve
148 /// with controls for colors and styling.
149 ///
150 /// # Arguments
151 /// * `title` - Optional title for the plot
152 /// * `width` - Width of the plot
153 /// * `height` - Height of the plot
154 /// * `color_options` - Color options for visualization
155 ///
156 /// # Returns
157 /// * `String` - ASCII line plot with colors
158 ///
159 /// # Example
160 /// ```
161 /// use scirs2_neural::utils::colors::ColorOptions;
162 /// use scirs2_neural::utils::ROCCurve;
163 /// use scirs2_core::ndarray::Array1;
164 ///
165 /// // Create test data
166 /// let y_true = Array1::from_vec(vec![0, 0, 1, 1]);
167 /// let y_scores = Array1::from_vec(vec![0.1, 0.4, 0.35, 0.8]);
168 /// let roc = ROCCurve::new(&y_true.view(), &y_scores.view()).unwrap();
169 ///
170 /// // Create ROC curve visualization
171 /// let options = ColorOptions::default();
172 /// let plot = roc.to_ascii_with_options(Some("Model Performance"), 50, 20, &options);
173 ///
174 /// // Visualization will show the curve with the AUC value
175 /// assert!(plot.contains("AUC ="));
176 /// ```
177 pub fn to_ascii_with_options(
178 &self,
179 title: Option<&str>,
180 width: usize,
181 height: usize,
182 color_options: &ColorOptions,
183 ) -> String {
184 // Pre-allocate result string with estimated capacity
185 let mut result = String::with_capacity(width * height * 2);
186
187 // Add title and AUC with coloring if enabled
188 if let Some(titletext) = title {
189 if color_options.enabled {
190 let styled_title = stylize(titletext, Style::Bold);
191 let auc_value = self.auc.to_f64().unwrap_or(0.0);
192 let colored_auc =
193 colored_metric_cell(format!("{:.3}", self.auc), auc_value, color_options);
194 result.push_str(&format!("{styled_title} (AUC = {colored_auc})\n\n"));
195 } else {
196 result.push_str(&format!("{} (AUC = {:.3})\n\n", titletext, self.auc));
197 }
198 } else if color_options.enabled {
199 let styled_title = stylize("ROC Curve", Style::Bold);
200 let auc_value = self.auc.to_f64().unwrap_or(0.0);
201 let colored_auc =
202 colored_metric_cell(format!("{:.3}", self.auc), auc_value, color_options);
203 result.push_str(&format!("{styled_title} (AUC = {colored_auc})\n\n"));
204 } else {
205 result.push_str(&format!("ROC Curve (AUC = {:.3})\n\n", self.auc));
206 }
207
208 // Create a 2D grid for the plot
209 let mut grid = vec![vec![' '; width]; height];
210
211 // Draw the diagonal (random classifier line)
212 for i in 0..std::cmp::min(width, height) {
213 let x = i;
214 let y = height - 1 - i * (height - 1) / (width - 1);
215 if x < width && y < height {
216 grid[y][x] = '.';
217 }
218 }
219
220 // Convert ROC curve points to line segments
221 let mut prev_x = 0;
222 let mut prev_y = height - 1; // Start at (0,0) in ROC space
223 for i in 1..self.fpr.len() {
224 let x = (self.fpr[i].to_f64().unwrap() * (width - 1) as f64).round() as usize;
225 let y =
226 height - 1 - (self.tpr[i].to_f64().unwrap() * (height - 1) as f64).round() as usize;
227
228 // Draw line segments between points
229 if x != prev_x || y != prev_y {
230 for (line_x, line_y) in
231 draw_line_with_coords(prev_x, prev_y, x, y, Some(width), Some(height))
232 {
233 grid[line_y][line_x] = '●';
234 }
235 prev_x = x;
236 prev_y = y;
237 }
238 }
239
240 // Draw the grid
241 for (y, row) in grid.iter().enumerate() {
242 // Y-axis labels with styling
243 if y == height - 1 {
244 if color_options.enabled {
245 let fg_code = Color::BrightCyan.fg_code();
246 result.push_str(&format!("{fg_code}0.0{RESET} |"));
247 } else {
248 result.push_str("0.0 |");
249 }
250 } else if y == 0 {
251 if color_options.enabled {
252 let fg_code = Color::BrightCyan.fg_code();
253 result.push_str(&format!("{fg_code}1.0{RESET} |"));
254 } else {
255 result.push_str("1.0 |");
256 }
257 } else if y == height / 2 {
258 if color_options.enabled {
259 let fg_code = Color::BrightCyan.fg_code();
260 result.push_str(&format!("{fg_code}0.5{RESET} |"));
261 } else {
262 result.push_str("0.5 |");
263 }
264 } else {
265 result.push_str(" |");
266 }
267
268 // Grid content with coloring
269 for char in row.iter().take(width) {
270 if color_options.enabled {
271 match char {
272 '●' => {
273 // Color the ROC curve points
274 result.push_str(&colorize("●", Color::BrightGreen));
275 }
276 '.' => {
277 // Color the diagonal line
278 result.push_str(&colorize(".", Color::BrightBlack));
279 }
280 _ => result.push(*char),
281 }
282 } else {
283 result.push(*char);
284 }
285 }
286 result.push('\n');
287 }
288
289 // X-axis
290 result.push_str(" +");
291 result.push_str(&"-".repeat(width));
292 result.push('\n');
293
294 // X-axis labels with styling
295 result.push_str(" ");
296 if color_options.enabled {
297 result.push_str(&colorize("0.0", Color::BrightCyan));
298 result.push_str(&" ".repeat(width - 6));
299 result.push_str(&colorize("1.0", Color::BrightCyan));
300 } else {
301 result.push_str("0.0");
302 result.push_str(&" ".repeat(width - 6));
303 result.push_str("1.0");
304 }
305 result.push('\n');
306
307 // Axis labels with styling
308 if color_options.enabled {
309 result.push_str(&format!(
310 " {}\n",
311 stylize("False Positive Rate (FPR)", Style::Bold)
312 ));
313 } else {
314 result.push_str(" False Positive Rate (FPR)\n");
315 }
316
317 // Add legend if colors are enabled
318 if color_options.enabled {
319 result.push_str(&format!(
320 " {} ROC curve {} Random classifier\n",
321 colorize("●", Color::BrightGreen),
322 colorize(".", Color::BrightBlack)
323 ));
324 }
325
326 result
327 }
328}