pub struct ROCCurve<F: Float + Debug + Display> {
pub fpr: Array1<F>,
pub tpr: Array1<F>,
pub thresholds: Array1<F>,
pub auc: F,
}
Expand description
ROC curve data structure for binary classification evaluation
This struct represents a Receiver Operating Characteristic (ROC) curve, which plots the True Positive Rate (TPR) against the False Positive Rate (FPR) at various classification thresholds. It also calculates the Area Under the Curve (AUC), a common metric for binary classification performance.
Fields§
§fpr: Array1<F>
False positive rates at different thresholds
tpr: Array1<F>
True positive rates at different thresholds
thresholds: Array1<F>
Classification thresholds
auc: F
Area Under the ROC Curve (AUC)
Implementations§
Source§impl<F: Float + Debug + Display> ROCCurve<F>
impl<F: Float + Debug + Display> ROCCurve<F>
Sourcepub fn new(
y_true: &ArrayView1<'_, usize>,
yscore: &ArrayView1<'_, F>,
) -> Result<Self>
pub fn new( y_true: &ArrayView1<'_, usize>, yscore: &ArrayView1<'_, F>, ) -> Result<Self>
Compute ROC curve and AUC from binary classification scores
§Arguments
y_true
- True binary labels (0 or 1)y_score
- Predicted probabilities or decision function
§Returns
Result<ROCCurve<F>>
- ROC curve data
§Example
use ndarray::{Array1, ArrayView1};
use scirs2_neural::utils::evaluation::ROCCurve;
// Create some example data
let y_true = Array1::from_vec(vec![0, 1, 1, 0, 1, 0, 1, 0, 1, 0]);
let y_score = Array1::from_vec(vec![0.1, 0.9, 0.8, 0.3, 0.7, 0.2, 0.6, 0.4, 0.8, 0.3]);
// Compute ROC curve
let roc = ROCCurve::<f64>::new(&y_true.view(), &y_score.view()).unwrap();
// AUC should be > 0.5 for a model better than random guessing
assert!(roc.auc > 0.5);
Examples found in repository?
11fn main() -> Result<()> {
12 println!(
13 "{}",
14 stylize("Neural Network Model Evaluation with Color", Style::Bold)
15 );
16 println!("{}", "-".repeat(50));
17 // Set up color options
18 let color_options = ColorOptions {
19 enabled: true,
20 use_background: false,
21 use_bright: true,
22 };
23 // Generate some example data
24 let n_samples = 500;
25 let n_features = 10;
26 let n_classes = 4;
27 println!(
28 "\n{} {} {} {} {} {}",
29 colorize("Generating", Color::BrightGreen),
30 colorize(n_samples.to_string(), Color::BrightYellow),
31 colorize("samples with", Color::BrightGreen),
32 colorize(n_features.to_string(), Color::BrightYellow),
33 colorize("features for", Color::BrightGreen),
34 colorize(n_classes.to_string(), Color::BrightYellow),
35 );
36
37 // Create a deterministic RNG for reproducibility
38 let mut rng = SmallRng::from_seed([42; 32]);
39
40 // 1. Confusion Matrix Example
41 println!(
42 "\n{}",
43 stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
44 );
45 // Generate random predictions and true labels
46 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
47 // Create slightly correlated predictions (not completely random)
48 let y_pred = Array::from_shape_fn(n_samples, |i| {
49 if rng.random::<f32>() < 0.7 {
50 // 70% chance of correct prediction
51 y_true[i]
52 } else {
53 // 30% chance of random class
54 rng.random_range(0..n_classes)
55 }
56 });
57 // Create confusion matrix
58 let class_labels = vec![
59 "Class A".to_string(),
60 "Class B".to_string(),
61 "Class C".to_string(),
62 "Class D".to_string(),
63 ];
64 let cm = ConfusionMatrix::<f32>::new(
65 &y_true.view(),
66 &y_pred.view(),
67 Some(n_classes),
68 Some(class_labels),
69 )?;
70 // Print raw and normalized confusion matrices with color
71 println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
72 println!(
73 "{}",
74 cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
75 );
76 println!(
77 "\n{}",
78 colorize("Normalized Confusion Matrix:", Color::BrightCyan)
79 );
80 println!(
81 "{}",
82 cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
83 );
84 // Print metrics
85 println!(
86 "\n{} {:.3}",
87 colorize("Overall Accuracy:", Color::BrightMagenta),
88 cm.accuracy()
89 );
90 let precision = cm.precision();
91 let recall = cm.recall();
92 let f1 = cm.f1_score();
93 println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
94 for i in 0..n_classes {
95 println!(
96 " {}: {}={:.3}, {}={:.3}, {}={:.3}",
97 colorize(format!("Class {i}"), Color::BrightYellow),
98 colorize("Precision", Color::BrightCyan),
99 precision[i],
100 colorize("Recall", Color::BrightGreen),
101 recall[i],
102 colorize("F1", Color::BrightBlue),
103 f1[i]
104 );
105 }
106 println!(
107 "{} {:.3}",
108 colorize("Macro F1 Score:", Color::BrightMagenta),
109 cm.macro_f1()
110 );
111 // 2. Feature Importance Visualization
112 println!(
113 "{}",
114 stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
115 );
116 // Generate random feature importance scores
117 let feature_names = (0..n_features)
118 .map(|i| format!("Feature_{i}"))
119 .collect::<Vec<String>>();
120 let importance = Array1::from_shape_fn(n_features, |i| {
121 // Make some features more important than others
122 let base = (n_features - i) as f32 / n_features as f32;
123 base + 0.2 * rng.random::<f32>()
124 });
125
126 let fi = FeatureImportance::new(feature_names, importance)?;
127
128 // Print full feature importance with color
129 println!(
130 "{}",
131 fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
132 );
133
134 // Print top-5 features with color
135 println!(
136 "{}",
137 colorize("Top 5 Most Important Features:", Color::BrightCyan)
138 );
139 println!(
140 "{}",
141 fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
142 );
143 // 3. ROC Curve for Binary Classification
144 println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
145 // Generate binary classification data
146 let n_binary = 200;
147 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
148 // Generate scores with some predictive power
149 let y_scores = Array1::from_shape_fn(n_binary, |i| {
150 if y_true_binary[i] == 1 {
151 // Higher scores for positive class
152 0.6 + 0.4 * rng.random::<f32>()
153 } else {
154 // Lower scores for negative class
155 0.4 * rng.random::<f32>()
156 }
157 });
158
159 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
160 println!(
161 "{}: {:.3}",
162 colorize("ROC AUC:", Color::BrightMagenta),
163 roc.auc
164 );
165 println!("\n{}", roc.to_ascii(None, 50, 20));
166
167 // 4. Learning Curve Visualization
168 println!(
169 "\n{}",
170 stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
171 );
172 // Generate learning curve data
173 let n_points = 10;
174 let n_cv = 5;
175 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
176 // Generate training scores (decreasing with size due to overfitting)
177 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
178 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
179 });
180
181 // Generate validation scores (increasing with size)
182 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
183 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
184 });
185
186 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
187 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
188
189 // Print final message with color
190 println!(
191 "{}",
192 colorize(
193 "Model evaluation visualizations completed successfully!",
194 Color::BrightGreen
195 )
196 );
197 Ok(())
198}
Sourcepub fn to_ascii(
&self,
title: Option<&str>,
width: usize,
height: usize,
) -> String
pub fn to_ascii( &self, title: Option<&str>, width: usize, height: usize, ) -> String
Create an ASCII line plot of the ROC curve
§Arguments
title
- Optional title for the plotwidth
- Width of the plotheight
- Height of the plot
§Returns
String
- ASCII line plot
Examples found in repository?
11fn main() -> Result<()> {
12 println!(
13 "{}",
14 stylize("Neural Network Model Evaluation with Color", Style::Bold)
15 );
16 println!("{}", "-".repeat(50));
17 // Set up color options
18 let color_options = ColorOptions {
19 enabled: true,
20 use_background: false,
21 use_bright: true,
22 };
23 // Generate some example data
24 let n_samples = 500;
25 let n_features = 10;
26 let n_classes = 4;
27 println!(
28 "\n{} {} {} {} {} {}",
29 colorize("Generating", Color::BrightGreen),
30 colorize(n_samples.to_string(), Color::BrightYellow),
31 colorize("samples with", Color::BrightGreen),
32 colorize(n_features.to_string(), Color::BrightYellow),
33 colorize("features for", Color::BrightGreen),
34 colorize(n_classes.to_string(), Color::BrightYellow),
35 );
36
37 // Create a deterministic RNG for reproducibility
38 let mut rng = SmallRng::from_seed([42; 32]);
39
40 // 1. Confusion Matrix Example
41 println!(
42 "\n{}",
43 stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
44 );
45 // Generate random predictions and true labels
46 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
47 // Create slightly correlated predictions (not completely random)
48 let y_pred = Array::from_shape_fn(n_samples, |i| {
49 if rng.random::<f32>() < 0.7 {
50 // 70% chance of correct prediction
51 y_true[i]
52 } else {
53 // 30% chance of random class
54 rng.random_range(0..n_classes)
55 }
56 });
57 // Create confusion matrix
58 let class_labels = vec![
59 "Class A".to_string(),
60 "Class B".to_string(),
61 "Class C".to_string(),
62 "Class D".to_string(),
63 ];
64 let cm = ConfusionMatrix::<f32>::new(
65 &y_true.view(),
66 &y_pred.view(),
67 Some(n_classes),
68 Some(class_labels),
69 )?;
70 // Print raw and normalized confusion matrices with color
71 println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
72 println!(
73 "{}",
74 cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
75 );
76 println!(
77 "\n{}",
78 colorize("Normalized Confusion Matrix:", Color::BrightCyan)
79 );
80 println!(
81 "{}",
82 cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
83 );
84 // Print metrics
85 println!(
86 "\n{} {:.3}",
87 colorize("Overall Accuracy:", Color::BrightMagenta),
88 cm.accuracy()
89 );
90 let precision = cm.precision();
91 let recall = cm.recall();
92 let f1 = cm.f1_score();
93 println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
94 for i in 0..n_classes {
95 println!(
96 " {}: {}={:.3}, {}={:.3}, {}={:.3}",
97 colorize(format!("Class {i}"), Color::BrightYellow),
98 colorize("Precision", Color::BrightCyan),
99 precision[i],
100 colorize("Recall", Color::BrightGreen),
101 recall[i],
102 colorize("F1", Color::BrightBlue),
103 f1[i]
104 );
105 }
106 println!(
107 "{} {:.3}",
108 colorize("Macro F1 Score:", Color::BrightMagenta),
109 cm.macro_f1()
110 );
111 // 2. Feature Importance Visualization
112 println!(
113 "{}",
114 stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
115 );
116 // Generate random feature importance scores
117 let feature_names = (0..n_features)
118 .map(|i| format!("Feature_{i}"))
119 .collect::<Vec<String>>();
120 let importance = Array1::from_shape_fn(n_features, |i| {
121 // Make some features more important than others
122 let base = (n_features - i) as f32 / n_features as f32;
123 base + 0.2 * rng.random::<f32>()
124 });
125
126 let fi = FeatureImportance::new(feature_names, importance)?;
127
128 // Print full feature importance with color
129 println!(
130 "{}",
131 fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
132 );
133
134 // Print top-5 features with color
135 println!(
136 "{}",
137 colorize("Top 5 Most Important Features:", Color::BrightCyan)
138 );
139 println!(
140 "{}",
141 fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
142 );
143 // 3. ROC Curve for Binary Classification
144 println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
145 // Generate binary classification data
146 let n_binary = 200;
147 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
148 // Generate scores with some predictive power
149 let y_scores = Array1::from_shape_fn(n_binary, |i| {
150 if y_true_binary[i] == 1 {
151 // Higher scores for positive class
152 0.6 + 0.4 * rng.random::<f32>()
153 } else {
154 // Lower scores for negative class
155 0.4 * rng.random::<f32>()
156 }
157 });
158
159 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
160 println!(
161 "{}: {:.3}",
162 colorize("ROC AUC:", Color::BrightMagenta),
163 roc.auc
164 );
165 println!("\n{}", roc.to_ascii(None, 50, 20));
166
167 // 4. Learning Curve Visualization
168 println!(
169 "\n{}",
170 stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
171 );
172 // Generate learning curve data
173 let n_points = 10;
174 let n_cv = 5;
175 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
176 // Generate training scores (decreasing with size due to overfitting)
177 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
178 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
179 });
180
181 // Generate validation scores (increasing with size)
182 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
183 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
184 });
185
186 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
187 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
188
189 // Print final message with color
190 println!(
191 "{}",
192 colorize(
193 "Model evaluation visualizations completed successfully!",
194 Color::BrightGreen
195 )
196 );
197 Ok(())
198}
Sourcepub fn to_ascii_with_options(
&self,
title: Option<&str>,
width: usize,
height: usize,
color_options: &ColorOptions,
) -> String
pub fn to_ascii_with_options( &self, title: Option<&str>, width: usize, height: usize, color_options: &ColorOptions, ) -> String
Create an ASCII line plot of the ROC curve with color options This method provides a customizable visualization of the ROC curve with controls for colors and styling.
§Arguments
title
- Optional title for the plotwidth
- Width of the plotheight
- Height of the plotcolor_options
- Color options for visualization
§Returns
String
- ASCII line plot with colors
§Example
use scirs2_neural::utils::colors::ColorOptions;
use scirs2_neural::utils::ROCCurve;
use ndarray::Array1;
// Create test data
let y_true = Array1::from_vec(vec![0, 0, 1, 1]);
let y_scores = Array1::from_vec(vec![0.1, 0.4, 0.35, 0.8]);
let roc = ROCCurve::new(&y_true.view(), &y_scores.view()).unwrap();
// Create ROC curve visualization
let options = ColorOptions::default();
let plot = roc.to_ascii_with_options(Some("Model Performance"), 50, 20, &options);
// Visualization will show the curve with the AUC value
assert!(plot.contains("AUC ="));
Auto Trait Implementations§
impl<F> Freeze for ROCCurve<F>where
F: Freeze,
impl<F> RefUnwindSafe for ROCCurve<F>where
F: RefUnwindSafe,
impl<F> Send for ROCCurve<F>where
F: Send,
impl<F> Sync for ROCCurve<F>where
F: Sync,
impl<F> Unpin for ROCCurve<F>where
F: Unpin,
impl<F> UnwindSafe for ROCCurve<F>where
F: UnwindSafe + RefUnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more