pub struct LearningCurve<F: Float + Debug + Display> {
pub train_sizes: Array1<usize>,
pub train_scores: Array2<F>,
pub val_scores: Array2<F>,
pub train_mean: Array1<F>,
pub train_std: Array1<F>,
pub val_mean: Array1<F>,
pub val_std: Array1<F>,
}
Expand description
Learning curve data structure for visualizing model performance
This structure represents learning curves that show how model performance changes as the training set size increases, comparing training and validation metrics to help diagnose overfitting, underfitting, and other training issues.
Fields§
§train_sizes: Array1<usize>
Training set sizes used for evaluation
train_scores: Array2<F>
Training scores for each size and fold (rows=sizes, cols=folds)
val_scores: Array2<F>
Validation scores for each size and fold (rows=sizes, cols=folds)
train_mean: Array1<F>
Mean training scores across folds
train_std: Array1<F>
Standard deviation of training scores
val_mean: Array1<F>
Mean validation scores across folds
val_std: Array1<F>
Standard deviation of validation scores
Implementations§
Source§impl<F: Float + Debug + Display + FromPrimitive> LearningCurve<F>
impl<F: Float + Debug + Display + FromPrimitive> LearningCurve<F>
Sourcepub fn new(
train_sizes: Array1<usize>,
train_scores: Array2<F>,
val_scores: Array2<F>,
) -> Result<Self>
pub fn new( train_sizes: Array1<usize>, train_scores: Array2<F>, val_scores: Array2<F>, ) -> Result<Self>
Create a new learning curve from training and validation scores
§Arguments
train_sizes
- Array of training set sizestrain_scores
- 2D array of training scores (rows=sizes, cols=cv folds)val_scores
- 2D array of validation scores (rows=sizes, cols=cv folds)
§Returns
Result<LearningCurve<F>>
- Learning curve data
§Example
use ndarray::{Array1, Array2};
use scirs2_neural::utils::evaluation::LearningCurve;
// Create sample data
let train_sizes = Array1::from_vec(vec![100, 200, 300, 400, 500]);
let train_scores = Array2::from_shape_vec((5, 3), vec![
0.6, 0.62, 0.58, // 100 samples, 3 folds
0.7, 0.72, 0.68, // 200 samples, 3 folds
0.8, 0.78, 0.79, // 300 samples, 3 folds
0.85, 0.83, 0.84, // 400 samples, 3 folds
0.87, 0.88, 0.86, // 500 samples, 3 folds
]).unwrap();
let val_scores = Array2::from_shape_vec((5, 3), vec![
0.55, 0.53, 0.54, // 100 samples, 3 folds
0.65, 0.63, 0.64, // 200 samples, 3 folds
0.75, 0.73, 0.74, // 300 samples, 3 folds
0.76, 0.74, 0.75, // 400 samples, 3 folds
0.77, 0.76, 0.76, // 500 samples, 3 folds
]).unwrap();
// Create learning curve
let curve = LearningCurve::<f64>::new(train_sizes, train_scores, val_scores).unwrap();
Examples found in repository?
8fn main() {
9 // Create a reproducible random number generator
10 let mut rng = SmallRng::seed_from_u64(42);
11
12 // Example 1: ROC Curve with color
13 println!("Example 1: ROC Curve Visualization (with color)\n");
14
15 // Generate synthetic binary classification data
16 let n_samples = 200;
17
18 // Generate true labels: 0 or 1
19 let y_true: Vec<usize> = (0..n_samples)
20 .map(|_| if rng.random::<f64>() > 0.5 { 1 } else { 0 })
21 .collect();
22
23 // Generate scores with some separability
24 let y_score: Vec<f64> = y_true
25 .iter()
26 .map(|&label| {
27 if label == 1 {
28 0.7 + 0.3 * rng.sample::<f64, _>(StandardNormal)
29 } else {
30 0.3 + 0.3 * rng.sample::<f64, _>(StandardNormal)
31 }
32 })
33 .collect();
34
35 // Convert to ndarray views
36 let y_true_array = Array1::from(y_true.clone());
37 let y_score_array = Array1::from(y_score.clone());
38
39 // Create ROC curve
40 let roc = ROCCurve::new(&y_true_array.view(), &y_score_array.view()).unwrap();
41
42 // Enable color options
43 let color_options = ColorOptions {
44 enabled: true,
45 use_bright: true,
46 use_background: false,
47 };
48
49 // Plot ROC curve with color
50 let roc_plot = roc.to_ascii_with_options(
51 Some("Binary Classification ROC Curve"),
52 60,
53 20,
54 &color_options,
55 );
56 println!("{}", roc_plot);
57
58 // Example 2: Learning Curve with color
59 println!("\nExample 2: Learning Curve Visualization (with color)\n");
60
61 // Simulate learning curves for different training set sizes
62 let train_sizes = Array1::from(vec![100, 200, 300, 400, 500]);
63
64 // Simulated training scores for each size (5 sizes, 3 CV folds)
65 let train_scores = Array2::from_shape_fn((5, 3), |(i, _j)| {
66 let base = 0.5 + 0.4 * (i as f64 / 4.0);
67 let noise = 0.05 * rng.sample::<f64, _>(StandardNormal);
68 base + noise
69 });
70
71 // Simulated validation scores (typically lower than training)
72 let val_scores = Array2::from_shape_fn((5, 3), |(i, _j)| {
73 let base = 0.4 + 0.3 * (i as f64 / 4.0);
74 let noise = 0.07 * rng.sample::<f64, _>(StandardNormal);
75 base + noise
76 });
77
78 // Create learning curve
79 let learning_curve = LearningCurve::new(train_sizes, train_scores, val_scores).unwrap();
80
81 // Plot learning curve with color
82 let learning_plot = learning_curve.to_ascii_with_options(
83 Some("Neural Network Training"),
84 70,
85 20,
86 "Accuracy",
87 &color_options,
88 );
89 println!("{}", learning_plot);
90}
More examples
7fn main() -> Result<()> {
8 println!("Neural Network Model Evaluation Visualization Example\n");
9
10 // Generate some example data
11 let n_samples = 500;
12 let n_features = 10;
13 let n_classes = 4;
14
15 println!(
16 "Generating {} samples with {} features for {} classes",
17 n_samples, n_features, n_classes
18 );
19
20 // 1. Confusion Matrix Example
21 println!("\n--- Confusion Matrix Visualization ---\n");
22
23 // Create a deterministic RNG for reproducibility
24 let mut rng = SmallRng::seed_from_u64(42);
25
26 // Generate random predictions and true labels
27 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
28
29 // Create slightly correlated predictions (not completely random)
30 let y_pred = Array::from_shape_fn(n_samples, |i| {
31 if rng.random::<f32>() < 0.7 {
32 // 70% chance of correct prediction
33 y_true[i]
34 } else {
35 // 30% chance of random class
36 rng.random_range(0..n_classes)
37 }
38 });
39
40 // Create confusion matrix
41 let class_labels = vec![
42 "Class A".to_string(),
43 "Class B".to_string(),
44 "Class C".to_string(),
45 "Class D".to_string(),
46 ];
47
48 let cm = ConfusionMatrix::<f32>::new(
49 &y_true.view(),
50 &y_pred.view(),
51 Some(n_classes),
52 Some(class_labels),
53 )?;
54
55 // Print raw and normalized confusion matrices
56 println!("Raw Confusion Matrix:\n");
57 println!("{}", cm.to_ascii(Some("Confusion Matrix"), false));
58
59 println!("\nNormalized Confusion Matrix:\n");
60 println!("{}", cm.to_ascii(Some("Normalized Confusion Matrix"), true));
61
62 // Print metrics
63 println!("\nAccuracy: {:.3}", cm.accuracy());
64
65 let precision = cm.precision();
66 let recall = cm.recall();
67 let f1 = cm.f1_score();
68
69 println!("Per-class metrics:");
70 for i in 0..n_classes {
71 println!(
72 " Class {}: Precision={:.3}, Recall={:.3}, F1={:.3}",
73 i, precision[i], recall[i], f1[i]
74 );
75 }
76
77 println!("Macro F1 Score: {:.3}", cm.macro_f1());
78
79 // 2. Feature Importance Visualization
80 println!("\n--- Feature Importance Visualization ---\n");
81
82 // Generate random feature importance scores
83 let feature_names = (0..n_features)
84 .map(|i| format!("Feature_{}", i))
85 .collect::<Vec<String>>();
86
87 let importance = Array1::from_shape_fn(n_features, |i| {
88 // Make some features more important than others
89 let base = (n_features - i) as f32 / n_features as f32;
90 base + 0.2 * rng.random::<f32>()
91 });
92
93 let fi = FeatureImportance::new(feature_names, importance)?;
94
95 // Print full feature importance
96 println!("{}", fi.to_ascii(Some("Feature Importance"), 60, None));
97
98 // Print top-5 features
99 println!("\nTop 5 Most Important Features:\n");
100 println!("{}", fi.to_ascii(Some("Top 5 Features"), 60, Some(5)));
101
102 // 3. ROC Curve for Binary Classification
103 println!("\n--- ROC Curve Visualization ---\n");
104
105 // Generate binary classification data
106 let n_binary = 200;
107 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
108
109 // Generate scores with some predictive power
110 let y_scores = Array1::from_shape_fn(n_binary, |i| {
111 if y_true_binary[i] == 1 {
112 // Higher scores for positive class
113 0.6 + 0.4 * rng.random::<f32>()
114 } else {
115 // Lower scores for negative class
116 0.4 * rng.random::<f32>()
117 }
118 });
119
120 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
121
122 println!("ROC AUC: {:.3}", roc.auc);
123 println!("\n{}", roc.to_ascii(None, 50, 20));
124
125 // 4. Learning Curve Visualization
126 println!("\n--- Learning Curve Visualization ---\n");
127
128 // Generate learning curve data
129 let n_points = 10;
130 let n_cv = 5;
131
132 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
133
134 // Generate training scores (decreasing with size due to overfitting)
135 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
136 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
137 });
138
139 // Generate validation scores (increasing with size)
140 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
141 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
142 });
143
144 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
145
146 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
147
148 // Print final message
149 println!("\nModel evaluation visualizations completed successfully!");
150
151 Ok(())
152}
10fn main() -> Result<()> {
11 println!(
12 "{}",
13 stylize("Neural Network Model Evaluation with Color", Style::Bold)
14 );
15 println!("{}", "-".repeat(50));
16
17 // Set up color options
18 let color_options = ColorOptions {
19 enabled: true,
20 use_background: false,
21 use_bright: true,
22 };
23
24 // Generate some example data
25 let n_samples = 500;
26 let n_features = 10;
27 let n_classes = 4;
28
29 println!(
30 "\n{} {} {} {} {} {}",
31 colorize("Generating", Color::BrightGreen),
32 colorize(n_samples.to_string(), Color::BrightYellow),
33 colorize("samples with", Color::BrightGreen),
34 colorize(n_features.to_string(), Color::BrightYellow),
35 colorize("features for", Color::BrightGreen),
36 colorize(n_classes.to_string(), Color::BrightYellow),
37 );
38
39 // Create a deterministic RNG for reproducibility
40 let mut rng = SmallRng::seed_from_u64(42);
41
42 // 1. Confusion Matrix Example
43 println!(
44 "\n{}",
45 stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
46 );
47
48 // Generate random predictions and true labels
49 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
50
51 // Create slightly correlated predictions (not completely random)
52 let y_pred = Array::from_shape_fn(n_samples, |i| {
53 if rng.random::<f32>() < 0.7 {
54 // 70% chance of correct prediction
55 y_true[i]
56 } else {
57 // 30% chance of random class
58 rng.random_range(0..n_classes)
59 }
60 });
61
62 // Create confusion matrix
63 let class_labels = vec![
64 "Class A".to_string(),
65 "Class B".to_string(),
66 "Class C".to_string(),
67 "Class D".to_string(),
68 ];
69
70 let cm = ConfusionMatrix::<f32>::new(
71 &y_true.view(),
72 &y_pred.view(),
73 Some(n_classes),
74 Some(class_labels),
75 )?;
76
77 // Print raw and normalized confusion matrices with color
78 println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
79 println!(
80 "{}",
81 cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
82 );
83
84 println!(
85 "\n{}",
86 colorize("Normalized Confusion Matrix:", Color::BrightCyan)
87 );
88 println!(
89 "{}",
90 cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
91 );
92
93 // Print metrics
94 println!(
95 "\n{} {:.3}",
96 colorize("Overall Accuracy:", Color::BrightMagenta),
97 cm.accuracy()
98 );
99
100 let precision = cm.precision();
101 let recall = cm.recall();
102 let f1 = cm.f1_score();
103
104 println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
105 for i in 0..n_classes {
106 println!(
107 " {}: {}={:.3}, {}={:.3}, {}={:.3}",
108 colorize(format!("Class {}", i), Color::BrightYellow),
109 colorize("Precision", Color::BrightCyan),
110 precision[i],
111 colorize("Recall", Color::BrightGreen),
112 recall[i],
113 colorize("F1", Color::BrightBlue),
114 f1[i]
115 );
116 }
117
118 println!(
119 "{} {:.3}",
120 colorize("Macro F1 Score:", Color::BrightMagenta),
121 cm.macro_f1()
122 );
123
124 // 2. Feature Importance Visualization
125 println!(
126 "\n{}",
127 stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
128 );
129
130 // Generate random feature importance scores
131 let feature_names = (0..n_features)
132 .map(|i| format!("Feature_{}", i))
133 .collect::<Vec<String>>();
134
135 let importance = Array1::from_shape_fn(n_features, |i| {
136 // Make some features more important than others
137 let base = (n_features - i) as f32 / n_features as f32;
138 base + 0.2 * rng.random::<f32>()
139 });
140
141 let fi = FeatureImportance::new(feature_names, importance)?;
142
143 // Print full feature importance with color
144 println!(
145 "{}",
146 fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
147 );
148
149 // Print top-5 features with color
150 println!(
151 "\n{}",
152 colorize("Top 5 Most Important Features:", Color::BrightCyan)
153 );
154 println!(
155 "{}",
156 fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
157 );
158
159 // 3. ROC Curve for Binary Classification
160 println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
161
162 // Generate binary classification data
163 let n_binary = 200;
164 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
165
166 // Generate scores with some predictive power
167 let y_scores = Array1::from_shape_fn(n_binary, |i| {
168 if y_true_binary[i] == 1 {
169 // Higher scores for positive class
170 0.6 + 0.4 * rng.random::<f32>()
171 } else {
172 // Lower scores for negative class
173 0.4 * rng.random::<f32>()
174 }
175 });
176
177 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
178
179 println!(
180 "{} {:.3}",
181 colorize("ROC AUC:", Color::BrightMagenta),
182 roc.auc
183 );
184
185 println!("\n{}", roc.to_ascii(None, 50, 20));
186
187 // 4. Learning Curve Visualization
188 println!(
189 "\n{}",
190 stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
191 );
192
193 // Generate learning curve data
194 let n_points = 10;
195 let n_cv = 5;
196
197 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
198
199 // Generate training scores (decreasing with size due to overfitting)
200 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
201 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
202 });
203
204 // Generate validation scores (increasing with size)
205 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
206 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
207 });
208
209 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
210
211 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
212
213 // Print final message with color
214 println!(
215 "\n{}",
216 colorize(
217 "Model evaluation visualizations completed successfully!",
218 Color::BrightGreen
219 )
220 );
221
222 Ok(())
223}
Sourcepub fn to_ascii(
&self,
title: Option<&str>,
width: usize,
height: usize,
metric_name: &str,
) -> String
pub fn to_ascii( &self, title: Option<&str>, width: usize, height: usize, metric_name: &str, ) -> String
Create an ASCII line plot of the learning curve
§Arguments
title
- Optional title for the plotwidth
- Width of the plotheight
- Height of the plotmetric_name
- Name of the metric (e.g., “Accuracy”)
§Returns
String
- ASCII line plot
Examples found in repository?
7fn main() -> Result<()> {
8 println!("Neural Network Model Evaluation Visualization Example\n");
9
10 // Generate some example data
11 let n_samples = 500;
12 let n_features = 10;
13 let n_classes = 4;
14
15 println!(
16 "Generating {} samples with {} features for {} classes",
17 n_samples, n_features, n_classes
18 );
19
20 // 1. Confusion Matrix Example
21 println!("\n--- Confusion Matrix Visualization ---\n");
22
23 // Create a deterministic RNG for reproducibility
24 let mut rng = SmallRng::seed_from_u64(42);
25
26 // Generate random predictions and true labels
27 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
28
29 // Create slightly correlated predictions (not completely random)
30 let y_pred = Array::from_shape_fn(n_samples, |i| {
31 if rng.random::<f32>() < 0.7 {
32 // 70% chance of correct prediction
33 y_true[i]
34 } else {
35 // 30% chance of random class
36 rng.random_range(0..n_classes)
37 }
38 });
39
40 // Create confusion matrix
41 let class_labels = vec![
42 "Class A".to_string(),
43 "Class B".to_string(),
44 "Class C".to_string(),
45 "Class D".to_string(),
46 ];
47
48 let cm = ConfusionMatrix::<f32>::new(
49 &y_true.view(),
50 &y_pred.view(),
51 Some(n_classes),
52 Some(class_labels),
53 )?;
54
55 // Print raw and normalized confusion matrices
56 println!("Raw Confusion Matrix:\n");
57 println!("{}", cm.to_ascii(Some("Confusion Matrix"), false));
58
59 println!("\nNormalized Confusion Matrix:\n");
60 println!("{}", cm.to_ascii(Some("Normalized Confusion Matrix"), true));
61
62 // Print metrics
63 println!("\nAccuracy: {:.3}", cm.accuracy());
64
65 let precision = cm.precision();
66 let recall = cm.recall();
67 let f1 = cm.f1_score();
68
69 println!("Per-class metrics:");
70 for i in 0..n_classes {
71 println!(
72 " Class {}: Precision={:.3}, Recall={:.3}, F1={:.3}",
73 i, precision[i], recall[i], f1[i]
74 );
75 }
76
77 println!("Macro F1 Score: {:.3}", cm.macro_f1());
78
79 // 2. Feature Importance Visualization
80 println!("\n--- Feature Importance Visualization ---\n");
81
82 // Generate random feature importance scores
83 let feature_names = (0..n_features)
84 .map(|i| format!("Feature_{}", i))
85 .collect::<Vec<String>>();
86
87 let importance = Array1::from_shape_fn(n_features, |i| {
88 // Make some features more important than others
89 let base = (n_features - i) as f32 / n_features as f32;
90 base + 0.2 * rng.random::<f32>()
91 });
92
93 let fi = FeatureImportance::new(feature_names, importance)?;
94
95 // Print full feature importance
96 println!("{}", fi.to_ascii(Some("Feature Importance"), 60, None));
97
98 // Print top-5 features
99 println!("\nTop 5 Most Important Features:\n");
100 println!("{}", fi.to_ascii(Some("Top 5 Features"), 60, Some(5)));
101
102 // 3. ROC Curve for Binary Classification
103 println!("\n--- ROC Curve Visualization ---\n");
104
105 // Generate binary classification data
106 let n_binary = 200;
107 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
108
109 // Generate scores with some predictive power
110 let y_scores = Array1::from_shape_fn(n_binary, |i| {
111 if y_true_binary[i] == 1 {
112 // Higher scores for positive class
113 0.6 + 0.4 * rng.random::<f32>()
114 } else {
115 // Lower scores for negative class
116 0.4 * rng.random::<f32>()
117 }
118 });
119
120 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
121
122 println!("ROC AUC: {:.3}", roc.auc);
123 println!("\n{}", roc.to_ascii(None, 50, 20));
124
125 // 4. Learning Curve Visualization
126 println!("\n--- Learning Curve Visualization ---\n");
127
128 // Generate learning curve data
129 let n_points = 10;
130 let n_cv = 5;
131
132 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
133
134 // Generate training scores (decreasing with size due to overfitting)
135 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
136 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
137 });
138
139 // Generate validation scores (increasing with size)
140 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
141 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
142 });
143
144 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
145
146 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
147
148 // Print final message
149 println!("\nModel evaluation visualizations completed successfully!");
150
151 Ok(())
152}
More examples
10fn main() -> Result<()> {
11 println!(
12 "{}",
13 stylize("Neural Network Model Evaluation with Color", Style::Bold)
14 );
15 println!("{}", "-".repeat(50));
16
17 // Set up color options
18 let color_options = ColorOptions {
19 enabled: true,
20 use_background: false,
21 use_bright: true,
22 };
23
24 // Generate some example data
25 let n_samples = 500;
26 let n_features = 10;
27 let n_classes = 4;
28
29 println!(
30 "\n{} {} {} {} {} {}",
31 colorize("Generating", Color::BrightGreen),
32 colorize(n_samples.to_string(), Color::BrightYellow),
33 colorize("samples with", Color::BrightGreen),
34 colorize(n_features.to_string(), Color::BrightYellow),
35 colorize("features for", Color::BrightGreen),
36 colorize(n_classes.to_string(), Color::BrightYellow),
37 );
38
39 // Create a deterministic RNG for reproducibility
40 let mut rng = SmallRng::seed_from_u64(42);
41
42 // 1. Confusion Matrix Example
43 println!(
44 "\n{}",
45 stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
46 );
47
48 // Generate random predictions and true labels
49 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
50
51 // Create slightly correlated predictions (not completely random)
52 let y_pred = Array::from_shape_fn(n_samples, |i| {
53 if rng.random::<f32>() < 0.7 {
54 // 70% chance of correct prediction
55 y_true[i]
56 } else {
57 // 30% chance of random class
58 rng.random_range(0..n_classes)
59 }
60 });
61
62 // Create confusion matrix
63 let class_labels = vec![
64 "Class A".to_string(),
65 "Class B".to_string(),
66 "Class C".to_string(),
67 "Class D".to_string(),
68 ];
69
70 let cm = ConfusionMatrix::<f32>::new(
71 &y_true.view(),
72 &y_pred.view(),
73 Some(n_classes),
74 Some(class_labels),
75 )?;
76
77 // Print raw and normalized confusion matrices with color
78 println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
79 println!(
80 "{}",
81 cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
82 );
83
84 println!(
85 "\n{}",
86 colorize("Normalized Confusion Matrix:", Color::BrightCyan)
87 );
88 println!(
89 "{}",
90 cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
91 );
92
93 // Print metrics
94 println!(
95 "\n{} {:.3}",
96 colorize("Overall Accuracy:", Color::BrightMagenta),
97 cm.accuracy()
98 );
99
100 let precision = cm.precision();
101 let recall = cm.recall();
102 let f1 = cm.f1_score();
103
104 println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
105 for i in 0..n_classes {
106 println!(
107 " {}: {}={:.3}, {}={:.3}, {}={:.3}",
108 colorize(format!("Class {}", i), Color::BrightYellow),
109 colorize("Precision", Color::BrightCyan),
110 precision[i],
111 colorize("Recall", Color::BrightGreen),
112 recall[i],
113 colorize("F1", Color::BrightBlue),
114 f1[i]
115 );
116 }
117
118 println!(
119 "{} {:.3}",
120 colorize("Macro F1 Score:", Color::BrightMagenta),
121 cm.macro_f1()
122 );
123
124 // 2. Feature Importance Visualization
125 println!(
126 "\n{}",
127 stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
128 );
129
130 // Generate random feature importance scores
131 let feature_names = (0..n_features)
132 .map(|i| format!("Feature_{}", i))
133 .collect::<Vec<String>>();
134
135 let importance = Array1::from_shape_fn(n_features, |i| {
136 // Make some features more important than others
137 let base = (n_features - i) as f32 / n_features as f32;
138 base + 0.2 * rng.random::<f32>()
139 });
140
141 let fi = FeatureImportance::new(feature_names, importance)?;
142
143 // Print full feature importance with color
144 println!(
145 "{}",
146 fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
147 );
148
149 // Print top-5 features with color
150 println!(
151 "\n{}",
152 colorize("Top 5 Most Important Features:", Color::BrightCyan)
153 );
154 println!(
155 "{}",
156 fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
157 );
158
159 // 3. ROC Curve for Binary Classification
160 println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
161
162 // Generate binary classification data
163 let n_binary = 200;
164 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
165
166 // Generate scores with some predictive power
167 let y_scores = Array1::from_shape_fn(n_binary, |i| {
168 if y_true_binary[i] == 1 {
169 // Higher scores for positive class
170 0.6 + 0.4 * rng.random::<f32>()
171 } else {
172 // Lower scores for negative class
173 0.4 * rng.random::<f32>()
174 }
175 });
176
177 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
178
179 println!(
180 "{} {:.3}",
181 colorize("ROC AUC:", Color::BrightMagenta),
182 roc.auc
183 );
184
185 println!("\n{}", roc.to_ascii(None, 50, 20));
186
187 // 4. Learning Curve Visualization
188 println!(
189 "\n{}",
190 stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
191 );
192
193 // Generate learning curve data
194 let n_points = 10;
195 let n_cv = 5;
196
197 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
198
199 // Generate training scores (decreasing with size due to overfitting)
200 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
201 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
202 });
203
204 // Generate validation scores (increasing with size)
205 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
206 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
207 });
208
209 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
210
211 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
212
213 // Print final message with color
214 println!(
215 "\n{}",
216 colorize(
217 "Model evaluation visualizations completed successfully!",
218 Color::BrightGreen
219 )
220 );
221
222 Ok(())
223}
Sourcepub fn to_ascii_with_options(
&self,
title: Option<&str>,
width: usize,
height: usize,
metric_name: &str,
color_options: &ColorOptions,
) -> String
pub fn to_ascii_with_options( &self, title: Option<&str>, width: usize, height: usize, metric_name: &str, color_options: &ColorOptions, ) -> String
Create an ASCII line plot of the learning curve with customizable colors
This method allows fine-grained control over the color scheme using the provided ColorOptions parameter.
§Arguments
title
- Optional title for the plotwidth
- Width of the plotheight
- Height of the plotmetric_name
- Name of the metric (e.g., “Accuracy”)color_options
- Color options for visualization
§Returns
String
- ASCII line plot with colors
Examples found in repository?
8fn main() {
9 // Create a reproducible random number generator
10 let mut rng = SmallRng::seed_from_u64(42);
11
12 // Example 1: ROC Curve with color
13 println!("Example 1: ROC Curve Visualization (with color)\n");
14
15 // Generate synthetic binary classification data
16 let n_samples = 200;
17
18 // Generate true labels: 0 or 1
19 let y_true: Vec<usize> = (0..n_samples)
20 .map(|_| if rng.random::<f64>() > 0.5 { 1 } else { 0 })
21 .collect();
22
23 // Generate scores with some separability
24 let y_score: Vec<f64> = y_true
25 .iter()
26 .map(|&label| {
27 if label == 1 {
28 0.7 + 0.3 * rng.sample::<f64, _>(StandardNormal)
29 } else {
30 0.3 + 0.3 * rng.sample::<f64, _>(StandardNormal)
31 }
32 })
33 .collect();
34
35 // Convert to ndarray views
36 let y_true_array = Array1::from(y_true.clone());
37 let y_score_array = Array1::from(y_score.clone());
38
39 // Create ROC curve
40 let roc = ROCCurve::new(&y_true_array.view(), &y_score_array.view()).unwrap();
41
42 // Enable color options
43 let color_options = ColorOptions {
44 enabled: true,
45 use_bright: true,
46 use_background: false,
47 };
48
49 // Plot ROC curve with color
50 let roc_plot = roc.to_ascii_with_options(
51 Some("Binary Classification ROC Curve"),
52 60,
53 20,
54 &color_options,
55 );
56 println!("{}", roc_plot);
57
58 // Example 2: Learning Curve with color
59 println!("\nExample 2: Learning Curve Visualization (with color)\n");
60
61 // Simulate learning curves for different training set sizes
62 let train_sizes = Array1::from(vec![100, 200, 300, 400, 500]);
63
64 // Simulated training scores for each size (5 sizes, 3 CV folds)
65 let train_scores = Array2::from_shape_fn((5, 3), |(i, _j)| {
66 let base = 0.5 + 0.4 * (i as f64 / 4.0);
67 let noise = 0.05 * rng.sample::<f64, _>(StandardNormal);
68 base + noise
69 });
70
71 // Simulated validation scores (typically lower than training)
72 let val_scores = Array2::from_shape_fn((5, 3), |(i, _j)| {
73 let base = 0.4 + 0.3 * (i as f64 / 4.0);
74 let noise = 0.07 * rng.sample::<f64, _>(StandardNormal);
75 base + noise
76 });
77
78 // Create learning curve
79 let learning_curve = LearningCurve::new(train_sizes, train_scores, val_scores).unwrap();
80
81 // Plot learning curve with color
82 let learning_plot = learning_curve.to_ascii_with_options(
83 Some("Neural Network Training"),
84 70,
85 20,
86 "Accuracy",
87 &color_options,
88 );
89 println!("{}", learning_plot);
90}
Auto Trait Implementations§
impl<F> Freeze for LearningCurve<F>
impl<F> RefUnwindSafe for LearningCurve<F>where
F: RefUnwindSafe,
impl<F> Send for LearningCurve<F>where
F: Send,
impl<F> Sync for LearningCurve<F>where
F: Sync,
impl<F> Unpin for LearningCurve<F>
impl<F> UnwindSafe for LearningCurve<F>where
F: RefUnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more