pub struct ConfusionMatrix<F: Float + Debug + Display> {
pub matrix: Array2<F>,
pub labels: Option<Vec<String>>,
pub num_classes: usize,
}
Expand description
Confusion matrix for classification problems
Fields§
§matrix: Array2<F>
The raw confusion matrix data
labels: Option<Vec<String>>
Class labels (optional)
num_classes: usize
Number of classes
Implementations§
Source§impl<F: Float + Debug + Display> ConfusionMatrix<F>
impl<F: Float + Debug + Display> ConfusionMatrix<F>
Sourcepub fn new(
y_true: &ArrayView1<'_, usize>,
y_pred: &ArrayView1<'_, usize>,
num_classes: Option<usize>,
labels: Option<Vec<String>>,
) -> Result<Self>
pub fn new( y_true: &ArrayView1<'_, usize>, y_pred: &ArrayView1<'_, usize>, num_classes: Option<usize>, labels: Option<Vec<String>>, ) -> Result<Self>
Create a new confusion matrix from predictions and true labels
§Arguments
y_true
- True class labels as integersy_pred
- Predicted class labels as integersnum_classes
- Number of classes (if None, determined from data)labels
- Optional class labels as strings
§Returns
Result<ConfusionMatrix<F>>
- The confusion matrix
§Example
use scirs2_neural::utils::evaluation::ConfusionMatrix;
use ndarray::Array1;
let y_true = Array1::from_vec(vec![0, 1, 2, 0, 1, 2, 0]);
let y_pred = Array1::from_vec(vec![0, 1, 1, 0, 1, 2, 0]);
let cm = ConfusionMatrix::<f32>::new(&y_true.view(), &y_pred.view(), None, None).unwrap();
Examples found in repository?
7fn main() {
8 // Create a reproducible random number generator
9 let mut rng = SmallRng::seed_from_u64(42);
10
11 // Generate synthetic multiclass classification data
12 let num_classes = 5;
13 let n_samples = 500;
14
15 // Generate true labels (0 to num_classes-1)
16 let mut y_true = Vec::with_capacity(n_samples);
17 for _ in 0..n_samples {
18 y_true.push(rng.random_range(0..num_classes));
19 }
20
21 // Generate predicted labels with controlled accuracy
22 let mut y_pred = Vec::with_capacity(n_samples);
23 for &true_label in &y_true {
24 // 80% chance to predict correctly, 20% chance of error
25 if rng.random::<f64>() < 0.8 {
26 y_pred.push(true_label);
27 } else {
28 // When wrong, tend to predict adjacent classes more often
29 let mut pred = true_label;
30 while pred == true_label {
31 // Generate error that's more likely to be close to true label
32 let error_margin = (rng.random::<f64>() * 2.0).round() as usize; // 0, 1, or 2
33 if rng.random::<bool>() {
34 pred = (true_label + error_margin) % num_classes;
35 } else {
36 pred = (true_label + num_classes - error_margin) % num_classes;
37 }
38 }
39 y_pred.push(pred);
40 }
41 }
42
43 // Convert to ndarray arrays
44 let y_true_array = Array1::from(y_true);
45 let y_pred_array = Array1::from(y_pred);
46
47 // Create class labels
48 let class_labels = vec![
49 "Cat".to_string(),
50 "Dog".to_string(),
51 "Bird".to_string(),
52 "Fish".to_string(),
53 "Rabbit".to_string(),
54 ];
55
56 // Create confusion matrix
57 let cm = ConfusionMatrix::<f64>::new(
58 &y_true_array.view(),
59 &y_pred_array.view(),
60 Some(num_classes),
61 Some(class_labels),
62 )
63 .unwrap();
64
65 // Example 1: Standard confusion matrix
66 println!("Example 1: Standard Confusion Matrix\n");
67 let regular_output = cm.to_ascii(Some("Animal Classification Results"), false);
68 println!("{}", regular_output);
69
70 // Example 2: Confusion matrix with color
71 println!("\n\nExample 2: Colored Confusion Matrix\n");
72 let color_options = ColorOptions {
73 enabled: true,
74 use_bright: true,
75 use_background: false,
76 };
77 let colored_output = cm.to_ascii_with_options(
78 Some("Animal Classification Results (with color)"),
79 false,
80 &color_options,
81 );
82 println!("{}", colored_output);
83
84 // Example 3: Normalized confusion matrix heatmap
85 println!("\n\nExample 3: Normalized Confusion Matrix Heatmap\n");
86 let heatmap_output = cm.to_heatmap_with_options(
87 Some("Animal Classification Heatmap (normalized)"),
88 true, // normalized
89 &color_options,
90 );
91 println!("{}", heatmap_output);
92
93 // Example 4: Raw counts heatmap
94 println!("\n\nExample 4: Raw Counts Confusion Matrix Heatmap\n");
95 let raw_heatmap = cm.to_heatmap_with_options(
96 Some("Animal Classification Heatmap (raw counts)"),
97 false, // not normalized
98 &color_options,
99 );
100 println!("{}", raw_heatmap);
101}
More examples
7fn main() -> Result<()> {
8 println!("Neural Network Model Evaluation Visualization Example\n");
9
10 // Generate some example data
11 let n_samples = 500;
12 let n_features = 10;
13 let n_classes = 4;
14
15 println!(
16 "Generating {} samples with {} features for {} classes",
17 n_samples, n_features, n_classes
18 );
19
20 // 1. Confusion Matrix Example
21 println!("\n--- Confusion Matrix Visualization ---\n");
22
23 // Create a deterministic RNG for reproducibility
24 let mut rng = SmallRng::seed_from_u64(42);
25
26 // Generate random predictions and true labels
27 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
28
29 // Create slightly correlated predictions (not completely random)
30 let y_pred = Array::from_shape_fn(n_samples, |i| {
31 if rng.random::<f32>() < 0.7 {
32 // 70% chance of correct prediction
33 y_true[i]
34 } else {
35 // 30% chance of random class
36 rng.random_range(0..n_classes)
37 }
38 });
39
40 // Create confusion matrix
41 let class_labels = vec![
42 "Class A".to_string(),
43 "Class B".to_string(),
44 "Class C".to_string(),
45 "Class D".to_string(),
46 ];
47
48 let cm = ConfusionMatrix::<f32>::new(
49 &y_true.view(),
50 &y_pred.view(),
51 Some(n_classes),
52 Some(class_labels),
53 )?;
54
55 // Print raw and normalized confusion matrices
56 println!("Raw Confusion Matrix:\n");
57 println!("{}", cm.to_ascii(Some("Confusion Matrix"), false));
58
59 println!("\nNormalized Confusion Matrix:\n");
60 println!("{}", cm.to_ascii(Some("Normalized Confusion Matrix"), true));
61
62 // Print metrics
63 println!("\nAccuracy: {:.3}", cm.accuracy());
64
65 let precision = cm.precision();
66 let recall = cm.recall();
67 let f1 = cm.f1_score();
68
69 println!("Per-class metrics:");
70 for i in 0..n_classes {
71 println!(
72 " Class {}: Precision={:.3}, Recall={:.3}, F1={:.3}",
73 i, precision[i], recall[i], f1[i]
74 );
75 }
76
77 println!("Macro F1 Score: {:.3}", cm.macro_f1());
78
79 // 2. Feature Importance Visualization
80 println!("\n--- Feature Importance Visualization ---\n");
81
82 // Generate random feature importance scores
83 let feature_names = (0..n_features)
84 .map(|i| format!("Feature_{}", i))
85 .collect::<Vec<String>>();
86
87 let importance = Array1::from_shape_fn(n_features, |i| {
88 // Make some features more important than others
89 let base = (n_features - i) as f32 / n_features as f32;
90 base + 0.2 * rng.random::<f32>()
91 });
92
93 let fi = FeatureImportance::new(feature_names, importance)?;
94
95 // Print full feature importance
96 println!("{}", fi.to_ascii(Some("Feature Importance"), 60, None));
97
98 // Print top-5 features
99 println!("\nTop 5 Most Important Features:\n");
100 println!("{}", fi.to_ascii(Some("Top 5 Features"), 60, Some(5)));
101
102 // 3. ROC Curve for Binary Classification
103 println!("\n--- ROC Curve Visualization ---\n");
104
105 // Generate binary classification data
106 let n_binary = 200;
107 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
108
109 // Generate scores with some predictive power
110 let y_scores = Array1::from_shape_fn(n_binary, |i| {
111 if y_true_binary[i] == 1 {
112 // Higher scores for positive class
113 0.6 + 0.4 * rng.random::<f32>()
114 } else {
115 // Lower scores for negative class
116 0.4 * rng.random::<f32>()
117 }
118 });
119
120 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
121
122 println!("ROC AUC: {:.3}", roc.auc);
123 println!("\n{}", roc.to_ascii(None, 50, 20));
124
125 // 4. Learning Curve Visualization
126 println!("\n--- Learning Curve Visualization ---\n");
127
128 // Generate learning curve data
129 let n_points = 10;
130 let n_cv = 5;
131
132 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
133
134 // Generate training scores (decreasing with size due to overfitting)
135 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
136 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
137 });
138
139 // Generate validation scores (increasing with size)
140 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
141 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
142 });
143
144 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
145
146 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
147
148 // Print final message
149 println!("\nModel evaluation visualizations completed successfully!");
150
151 Ok(())
152}
10fn main() -> Result<()> {
11 println!(
12 "{}",
13 stylize("Neural Network Model Evaluation with Color", Style::Bold)
14 );
15 println!("{}", "-".repeat(50));
16
17 // Set up color options
18 let color_options = ColorOptions {
19 enabled: true,
20 use_background: false,
21 use_bright: true,
22 };
23
24 // Generate some example data
25 let n_samples = 500;
26 let n_features = 10;
27 let n_classes = 4;
28
29 println!(
30 "\n{} {} {} {} {} {}",
31 colorize("Generating", Color::BrightGreen),
32 colorize(n_samples.to_string(), Color::BrightYellow),
33 colorize("samples with", Color::BrightGreen),
34 colorize(n_features.to_string(), Color::BrightYellow),
35 colorize("features for", Color::BrightGreen),
36 colorize(n_classes.to_string(), Color::BrightYellow),
37 );
38
39 // Create a deterministic RNG for reproducibility
40 let mut rng = SmallRng::seed_from_u64(42);
41
42 // 1. Confusion Matrix Example
43 println!(
44 "\n{}",
45 stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
46 );
47
48 // Generate random predictions and true labels
49 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
50
51 // Create slightly correlated predictions (not completely random)
52 let y_pred = Array::from_shape_fn(n_samples, |i| {
53 if rng.random::<f32>() < 0.7 {
54 // 70% chance of correct prediction
55 y_true[i]
56 } else {
57 // 30% chance of random class
58 rng.random_range(0..n_classes)
59 }
60 });
61
62 // Create confusion matrix
63 let class_labels = vec![
64 "Class A".to_string(),
65 "Class B".to_string(),
66 "Class C".to_string(),
67 "Class D".to_string(),
68 ];
69
70 let cm = ConfusionMatrix::<f32>::new(
71 &y_true.view(),
72 &y_pred.view(),
73 Some(n_classes),
74 Some(class_labels),
75 )?;
76
77 // Print raw and normalized confusion matrices with color
78 println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
79 println!(
80 "{}",
81 cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
82 );
83
84 println!(
85 "\n{}",
86 colorize("Normalized Confusion Matrix:", Color::BrightCyan)
87 );
88 println!(
89 "{}",
90 cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
91 );
92
93 // Print metrics
94 println!(
95 "\n{} {:.3}",
96 colorize("Overall Accuracy:", Color::BrightMagenta),
97 cm.accuracy()
98 );
99
100 let precision = cm.precision();
101 let recall = cm.recall();
102 let f1 = cm.f1_score();
103
104 println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
105 for i in 0..n_classes {
106 println!(
107 " {}: {}={:.3}, {}={:.3}, {}={:.3}",
108 colorize(format!("Class {}", i), Color::BrightYellow),
109 colorize("Precision", Color::BrightCyan),
110 precision[i],
111 colorize("Recall", Color::BrightGreen),
112 recall[i],
113 colorize("F1", Color::BrightBlue),
114 f1[i]
115 );
116 }
117
118 println!(
119 "{} {:.3}",
120 colorize("Macro F1 Score:", Color::BrightMagenta),
121 cm.macro_f1()
122 );
123
124 // 2. Feature Importance Visualization
125 println!(
126 "\n{}",
127 stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
128 );
129
130 // Generate random feature importance scores
131 let feature_names = (0..n_features)
132 .map(|i| format!("Feature_{}", i))
133 .collect::<Vec<String>>();
134
135 let importance = Array1::from_shape_fn(n_features, |i| {
136 // Make some features more important than others
137 let base = (n_features - i) as f32 / n_features as f32;
138 base + 0.2 * rng.random::<f32>()
139 });
140
141 let fi = FeatureImportance::new(feature_names, importance)?;
142
143 // Print full feature importance with color
144 println!(
145 "{}",
146 fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
147 );
148
149 // Print top-5 features with color
150 println!(
151 "\n{}",
152 colorize("Top 5 Most Important Features:", Color::BrightCyan)
153 );
154 println!(
155 "{}",
156 fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
157 );
158
159 // 3. ROC Curve for Binary Classification
160 println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
161
162 // Generate binary classification data
163 let n_binary = 200;
164 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
165
166 // Generate scores with some predictive power
167 let y_scores = Array1::from_shape_fn(n_binary, |i| {
168 if y_true_binary[i] == 1 {
169 // Higher scores for positive class
170 0.6 + 0.4 * rng.random::<f32>()
171 } else {
172 // Lower scores for negative class
173 0.4 * rng.random::<f32>()
174 }
175 });
176
177 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
178
179 println!(
180 "{} {:.3}",
181 colorize("ROC AUC:", Color::BrightMagenta),
182 roc.auc
183 );
184
185 println!("\n{}", roc.to_ascii(None, 50, 20));
186
187 // 4. Learning Curve Visualization
188 println!(
189 "\n{}",
190 stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
191 );
192
193 // Generate learning curve data
194 let n_points = 10;
195 let n_cv = 5;
196
197 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
198
199 // Generate training scores (decreasing with size due to overfitting)
200 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
201 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
202 });
203
204 // Generate validation scores (increasing with size)
205 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
206 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
207 });
208
209 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
210
211 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
212
213 // Print final message with color
214 println!(
215 "\n{}",
216 colorize(
217 "Model evaluation visualizations completed successfully!",
218 Color::BrightGreen
219 )
220 );
221
222 Ok(())
223}
118fn main() -> Result<()> {
119 println!("Neural Network Confusion Matrix Visualization");
120 println!("==============================================\n");
121
122 // Initialize RNG with a fixed seed for reproducibility
123 let mut rng = SmallRng::seed_from_u64(42);
124
125 // Generate spiral dataset for 3-class classification
126 let n_classes = 3;
127 let n_samples_per_class = 100;
128 let noise = 0.15;
129
130 let (x, y) = generate_spiral_dataset(n_samples_per_class, n_classes, noise, &mut rng);
131 println!(
132 "Generated spiral dataset with {} classes, {} samples per class",
133 n_classes, n_samples_per_class
134 );
135
136 // Split data into training and test sets (80/20 split)
137 let n_samples = x.shape()[0];
138 let n_train = (n_samples as f32 * 0.8) as usize;
139 let n_test = n_samples - n_train;
140
141 let x_train = x.slice(ndarray::s![0..n_train, ..]).to_owned();
142 let y_train = y.slice(ndarray::s![0..n_train]).to_owned();
143 let x_test = x.slice(ndarray::s![n_train.., ..]).to_owned();
144 let y_test = y.slice(ndarray::s![n_train..]).to_owned();
145
146 println!(
147 "Split data into {} training and {} test samples",
148 n_train, n_test
149 );
150
151 // Create a classification model
152 let input_dim = 2; // 2D input (x, y coordinates)
153 let hidden_dim = 32; // Hidden layer size
154 let output_dim = n_classes; // One output per class
155
156 let mut model = create_classification_model(input_dim, hidden_dim, output_dim, &mut rng)?;
157 println!("Created model with {} layers", model.num_layers());
158
159 // Setup loss function and optimizer
160 let loss_fn = MeanSquaredError::new();
161 let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
162
163 // Train the model
164 let epochs = 100;
165 let x_train_dyn = x_train.clone().into_dyn();
166 let y_train_onehot = one_hot_encode(&y_train, n_classes);
167 let y_train_onehot_dyn = y_train_onehot.into_dyn();
168
169 // Create visualization callback for training metrics
170 let mut visualization_cb = VisualizationCallback::new(10) // Show every 10 epochs
171 .with_tracked_metrics(vec![
172 "train_loss".to_string(),
173 "val_accuracy".to_string(),
174 ]);
175
176 // Define class labels for confusion matrix
177 let class_labels = vec![
178 "Class A".to_string(),
179 "Class B".to_string(),
180 "Class C".to_string(),
181 ];
182
183 // Train the model (simple manual training loop)
184 println!("\nTraining model...");
185
186 // Initialize history for tracking metrics
187 let mut epoch_history = HashMap::new();
188 epoch_history.insert("train_loss".to_string(), Vec::new());
189 epoch_history.insert("val_accuracy".to_string(), Vec::new());
190
191 // Training loop
192 for epoch in 0..epochs {
193 // Train for one epoch
194 let train_loss =
195 model.train_batch(&x_train_dyn, &y_train_onehot_dyn, &loss_fn, &mut optimizer)?;
196
197 // Compute validation accuracy
198 let x_test_dyn = x_test.clone().into_dyn();
199 let predictions = model.forward(&x_test_dyn)?;
200 let predicted_classes = predictions_to_classes(&predictions);
201
202 // Calculate validation accuracy
203 let mut correct = 0;
204 for i in 0..n_test {
205 if predicted_classes[i] == y_test[i] {
206 correct += 1;
207 }
208 }
209 let val_accuracy = correct as f32 / n_test as f32;
210
211 // Store metrics
212 epoch_history
213 .get_mut("train_loss")
214 .unwrap()
215 .push(train_loss);
216 epoch_history
217 .get_mut("val_accuracy")
218 .unwrap()
219 .push(val_accuracy);
220
221 // Print progress
222 if (epoch + 1) % 10 == 0 || epoch == 0 {
223 println!(
224 "Epoch {}/{}: loss = {:.6}, val_accuracy = {:.4}",
225 epoch + 1,
226 epochs,
227 train_loss,
228 val_accuracy
229 );
230 }
231
232 // Update visualization callback
233 let mut context = CallbackContext {
234 epoch,
235 total_epochs: epochs,
236 batch: 0,
237 total_batches: 1,
238 batch_loss: None,
239 epoch_loss: Some(train_loss),
240 val_loss: None,
241 metrics: vec![val_accuracy],
242 history: &epoch_history,
243 stop_training: false,
244 model: None,
245 };
246
247 // Visualize progress with metrics chart
248 if epoch % 10 == 0 || epoch == epochs - 1 {
249 visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
250 }
251
252 // Calculate and show confusion matrix during training
253 if epoch % 20 == 0 || epoch == epochs - 1 {
254 // Create confusion matrix
255 let cm = ConfusionMatrix::<f32>::new(
256 &y_test.view(),
257 &predicted_classes.view(),
258 Some(n_classes),
259 Some(class_labels.clone()),
260 )?;
261
262 // Show heatmap visualization
263 println!("\nConfusion Matrix at Epoch {}:", epoch + 1);
264 println!(
265 "{}",
266 cm.to_heatmap(
267 Some(&format!("Confusion Matrix - Epoch {}", epoch + 1)),
268 true
269 )
270 );
271 }
272 }
273
274 // Final evaluation
275 println!("\nFinal model evaluation:");
276
277 // Make predictions on test set
278 let x_test_dyn = x_test.clone().into_dyn();
279 let predictions = model.forward(&x_test_dyn)?;
280 let predicted_classes = predictions_to_classes(&predictions);
281
282 // Create confusion matrix
283 let cm = ConfusionMatrix::<f32>::new(
284 &y_test.view(),
285 &predicted_classes.view(),
286 Some(n_classes),
287 Some(class_labels.clone()),
288 )?;
289
290 // Calculate and show metrics
291 let accuracy = cm.accuracy();
292 let precision = cm.precision();
293 let recall = cm.recall();
294 let f1 = cm.f1_score();
295
296 println!("\nFinal Classification Metrics:");
297 println!("Overall Accuracy: {:.4}", accuracy);
298
299 println!("\nPer-Class Metrics:");
300 println!("Class | Precision | Recall | F1-Score");
301 println!("-----------------------------------");
302
303 for i in 0..n_classes {
304 println!(
305 "{} | {:.4} | {:.4} | {:.4}",
306 class_labels[i], precision[i], recall[i], f1[i]
307 );
308 }
309
310 println!("\nMacro F1-Score: {:.4}", cm.macro_f1());
311
312 // Show different confusion matrix visualizations
313 println!("\nFinal Confusion Matrix Visualizations:");
314
315 // 1. Standard confusion matrix
316 println!("\n1. Standard Confusion Matrix:");
317 println!("{}", cm.to_ascii(Some("Final Confusion Matrix"), false));
318
319 // 2. Normalized confusion matrix
320 println!("\n2. Normalized Confusion Matrix:");
321 println!(
322 "{}",
323 cm.to_ascii(Some("Final Normalized Confusion Matrix"), true)
324 );
325
326 // 3. Confusion matrix heatmap
327 println!("\n3. Confusion Matrix Heatmap:");
328 println!(
329 "{}",
330 cm.to_heatmap(Some("Final Confusion Matrix Heatmap"), true)
331 );
332
333 // 4. Error pattern analysis
334 println!("\n4. Error Pattern Analysis:");
335 println!("{}", cm.error_heatmap(Some("Final Error Pattern Analysis")));
336
337 println!("\nNeural Network Confusion Matrix Visualization Complete!");
338 Ok(())
339}
Sourcepub fn from_matrix(
matrix: Array2<F>,
labels: Option<Vec<String>>,
) -> Result<Self>
pub fn from_matrix( matrix: Array2<F>, labels: Option<Vec<String>>, ) -> Result<Self>
Create a confusion matrix from raw matrix data
§Arguments
matrix
- Raw confusion matrix datalabels
- Optional class labels
§Returns
Result<ConfusionMatrix<F>>
- The confusion matrix
Examples found in repository?
6fn main() {
7 // Create a reproducible random number generator
8 let mut rng = SmallRng::seed_from_u64(42);
9
10 // Generate synthetic multiclass classification data with specific error patterns
11 let num_classes = 5;
12
13 // Create confusion matrix with controlled error patterns
14 let mut matrix = vec![vec![0; num_classes]; num_classes];
15
16 // Set diagonal elements (correct classifications) with high values
17 for i in 0..num_classes {
18 matrix[i][i] = 70 + rng.random_range(0..15); // 70-85 correct per class
19 }
20
21 // Create specific error patterns:
22 // - Classes 0 and 1 often confused
23 matrix[0][1] = 25;
24 matrix[1][0] = 15;
25
26 // - Class 2 sometimes confused with Class 3
27 matrix[2][3] = 18;
28
29 // - Class 4 has some misclassifications to all other classes
30 matrix[4][0] = 8;
31 matrix[4][1] = 5;
32 matrix[4][2] = 10;
33 matrix[4][3] = 12;
34
35 // - Some minor errors scattered about
36 for i in 0..num_classes {
37 for j in 0..num_classes {
38 if i != j && matrix[i][j] == 0 {
39 matrix[i][j] = rng.random_range(0..5);
40 }
41 }
42 }
43
44 // Convert to ndarray
45 let flat_matrix: Vec<f64> = matrix.iter().flatten().map(|&x| x as f64).collect();
46 let ndarray_matrix =
47 ndarray::Array::from_shape_vec((num_classes, num_classes), flat_matrix).unwrap();
48
49 // Create class labels
50 let class_labels = vec![
51 "Class A".to_string(),
52 "Class B".to_string(),
53 "Class C".to_string(),
54 "Class D".to_string(),
55 "Class E".to_string(),
56 ];
57
58 // Create confusion matrix
59 let cm = ConfusionMatrix::from_matrix(ndarray_matrix, Some(class_labels)).unwrap();
60
61 // Example 1: Standard confusion matrix
62 println!("Example 1: Standard Confusion Matrix\n");
63 let regular_output = cm.to_ascii(Some("Classification Results"), false);
64 println!("{}", regular_output);
65
66 // Example 2: Normal heatmap
67 println!("\n\nExample 2: Standard Heatmap Visualization\n");
68 let color_options = ColorOptions {
69 enabled: true,
70 use_bright: true,
71 use_background: false,
72 };
73 let heatmap_output = cm.to_heatmap_with_options(
74 Some("Classification Heatmap"),
75 true, // normalized
76 &color_options,
77 );
78 println!("{}", heatmap_output);
79
80 // Example 3: Error pattern heatmap
81 println!("\n\nExample 3: Error Pattern Heatmap (highlighting misclassifications)\n");
82 let error_heatmap = cm.error_heatmap(Some("Misclassification Analysis"));
83 println!("{}", error_heatmap);
84}
Sourcepub fn normalized(&self) -> Array2<F>
pub fn normalized(&self) -> Array2<F>
Get the normalized confusion matrix (rows sum to 1)
Sourcepub fn accuracy(&self) -> F
pub fn accuracy(&self) -> F
Calculate accuracy from the confusion matrix
Examples found in repository?
7fn main() -> Result<()> {
8 println!("Neural Network Model Evaluation Visualization Example\n");
9
10 // Generate some example data
11 let n_samples = 500;
12 let n_features = 10;
13 let n_classes = 4;
14
15 println!(
16 "Generating {} samples with {} features for {} classes",
17 n_samples, n_features, n_classes
18 );
19
20 // 1. Confusion Matrix Example
21 println!("\n--- Confusion Matrix Visualization ---\n");
22
23 // Create a deterministic RNG for reproducibility
24 let mut rng = SmallRng::seed_from_u64(42);
25
26 // Generate random predictions and true labels
27 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
28
29 // Create slightly correlated predictions (not completely random)
30 let y_pred = Array::from_shape_fn(n_samples, |i| {
31 if rng.random::<f32>() < 0.7 {
32 // 70% chance of correct prediction
33 y_true[i]
34 } else {
35 // 30% chance of random class
36 rng.random_range(0..n_classes)
37 }
38 });
39
40 // Create confusion matrix
41 let class_labels = vec![
42 "Class A".to_string(),
43 "Class B".to_string(),
44 "Class C".to_string(),
45 "Class D".to_string(),
46 ];
47
48 let cm = ConfusionMatrix::<f32>::new(
49 &y_true.view(),
50 &y_pred.view(),
51 Some(n_classes),
52 Some(class_labels),
53 )?;
54
55 // Print raw and normalized confusion matrices
56 println!("Raw Confusion Matrix:\n");
57 println!("{}", cm.to_ascii(Some("Confusion Matrix"), false));
58
59 println!("\nNormalized Confusion Matrix:\n");
60 println!("{}", cm.to_ascii(Some("Normalized Confusion Matrix"), true));
61
62 // Print metrics
63 println!("\nAccuracy: {:.3}", cm.accuracy());
64
65 let precision = cm.precision();
66 let recall = cm.recall();
67 let f1 = cm.f1_score();
68
69 println!("Per-class metrics:");
70 for i in 0..n_classes {
71 println!(
72 " Class {}: Precision={:.3}, Recall={:.3}, F1={:.3}",
73 i, precision[i], recall[i], f1[i]
74 );
75 }
76
77 println!("Macro F1 Score: {:.3}", cm.macro_f1());
78
79 // 2. Feature Importance Visualization
80 println!("\n--- Feature Importance Visualization ---\n");
81
82 // Generate random feature importance scores
83 let feature_names = (0..n_features)
84 .map(|i| format!("Feature_{}", i))
85 .collect::<Vec<String>>();
86
87 let importance = Array1::from_shape_fn(n_features, |i| {
88 // Make some features more important than others
89 let base = (n_features - i) as f32 / n_features as f32;
90 base + 0.2 * rng.random::<f32>()
91 });
92
93 let fi = FeatureImportance::new(feature_names, importance)?;
94
95 // Print full feature importance
96 println!("{}", fi.to_ascii(Some("Feature Importance"), 60, None));
97
98 // Print top-5 features
99 println!("\nTop 5 Most Important Features:\n");
100 println!("{}", fi.to_ascii(Some("Top 5 Features"), 60, Some(5)));
101
102 // 3. ROC Curve for Binary Classification
103 println!("\n--- ROC Curve Visualization ---\n");
104
105 // Generate binary classification data
106 let n_binary = 200;
107 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
108
109 // Generate scores with some predictive power
110 let y_scores = Array1::from_shape_fn(n_binary, |i| {
111 if y_true_binary[i] == 1 {
112 // Higher scores for positive class
113 0.6 + 0.4 * rng.random::<f32>()
114 } else {
115 // Lower scores for negative class
116 0.4 * rng.random::<f32>()
117 }
118 });
119
120 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
121
122 println!("ROC AUC: {:.3}", roc.auc);
123 println!("\n{}", roc.to_ascii(None, 50, 20));
124
125 // 4. Learning Curve Visualization
126 println!("\n--- Learning Curve Visualization ---\n");
127
128 // Generate learning curve data
129 let n_points = 10;
130 let n_cv = 5;
131
132 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
133
134 // Generate training scores (decreasing with size due to overfitting)
135 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
136 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
137 });
138
139 // Generate validation scores (increasing with size)
140 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
141 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
142 });
143
144 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
145
146 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
147
148 // Print final message
149 println!("\nModel evaluation visualizations completed successfully!");
150
151 Ok(())
152}
More examples
10fn main() -> Result<()> {
11 println!(
12 "{}",
13 stylize("Neural Network Model Evaluation with Color", Style::Bold)
14 );
15 println!("{}", "-".repeat(50));
16
17 // Set up color options
18 let color_options = ColorOptions {
19 enabled: true,
20 use_background: false,
21 use_bright: true,
22 };
23
24 // Generate some example data
25 let n_samples = 500;
26 let n_features = 10;
27 let n_classes = 4;
28
29 println!(
30 "\n{} {} {} {} {} {}",
31 colorize("Generating", Color::BrightGreen),
32 colorize(n_samples.to_string(), Color::BrightYellow),
33 colorize("samples with", Color::BrightGreen),
34 colorize(n_features.to_string(), Color::BrightYellow),
35 colorize("features for", Color::BrightGreen),
36 colorize(n_classes.to_string(), Color::BrightYellow),
37 );
38
39 // Create a deterministic RNG for reproducibility
40 let mut rng = SmallRng::seed_from_u64(42);
41
42 // 1. Confusion Matrix Example
43 println!(
44 "\n{}",
45 stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
46 );
47
48 // Generate random predictions and true labels
49 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
50
51 // Create slightly correlated predictions (not completely random)
52 let y_pred = Array::from_shape_fn(n_samples, |i| {
53 if rng.random::<f32>() < 0.7 {
54 // 70% chance of correct prediction
55 y_true[i]
56 } else {
57 // 30% chance of random class
58 rng.random_range(0..n_classes)
59 }
60 });
61
62 // Create confusion matrix
63 let class_labels = vec![
64 "Class A".to_string(),
65 "Class B".to_string(),
66 "Class C".to_string(),
67 "Class D".to_string(),
68 ];
69
70 let cm = ConfusionMatrix::<f32>::new(
71 &y_true.view(),
72 &y_pred.view(),
73 Some(n_classes),
74 Some(class_labels),
75 )?;
76
77 // Print raw and normalized confusion matrices with color
78 println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
79 println!(
80 "{}",
81 cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
82 );
83
84 println!(
85 "\n{}",
86 colorize("Normalized Confusion Matrix:", Color::BrightCyan)
87 );
88 println!(
89 "{}",
90 cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
91 );
92
93 // Print metrics
94 println!(
95 "\n{} {:.3}",
96 colorize("Overall Accuracy:", Color::BrightMagenta),
97 cm.accuracy()
98 );
99
100 let precision = cm.precision();
101 let recall = cm.recall();
102 let f1 = cm.f1_score();
103
104 println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
105 for i in 0..n_classes {
106 println!(
107 " {}: {}={:.3}, {}={:.3}, {}={:.3}",
108 colorize(format!("Class {}", i), Color::BrightYellow),
109 colorize("Precision", Color::BrightCyan),
110 precision[i],
111 colorize("Recall", Color::BrightGreen),
112 recall[i],
113 colorize("F1", Color::BrightBlue),
114 f1[i]
115 );
116 }
117
118 println!(
119 "{} {:.3}",
120 colorize("Macro F1 Score:", Color::BrightMagenta),
121 cm.macro_f1()
122 );
123
124 // 2. Feature Importance Visualization
125 println!(
126 "\n{}",
127 stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
128 );
129
130 // Generate random feature importance scores
131 let feature_names = (0..n_features)
132 .map(|i| format!("Feature_{}", i))
133 .collect::<Vec<String>>();
134
135 let importance = Array1::from_shape_fn(n_features, |i| {
136 // Make some features more important than others
137 let base = (n_features - i) as f32 / n_features as f32;
138 base + 0.2 * rng.random::<f32>()
139 });
140
141 let fi = FeatureImportance::new(feature_names, importance)?;
142
143 // Print full feature importance with color
144 println!(
145 "{}",
146 fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
147 );
148
149 // Print top-5 features with color
150 println!(
151 "\n{}",
152 colorize("Top 5 Most Important Features:", Color::BrightCyan)
153 );
154 println!(
155 "{}",
156 fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
157 );
158
159 // 3. ROC Curve for Binary Classification
160 println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
161
162 // Generate binary classification data
163 let n_binary = 200;
164 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
165
166 // Generate scores with some predictive power
167 let y_scores = Array1::from_shape_fn(n_binary, |i| {
168 if y_true_binary[i] == 1 {
169 // Higher scores for positive class
170 0.6 + 0.4 * rng.random::<f32>()
171 } else {
172 // Lower scores for negative class
173 0.4 * rng.random::<f32>()
174 }
175 });
176
177 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
178
179 println!(
180 "{} {:.3}",
181 colorize("ROC AUC:", Color::BrightMagenta),
182 roc.auc
183 );
184
185 println!("\n{}", roc.to_ascii(None, 50, 20));
186
187 // 4. Learning Curve Visualization
188 println!(
189 "\n{}",
190 stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
191 );
192
193 // Generate learning curve data
194 let n_points = 10;
195 let n_cv = 5;
196
197 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
198
199 // Generate training scores (decreasing with size due to overfitting)
200 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
201 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
202 });
203
204 // Generate validation scores (increasing with size)
205 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
206 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
207 });
208
209 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
210
211 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
212
213 // Print final message with color
214 println!(
215 "\n{}",
216 colorize(
217 "Model evaluation visualizations completed successfully!",
218 Color::BrightGreen
219 )
220 );
221
222 Ok(())
223}
118fn main() -> Result<()> {
119 println!("Neural Network Confusion Matrix Visualization");
120 println!("==============================================\n");
121
122 // Initialize RNG with a fixed seed for reproducibility
123 let mut rng = SmallRng::seed_from_u64(42);
124
125 // Generate spiral dataset for 3-class classification
126 let n_classes = 3;
127 let n_samples_per_class = 100;
128 let noise = 0.15;
129
130 let (x, y) = generate_spiral_dataset(n_samples_per_class, n_classes, noise, &mut rng);
131 println!(
132 "Generated spiral dataset with {} classes, {} samples per class",
133 n_classes, n_samples_per_class
134 );
135
136 // Split data into training and test sets (80/20 split)
137 let n_samples = x.shape()[0];
138 let n_train = (n_samples as f32 * 0.8) as usize;
139 let n_test = n_samples - n_train;
140
141 let x_train = x.slice(ndarray::s![0..n_train, ..]).to_owned();
142 let y_train = y.slice(ndarray::s![0..n_train]).to_owned();
143 let x_test = x.slice(ndarray::s![n_train.., ..]).to_owned();
144 let y_test = y.slice(ndarray::s![n_train..]).to_owned();
145
146 println!(
147 "Split data into {} training and {} test samples",
148 n_train, n_test
149 );
150
151 // Create a classification model
152 let input_dim = 2; // 2D input (x, y coordinates)
153 let hidden_dim = 32; // Hidden layer size
154 let output_dim = n_classes; // One output per class
155
156 let mut model = create_classification_model(input_dim, hidden_dim, output_dim, &mut rng)?;
157 println!("Created model with {} layers", model.num_layers());
158
159 // Setup loss function and optimizer
160 let loss_fn = MeanSquaredError::new();
161 let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
162
163 // Train the model
164 let epochs = 100;
165 let x_train_dyn = x_train.clone().into_dyn();
166 let y_train_onehot = one_hot_encode(&y_train, n_classes);
167 let y_train_onehot_dyn = y_train_onehot.into_dyn();
168
169 // Create visualization callback for training metrics
170 let mut visualization_cb = VisualizationCallback::new(10) // Show every 10 epochs
171 .with_tracked_metrics(vec![
172 "train_loss".to_string(),
173 "val_accuracy".to_string(),
174 ]);
175
176 // Define class labels for confusion matrix
177 let class_labels = vec![
178 "Class A".to_string(),
179 "Class B".to_string(),
180 "Class C".to_string(),
181 ];
182
183 // Train the model (simple manual training loop)
184 println!("\nTraining model...");
185
186 // Initialize history for tracking metrics
187 let mut epoch_history = HashMap::new();
188 epoch_history.insert("train_loss".to_string(), Vec::new());
189 epoch_history.insert("val_accuracy".to_string(), Vec::new());
190
191 // Training loop
192 for epoch in 0..epochs {
193 // Train for one epoch
194 let train_loss =
195 model.train_batch(&x_train_dyn, &y_train_onehot_dyn, &loss_fn, &mut optimizer)?;
196
197 // Compute validation accuracy
198 let x_test_dyn = x_test.clone().into_dyn();
199 let predictions = model.forward(&x_test_dyn)?;
200 let predicted_classes = predictions_to_classes(&predictions);
201
202 // Calculate validation accuracy
203 let mut correct = 0;
204 for i in 0..n_test {
205 if predicted_classes[i] == y_test[i] {
206 correct += 1;
207 }
208 }
209 let val_accuracy = correct as f32 / n_test as f32;
210
211 // Store metrics
212 epoch_history
213 .get_mut("train_loss")
214 .unwrap()
215 .push(train_loss);
216 epoch_history
217 .get_mut("val_accuracy")
218 .unwrap()
219 .push(val_accuracy);
220
221 // Print progress
222 if (epoch + 1) % 10 == 0 || epoch == 0 {
223 println!(
224 "Epoch {}/{}: loss = {:.6}, val_accuracy = {:.4}",
225 epoch + 1,
226 epochs,
227 train_loss,
228 val_accuracy
229 );
230 }
231
232 // Update visualization callback
233 let mut context = CallbackContext {
234 epoch,
235 total_epochs: epochs,
236 batch: 0,
237 total_batches: 1,
238 batch_loss: None,
239 epoch_loss: Some(train_loss),
240 val_loss: None,
241 metrics: vec![val_accuracy],
242 history: &epoch_history,
243 stop_training: false,
244 model: None,
245 };
246
247 // Visualize progress with metrics chart
248 if epoch % 10 == 0 || epoch == epochs - 1 {
249 visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
250 }
251
252 // Calculate and show confusion matrix during training
253 if epoch % 20 == 0 || epoch == epochs - 1 {
254 // Create confusion matrix
255 let cm = ConfusionMatrix::<f32>::new(
256 &y_test.view(),
257 &predicted_classes.view(),
258 Some(n_classes),
259 Some(class_labels.clone()),
260 )?;
261
262 // Show heatmap visualization
263 println!("\nConfusion Matrix at Epoch {}:", epoch + 1);
264 println!(
265 "{}",
266 cm.to_heatmap(
267 Some(&format!("Confusion Matrix - Epoch {}", epoch + 1)),
268 true
269 )
270 );
271 }
272 }
273
274 // Final evaluation
275 println!("\nFinal model evaluation:");
276
277 // Make predictions on test set
278 let x_test_dyn = x_test.clone().into_dyn();
279 let predictions = model.forward(&x_test_dyn)?;
280 let predicted_classes = predictions_to_classes(&predictions);
281
282 // Create confusion matrix
283 let cm = ConfusionMatrix::<f32>::new(
284 &y_test.view(),
285 &predicted_classes.view(),
286 Some(n_classes),
287 Some(class_labels.clone()),
288 )?;
289
290 // Calculate and show metrics
291 let accuracy = cm.accuracy();
292 let precision = cm.precision();
293 let recall = cm.recall();
294 let f1 = cm.f1_score();
295
296 println!("\nFinal Classification Metrics:");
297 println!("Overall Accuracy: {:.4}", accuracy);
298
299 println!("\nPer-Class Metrics:");
300 println!("Class | Precision | Recall | F1-Score");
301 println!("-----------------------------------");
302
303 for i in 0..n_classes {
304 println!(
305 "{} | {:.4} | {:.4} | {:.4}",
306 class_labels[i], precision[i], recall[i], f1[i]
307 );
308 }
309
310 println!("\nMacro F1-Score: {:.4}", cm.macro_f1());
311
312 // Show different confusion matrix visualizations
313 println!("\nFinal Confusion Matrix Visualizations:");
314
315 // 1. Standard confusion matrix
316 println!("\n1. Standard Confusion Matrix:");
317 println!("{}", cm.to_ascii(Some("Final Confusion Matrix"), false));
318
319 // 2. Normalized confusion matrix
320 println!("\n2. Normalized Confusion Matrix:");
321 println!(
322 "{}",
323 cm.to_ascii(Some("Final Normalized Confusion Matrix"), true)
324 );
325
326 // 3. Confusion matrix heatmap
327 println!("\n3. Confusion Matrix Heatmap:");
328 println!(
329 "{}",
330 cm.to_heatmap(Some("Final Confusion Matrix Heatmap"), true)
331 );
332
333 // 4. Error pattern analysis
334 println!("\n4. Error Pattern Analysis:");
335 println!("{}", cm.error_heatmap(Some("Final Error Pattern Analysis")));
336
337 println!("\nNeural Network Confusion Matrix Visualization Complete!");
338 Ok(())
339}
Sourcepub fn precision(&self) -> Array1<F>
pub fn precision(&self) -> Array1<F>
Calculate precision for each class
Examples found in repository?
7fn main() -> Result<()> {
8 println!("Neural Network Model Evaluation Visualization Example\n");
9
10 // Generate some example data
11 let n_samples = 500;
12 let n_features = 10;
13 let n_classes = 4;
14
15 println!(
16 "Generating {} samples with {} features for {} classes",
17 n_samples, n_features, n_classes
18 );
19
20 // 1. Confusion Matrix Example
21 println!("\n--- Confusion Matrix Visualization ---\n");
22
23 // Create a deterministic RNG for reproducibility
24 let mut rng = SmallRng::seed_from_u64(42);
25
26 // Generate random predictions and true labels
27 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
28
29 // Create slightly correlated predictions (not completely random)
30 let y_pred = Array::from_shape_fn(n_samples, |i| {
31 if rng.random::<f32>() < 0.7 {
32 // 70% chance of correct prediction
33 y_true[i]
34 } else {
35 // 30% chance of random class
36 rng.random_range(0..n_classes)
37 }
38 });
39
40 // Create confusion matrix
41 let class_labels = vec![
42 "Class A".to_string(),
43 "Class B".to_string(),
44 "Class C".to_string(),
45 "Class D".to_string(),
46 ];
47
48 let cm = ConfusionMatrix::<f32>::new(
49 &y_true.view(),
50 &y_pred.view(),
51 Some(n_classes),
52 Some(class_labels),
53 )?;
54
55 // Print raw and normalized confusion matrices
56 println!("Raw Confusion Matrix:\n");
57 println!("{}", cm.to_ascii(Some("Confusion Matrix"), false));
58
59 println!("\nNormalized Confusion Matrix:\n");
60 println!("{}", cm.to_ascii(Some("Normalized Confusion Matrix"), true));
61
62 // Print metrics
63 println!("\nAccuracy: {:.3}", cm.accuracy());
64
65 let precision = cm.precision();
66 let recall = cm.recall();
67 let f1 = cm.f1_score();
68
69 println!("Per-class metrics:");
70 for i in 0..n_classes {
71 println!(
72 " Class {}: Precision={:.3}, Recall={:.3}, F1={:.3}",
73 i, precision[i], recall[i], f1[i]
74 );
75 }
76
77 println!("Macro F1 Score: {:.3}", cm.macro_f1());
78
79 // 2. Feature Importance Visualization
80 println!("\n--- Feature Importance Visualization ---\n");
81
82 // Generate random feature importance scores
83 let feature_names = (0..n_features)
84 .map(|i| format!("Feature_{}", i))
85 .collect::<Vec<String>>();
86
87 let importance = Array1::from_shape_fn(n_features, |i| {
88 // Make some features more important than others
89 let base = (n_features - i) as f32 / n_features as f32;
90 base + 0.2 * rng.random::<f32>()
91 });
92
93 let fi = FeatureImportance::new(feature_names, importance)?;
94
95 // Print full feature importance
96 println!("{}", fi.to_ascii(Some("Feature Importance"), 60, None));
97
98 // Print top-5 features
99 println!("\nTop 5 Most Important Features:\n");
100 println!("{}", fi.to_ascii(Some("Top 5 Features"), 60, Some(5)));
101
102 // 3. ROC Curve for Binary Classification
103 println!("\n--- ROC Curve Visualization ---\n");
104
105 // Generate binary classification data
106 let n_binary = 200;
107 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
108
109 // Generate scores with some predictive power
110 let y_scores = Array1::from_shape_fn(n_binary, |i| {
111 if y_true_binary[i] == 1 {
112 // Higher scores for positive class
113 0.6 + 0.4 * rng.random::<f32>()
114 } else {
115 // Lower scores for negative class
116 0.4 * rng.random::<f32>()
117 }
118 });
119
120 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
121
122 println!("ROC AUC: {:.3}", roc.auc);
123 println!("\n{}", roc.to_ascii(None, 50, 20));
124
125 // 4. Learning Curve Visualization
126 println!("\n--- Learning Curve Visualization ---\n");
127
128 // Generate learning curve data
129 let n_points = 10;
130 let n_cv = 5;
131
132 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
133
134 // Generate training scores (decreasing with size due to overfitting)
135 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
136 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
137 });
138
139 // Generate validation scores (increasing with size)
140 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
141 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
142 });
143
144 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
145
146 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
147
148 // Print final message
149 println!("\nModel evaluation visualizations completed successfully!");
150
151 Ok(())
152}
More examples
10fn main() -> Result<()> {
11 println!(
12 "{}",
13 stylize("Neural Network Model Evaluation with Color", Style::Bold)
14 );
15 println!("{}", "-".repeat(50));
16
17 // Set up color options
18 let color_options = ColorOptions {
19 enabled: true,
20 use_background: false,
21 use_bright: true,
22 };
23
24 // Generate some example data
25 let n_samples = 500;
26 let n_features = 10;
27 let n_classes = 4;
28
29 println!(
30 "\n{} {} {} {} {} {}",
31 colorize("Generating", Color::BrightGreen),
32 colorize(n_samples.to_string(), Color::BrightYellow),
33 colorize("samples with", Color::BrightGreen),
34 colorize(n_features.to_string(), Color::BrightYellow),
35 colorize("features for", Color::BrightGreen),
36 colorize(n_classes.to_string(), Color::BrightYellow),
37 );
38
39 // Create a deterministic RNG for reproducibility
40 let mut rng = SmallRng::seed_from_u64(42);
41
42 // 1. Confusion Matrix Example
43 println!(
44 "\n{}",
45 stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
46 );
47
48 // Generate random predictions and true labels
49 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
50
51 // Create slightly correlated predictions (not completely random)
52 let y_pred = Array::from_shape_fn(n_samples, |i| {
53 if rng.random::<f32>() < 0.7 {
54 // 70% chance of correct prediction
55 y_true[i]
56 } else {
57 // 30% chance of random class
58 rng.random_range(0..n_classes)
59 }
60 });
61
62 // Create confusion matrix
63 let class_labels = vec![
64 "Class A".to_string(),
65 "Class B".to_string(),
66 "Class C".to_string(),
67 "Class D".to_string(),
68 ];
69
70 let cm = ConfusionMatrix::<f32>::new(
71 &y_true.view(),
72 &y_pred.view(),
73 Some(n_classes),
74 Some(class_labels),
75 )?;
76
77 // Print raw and normalized confusion matrices with color
78 println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
79 println!(
80 "{}",
81 cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
82 );
83
84 println!(
85 "\n{}",
86 colorize("Normalized Confusion Matrix:", Color::BrightCyan)
87 );
88 println!(
89 "{}",
90 cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
91 );
92
93 // Print metrics
94 println!(
95 "\n{} {:.3}",
96 colorize("Overall Accuracy:", Color::BrightMagenta),
97 cm.accuracy()
98 );
99
100 let precision = cm.precision();
101 let recall = cm.recall();
102 let f1 = cm.f1_score();
103
104 println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
105 for i in 0..n_classes {
106 println!(
107 " {}: {}={:.3}, {}={:.3}, {}={:.3}",
108 colorize(format!("Class {}", i), Color::BrightYellow),
109 colorize("Precision", Color::BrightCyan),
110 precision[i],
111 colorize("Recall", Color::BrightGreen),
112 recall[i],
113 colorize("F1", Color::BrightBlue),
114 f1[i]
115 );
116 }
117
118 println!(
119 "{} {:.3}",
120 colorize("Macro F1 Score:", Color::BrightMagenta),
121 cm.macro_f1()
122 );
123
124 // 2. Feature Importance Visualization
125 println!(
126 "\n{}",
127 stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
128 );
129
130 // Generate random feature importance scores
131 let feature_names = (0..n_features)
132 .map(|i| format!("Feature_{}", i))
133 .collect::<Vec<String>>();
134
135 let importance = Array1::from_shape_fn(n_features, |i| {
136 // Make some features more important than others
137 let base = (n_features - i) as f32 / n_features as f32;
138 base + 0.2 * rng.random::<f32>()
139 });
140
141 let fi = FeatureImportance::new(feature_names, importance)?;
142
143 // Print full feature importance with color
144 println!(
145 "{}",
146 fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
147 );
148
149 // Print top-5 features with color
150 println!(
151 "\n{}",
152 colorize("Top 5 Most Important Features:", Color::BrightCyan)
153 );
154 println!(
155 "{}",
156 fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
157 );
158
159 // 3. ROC Curve for Binary Classification
160 println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
161
162 // Generate binary classification data
163 let n_binary = 200;
164 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
165
166 // Generate scores with some predictive power
167 let y_scores = Array1::from_shape_fn(n_binary, |i| {
168 if y_true_binary[i] == 1 {
169 // Higher scores for positive class
170 0.6 + 0.4 * rng.random::<f32>()
171 } else {
172 // Lower scores for negative class
173 0.4 * rng.random::<f32>()
174 }
175 });
176
177 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
178
179 println!(
180 "{} {:.3}",
181 colorize("ROC AUC:", Color::BrightMagenta),
182 roc.auc
183 );
184
185 println!("\n{}", roc.to_ascii(None, 50, 20));
186
187 // 4. Learning Curve Visualization
188 println!(
189 "\n{}",
190 stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
191 );
192
193 // Generate learning curve data
194 let n_points = 10;
195 let n_cv = 5;
196
197 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
198
199 // Generate training scores (decreasing with size due to overfitting)
200 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
201 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
202 });
203
204 // Generate validation scores (increasing with size)
205 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
206 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
207 });
208
209 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
210
211 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
212
213 // Print final message with color
214 println!(
215 "\n{}",
216 colorize(
217 "Model evaluation visualizations completed successfully!",
218 Color::BrightGreen
219 )
220 );
221
222 Ok(())
223}
118fn main() -> Result<()> {
119 println!("Neural Network Confusion Matrix Visualization");
120 println!("==============================================\n");
121
122 // Initialize RNG with a fixed seed for reproducibility
123 let mut rng = SmallRng::seed_from_u64(42);
124
125 // Generate spiral dataset for 3-class classification
126 let n_classes = 3;
127 let n_samples_per_class = 100;
128 let noise = 0.15;
129
130 let (x, y) = generate_spiral_dataset(n_samples_per_class, n_classes, noise, &mut rng);
131 println!(
132 "Generated spiral dataset with {} classes, {} samples per class",
133 n_classes, n_samples_per_class
134 );
135
136 // Split data into training and test sets (80/20 split)
137 let n_samples = x.shape()[0];
138 let n_train = (n_samples as f32 * 0.8) as usize;
139 let n_test = n_samples - n_train;
140
141 let x_train = x.slice(ndarray::s![0..n_train, ..]).to_owned();
142 let y_train = y.slice(ndarray::s![0..n_train]).to_owned();
143 let x_test = x.slice(ndarray::s![n_train.., ..]).to_owned();
144 let y_test = y.slice(ndarray::s![n_train..]).to_owned();
145
146 println!(
147 "Split data into {} training and {} test samples",
148 n_train, n_test
149 );
150
151 // Create a classification model
152 let input_dim = 2; // 2D input (x, y coordinates)
153 let hidden_dim = 32; // Hidden layer size
154 let output_dim = n_classes; // One output per class
155
156 let mut model = create_classification_model(input_dim, hidden_dim, output_dim, &mut rng)?;
157 println!("Created model with {} layers", model.num_layers());
158
159 // Setup loss function and optimizer
160 let loss_fn = MeanSquaredError::new();
161 let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
162
163 // Train the model
164 let epochs = 100;
165 let x_train_dyn = x_train.clone().into_dyn();
166 let y_train_onehot = one_hot_encode(&y_train, n_classes);
167 let y_train_onehot_dyn = y_train_onehot.into_dyn();
168
169 // Create visualization callback for training metrics
170 let mut visualization_cb = VisualizationCallback::new(10) // Show every 10 epochs
171 .with_tracked_metrics(vec![
172 "train_loss".to_string(),
173 "val_accuracy".to_string(),
174 ]);
175
176 // Define class labels for confusion matrix
177 let class_labels = vec![
178 "Class A".to_string(),
179 "Class B".to_string(),
180 "Class C".to_string(),
181 ];
182
183 // Train the model (simple manual training loop)
184 println!("\nTraining model...");
185
186 // Initialize history for tracking metrics
187 let mut epoch_history = HashMap::new();
188 epoch_history.insert("train_loss".to_string(), Vec::new());
189 epoch_history.insert("val_accuracy".to_string(), Vec::new());
190
191 // Training loop
192 for epoch in 0..epochs {
193 // Train for one epoch
194 let train_loss =
195 model.train_batch(&x_train_dyn, &y_train_onehot_dyn, &loss_fn, &mut optimizer)?;
196
197 // Compute validation accuracy
198 let x_test_dyn = x_test.clone().into_dyn();
199 let predictions = model.forward(&x_test_dyn)?;
200 let predicted_classes = predictions_to_classes(&predictions);
201
202 // Calculate validation accuracy
203 let mut correct = 0;
204 for i in 0..n_test {
205 if predicted_classes[i] == y_test[i] {
206 correct += 1;
207 }
208 }
209 let val_accuracy = correct as f32 / n_test as f32;
210
211 // Store metrics
212 epoch_history
213 .get_mut("train_loss")
214 .unwrap()
215 .push(train_loss);
216 epoch_history
217 .get_mut("val_accuracy")
218 .unwrap()
219 .push(val_accuracy);
220
221 // Print progress
222 if (epoch + 1) % 10 == 0 || epoch == 0 {
223 println!(
224 "Epoch {}/{}: loss = {:.6}, val_accuracy = {:.4}",
225 epoch + 1,
226 epochs,
227 train_loss,
228 val_accuracy
229 );
230 }
231
232 // Update visualization callback
233 let mut context = CallbackContext {
234 epoch,
235 total_epochs: epochs,
236 batch: 0,
237 total_batches: 1,
238 batch_loss: None,
239 epoch_loss: Some(train_loss),
240 val_loss: None,
241 metrics: vec![val_accuracy],
242 history: &epoch_history,
243 stop_training: false,
244 model: None,
245 };
246
247 // Visualize progress with metrics chart
248 if epoch % 10 == 0 || epoch == epochs - 1 {
249 visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
250 }
251
252 // Calculate and show confusion matrix during training
253 if epoch % 20 == 0 || epoch == epochs - 1 {
254 // Create confusion matrix
255 let cm = ConfusionMatrix::<f32>::new(
256 &y_test.view(),
257 &predicted_classes.view(),
258 Some(n_classes),
259 Some(class_labels.clone()),
260 )?;
261
262 // Show heatmap visualization
263 println!("\nConfusion Matrix at Epoch {}:", epoch + 1);
264 println!(
265 "{}",
266 cm.to_heatmap(
267 Some(&format!("Confusion Matrix - Epoch {}", epoch + 1)),
268 true
269 )
270 );
271 }
272 }
273
274 // Final evaluation
275 println!("\nFinal model evaluation:");
276
277 // Make predictions on test set
278 let x_test_dyn = x_test.clone().into_dyn();
279 let predictions = model.forward(&x_test_dyn)?;
280 let predicted_classes = predictions_to_classes(&predictions);
281
282 // Create confusion matrix
283 let cm = ConfusionMatrix::<f32>::new(
284 &y_test.view(),
285 &predicted_classes.view(),
286 Some(n_classes),
287 Some(class_labels.clone()),
288 )?;
289
290 // Calculate and show metrics
291 let accuracy = cm.accuracy();
292 let precision = cm.precision();
293 let recall = cm.recall();
294 let f1 = cm.f1_score();
295
296 println!("\nFinal Classification Metrics:");
297 println!("Overall Accuracy: {:.4}", accuracy);
298
299 println!("\nPer-Class Metrics:");
300 println!("Class | Precision | Recall | F1-Score");
301 println!("-----------------------------------");
302
303 for i in 0..n_classes {
304 println!(
305 "{} | {:.4} | {:.4} | {:.4}",
306 class_labels[i], precision[i], recall[i], f1[i]
307 );
308 }
309
310 println!("\nMacro F1-Score: {:.4}", cm.macro_f1());
311
312 // Show different confusion matrix visualizations
313 println!("\nFinal Confusion Matrix Visualizations:");
314
315 // 1. Standard confusion matrix
316 println!("\n1. Standard Confusion Matrix:");
317 println!("{}", cm.to_ascii(Some("Final Confusion Matrix"), false));
318
319 // 2. Normalized confusion matrix
320 println!("\n2. Normalized Confusion Matrix:");
321 println!(
322 "{}",
323 cm.to_ascii(Some("Final Normalized Confusion Matrix"), true)
324 );
325
326 // 3. Confusion matrix heatmap
327 println!("\n3. Confusion Matrix Heatmap:");
328 println!(
329 "{}",
330 cm.to_heatmap(Some("Final Confusion Matrix Heatmap"), true)
331 );
332
333 // 4. Error pattern analysis
334 println!("\n4. Error Pattern Analysis:");
335 println!("{}", cm.error_heatmap(Some("Final Error Pattern Analysis")));
336
337 println!("\nNeural Network Confusion Matrix Visualization Complete!");
338 Ok(())
339}
Sourcepub fn recall(&self) -> Array1<F>
pub fn recall(&self) -> Array1<F>
Calculate recall for each class
Examples found in repository?
7fn main() -> Result<()> {
8 println!("Neural Network Model Evaluation Visualization Example\n");
9
10 // Generate some example data
11 let n_samples = 500;
12 let n_features = 10;
13 let n_classes = 4;
14
15 println!(
16 "Generating {} samples with {} features for {} classes",
17 n_samples, n_features, n_classes
18 );
19
20 // 1. Confusion Matrix Example
21 println!("\n--- Confusion Matrix Visualization ---\n");
22
23 // Create a deterministic RNG for reproducibility
24 let mut rng = SmallRng::seed_from_u64(42);
25
26 // Generate random predictions and true labels
27 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
28
29 // Create slightly correlated predictions (not completely random)
30 let y_pred = Array::from_shape_fn(n_samples, |i| {
31 if rng.random::<f32>() < 0.7 {
32 // 70% chance of correct prediction
33 y_true[i]
34 } else {
35 // 30% chance of random class
36 rng.random_range(0..n_classes)
37 }
38 });
39
40 // Create confusion matrix
41 let class_labels = vec![
42 "Class A".to_string(),
43 "Class B".to_string(),
44 "Class C".to_string(),
45 "Class D".to_string(),
46 ];
47
48 let cm = ConfusionMatrix::<f32>::new(
49 &y_true.view(),
50 &y_pred.view(),
51 Some(n_classes),
52 Some(class_labels),
53 )?;
54
55 // Print raw and normalized confusion matrices
56 println!("Raw Confusion Matrix:\n");
57 println!("{}", cm.to_ascii(Some("Confusion Matrix"), false));
58
59 println!("\nNormalized Confusion Matrix:\n");
60 println!("{}", cm.to_ascii(Some("Normalized Confusion Matrix"), true));
61
62 // Print metrics
63 println!("\nAccuracy: {:.3}", cm.accuracy());
64
65 let precision = cm.precision();
66 let recall = cm.recall();
67 let f1 = cm.f1_score();
68
69 println!("Per-class metrics:");
70 for i in 0..n_classes {
71 println!(
72 " Class {}: Precision={:.3}, Recall={:.3}, F1={:.3}",
73 i, precision[i], recall[i], f1[i]
74 );
75 }
76
77 println!("Macro F1 Score: {:.3}", cm.macro_f1());
78
79 // 2. Feature Importance Visualization
80 println!("\n--- Feature Importance Visualization ---\n");
81
82 // Generate random feature importance scores
83 let feature_names = (0..n_features)
84 .map(|i| format!("Feature_{}", i))
85 .collect::<Vec<String>>();
86
87 let importance = Array1::from_shape_fn(n_features, |i| {
88 // Make some features more important than others
89 let base = (n_features - i) as f32 / n_features as f32;
90 base + 0.2 * rng.random::<f32>()
91 });
92
93 let fi = FeatureImportance::new(feature_names, importance)?;
94
95 // Print full feature importance
96 println!("{}", fi.to_ascii(Some("Feature Importance"), 60, None));
97
98 // Print top-5 features
99 println!("\nTop 5 Most Important Features:\n");
100 println!("{}", fi.to_ascii(Some("Top 5 Features"), 60, Some(5)));
101
102 // 3. ROC Curve for Binary Classification
103 println!("\n--- ROC Curve Visualization ---\n");
104
105 // Generate binary classification data
106 let n_binary = 200;
107 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
108
109 // Generate scores with some predictive power
110 let y_scores = Array1::from_shape_fn(n_binary, |i| {
111 if y_true_binary[i] == 1 {
112 // Higher scores for positive class
113 0.6 + 0.4 * rng.random::<f32>()
114 } else {
115 // Lower scores for negative class
116 0.4 * rng.random::<f32>()
117 }
118 });
119
120 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
121
122 println!("ROC AUC: {:.3}", roc.auc);
123 println!("\n{}", roc.to_ascii(None, 50, 20));
124
125 // 4. Learning Curve Visualization
126 println!("\n--- Learning Curve Visualization ---\n");
127
128 // Generate learning curve data
129 let n_points = 10;
130 let n_cv = 5;
131
132 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
133
134 // Generate training scores (decreasing with size due to overfitting)
135 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
136 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
137 });
138
139 // Generate validation scores (increasing with size)
140 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
141 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
142 });
143
144 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
145
146 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
147
148 // Print final message
149 println!("\nModel evaluation visualizations completed successfully!");
150
151 Ok(())
152}
More examples
10fn main() -> Result<()> {
11 println!(
12 "{}",
13 stylize("Neural Network Model Evaluation with Color", Style::Bold)
14 );
15 println!("{}", "-".repeat(50));
16
17 // Set up color options
18 let color_options = ColorOptions {
19 enabled: true,
20 use_background: false,
21 use_bright: true,
22 };
23
24 // Generate some example data
25 let n_samples = 500;
26 let n_features = 10;
27 let n_classes = 4;
28
29 println!(
30 "\n{} {} {} {} {} {}",
31 colorize("Generating", Color::BrightGreen),
32 colorize(n_samples.to_string(), Color::BrightYellow),
33 colorize("samples with", Color::BrightGreen),
34 colorize(n_features.to_string(), Color::BrightYellow),
35 colorize("features for", Color::BrightGreen),
36 colorize(n_classes.to_string(), Color::BrightYellow),
37 );
38
39 // Create a deterministic RNG for reproducibility
40 let mut rng = SmallRng::seed_from_u64(42);
41
42 // 1. Confusion Matrix Example
43 println!(
44 "\n{}",
45 stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
46 );
47
48 // Generate random predictions and true labels
49 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
50
51 // Create slightly correlated predictions (not completely random)
52 let y_pred = Array::from_shape_fn(n_samples, |i| {
53 if rng.random::<f32>() < 0.7 {
54 // 70% chance of correct prediction
55 y_true[i]
56 } else {
57 // 30% chance of random class
58 rng.random_range(0..n_classes)
59 }
60 });
61
62 // Create confusion matrix
63 let class_labels = vec![
64 "Class A".to_string(),
65 "Class B".to_string(),
66 "Class C".to_string(),
67 "Class D".to_string(),
68 ];
69
70 let cm = ConfusionMatrix::<f32>::new(
71 &y_true.view(),
72 &y_pred.view(),
73 Some(n_classes),
74 Some(class_labels),
75 )?;
76
77 // Print raw and normalized confusion matrices with color
78 println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
79 println!(
80 "{}",
81 cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
82 );
83
84 println!(
85 "\n{}",
86 colorize("Normalized Confusion Matrix:", Color::BrightCyan)
87 );
88 println!(
89 "{}",
90 cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
91 );
92
93 // Print metrics
94 println!(
95 "\n{} {:.3}",
96 colorize("Overall Accuracy:", Color::BrightMagenta),
97 cm.accuracy()
98 );
99
100 let precision = cm.precision();
101 let recall = cm.recall();
102 let f1 = cm.f1_score();
103
104 println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
105 for i in 0..n_classes {
106 println!(
107 " {}: {}={:.3}, {}={:.3}, {}={:.3}",
108 colorize(format!("Class {}", i), Color::BrightYellow),
109 colorize("Precision", Color::BrightCyan),
110 precision[i],
111 colorize("Recall", Color::BrightGreen),
112 recall[i],
113 colorize("F1", Color::BrightBlue),
114 f1[i]
115 );
116 }
117
118 println!(
119 "{} {:.3}",
120 colorize("Macro F1 Score:", Color::BrightMagenta),
121 cm.macro_f1()
122 );
123
124 // 2. Feature Importance Visualization
125 println!(
126 "\n{}",
127 stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
128 );
129
130 // Generate random feature importance scores
131 let feature_names = (0..n_features)
132 .map(|i| format!("Feature_{}", i))
133 .collect::<Vec<String>>();
134
135 let importance = Array1::from_shape_fn(n_features, |i| {
136 // Make some features more important than others
137 let base = (n_features - i) as f32 / n_features as f32;
138 base + 0.2 * rng.random::<f32>()
139 });
140
141 let fi = FeatureImportance::new(feature_names, importance)?;
142
143 // Print full feature importance with color
144 println!(
145 "{}",
146 fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
147 );
148
149 // Print top-5 features with color
150 println!(
151 "\n{}",
152 colorize("Top 5 Most Important Features:", Color::BrightCyan)
153 );
154 println!(
155 "{}",
156 fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
157 );
158
159 // 3. ROC Curve for Binary Classification
160 println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
161
162 // Generate binary classification data
163 let n_binary = 200;
164 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
165
166 // Generate scores with some predictive power
167 let y_scores = Array1::from_shape_fn(n_binary, |i| {
168 if y_true_binary[i] == 1 {
169 // Higher scores for positive class
170 0.6 + 0.4 * rng.random::<f32>()
171 } else {
172 // Lower scores for negative class
173 0.4 * rng.random::<f32>()
174 }
175 });
176
177 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
178
179 println!(
180 "{} {:.3}",
181 colorize("ROC AUC:", Color::BrightMagenta),
182 roc.auc
183 );
184
185 println!("\n{}", roc.to_ascii(None, 50, 20));
186
187 // 4. Learning Curve Visualization
188 println!(
189 "\n{}",
190 stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
191 );
192
193 // Generate learning curve data
194 let n_points = 10;
195 let n_cv = 5;
196
197 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
198
199 // Generate training scores (decreasing with size due to overfitting)
200 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
201 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
202 });
203
204 // Generate validation scores (increasing with size)
205 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
206 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
207 });
208
209 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
210
211 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
212
213 // Print final message with color
214 println!(
215 "\n{}",
216 colorize(
217 "Model evaluation visualizations completed successfully!",
218 Color::BrightGreen
219 )
220 );
221
222 Ok(())
223}
118fn main() -> Result<()> {
119 println!("Neural Network Confusion Matrix Visualization");
120 println!("==============================================\n");
121
122 // Initialize RNG with a fixed seed for reproducibility
123 let mut rng = SmallRng::seed_from_u64(42);
124
125 // Generate spiral dataset for 3-class classification
126 let n_classes = 3;
127 let n_samples_per_class = 100;
128 let noise = 0.15;
129
130 let (x, y) = generate_spiral_dataset(n_samples_per_class, n_classes, noise, &mut rng);
131 println!(
132 "Generated spiral dataset with {} classes, {} samples per class",
133 n_classes, n_samples_per_class
134 );
135
136 // Split data into training and test sets (80/20 split)
137 let n_samples = x.shape()[0];
138 let n_train = (n_samples as f32 * 0.8) as usize;
139 let n_test = n_samples - n_train;
140
141 let x_train = x.slice(ndarray::s![0..n_train, ..]).to_owned();
142 let y_train = y.slice(ndarray::s![0..n_train]).to_owned();
143 let x_test = x.slice(ndarray::s![n_train.., ..]).to_owned();
144 let y_test = y.slice(ndarray::s![n_train..]).to_owned();
145
146 println!(
147 "Split data into {} training and {} test samples",
148 n_train, n_test
149 );
150
151 // Create a classification model
152 let input_dim = 2; // 2D input (x, y coordinates)
153 let hidden_dim = 32; // Hidden layer size
154 let output_dim = n_classes; // One output per class
155
156 let mut model = create_classification_model(input_dim, hidden_dim, output_dim, &mut rng)?;
157 println!("Created model with {} layers", model.num_layers());
158
159 // Setup loss function and optimizer
160 let loss_fn = MeanSquaredError::new();
161 let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
162
163 // Train the model
164 let epochs = 100;
165 let x_train_dyn = x_train.clone().into_dyn();
166 let y_train_onehot = one_hot_encode(&y_train, n_classes);
167 let y_train_onehot_dyn = y_train_onehot.into_dyn();
168
169 // Create visualization callback for training metrics
170 let mut visualization_cb = VisualizationCallback::new(10) // Show every 10 epochs
171 .with_tracked_metrics(vec![
172 "train_loss".to_string(),
173 "val_accuracy".to_string(),
174 ]);
175
176 // Define class labels for confusion matrix
177 let class_labels = vec![
178 "Class A".to_string(),
179 "Class B".to_string(),
180 "Class C".to_string(),
181 ];
182
183 // Train the model (simple manual training loop)
184 println!("\nTraining model...");
185
186 // Initialize history for tracking metrics
187 let mut epoch_history = HashMap::new();
188 epoch_history.insert("train_loss".to_string(), Vec::new());
189 epoch_history.insert("val_accuracy".to_string(), Vec::new());
190
191 // Training loop
192 for epoch in 0..epochs {
193 // Train for one epoch
194 let train_loss =
195 model.train_batch(&x_train_dyn, &y_train_onehot_dyn, &loss_fn, &mut optimizer)?;
196
197 // Compute validation accuracy
198 let x_test_dyn = x_test.clone().into_dyn();
199 let predictions = model.forward(&x_test_dyn)?;
200 let predicted_classes = predictions_to_classes(&predictions);
201
202 // Calculate validation accuracy
203 let mut correct = 0;
204 for i in 0..n_test {
205 if predicted_classes[i] == y_test[i] {
206 correct += 1;
207 }
208 }
209 let val_accuracy = correct as f32 / n_test as f32;
210
211 // Store metrics
212 epoch_history
213 .get_mut("train_loss")
214 .unwrap()
215 .push(train_loss);
216 epoch_history
217 .get_mut("val_accuracy")
218 .unwrap()
219 .push(val_accuracy);
220
221 // Print progress
222 if (epoch + 1) % 10 == 0 || epoch == 0 {
223 println!(
224 "Epoch {}/{}: loss = {:.6}, val_accuracy = {:.4}",
225 epoch + 1,
226 epochs,
227 train_loss,
228 val_accuracy
229 );
230 }
231
232 // Update visualization callback
233 let mut context = CallbackContext {
234 epoch,
235 total_epochs: epochs,
236 batch: 0,
237 total_batches: 1,
238 batch_loss: None,
239 epoch_loss: Some(train_loss),
240 val_loss: None,
241 metrics: vec![val_accuracy],
242 history: &epoch_history,
243 stop_training: false,
244 model: None,
245 };
246
247 // Visualize progress with metrics chart
248 if epoch % 10 == 0 || epoch == epochs - 1 {
249 visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
250 }
251
252 // Calculate and show confusion matrix during training
253 if epoch % 20 == 0 || epoch == epochs - 1 {
254 // Create confusion matrix
255 let cm = ConfusionMatrix::<f32>::new(
256 &y_test.view(),
257 &predicted_classes.view(),
258 Some(n_classes),
259 Some(class_labels.clone()),
260 )?;
261
262 // Show heatmap visualization
263 println!("\nConfusion Matrix at Epoch {}:", epoch + 1);
264 println!(
265 "{}",
266 cm.to_heatmap(
267 Some(&format!("Confusion Matrix - Epoch {}", epoch + 1)),
268 true
269 )
270 );
271 }
272 }
273
274 // Final evaluation
275 println!("\nFinal model evaluation:");
276
277 // Make predictions on test set
278 let x_test_dyn = x_test.clone().into_dyn();
279 let predictions = model.forward(&x_test_dyn)?;
280 let predicted_classes = predictions_to_classes(&predictions);
281
282 // Create confusion matrix
283 let cm = ConfusionMatrix::<f32>::new(
284 &y_test.view(),
285 &predicted_classes.view(),
286 Some(n_classes),
287 Some(class_labels.clone()),
288 )?;
289
290 // Calculate and show metrics
291 let accuracy = cm.accuracy();
292 let precision = cm.precision();
293 let recall = cm.recall();
294 let f1 = cm.f1_score();
295
296 println!("\nFinal Classification Metrics:");
297 println!("Overall Accuracy: {:.4}", accuracy);
298
299 println!("\nPer-Class Metrics:");
300 println!("Class | Precision | Recall | F1-Score");
301 println!("-----------------------------------");
302
303 for i in 0..n_classes {
304 println!(
305 "{} | {:.4} | {:.4} | {:.4}",
306 class_labels[i], precision[i], recall[i], f1[i]
307 );
308 }
309
310 println!("\nMacro F1-Score: {:.4}", cm.macro_f1());
311
312 // Show different confusion matrix visualizations
313 println!("\nFinal Confusion Matrix Visualizations:");
314
315 // 1. Standard confusion matrix
316 println!("\n1. Standard Confusion Matrix:");
317 println!("{}", cm.to_ascii(Some("Final Confusion Matrix"), false));
318
319 // 2. Normalized confusion matrix
320 println!("\n2. Normalized Confusion Matrix:");
321 println!(
322 "{}",
323 cm.to_ascii(Some("Final Normalized Confusion Matrix"), true)
324 );
325
326 // 3. Confusion matrix heatmap
327 println!("\n3. Confusion Matrix Heatmap:");
328 println!(
329 "{}",
330 cm.to_heatmap(Some("Final Confusion Matrix Heatmap"), true)
331 );
332
333 // 4. Error pattern analysis
334 println!("\n4. Error Pattern Analysis:");
335 println!("{}", cm.error_heatmap(Some("Final Error Pattern Analysis")));
336
337 println!("\nNeural Network Confusion Matrix Visualization Complete!");
338 Ok(())
339}
Sourcepub fn f1_score(&self) -> Array1<F>
pub fn f1_score(&self) -> Array1<F>
Calculate F1 score for each class
Examples found in repository?
7fn main() -> Result<()> {
8 println!("Neural Network Model Evaluation Visualization Example\n");
9
10 // Generate some example data
11 let n_samples = 500;
12 let n_features = 10;
13 let n_classes = 4;
14
15 println!(
16 "Generating {} samples with {} features for {} classes",
17 n_samples, n_features, n_classes
18 );
19
20 // 1. Confusion Matrix Example
21 println!("\n--- Confusion Matrix Visualization ---\n");
22
23 // Create a deterministic RNG for reproducibility
24 let mut rng = SmallRng::seed_from_u64(42);
25
26 // Generate random predictions and true labels
27 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
28
29 // Create slightly correlated predictions (not completely random)
30 let y_pred = Array::from_shape_fn(n_samples, |i| {
31 if rng.random::<f32>() < 0.7 {
32 // 70% chance of correct prediction
33 y_true[i]
34 } else {
35 // 30% chance of random class
36 rng.random_range(0..n_classes)
37 }
38 });
39
40 // Create confusion matrix
41 let class_labels = vec![
42 "Class A".to_string(),
43 "Class B".to_string(),
44 "Class C".to_string(),
45 "Class D".to_string(),
46 ];
47
48 let cm = ConfusionMatrix::<f32>::new(
49 &y_true.view(),
50 &y_pred.view(),
51 Some(n_classes),
52 Some(class_labels),
53 )?;
54
55 // Print raw and normalized confusion matrices
56 println!("Raw Confusion Matrix:\n");
57 println!("{}", cm.to_ascii(Some("Confusion Matrix"), false));
58
59 println!("\nNormalized Confusion Matrix:\n");
60 println!("{}", cm.to_ascii(Some("Normalized Confusion Matrix"), true));
61
62 // Print metrics
63 println!("\nAccuracy: {:.3}", cm.accuracy());
64
65 let precision = cm.precision();
66 let recall = cm.recall();
67 let f1 = cm.f1_score();
68
69 println!("Per-class metrics:");
70 for i in 0..n_classes {
71 println!(
72 " Class {}: Precision={:.3}, Recall={:.3}, F1={:.3}",
73 i, precision[i], recall[i], f1[i]
74 );
75 }
76
77 println!("Macro F1 Score: {:.3}", cm.macro_f1());
78
79 // 2. Feature Importance Visualization
80 println!("\n--- Feature Importance Visualization ---\n");
81
82 // Generate random feature importance scores
83 let feature_names = (0..n_features)
84 .map(|i| format!("Feature_{}", i))
85 .collect::<Vec<String>>();
86
87 let importance = Array1::from_shape_fn(n_features, |i| {
88 // Make some features more important than others
89 let base = (n_features - i) as f32 / n_features as f32;
90 base + 0.2 * rng.random::<f32>()
91 });
92
93 let fi = FeatureImportance::new(feature_names, importance)?;
94
95 // Print full feature importance
96 println!("{}", fi.to_ascii(Some("Feature Importance"), 60, None));
97
98 // Print top-5 features
99 println!("\nTop 5 Most Important Features:\n");
100 println!("{}", fi.to_ascii(Some("Top 5 Features"), 60, Some(5)));
101
102 // 3. ROC Curve for Binary Classification
103 println!("\n--- ROC Curve Visualization ---\n");
104
105 // Generate binary classification data
106 let n_binary = 200;
107 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
108
109 // Generate scores with some predictive power
110 let y_scores = Array1::from_shape_fn(n_binary, |i| {
111 if y_true_binary[i] == 1 {
112 // Higher scores for positive class
113 0.6 + 0.4 * rng.random::<f32>()
114 } else {
115 // Lower scores for negative class
116 0.4 * rng.random::<f32>()
117 }
118 });
119
120 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
121
122 println!("ROC AUC: {:.3}", roc.auc);
123 println!("\n{}", roc.to_ascii(None, 50, 20));
124
125 // 4. Learning Curve Visualization
126 println!("\n--- Learning Curve Visualization ---\n");
127
128 // Generate learning curve data
129 let n_points = 10;
130 let n_cv = 5;
131
132 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
133
134 // Generate training scores (decreasing with size due to overfitting)
135 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
136 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
137 });
138
139 // Generate validation scores (increasing with size)
140 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
141 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
142 });
143
144 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
145
146 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
147
148 // Print final message
149 println!("\nModel evaluation visualizations completed successfully!");
150
151 Ok(())
152}
More examples
10fn main() -> Result<()> {
11 println!(
12 "{}",
13 stylize("Neural Network Model Evaluation with Color", Style::Bold)
14 );
15 println!("{}", "-".repeat(50));
16
17 // Set up color options
18 let color_options = ColorOptions {
19 enabled: true,
20 use_background: false,
21 use_bright: true,
22 };
23
24 // Generate some example data
25 let n_samples = 500;
26 let n_features = 10;
27 let n_classes = 4;
28
29 println!(
30 "\n{} {} {} {} {} {}",
31 colorize("Generating", Color::BrightGreen),
32 colorize(n_samples.to_string(), Color::BrightYellow),
33 colorize("samples with", Color::BrightGreen),
34 colorize(n_features.to_string(), Color::BrightYellow),
35 colorize("features for", Color::BrightGreen),
36 colorize(n_classes.to_string(), Color::BrightYellow),
37 );
38
39 // Create a deterministic RNG for reproducibility
40 let mut rng = SmallRng::seed_from_u64(42);
41
42 // 1. Confusion Matrix Example
43 println!(
44 "\n{}",
45 stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
46 );
47
48 // Generate random predictions and true labels
49 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
50
51 // Create slightly correlated predictions (not completely random)
52 let y_pred = Array::from_shape_fn(n_samples, |i| {
53 if rng.random::<f32>() < 0.7 {
54 // 70% chance of correct prediction
55 y_true[i]
56 } else {
57 // 30% chance of random class
58 rng.random_range(0..n_classes)
59 }
60 });
61
62 // Create confusion matrix
63 let class_labels = vec![
64 "Class A".to_string(),
65 "Class B".to_string(),
66 "Class C".to_string(),
67 "Class D".to_string(),
68 ];
69
70 let cm = ConfusionMatrix::<f32>::new(
71 &y_true.view(),
72 &y_pred.view(),
73 Some(n_classes),
74 Some(class_labels),
75 )?;
76
77 // Print raw and normalized confusion matrices with color
78 println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
79 println!(
80 "{}",
81 cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
82 );
83
84 println!(
85 "\n{}",
86 colorize("Normalized Confusion Matrix:", Color::BrightCyan)
87 );
88 println!(
89 "{}",
90 cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
91 );
92
93 // Print metrics
94 println!(
95 "\n{} {:.3}",
96 colorize("Overall Accuracy:", Color::BrightMagenta),
97 cm.accuracy()
98 );
99
100 let precision = cm.precision();
101 let recall = cm.recall();
102 let f1 = cm.f1_score();
103
104 println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
105 for i in 0..n_classes {
106 println!(
107 " {}: {}={:.3}, {}={:.3}, {}={:.3}",
108 colorize(format!("Class {}", i), Color::BrightYellow),
109 colorize("Precision", Color::BrightCyan),
110 precision[i],
111 colorize("Recall", Color::BrightGreen),
112 recall[i],
113 colorize("F1", Color::BrightBlue),
114 f1[i]
115 );
116 }
117
118 println!(
119 "{} {:.3}",
120 colorize("Macro F1 Score:", Color::BrightMagenta),
121 cm.macro_f1()
122 );
123
124 // 2. Feature Importance Visualization
125 println!(
126 "\n{}",
127 stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
128 );
129
130 // Generate random feature importance scores
131 let feature_names = (0..n_features)
132 .map(|i| format!("Feature_{}", i))
133 .collect::<Vec<String>>();
134
135 let importance = Array1::from_shape_fn(n_features, |i| {
136 // Make some features more important than others
137 let base = (n_features - i) as f32 / n_features as f32;
138 base + 0.2 * rng.random::<f32>()
139 });
140
141 let fi = FeatureImportance::new(feature_names, importance)?;
142
143 // Print full feature importance with color
144 println!(
145 "{}",
146 fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
147 );
148
149 // Print top-5 features with color
150 println!(
151 "\n{}",
152 colorize("Top 5 Most Important Features:", Color::BrightCyan)
153 );
154 println!(
155 "{}",
156 fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
157 );
158
159 // 3. ROC Curve for Binary Classification
160 println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
161
162 // Generate binary classification data
163 let n_binary = 200;
164 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
165
166 // Generate scores with some predictive power
167 let y_scores = Array1::from_shape_fn(n_binary, |i| {
168 if y_true_binary[i] == 1 {
169 // Higher scores for positive class
170 0.6 + 0.4 * rng.random::<f32>()
171 } else {
172 // Lower scores for negative class
173 0.4 * rng.random::<f32>()
174 }
175 });
176
177 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
178
179 println!(
180 "{} {:.3}",
181 colorize("ROC AUC:", Color::BrightMagenta),
182 roc.auc
183 );
184
185 println!("\n{}", roc.to_ascii(None, 50, 20));
186
187 // 4. Learning Curve Visualization
188 println!(
189 "\n{}",
190 stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
191 );
192
193 // Generate learning curve data
194 let n_points = 10;
195 let n_cv = 5;
196
197 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
198
199 // Generate training scores (decreasing with size due to overfitting)
200 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
201 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
202 });
203
204 // Generate validation scores (increasing with size)
205 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
206 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
207 });
208
209 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
210
211 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
212
213 // Print final message with color
214 println!(
215 "\n{}",
216 colorize(
217 "Model evaluation visualizations completed successfully!",
218 Color::BrightGreen
219 )
220 );
221
222 Ok(())
223}
118fn main() -> Result<()> {
119 println!("Neural Network Confusion Matrix Visualization");
120 println!("==============================================\n");
121
122 // Initialize RNG with a fixed seed for reproducibility
123 let mut rng = SmallRng::seed_from_u64(42);
124
125 // Generate spiral dataset for 3-class classification
126 let n_classes = 3;
127 let n_samples_per_class = 100;
128 let noise = 0.15;
129
130 let (x, y) = generate_spiral_dataset(n_samples_per_class, n_classes, noise, &mut rng);
131 println!(
132 "Generated spiral dataset with {} classes, {} samples per class",
133 n_classes, n_samples_per_class
134 );
135
136 // Split data into training and test sets (80/20 split)
137 let n_samples = x.shape()[0];
138 let n_train = (n_samples as f32 * 0.8) as usize;
139 let n_test = n_samples - n_train;
140
141 let x_train = x.slice(ndarray::s![0..n_train, ..]).to_owned();
142 let y_train = y.slice(ndarray::s![0..n_train]).to_owned();
143 let x_test = x.slice(ndarray::s![n_train.., ..]).to_owned();
144 let y_test = y.slice(ndarray::s![n_train..]).to_owned();
145
146 println!(
147 "Split data into {} training and {} test samples",
148 n_train, n_test
149 );
150
151 // Create a classification model
152 let input_dim = 2; // 2D input (x, y coordinates)
153 let hidden_dim = 32; // Hidden layer size
154 let output_dim = n_classes; // One output per class
155
156 let mut model = create_classification_model(input_dim, hidden_dim, output_dim, &mut rng)?;
157 println!("Created model with {} layers", model.num_layers());
158
159 // Setup loss function and optimizer
160 let loss_fn = MeanSquaredError::new();
161 let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
162
163 // Train the model
164 let epochs = 100;
165 let x_train_dyn = x_train.clone().into_dyn();
166 let y_train_onehot = one_hot_encode(&y_train, n_classes);
167 let y_train_onehot_dyn = y_train_onehot.into_dyn();
168
169 // Create visualization callback for training metrics
170 let mut visualization_cb = VisualizationCallback::new(10) // Show every 10 epochs
171 .with_tracked_metrics(vec![
172 "train_loss".to_string(),
173 "val_accuracy".to_string(),
174 ]);
175
176 // Define class labels for confusion matrix
177 let class_labels = vec![
178 "Class A".to_string(),
179 "Class B".to_string(),
180 "Class C".to_string(),
181 ];
182
183 // Train the model (simple manual training loop)
184 println!("\nTraining model...");
185
186 // Initialize history for tracking metrics
187 let mut epoch_history = HashMap::new();
188 epoch_history.insert("train_loss".to_string(), Vec::new());
189 epoch_history.insert("val_accuracy".to_string(), Vec::new());
190
191 // Training loop
192 for epoch in 0..epochs {
193 // Train for one epoch
194 let train_loss =
195 model.train_batch(&x_train_dyn, &y_train_onehot_dyn, &loss_fn, &mut optimizer)?;
196
197 // Compute validation accuracy
198 let x_test_dyn = x_test.clone().into_dyn();
199 let predictions = model.forward(&x_test_dyn)?;
200 let predicted_classes = predictions_to_classes(&predictions);
201
202 // Calculate validation accuracy
203 let mut correct = 0;
204 for i in 0..n_test {
205 if predicted_classes[i] == y_test[i] {
206 correct += 1;
207 }
208 }
209 let val_accuracy = correct as f32 / n_test as f32;
210
211 // Store metrics
212 epoch_history
213 .get_mut("train_loss")
214 .unwrap()
215 .push(train_loss);
216 epoch_history
217 .get_mut("val_accuracy")
218 .unwrap()
219 .push(val_accuracy);
220
221 // Print progress
222 if (epoch + 1) % 10 == 0 || epoch == 0 {
223 println!(
224 "Epoch {}/{}: loss = {:.6}, val_accuracy = {:.4}",
225 epoch + 1,
226 epochs,
227 train_loss,
228 val_accuracy
229 );
230 }
231
232 // Update visualization callback
233 let mut context = CallbackContext {
234 epoch,
235 total_epochs: epochs,
236 batch: 0,
237 total_batches: 1,
238 batch_loss: None,
239 epoch_loss: Some(train_loss),
240 val_loss: None,
241 metrics: vec![val_accuracy],
242 history: &epoch_history,
243 stop_training: false,
244 model: None,
245 };
246
247 // Visualize progress with metrics chart
248 if epoch % 10 == 0 || epoch == epochs - 1 {
249 visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
250 }
251
252 // Calculate and show confusion matrix during training
253 if epoch % 20 == 0 || epoch == epochs - 1 {
254 // Create confusion matrix
255 let cm = ConfusionMatrix::<f32>::new(
256 &y_test.view(),
257 &predicted_classes.view(),
258 Some(n_classes),
259 Some(class_labels.clone()),
260 )?;
261
262 // Show heatmap visualization
263 println!("\nConfusion Matrix at Epoch {}:", epoch + 1);
264 println!(
265 "{}",
266 cm.to_heatmap(
267 Some(&format!("Confusion Matrix - Epoch {}", epoch + 1)),
268 true
269 )
270 );
271 }
272 }
273
274 // Final evaluation
275 println!("\nFinal model evaluation:");
276
277 // Make predictions on test set
278 let x_test_dyn = x_test.clone().into_dyn();
279 let predictions = model.forward(&x_test_dyn)?;
280 let predicted_classes = predictions_to_classes(&predictions);
281
282 // Create confusion matrix
283 let cm = ConfusionMatrix::<f32>::new(
284 &y_test.view(),
285 &predicted_classes.view(),
286 Some(n_classes),
287 Some(class_labels.clone()),
288 )?;
289
290 // Calculate and show metrics
291 let accuracy = cm.accuracy();
292 let precision = cm.precision();
293 let recall = cm.recall();
294 let f1 = cm.f1_score();
295
296 println!("\nFinal Classification Metrics:");
297 println!("Overall Accuracy: {:.4}", accuracy);
298
299 println!("\nPer-Class Metrics:");
300 println!("Class | Precision | Recall | F1-Score");
301 println!("-----------------------------------");
302
303 for i in 0..n_classes {
304 println!(
305 "{} | {:.4} | {:.4} | {:.4}",
306 class_labels[i], precision[i], recall[i], f1[i]
307 );
308 }
309
310 println!("\nMacro F1-Score: {:.4}", cm.macro_f1());
311
312 // Show different confusion matrix visualizations
313 println!("\nFinal Confusion Matrix Visualizations:");
314
315 // 1. Standard confusion matrix
316 println!("\n1. Standard Confusion Matrix:");
317 println!("{}", cm.to_ascii(Some("Final Confusion Matrix"), false));
318
319 // 2. Normalized confusion matrix
320 println!("\n2. Normalized Confusion Matrix:");
321 println!(
322 "{}",
323 cm.to_ascii(Some("Final Normalized Confusion Matrix"), true)
324 );
325
326 // 3. Confusion matrix heatmap
327 println!("\n3. Confusion Matrix Heatmap:");
328 println!(
329 "{}",
330 cm.to_heatmap(Some("Final Confusion Matrix Heatmap"), true)
331 );
332
333 // 4. Error pattern analysis
334 println!("\n4. Error Pattern Analysis:");
335 println!("{}", cm.error_heatmap(Some("Final Error Pattern Analysis")));
336
337 println!("\nNeural Network Confusion Matrix Visualization Complete!");
338 Ok(())
339}
Sourcepub fn macro_f1(&self) -> F
pub fn macro_f1(&self) -> F
Calculate macro-averaged F1 score
Examples found in repository?
7fn main() -> Result<()> {
8 println!("Neural Network Model Evaluation Visualization Example\n");
9
10 // Generate some example data
11 let n_samples = 500;
12 let n_features = 10;
13 let n_classes = 4;
14
15 println!(
16 "Generating {} samples with {} features for {} classes",
17 n_samples, n_features, n_classes
18 );
19
20 // 1. Confusion Matrix Example
21 println!("\n--- Confusion Matrix Visualization ---\n");
22
23 // Create a deterministic RNG for reproducibility
24 let mut rng = SmallRng::seed_from_u64(42);
25
26 // Generate random predictions and true labels
27 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
28
29 // Create slightly correlated predictions (not completely random)
30 let y_pred = Array::from_shape_fn(n_samples, |i| {
31 if rng.random::<f32>() < 0.7 {
32 // 70% chance of correct prediction
33 y_true[i]
34 } else {
35 // 30% chance of random class
36 rng.random_range(0..n_classes)
37 }
38 });
39
40 // Create confusion matrix
41 let class_labels = vec![
42 "Class A".to_string(),
43 "Class B".to_string(),
44 "Class C".to_string(),
45 "Class D".to_string(),
46 ];
47
48 let cm = ConfusionMatrix::<f32>::new(
49 &y_true.view(),
50 &y_pred.view(),
51 Some(n_classes),
52 Some(class_labels),
53 )?;
54
55 // Print raw and normalized confusion matrices
56 println!("Raw Confusion Matrix:\n");
57 println!("{}", cm.to_ascii(Some("Confusion Matrix"), false));
58
59 println!("\nNormalized Confusion Matrix:\n");
60 println!("{}", cm.to_ascii(Some("Normalized Confusion Matrix"), true));
61
62 // Print metrics
63 println!("\nAccuracy: {:.3}", cm.accuracy());
64
65 let precision = cm.precision();
66 let recall = cm.recall();
67 let f1 = cm.f1_score();
68
69 println!("Per-class metrics:");
70 for i in 0..n_classes {
71 println!(
72 " Class {}: Precision={:.3}, Recall={:.3}, F1={:.3}",
73 i, precision[i], recall[i], f1[i]
74 );
75 }
76
77 println!("Macro F1 Score: {:.3}", cm.macro_f1());
78
79 // 2. Feature Importance Visualization
80 println!("\n--- Feature Importance Visualization ---\n");
81
82 // Generate random feature importance scores
83 let feature_names = (0..n_features)
84 .map(|i| format!("Feature_{}", i))
85 .collect::<Vec<String>>();
86
87 let importance = Array1::from_shape_fn(n_features, |i| {
88 // Make some features more important than others
89 let base = (n_features - i) as f32 / n_features as f32;
90 base + 0.2 * rng.random::<f32>()
91 });
92
93 let fi = FeatureImportance::new(feature_names, importance)?;
94
95 // Print full feature importance
96 println!("{}", fi.to_ascii(Some("Feature Importance"), 60, None));
97
98 // Print top-5 features
99 println!("\nTop 5 Most Important Features:\n");
100 println!("{}", fi.to_ascii(Some("Top 5 Features"), 60, Some(5)));
101
102 // 3. ROC Curve for Binary Classification
103 println!("\n--- ROC Curve Visualization ---\n");
104
105 // Generate binary classification data
106 let n_binary = 200;
107 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
108
109 // Generate scores with some predictive power
110 let y_scores = Array1::from_shape_fn(n_binary, |i| {
111 if y_true_binary[i] == 1 {
112 // Higher scores for positive class
113 0.6 + 0.4 * rng.random::<f32>()
114 } else {
115 // Lower scores for negative class
116 0.4 * rng.random::<f32>()
117 }
118 });
119
120 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
121
122 println!("ROC AUC: {:.3}", roc.auc);
123 println!("\n{}", roc.to_ascii(None, 50, 20));
124
125 // 4. Learning Curve Visualization
126 println!("\n--- Learning Curve Visualization ---\n");
127
128 // Generate learning curve data
129 let n_points = 10;
130 let n_cv = 5;
131
132 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
133
134 // Generate training scores (decreasing with size due to overfitting)
135 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
136 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
137 });
138
139 // Generate validation scores (increasing with size)
140 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
141 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
142 });
143
144 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
145
146 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
147
148 // Print final message
149 println!("\nModel evaluation visualizations completed successfully!");
150
151 Ok(())
152}
More examples
10fn main() -> Result<()> {
11 println!(
12 "{}",
13 stylize("Neural Network Model Evaluation with Color", Style::Bold)
14 );
15 println!("{}", "-".repeat(50));
16
17 // Set up color options
18 let color_options = ColorOptions {
19 enabled: true,
20 use_background: false,
21 use_bright: true,
22 };
23
24 // Generate some example data
25 let n_samples = 500;
26 let n_features = 10;
27 let n_classes = 4;
28
29 println!(
30 "\n{} {} {} {} {} {}",
31 colorize("Generating", Color::BrightGreen),
32 colorize(n_samples.to_string(), Color::BrightYellow),
33 colorize("samples with", Color::BrightGreen),
34 colorize(n_features.to_string(), Color::BrightYellow),
35 colorize("features for", Color::BrightGreen),
36 colorize(n_classes.to_string(), Color::BrightYellow),
37 );
38
39 // Create a deterministic RNG for reproducibility
40 let mut rng = SmallRng::seed_from_u64(42);
41
42 // 1. Confusion Matrix Example
43 println!(
44 "\n{}",
45 stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
46 );
47
48 // Generate random predictions and true labels
49 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
50
51 // Create slightly correlated predictions (not completely random)
52 let y_pred = Array::from_shape_fn(n_samples, |i| {
53 if rng.random::<f32>() < 0.7 {
54 // 70% chance of correct prediction
55 y_true[i]
56 } else {
57 // 30% chance of random class
58 rng.random_range(0..n_classes)
59 }
60 });
61
62 // Create confusion matrix
63 let class_labels = vec![
64 "Class A".to_string(),
65 "Class B".to_string(),
66 "Class C".to_string(),
67 "Class D".to_string(),
68 ];
69
70 let cm = ConfusionMatrix::<f32>::new(
71 &y_true.view(),
72 &y_pred.view(),
73 Some(n_classes),
74 Some(class_labels),
75 )?;
76
77 // Print raw and normalized confusion matrices with color
78 println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
79 println!(
80 "{}",
81 cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
82 );
83
84 println!(
85 "\n{}",
86 colorize("Normalized Confusion Matrix:", Color::BrightCyan)
87 );
88 println!(
89 "{}",
90 cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
91 );
92
93 // Print metrics
94 println!(
95 "\n{} {:.3}",
96 colorize("Overall Accuracy:", Color::BrightMagenta),
97 cm.accuracy()
98 );
99
100 let precision = cm.precision();
101 let recall = cm.recall();
102 let f1 = cm.f1_score();
103
104 println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
105 for i in 0..n_classes {
106 println!(
107 " {}: {}={:.3}, {}={:.3}, {}={:.3}",
108 colorize(format!("Class {}", i), Color::BrightYellow),
109 colorize("Precision", Color::BrightCyan),
110 precision[i],
111 colorize("Recall", Color::BrightGreen),
112 recall[i],
113 colorize("F1", Color::BrightBlue),
114 f1[i]
115 );
116 }
117
118 println!(
119 "{} {:.3}",
120 colorize("Macro F1 Score:", Color::BrightMagenta),
121 cm.macro_f1()
122 );
123
124 // 2. Feature Importance Visualization
125 println!(
126 "\n{}",
127 stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
128 );
129
130 // Generate random feature importance scores
131 let feature_names = (0..n_features)
132 .map(|i| format!("Feature_{}", i))
133 .collect::<Vec<String>>();
134
135 let importance = Array1::from_shape_fn(n_features, |i| {
136 // Make some features more important than others
137 let base = (n_features - i) as f32 / n_features as f32;
138 base + 0.2 * rng.random::<f32>()
139 });
140
141 let fi = FeatureImportance::new(feature_names, importance)?;
142
143 // Print full feature importance with color
144 println!(
145 "{}",
146 fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
147 );
148
149 // Print top-5 features with color
150 println!(
151 "\n{}",
152 colorize("Top 5 Most Important Features:", Color::BrightCyan)
153 );
154 println!(
155 "{}",
156 fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
157 );
158
159 // 3. ROC Curve for Binary Classification
160 println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
161
162 // Generate binary classification data
163 let n_binary = 200;
164 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
165
166 // Generate scores with some predictive power
167 let y_scores = Array1::from_shape_fn(n_binary, |i| {
168 if y_true_binary[i] == 1 {
169 // Higher scores for positive class
170 0.6 + 0.4 * rng.random::<f32>()
171 } else {
172 // Lower scores for negative class
173 0.4 * rng.random::<f32>()
174 }
175 });
176
177 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
178
179 println!(
180 "{} {:.3}",
181 colorize("ROC AUC:", Color::BrightMagenta),
182 roc.auc
183 );
184
185 println!("\n{}", roc.to_ascii(None, 50, 20));
186
187 // 4. Learning Curve Visualization
188 println!(
189 "\n{}",
190 stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
191 );
192
193 // Generate learning curve data
194 let n_points = 10;
195 let n_cv = 5;
196
197 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
198
199 // Generate training scores (decreasing with size due to overfitting)
200 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
201 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
202 });
203
204 // Generate validation scores (increasing with size)
205 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
206 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
207 });
208
209 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
210
211 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
212
213 // Print final message with color
214 println!(
215 "\n{}",
216 colorize(
217 "Model evaluation visualizations completed successfully!",
218 Color::BrightGreen
219 )
220 );
221
222 Ok(())
223}
118fn main() -> Result<()> {
119 println!("Neural Network Confusion Matrix Visualization");
120 println!("==============================================\n");
121
122 // Initialize RNG with a fixed seed for reproducibility
123 let mut rng = SmallRng::seed_from_u64(42);
124
125 // Generate spiral dataset for 3-class classification
126 let n_classes = 3;
127 let n_samples_per_class = 100;
128 let noise = 0.15;
129
130 let (x, y) = generate_spiral_dataset(n_samples_per_class, n_classes, noise, &mut rng);
131 println!(
132 "Generated spiral dataset with {} classes, {} samples per class",
133 n_classes, n_samples_per_class
134 );
135
136 // Split data into training and test sets (80/20 split)
137 let n_samples = x.shape()[0];
138 let n_train = (n_samples as f32 * 0.8) as usize;
139 let n_test = n_samples - n_train;
140
141 let x_train = x.slice(ndarray::s![0..n_train, ..]).to_owned();
142 let y_train = y.slice(ndarray::s![0..n_train]).to_owned();
143 let x_test = x.slice(ndarray::s![n_train.., ..]).to_owned();
144 let y_test = y.slice(ndarray::s![n_train..]).to_owned();
145
146 println!(
147 "Split data into {} training and {} test samples",
148 n_train, n_test
149 );
150
151 // Create a classification model
152 let input_dim = 2; // 2D input (x, y coordinates)
153 let hidden_dim = 32; // Hidden layer size
154 let output_dim = n_classes; // One output per class
155
156 let mut model = create_classification_model(input_dim, hidden_dim, output_dim, &mut rng)?;
157 println!("Created model with {} layers", model.num_layers());
158
159 // Setup loss function and optimizer
160 let loss_fn = MeanSquaredError::new();
161 let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
162
163 // Train the model
164 let epochs = 100;
165 let x_train_dyn = x_train.clone().into_dyn();
166 let y_train_onehot = one_hot_encode(&y_train, n_classes);
167 let y_train_onehot_dyn = y_train_onehot.into_dyn();
168
169 // Create visualization callback for training metrics
170 let mut visualization_cb = VisualizationCallback::new(10) // Show every 10 epochs
171 .with_tracked_metrics(vec![
172 "train_loss".to_string(),
173 "val_accuracy".to_string(),
174 ]);
175
176 // Define class labels for confusion matrix
177 let class_labels = vec![
178 "Class A".to_string(),
179 "Class B".to_string(),
180 "Class C".to_string(),
181 ];
182
183 // Train the model (simple manual training loop)
184 println!("\nTraining model...");
185
186 // Initialize history for tracking metrics
187 let mut epoch_history = HashMap::new();
188 epoch_history.insert("train_loss".to_string(), Vec::new());
189 epoch_history.insert("val_accuracy".to_string(), Vec::new());
190
191 // Training loop
192 for epoch in 0..epochs {
193 // Train for one epoch
194 let train_loss =
195 model.train_batch(&x_train_dyn, &y_train_onehot_dyn, &loss_fn, &mut optimizer)?;
196
197 // Compute validation accuracy
198 let x_test_dyn = x_test.clone().into_dyn();
199 let predictions = model.forward(&x_test_dyn)?;
200 let predicted_classes = predictions_to_classes(&predictions);
201
202 // Calculate validation accuracy
203 let mut correct = 0;
204 for i in 0..n_test {
205 if predicted_classes[i] == y_test[i] {
206 correct += 1;
207 }
208 }
209 let val_accuracy = correct as f32 / n_test as f32;
210
211 // Store metrics
212 epoch_history
213 .get_mut("train_loss")
214 .unwrap()
215 .push(train_loss);
216 epoch_history
217 .get_mut("val_accuracy")
218 .unwrap()
219 .push(val_accuracy);
220
221 // Print progress
222 if (epoch + 1) % 10 == 0 || epoch == 0 {
223 println!(
224 "Epoch {}/{}: loss = {:.6}, val_accuracy = {:.4}",
225 epoch + 1,
226 epochs,
227 train_loss,
228 val_accuracy
229 );
230 }
231
232 // Update visualization callback
233 let mut context = CallbackContext {
234 epoch,
235 total_epochs: epochs,
236 batch: 0,
237 total_batches: 1,
238 batch_loss: None,
239 epoch_loss: Some(train_loss),
240 val_loss: None,
241 metrics: vec![val_accuracy],
242 history: &epoch_history,
243 stop_training: false,
244 model: None,
245 };
246
247 // Visualize progress with metrics chart
248 if epoch % 10 == 0 || epoch == epochs - 1 {
249 visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
250 }
251
252 // Calculate and show confusion matrix during training
253 if epoch % 20 == 0 || epoch == epochs - 1 {
254 // Create confusion matrix
255 let cm = ConfusionMatrix::<f32>::new(
256 &y_test.view(),
257 &predicted_classes.view(),
258 Some(n_classes),
259 Some(class_labels.clone()),
260 )?;
261
262 // Show heatmap visualization
263 println!("\nConfusion Matrix at Epoch {}:", epoch + 1);
264 println!(
265 "{}",
266 cm.to_heatmap(
267 Some(&format!("Confusion Matrix - Epoch {}", epoch + 1)),
268 true
269 )
270 );
271 }
272 }
273
274 // Final evaluation
275 println!("\nFinal model evaluation:");
276
277 // Make predictions on test set
278 let x_test_dyn = x_test.clone().into_dyn();
279 let predictions = model.forward(&x_test_dyn)?;
280 let predicted_classes = predictions_to_classes(&predictions);
281
282 // Create confusion matrix
283 let cm = ConfusionMatrix::<f32>::new(
284 &y_test.view(),
285 &predicted_classes.view(),
286 Some(n_classes),
287 Some(class_labels.clone()),
288 )?;
289
290 // Calculate and show metrics
291 let accuracy = cm.accuracy();
292 let precision = cm.precision();
293 let recall = cm.recall();
294 let f1 = cm.f1_score();
295
296 println!("\nFinal Classification Metrics:");
297 println!("Overall Accuracy: {:.4}", accuracy);
298
299 println!("\nPer-Class Metrics:");
300 println!("Class | Precision | Recall | F1-Score");
301 println!("-----------------------------------");
302
303 for i in 0..n_classes {
304 println!(
305 "{} | {:.4} | {:.4} | {:.4}",
306 class_labels[i], precision[i], recall[i], f1[i]
307 );
308 }
309
310 println!("\nMacro F1-Score: {:.4}", cm.macro_f1());
311
312 // Show different confusion matrix visualizations
313 println!("\nFinal Confusion Matrix Visualizations:");
314
315 // 1. Standard confusion matrix
316 println!("\n1. Standard Confusion Matrix:");
317 println!("{}", cm.to_ascii(Some("Final Confusion Matrix"), false));
318
319 // 2. Normalized confusion matrix
320 println!("\n2. Normalized Confusion Matrix:");
321 println!(
322 "{}",
323 cm.to_ascii(Some("Final Normalized Confusion Matrix"), true)
324 );
325
326 // 3. Confusion matrix heatmap
327 println!("\n3. Confusion Matrix Heatmap:");
328 println!(
329 "{}",
330 cm.to_heatmap(Some("Final Confusion Matrix Heatmap"), true)
331 );
332
333 // 4. Error pattern analysis
334 println!("\n4. Error Pattern Analysis:");
335 println!("{}", cm.error_heatmap(Some("Final Error Pattern Analysis")));
336
337 println!("\nNeural Network Confusion Matrix Visualization Complete!");
338 Ok(())
339}
Sourcepub fn class_metrics(&self) -> HashMap<String, Vec<F>>
pub fn class_metrics(&self) -> HashMap<String, Vec<F>>
Get class-wise metrics as a HashMap
Sourcepub fn to_ascii(&self, title: Option<&str>, normalized: bool) -> String
pub fn to_ascii(&self, title: Option<&str>, normalized: bool) -> String
Convert the confusion matrix to an ASCII representation
Examples found in repository?
6fn main() {
7 // Create a reproducible random number generator
8 let mut rng = SmallRng::seed_from_u64(42);
9
10 // Generate synthetic multiclass classification data with specific error patterns
11 let num_classes = 5;
12
13 // Create confusion matrix with controlled error patterns
14 let mut matrix = vec![vec![0; num_classes]; num_classes];
15
16 // Set diagonal elements (correct classifications) with high values
17 for i in 0..num_classes {
18 matrix[i][i] = 70 + rng.random_range(0..15); // 70-85 correct per class
19 }
20
21 // Create specific error patterns:
22 // - Classes 0 and 1 often confused
23 matrix[0][1] = 25;
24 matrix[1][0] = 15;
25
26 // - Class 2 sometimes confused with Class 3
27 matrix[2][3] = 18;
28
29 // - Class 4 has some misclassifications to all other classes
30 matrix[4][0] = 8;
31 matrix[4][1] = 5;
32 matrix[4][2] = 10;
33 matrix[4][3] = 12;
34
35 // - Some minor errors scattered about
36 for i in 0..num_classes {
37 for j in 0..num_classes {
38 if i != j && matrix[i][j] == 0 {
39 matrix[i][j] = rng.random_range(0..5);
40 }
41 }
42 }
43
44 // Convert to ndarray
45 let flat_matrix: Vec<f64> = matrix.iter().flatten().map(|&x| x as f64).collect();
46 let ndarray_matrix =
47 ndarray::Array::from_shape_vec((num_classes, num_classes), flat_matrix).unwrap();
48
49 // Create class labels
50 let class_labels = vec![
51 "Class A".to_string(),
52 "Class B".to_string(),
53 "Class C".to_string(),
54 "Class D".to_string(),
55 "Class E".to_string(),
56 ];
57
58 // Create confusion matrix
59 let cm = ConfusionMatrix::from_matrix(ndarray_matrix, Some(class_labels)).unwrap();
60
61 // Example 1: Standard confusion matrix
62 println!("Example 1: Standard Confusion Matrix\n");
63 let regular_output = cm.to_ascii(Some("Classification Results"), false);
64 println!("{}", regular_output);
65
66 // Example 2: Normal heatmap
67 println!("\n\nExample 2: Standard Heatmap Visualization\n");
68 let color_options = ColorOptions {
69 enabled: true,
70 use_bright: true,
71 use_background: false,
72 };
73 let heatmap_output = cm.to_heatmap_with_options(
74 Some("Classification Heatmap"),
75 true, // normalized
76 &color_options,
77 );
78 println!("{}", heatmap_output);
79
80 // Example 3: Error pattern heatmap
81 println!("\n\nExample 3: Error Pattern Heatmap (highlighting misclassifications)\n");
82 let error_heatmap = cm.error_heatmap(Some("Misclassification Analysis"));
83 println!("{}", error_heatmap);
84}
More examples
7fn main() {
8 // Create a reproducible random number generator
9 let mut rng = SmallRng::seed_from_u64(42);
10
11 // Generate synthetic multiclass classification data
12 let num_classes = 5;
13 let n_samples = 500;
14
15 // Generate true labels (0 to num_classes-1)
16 let mut y_true = Vec::with_capacity(n_samples);
17 for _ in 0..n_samples {
18 y_true.push(rng.random_range(0..num_classes));
19 }
20
21 // Generate predicted labels with controlled accuracy
22 let mut y_pred = Vec::with_capacity(n_samples);
23 for &true_label in &y_true {
24 // 80% chance to predict correctly, 20% chance of error
25 if rng.random::<f64>() < 0.8 {
26 y_pred.push(true_label);
27 } else {
28 // When wrong, tend to predict adjacent classes more often
29 let mut pred = true_label;
30 while pred == true_label {
31 // Generate error that's more likely to be close to true label
32 let error_margin = (rng.random::<f64>() * 2.0).round() as usize; // 0, 1, or 2
33 if rng.random::<bool>() {
34 pred = (true_label + error_margin) % num_classes;
35 } else {
36 pred = (true_label + num_classes - error_margin) % num_classes;
37 }
38 }
39 y_pred.push(pred);
40 }
41 }
42
43 // Convert to ndarray arrays
44 let y_true_array = Array1::from(y_true);
45 let y_pred_array = Array1::from(y_pred);
46
47 // Create class labels
48 let class_labels = vec![
49 "Cat".to_string(),
50 "Dog".to_string(),
51 "Bird".to_string(),
52 "Fish".to_string(),
53 "Rabbit".to_string(),
54 ];
55
56 // Create confusion matrix
57 let cm = ConfusionMatrix::<f64>::new(
58 &y_true_array.view(),
59 &y_pred_array.view(),
60 Some(num_classes),
61 Some(class_labels),
62 )
63 .unwrap();
64
65 // Example 1: Standard confusion matrix
66 println!("Example 1: Standard Confusion Matrix\n");
67 let regular_output = cm.to_ascii(Some("Animal Classification Results"), false);
68 println!("{}", regular_output);
69
70 // Example 2: Confusion matrix with color
71 println!("\n\nExample 2: Colored Confusion Matrix\n");
72 let color_options = ColorOptions {
73 enabled: true,
74 use_bright: true,
75 use_background: false,
76 };
77 let colored_output = cm.to_ascii_with_options(
78 Some("Animal Classification Results (with color)"),
79 false,
80 &color_options,
81 );
82 println!("{}", colored_output);
83
84 // Example 3: Normalized confusion matrix heatmap
85 println!("\n\nExample 3: Normalized Confusion Matrix Heatmap\n");
86 let heatmap_output = cm.to_heatmap_with_options(
87 Some("Animal Classification Heatmap (normalized)"),
88 true, // normalized
89 &color_options,
90 );
91 println!("{}", heatmap_output);
92
93 // Example 4: Raw counts heatmap
94 println!("\n\nExample 4: Raw Counts Confusion Matrix Heatmap\n");
95 let raw_heatmap = cm.to_heatmap_with_options(
96 Some("Animal Classification Heatmap (raw counts)"),
97 false, // not normalized
98 &color_options,
99 );
100 println!("{}", raw_heatmap);
101}
7fn main() -> Result<()> {
8 println!("Neural Network Model Evaluation Visualization Example\n");
9
10 // Generate some example data
11 let n_samples = 500;
12 let n_features = 10;
13 let n_classes = 4;
14
15 println!(
16 "Generating {} samples with {} features for {} classes",
17 n_samples, n_features, n_classes
18 );
19
20 // 1. Confusion Matrix Example
21 println!("\n--- Confusion Matrix Visualization ---\n");
22
23 // Create a deterministic RNG for reproducibility
24 let mut rng = SmallRng::seed_from_u64(42);
25
26 // Generate random predictions and true labels
27 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
28
29 // Create slightly correlated predictions (not completely random)
30 let y_pred = Array::from_shape_fn(n_samples, |i| {
31 if rng.random::<f32>() < 0.7 {
32 // 70% chance of correct prediction
33 y_true[i]
34 } else {
35 // 30% chance of random class
36 rng.random_range(0..n_classes)
37 }
38 });
39
40 // Create confusion matrix
41 let class_labels = vec![
42 "Class A".to_string(),
43 "Class B".to_string(),
44 "Class C".to_string(),
45 "Class D".to_string(),
46 ];
47
48 let cm = ConfusionMatrix::<f32>::new(
49 &y_true.view(),
50 &y_pred.view(),
51 Some(n_classes),
52 Some(class_labels),
53 )?;
54
55 // Print raw and normalized confusion matrices
56 println!("Raw Confusion Matrix:\n");
57 println!("{}", cm.to_ascii(Some("Confusion Matrix"), false));
58
59 println!("\nNormalized Confusion Matrix:\n");
60 println!("{}", cm.to_ascii(Some("Normalized Confusion Matrix"), true));
61
62 // Print metrics
63 println!("\nAccuracy: {:.3}", cm.accuracy());
64
65 let precision = cm.precision();
66 let recall = cm.recall();
67 let f1 = cm.f1_score();
68
69 println!("Per-class metrics:");
70 for i in 0..n_classes {
71 println!(
72 " Class {}: Precision={:.3}, Recall={:.3}, F1={:.3}",
73 i, precision[i], recall[i], f1[i]
74 );
75 }
76
77 println!("Macro F1 Score: {:.3}", cm.macro_f1());
78
79 // 2. Feature Importance Visualization
80 println!("\n--- Feature Importance Visualization ---\n");
81
82 // Generate random feature importance scores
83 let feature_names = (0..n_features)
84 .map(|i| format!("Feature_{}", i))
85 .collect::<Vec<String>>();
86
87 let importance = Array1::from_shape_fn(n_features, |i| {
88 // Make some features more important than others
89 let base = (n_features - i) as f32 / n_features as f32;
90 base + 0.2 * rng.random::<f32>()
91 });
92
93 let fi = FeatureImportance::new(feature_names, importance)?;
94
95 // Print full feature importance
96 println!("{}", fi.to_ascii(Some("Feature Importance"), 60, None));
97
98 // Print top-5 features
99 println!("\nTop 5 Most Important Features:\n");
100 println!("{}", fi.to_ascii(Some("Top 5 Features"), 60, Some(5)));
101
102 // 3. ROC Curve for Binary Classification
103 println!("\n--- ROC Curve Visualization ---\n");
104
105 // Generate binary classification data
106 let n_binary = 200;
107 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
108
109 // Generate scores with some predictive power
110 let y_scores = Array1::from_shape_fn(n_binary, |i| {
111 if y_true_binary[i] == 1 {
112 // Higher scores for positive class
113 0.6 + 0.4 * rng.random::<f32>()
114 } else {
115 // Lower scores for negative class
116 0.4 * rng.random::<f32>()
117 }
118 });
119
120 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
121
122 println!("ROC AUC: {:.3}", roc.auc);
123 println!("\n{}", roc.to_ascii(None, 50, 20));
124
125 // 4. Learning Curve Visualization
126 println!("\n--- Learning Curve Visualization ---\n");
127
128 // Generate learning curve data
129 let n_points = 10;
130 let n_cv = 5;
131
132 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
133
134 // Generate training scores (decreasing with size due to overfitting)
135 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
136 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
137 });
138
139 // Generate validation scores (increasing with size)
140 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
141 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
142 });
143
144 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
145
146 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
147
148 // Print final message
149 println!("\nModel evaluation visualizations completed successfully!");
150
151 Ok(())
152}
118fn main() -> Result<()> {
119 println!("Neural Network Confusion Matrix Visualization");
120 println!("==============================================\n");
121
122 // Initialize RNG with a fixed seed for reproducibility
123 let mut rng = SmallRng::seed_from_u64(42);
124
125 // Generate spiral dataset for 3-class classification
126 let n_classes = 3;
127 let n_samples_per_class = 100;
128 let noise = 0.15;
129
130 let (x, y) = generate_spiral_dataset(n_samples_per_class, n_classes, noise, &mut rng);
131 println!(
132 "Generated spiral dataset with {} classes, {} samples per class",
133 n_classes, n_samples_per_class
134 );
135
136 // Split data into training and test sets (80/20 split)
137 let n_samples = x.shape()[0];
138 let n_train = (n_samples as f32 * 0.8) as usize;
139 let n_test = n_samples - n_train;
140
141 let x_train = x.slice(ndarray::s![0..n_train, ..]).to_owned();
142 let y_train = y.slice(ndarray::s![0..n_train]).to_owned();
143 let x_test = x.slice(ndarray::s![n_train.., ..]).to_owned();
144 let y_test = y.slice(ndarray::s![n_train..]).to_owned();
145
146 println!(
147 "Split data into {} training and {} test samples",
148 n_train, n_test
149 );
150
151 // Create a classification model
152 let input_dim = 2; // 2D input (x, y coordinates)
153 let hidden_dim = 32; // Hidden layer size
154 let output_dim = n_classes; // One output per class
155
156 let mut model = create_classification_model(input_dim, hidden_dim, output_dim, &mut rng)?;
157 println!("Created model with {} layers", model.num_layers());
158
159 // Setup loss function and optimizer
160 let loss_fn = MeanSquaredError::new();
161 let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
162
163 // Train the model
164 let epochs = 100;
165 let x_train_dyn = x_train.clone().into_dyn();
166 let y_train_onehot = one_hot_encode(&y_train, n_classes);
167 let y_train_onehot_dyn = y_train_onehot.into_dyn();
168
169 // Create visualization callback for training metrics
170 let mut visualization_cb = VisualizationCallback::new(10) // Show every 10 epochs
171 .with_tracked_metrics(vec![
172 "train_loss".to_string(),
173 "val_accuracy".to_string(),
174 ]);
175
176 // Define class labels for confusion matrix
177 let class_labels = vec![
178 "Class A".to_string(),
179 "Class B".to_string(),
180 "Class C".to_string(),
181 ];
182
183 // Train the model (simple manual training loop)
184 println!("\nTraining model...");
185
186 // Initialize history for tracking metrics
187 let mut epoch_history = HashMap::new();
188 epoch_history.insert("train_loss".to_string(), Vec::new());
189 epoch_history.insert("val_accuracy".to_string(), Vec::new());
190
191 // Training loop
192 for epoch in 0..epochs {
193 // Train for one epoch
194 let train_loss =
195 model.train_batch(&x_train_dyn, &y_train_onehot_dyn, &loss_fn, &mut optimizer)?;
196
197 // Compute validation accuracy
198 let x_test_dyn = x_test.clone().into_dyn();
199 let predictions = model.forward(&x_test_dyn)?;
200 let predicted_classes = predictions_to_classes(&predictions);
201
202 // Calculate validation accuracy
203 let mut correct = 0;
204 for i in 0..n_test {
205 if predicted_classes[i] == y_test[i] {
206 correct += 1;
207 }
208 }
209 let val_accuracy = correct as f32 / n_test as f32;
210
211 // Store metrics
212 epoch_history
213 .get_mut("train_loss")
214 .unwrap()
215 .push(train_loss);
216 epoch_history
217 .get_mut("val_accuracy")
218 .unwrap()
219 .push(val_accuracy);
220
221 // Print progress
222 if (epoch + 1) % 10 == 0 || epoch == 0 {
223 println!(
224 "Epoch {}/{}: loss = {:.6}, val_accuracy = {:.4}",
225 epoch + 1,
226 epochs,
227 train_loss,
228 val_accuracy
229 );
230 }
231
232 // Update visualization callback
233 let mut context = CallbackContext {
234 epoch,
235 total_epochs: epochs,
236 batch: 0,
237 total_batches: 1,
238 batch_loss: None,
239 epoch_loss: Some(train_loss),
240 val_loss: None,
241 metrics: vec![val_accuracy],
242 history: &epoch_history,
243 stop_training: false,
244 model: None,
245 };
246
247 // Visualize progress with metrics chart
248 if epoch % 10 == 0 || epoch == epochs - 1 {
249 visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
250 }
251
252 // Calculate and show confusion matrix during training
253 if epoch % 20 == 0 || epoch == epochs - 1 {
254 // Create confusion matrix
255 let cm = ConfusionMatrix::<f32>::new(
256 &y_test.view(),
257 &predicted_classes.view(),
258 Some(n_classes),
259 Some(class_labels.clone()),
260 )?;
261
262 // Show heatmap visualization
263 println!("\nConfusion Matrix at Epoch {}:", epoch + 1);
264 println!(
265 "{}",
266 cm.to_heatmap(
267 Some(&format!("Confusion Matrix - Epoch {}", epoch + 1)),
268 true
269 )
270 );
271 }
272 }
273
274 // Final evaluation
275 println!("\nFinal model evaluation:");
276
277 // Make predictions on test set
278 let x_test_dyn = x_test.clone().into_dyn();
279 let predictions = model.forward(&x_test_dyn)?;
280 let predicted_classes = predictions_to_classes(&predictions);
281
282 // Create confusion matrix
283 let cm = ConfusionMatrix::<f32>::new(
284 &y_test.view(),
285 &predicted_classes.view(),
286 Some(n_classes),
287 Some(class_labels.clone()),
288 )?;
289
290 // Calculate and show metrics
291 let accuracy = cm.accuracy();
292 let precision = cm.precision();
293 let recall = cm.recall();
294 let f1 = cm.f1_score();
295
296 println!("\nFinal Classification Metrics:");
297 println!("Overall Accuracy: {:.4}", accuracy);
298
299 println!("\nPer-Class Metrics:");
300 println!("Class | Precision | Recall | F1-Score");
301 println!("-----------------------------------");
302
303 for i in 0..n_classes {
304 println!(
305 "{} | {:.4} | {:.4} | {:.4}",
306 class_labels[i], precision[i], recall[i], f1[i]
307 );
308 }
309
310 println!("\nMacro F1-Score: {:.4}", cm.macro_f1());
311
312 // Show different confusion matrix visualizations
313 println!("\nFinal Confusion Matrix Visualizations:");
314
315 // 1. Standard confusion matrix
316 println!("\n1. Standard Confusion Matrix:");
317 println!("{}", cm.to_ascii(Some("Final Confusion Matrix"), false));
318
319 // 2. Normalized confusion matrix
320 println!("\n2. Normalized Confusion Matrix:");
321 println!(
322 "{}",
323 cm.to_ascii(Some("Final Normalized Confusion Matrix"), true)
324 );
325
326 // 3. Confusion matrix heatmap
327 println!("\n3. Confusion Matrix Heatmap:");
328 println!(
329 "{}",
330 cm.to_heatmap(Some("Final Confusion Matrix Heatmap"), true)
331 );
332
333 // 4. Error pattern analysis
334 println!("\n4. Error Pattern Analysis:");
335 println!("{}", cm.error_heatmap(Some("Final Error Pattern Analysis")));
336
337 println!("\nNeural Network Confusion Matrix Visualization Complete!");
338 Ok(())
339}
Sourcepub fn to_ascii_with_options(
&self,
title: Option<&str>,
normalized: bool,
color_options: &ColorOptions,
) -> String
pub fn to_ascii_with_options( &self, title: Option<&str>, normalized: bool, color_options: &ColorOptions, ) -> String
Convert the confusion matrix to an ASCII representation with color options
Examples found in repository?
7fn main() {
8 // Create a reproducible random number generator
9 let mut rng = SmallRng::seed_from_u64(42);
10
11 // Generate synthetic multiclass classification data
12 let num_classes = 5;
13 let n_samples = 500;
14
15 // Generate true labels (0 to num_classes-1)
16 let mut y_true = Vec::with_capacity(n_samples);
17 for _ in 0..n_samples {
18 y_true.push(rng.random_range(0..num_classes));
19 }
20
21 // Generate predicted labels with controlled accuracy
22 let mut y_pred = Vec::with_capacity(n_samples);
23 for &true_label in &y_true {
24 // 80% chance to predict correctly, 20% chance of error
25 if rng.random::<f64>() < 0.8 {
26 y_pred.push(true_label);
27 } else {
28 // When wrong, tend to predict adjacent classes more often
29 let mut pred = true_label;
30 while pred == true_label {
31 // Generate error that's more likely to be close to true label
32 let error_margin = (rng.random::<f64>() * 2.0).round() as usize; // 0, 1, or 2
33 if rng.random::<bool>() {
34 pred = (true_label + error_margin) % num_classes;
35 } else {
36 pred = (true_label + num_classes - error_margin) % num_classes;
37 }
38 }
39 y_pred.push(pred);
40 }
41 }
42
43 // Convert to ndarray arrays
44 let y_true_array = Array1::from(y_true);
45 let y_pred_array = Array1::from(y_pred);
46
47 // Create class labels
48 let class_labels = vec![
49 "Cat".to_string(),
50 "Dog".to_string(),
51 "Bird".to_string(),
52 "Fish".to_string(),
53 "Rabbit".to_string(),
54 ];
55
56 // Create confusion matrix
57 let cm = ConfusionMatrix::<f64>::new(
58 &y_true_array.view(),
59 &y_pred_array.view(),
60 Some(num_classes),
61 Some(class_labels),
62 )
63 .unwrap();
64
65 // Example 1: Standard confusion matrix
66 println!("Example 1: Standard Confusion Matrix\n");
67 let regular_output = cm.to_ascii(Some("Animal Classification Results"), false);
68 println!("{}", regular_output);
69
70 // Example 2: Confusion matrix with color
71 println!("\n\nExample 2: Colored Confusion Matrix\n");
72 let color_options = ColorOptions {
73 enabled: true,
74 use_bright: true,
75 use_background: false,
76 };
77 let colored_output = cm.to_ascii_with_options(
78 Some("Animal Classification Results (with color)"),
79 false,
80 &color_options,
81 );
82 println!("{}", colored_output);
83
84 // Example 3: Normalized confusion matrix heatmap
85 println!("\n\nExample 3: Normalized Confusion Matrix Heatmap\n");
86 let heatmap_output = cm.to_heatmap_with_options(
87 Some("Animal Classification Heatmap (normalized)"),
88 true, // normalized
89 &color_options,
90 );
91 println!("{}", heatmap_output);
92
93 // Example 4: Raw counts heatmap
94 println!("\n\nExample 4: Raw Counts Confusion Matrix Heatmap\n");
95 let raw_heatmap = cm.to_heatmap_with_options(
96 Some("Animal Classification Heatmap (raw counts)"),
97 false, // not normalized
98 &color_options,
99 );
100 println!("{}", raw_heatmap);
101}
More examples
10fn main() -> Result<()> {
11 println!(
12 "{}",
13 stylize("Neural Network Model Evaluation with Color", Style::Bold)
14 );
15 println!("{}", "-".repeat(50));
16
17 // Set up color options
18 let color_options = ColorOptions {
19 enabled: true,
20 use_background: false,
21 use_bright: true,
22 };
23
24 // Generate some example data
25 let n_samples = 500;
26 let n_features = 10;
27 let n_classes = 4;
28
29 println!(
30 "\n{} {} {} {} {} {}",
31 colorize("Generating", Color::BrightGreen),
32 colorize(n_samples.to_string(), Color::BrightYellow),
33 colorize("samples with", Color::BrightGreen),
34 colorize(n_features.to_string(), Color::BrightYellow),
35 colorize("features for", Color::BrightGreen),
36 colorize(n_classes.to_string(), Color::BrightYellow),
37 );
38
39 // Create a deterministic RNG for reproducibility
40 let mut rng = SmallRng::seed_from_u64(42);
41
42 // 1. Confusion Matrix Example
43 println!(
44 "\n{}",
45 stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
46 );
47
48 // Generate random predictions and true labels
49 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
50
51 // Create slightly correlated predictions (not completely random)
52 let y_pred = Array::from_shape_fn(n_samples, |i| {
53 if rng.random::<f32>() < 0.7 {
54 // 70% chance of correct prediction
55 y_true[i]
56 } else {
57 // 30% chance of random class
58 rng.random_range(0..n_classes)
59 }
60 });
61
62 // Create confusion matrix
63 let class_labels = vec![
64 "Class A".to_string(),
65 "Class B".to_string(),
66 "Class C".to_string(),
67 "Class D".to_string(),
68 ];
69
70 let cm = ConfusionMatrix::<f32>::new(
71 &y_true.view(),
72 &y_pred.view(),
73 Some(n_classes),
74 Some(class_labels),
75 )?;
76
77 // Print raw and normalized confusion matrices with color
78 println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
79 println!(
80 "{}",
81 cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
82 );
83
84 println!(
85 "\n{}",
86 colorize("Normalized Confusion Matrix:", Color::BrightCyan)
87 );
88 println!(
89 "{}",
90 cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
91 );
92
93 // Print metrics
94 println!(
95 "\n{} {:.3}",
96 colorize("Overall Accuracy:", Color::BrightMagenta),
97 cm.accuracy()
98 );
99
100 let precision = cm.precision();
101 let recall = cm.recall();
102 let f1 = cm.f1_score();
103
104 println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
105 for i in 0..n_classes {
106 println!(
107 " {}: {}={:.3}, {}={:.3}, {}={:.3}",
108 colorize(format!("Class {}", i), Color::BrightYellow),
109 colorize("Precision", Color::BrightCyan),
110 precision[i],
111 colorize("Recall", Color::BrightGreen),
112 recall[i],
113 colorize("F1", Color::BrightBlue),
114 f1[i]
115 );
116 }
117
118 println!(
119 "{} {:.3}",
120 colorize("Macro F1 Score:", Color::BrightMagenta),
121 cm.macro_f1()
122 );
123
124 // 2. Feature Importance Visualization
125 println!(
126 "\n{}",
127 stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
128 );
129
130 // Generate random feature importance scores
131 let feature_names = (0..n_features)
132 .map(|i| format!("Feature_{}", i))
133 .collect::<Vec<String>>();
134
135 let importance = Array1::from_shape_fn(n_features, |i| {
136 // Make some features more important than others
137 let base = (n_features - i) as f32 / n_features as f32;
138 base + 0.2 * rng.random::<f32>()
139 });
140
141 let fi = FeatureImportance::new(feature_names, importance)?;
142
143 // Print full feature importance with color
144 println!(
145 "{}",
146 fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
147 );
148
149 // Print top-5 features with color
150 println!(
151 "\n{}",
152 colorize("Top 5 Most Important Features:", Color::BrightCyan)
153 );
154 println!(
155 "{}",
156 fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
157 );
158
159 // 3. ROC Curve for Binary Classification
160 println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
161
162 // Generate binary classification data
163 let n_binary = 200;
164 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
165
166 // Generate scores with some predictive power
167 let y_scores = Array1::from_shape_fn(n_binary, |i| {
168 if y_true_binary[i] == 1 {
169 // Higher scores for positive class
170 0.6 + 0.4 * rng.random::<f32>()
171 } else {
172 // Lower scores for negative class
173 0.4 * rng.random::<f32>()
174 }
175 });
176
177 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
178
179 println!(
180 "{} {:.3}",
181 colorize("ROC AUC:", Color::BrightMagenta),
182 roc.auc
183 );
184
185 println!("\n{}", roc.to_ascii(None, 50, 20));
186
187 // 4. Learning Curve Visualization
188 println!(
189 "\n{}",
190 stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
191 );
192
193 // Generate learning curve data
194 let n_points = 10;
195 let n_cv = 5;
196
197 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
198
199 // Generate training scores (decreasing with size due to overfitting)
200 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
201 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
202 });
203
204 // Generate validation scores (increasing with size)
205 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
206 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
207 });
208
209 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
210
211 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
212
213 // Print final message with color
214 println!(
215 "\n{}",
216 colorize(
217 "Model evaluation visualizations completed successfully!",
218 Color::BrightGreen
219 )
220 );
221
222 Ok(())
223}
Sourcepub fn to_heatmap(&self, title: Option<&str>, normalized: bool) -> String
pub fn to_heatmap(&self, title: Option<&str>, normalized: bool) -> String
Convert the confusion matrix to a heatmap visualization
This creates a colorful heatmap visualization of the confusion matrix where cell colors represent the intensity of values using a detailed color gradient.
§Arguments
title
- Optional title for the heatmapnormalized
- Whether to normalize the matrix (row values sum to 1)
§Returns
String
- ASCII heatmap representation
Examples found in repository?
118fn main() -> Result<()> {
119 println!("Neural Network Confusion Matrix Visualization");
120 println!("==============================================\n");
121
122 // Initialize RNG with a fixed seed for reproducibility
123 let mut rng = SmallRng::seed_from_u64(42);
124
125 // Generate spiral dataset for 3-class classification
126 let n_classes = 3;
127 let n_samples_per_class = 100;
128 let noise = 0.15;
129
130 let (x, y) = generate_spiral_dataset(n_samples_per_class, n_classes, noise, &mut rng);
131 println!(
132 "Generated spiral dataset with {} classes, {} samples per class",
133 n_classes, n_samples_per_class
134 );
135
136 // Split data into training and test sets (80/20 split)
137 let n_samples = x.shape()[0];
138 let n_train = (n_samples as f32 * 0.8) as usize;
139 let n_test = n_samples - n_train;
140
141 let x_train = x.slice(ndarray::s![0..n_train, ..]).to_owned();
142 let y_train = y.slice(ndarray::s![0..n_train]).to_owned();
143 let x_test = x.slice(ndarray::s![n_train.., ..]).to_owned();
144 let y_test = y.slice(ndarray::s![n_train..]).to_owned();
145
146 println!(
147 "Split data into {} training and {} test samples",
148 n_train, n_test
149 );
150
151 // Create a classification model
152 let input_dim = 2; // 2D input (x, y coordinates)
153 let hidden_dim = 32; // Hidden layer size
154 let output_dim = n_classes; // One output per class
155
156 let mut model = create_classification_model(input_dim, hidden_dim, output_dim, &mut rng)?;
157 println!("Created model with {} layers", model.num_layers());
158
159 // Setup loss function and optimizer
160 let loss_fn = MeanSquaredError::new();
161 let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
162
163 // Train the model
164 let epochs = 100;
165 let x_train_dyn = x_train.clone().into_dyn();
166 let y_train_onehot = one_hot_encode(&y_train, n_classes);
167 let y_train_onehot_dyn = y_train_onehot.into_dyn();
168
169 // Create visualization callback for training metrics
170 let mut visualization_cb = VisualizationCallback::new(10) // Show every 10 epochs
171 .with_tracked_metrics(vec![
172 "train_loss".to_string(),
173 "val_accuracy".to_string(),
174 ]);
175
176 // Define class labels for confusion matrix
177 let class_labels = vec![
178 "Class A".to_string(),
179 "Class B".to_string(),
180 "Class C".to_string(),
181 ];
182
183 // Train the model (simple manual training loop)
184 println!("\nTraining model...");
185
186 // Initialize history for tracking metrics
187 let mut epoch_history = HashMap::new();
188 epoch_history.insert("train_loss".to_string(), Vec::new());
189 epoch_history.insert("val_accuracy".to_string(), Vec::new());
190
191 // Training loop
192 for epoch in 0..epochs {
193 // Train for one epoch
194 let train_loss =
195 model.train_batch(&x_train_dyn, &y_train_onehot_dyn, &loss_fn, &mut optimizer)?;
196
197 // Compute validation accuracy
198 let x_test_dyn = x_test.clone().into_dyn();
199 let predictions = model.forward(&x_test_dyn)?;
200 let predicted_classes = predictions_to_classes(&predictions);
201
202 // Calculate validation accuracy
203 let mut correct = 0;
204 for i in 0..n_test {
205 if predicted_classes[i] == y_test[i] {
206 correct += 1;
207 }
208 }
209 let val_accuracy = correct as f32 / n_test as f32;
210
211 // Store metrics
212 epoch_history
213 .get_mut("train_loss")
214 .unwrap()
215 .push(train_loss);
216 epoch_history
217 .get_mut("val_accuracy")
218 .unwrap()
219 .push(val_accuracy);
220
221 // Print progress
222 if (epoch + 1) % 10 == 0 || epoch == 0 {
223 println!(
224 "Epoch {}/{}: loss = {:.6}, val_accuracy = {:.4}",
225 epoch + 1,
226 epochs,
227 train_loss,
228 val_accuracy
229 );
230 }
231
232 // Update visualization callback
233 let mut context = CallbackContext {
234 epoch,
235 total_epochs: epochs,
236 batch: 0,
237 total_batches: 1,
238 batch_loss: None,
239 epoch_loss: Some(train_loss),
240 val_loss: None,
241 metrics: vec![val_accuracy],
242 history: &epoch_history,
243 stop_training: false,
244 model: None,
245 };
246
247 // Visualize progress with metrics chart
248 if epoch % 10 == 0 || epoch == epochs - 1 {
249 visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
250 }
251
252 // Calculate and show confusion matrix during training
253 if epoch % 20 == 0 || epoch == epochs - 1 {
254 // Create confusion matrix
255 let cm = ConfusionMatrix::<f32>::new(
256 &y_test.view(),
257 &predicted_classes.view(),
258 Some(n_classes),
259 Some(class_labels.clone()),
260 )?;
261
262 // Show heatmap visualization
263 println!("\nConfusion Matrix at Epoch {}:", epoch + 1);
264 println!(
265 "{}",
266 cm.to_heatmap(
267 Some(&format!("Confusion Matrix - Epoch {}", epoch + 1)),
268 true
269 )
270 );
271 }
272 }
273
274 // Final evaluation
275 println!("\nFinal model evaluation:");
276
277 // Make predictions on test set
278 let x_test_dyn = x_test.clone().into_dyn();
279 let predictions = model.forward(&x_test_dyn)?;
280 let predicted_classes = predictions_to_classes(&predictions);
281
282 // Create confusion matrix
283 let cm = ConfusionMatrix::<f32>::new(
284 &y_test.view(),
285 &predicted_classes.view(),
286 Some(n_classes),
287 Some(class_labels.clone()),
288 )?;
289
290 // Calculate and show metrics
291 let accuracy = cm.accuracy();
292 let precision = cm.precision();
293 let recall = cm.recall();
294 let f1 = cm.f1_score();
295
296 println!("\nFinal Classification Metrics:");
297 println!("Overall Accuracy: {:.4}", accuracy);
298
299 println!("\nPer-Class Metrics:");
300 println!("Class | Precision | Recall | F1-Score");
301 println!("-----------------------------------");
302
303 for i in 0..n_classes {
304 println!(
305 "{} | {:.4} | {:.4} | {:.4}",
306 class_labels[i], precision[i], recall[i], f1[i]
307 );
308 }
309
310 println!("\nMacro F1-Score: {:.4}", cm.macro_f1());
311
312 // Show different confusion matrix visualizations
313 println!("\nFinal Confusion Matrix Visualizations:");
314
315 // 1. Standard confusion matrix
316 println!("\n1. Standard Confusion Matrix:");
317 println!("{}", cm.to_ascii(Some("Final Confusion Matrix"), false));
318
319 // 2. Normalized confusion matrix
320 println!("\n2. Normalized Confusion Matrix:");
321 println!(
322 "{}",
323 cm.to_ascii(Some("Final Normalized Confusion Matrix"), true)
324 );
325
326 // 3. Confusion matrix heatmap
327 println!("\n3. Confusion Matrix Heatmap:");
328 println!(
329 "{}",
330 cm.to_heatmap(Some("Final Confusion Matrix Heatmap"), true)
331 );
332
333 // 4. Error pattern analysis
334 println!("\n4. Error Pattern Analysis:");
335 println!("{}", cm.error_heatmap(Some("Final Error Pattern Analysis")));
336
337 println!("\nNeural Network Confusion Matrix Visualization Complete!");
338 Ok(())
339}
Sourcepub fn error_heatmap(&self, title: Option<&str>) -> String
pub fn error_heatmap(&self, title: Option<&str>) -> String
Create a confusion matrix heatmap that focuses on misclassification patterns
This visualization is specialized to highlight where the model makes mistakes, with emphasis on the off-diagonal elements to help identify error patterns. The key features of this visualization are:
- Diagonal elements (correct classifications) are de-emphasized with dim styling
- Off-diagonal elements (errors) are highlighted with a color gradient
- Colors are normalized relative to the maximum off-diagonal value
- A specialized legend explains error intensity levels
§Arguments
title
- Optional title for the heatmap
§Returns
String
- ASCII error pattern heatmap
§Example
use scirs2_neural::utils::evaluation::ConfusionMatrix;
use ndarray::Array1;
// Create some example data
let y_true = Array1::from_vec(vec![0, 1, 2, 0, 1, 2, 0, 1, 2, 0]);
let y_pred = Array1::from_vec(vec![0, 1, 1, 0, 1, 2, 1, 1, 0, 0]);
let class_labels = vec!["Class A".to_string(), "Class B".to_string(), "Class C".to_string()];
let cm = ConfusionMatrix::<f32>::new(&y_true.view(), &y_pred.view(), None, Some(class_labels)).unwrap();
// Generate the error pattern heatmap
let error_viz = cm.error_heatmap(Some("Misclassification Analysis"));
println!("{}", error_viz);
Examples found in repository?
6fn main() {
7 // Create a reproducible random number generator
8 let mut rng = SmallRng::seed_from_u64(42);
9
10 // Generate synthetic multiclass classification data with specific error patterns
11 let num_classes = 5;
12
13 // Create confusion matrix with controlled error patterns
14 let mut matrix = vec![vec![0; num_classes]; num_classes];
15
16 // Set diagonal elements (correct classifications) with high values
17 for i in 0..num_classes {
18 matrix[i][i] = 70 + rng.random_range(0..15); // 70-85 correct per class
19 }
20
21 // Create specific error patterns:
22 // - Classes 0 and 1 often confused
23 matrix[0][1] = 25;
24 matrix[1][0] = 15;
25
26 // - Class 2 sometimes confused with Class 3
27 matrix[2][3] = 18;
28
29 // - Class 4 has some misclassifications to all other classes
30 matrix[4][0] = 8;
31 matrix[4][1] = 5;
32 matrix[4][2] = 10;
33 matrix[4][3] = 12;
34
35 // - Some minor errors scattered about
36 for i in 0..num_classes {
37 for j in 0..num_classes {
38 if i != j && matrix[i][j] == 0 {
39 matrix[i][j] = rng.random_range(0..5);
40 }
41 }
42 }
43
44 // Convert to ndarray
45 let flat_matrix: Vec<f64> = matrix.iter().flatten().map(|&x| x as f64).collect();
46 let ndarray_matrix =
47 ndarray::Array::from_shape_vec((num_classes, num_classes), flat_matrix).unwrap();
48
49 // Create class labels
50 let class_labels = vec![
51 "Class A".to_string(),
52 "Class B".to_string(),
53 "Class C".to_string(),
54 "Class D".to_string(),
55 "Class E".to_string(),
56 ];
57
58 // Create confusion matrix
59 let cm = ConfusionMatrix::from_matrix(ndarray_matrix, Some(class_labels)).unwrap();
60
61 // Example 1: Standard confusion matrix
62 println!("Example 1: Standard Confusion Matrix\n");
63 let regular_output = cm.to_ascii(Some("Classification Results"), false);
64 println!("{}", regular_output);
65
66 // Example 2: Normal heatmap
67 println!("\n\nExample 2: Standard Heatmap Visualization\n");
68 let color_options = ColorOptions {
69 enabled: true,
70 use_bright: true,
71 use_background: false,
72 };
73 let heatmap_output = cm.to_heatmap_with_options(
74 Some("Classification Heatmap"),
75 true, // normalized
76 &color_options,
77 );
78 println!("{}", heatmap_output);
79
80 // Example 3: Error pattern heatmap
81 println!("\n\nExample 3: Error Pattern Heatmap (highlighting misclassifications)\n");
82 let error_heatmap = cm.error_heatmap(Some("Misclassification Analysis"));
83 println!("{}", error_heatmap);
84}
More examples
118fn main() -> Result<()> {
119 println!("Neural Network Confusion Matrix Visualization");
120 println!("==============================================\n");
121
122 // Initialize RNG with a fixed seed for reproducibility
123 let mut rng = SmallRng::seed_from_u64(42);
124
125 // Generate spiral dataset for 3-class classification
126 let n_classes = 3;
127 let n_samples_per_class = 100;
128 let noise = 0.15;
129
130 let (x, y) = generate_spiral_dataset(n_samples_per_class, n_classes, noise, &mut rng);
131 println!(
132 "Generated spiral dataset with {} classes, {} samples per class",
133 n_classes, n_samples_per_class
134 );
135
136 // Split data into training and test sets (80/20 split)
137 let n_samples = x.shape()[0];
138 let n_train = (n_samples as f32 * 0.8) as usize;
139 let n_test = n_samples - n_train;
140
141 let x_train = x.slice(ndarray::s![0..n_train, ..]).to_owned();
142 let y_train = y.slice(ndarray::s![0..n_train]).to_owned();
143 let x_test = x.slice(ndarray::s![n_train.., ..]).to_owned();
144 let y_test = y.slice(ndarray::s![n_train..]).to_owned();
145
146 println!(
147 "Split data into {} training and {} test samples",
148 n_train, n_test
149 );
150
151 // Create a classification model
152 let input_dim = 2; // 2D input (x, y coordinates)
153 let hidden_dim = 32; // Hidden layer size
154 let output_dim = n_classes; // One output per class
155
156 let mut model = create_classification_model(input_dim, hidden_dim, output_dim, &mut rng)?;
157 println!("Created model with {} layers", model.num_layers());
158
159 // Setup loss function and optimizer
160 let loss_fn = MeanSquaredError::new();
161 let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
162
163 // Train the model
164 let epochs = 100;
165 let x_train_dyn = x_train.clone().into_dyn();
166 let y_train_onehot = one_hot_encode(&y_train, n_classes);
167 let y_train_onehot_dyn = y_train_onehot.into_dyn();
168
169 // Create visualization callback for training metrics
170 let mut visualization_cb = VisualizationCallback::new(10) // Show every 10 epochs
171 .with_tracked_metrics(vec![
172 "train_loss".to_string(),
173 "val_accuracy".to_string(),
174 ]);
175
176 // Define class labels for confusion matrix
177 let class_labels = vec![
178 "Class A".to_string(),
179 "Class B".to_string(),
180 "Class C".to_string(),
181 ];
182
183 // Train the model (simple manual training loop)
184 println!("\nTraining model...");
185
186 // Initialize history for tracking metrics
187 let mut epoch_history = HashMap::new();
188 epoch_history.insert("train_loss".to_string(), Vec::new());
189 epoch_history.insert("val_accuracy".to_string(), Vec::new());
190
191 // Training loop
192 for epoch in 0..epochs {
193 // Train for one epoch
194 let train_loss =
195 model.train_batch(&x_train_dyn, &y_train_onehot_dyn, &loss_fn, &mut optimizer)?;
196
197 // Compute validation accuracy
198 let x_test_dyn = x_test.clone().into_dyn();
199 let predictions = model.forward(&x_test_dyn)?;
200 let predicted_classes = predictions_to_classes(&predictions);
201
202 // Calculate validation accuracy
203 let mut correct = 0;
204 for i in 0..n_test {
205 if predicted_classes[i] == y_test[i] {
206 correct += 1;
207 }
208 }
209 let val_accuracy = correct as f32 / n_test as f32;
210
211 // Store metrics
212 epoch_history
213 .get_mut("train_loss")
214 .unwrap()
215 .push(train_loss);
216 epoch_history
217 .get_mut("val_accuracy")
218 .unwrap()
219 .push(val_accuracy);
220
221 // Print progress
222 if (epoch + 1) % 10 == 0 || epoch == 0 {
223 println!(
224 "Epoch {}/{}: loss = {:.6}, val_accuracy = {:.4}",
225 epoch + 1,
226 epochs,
227 train_loss,
228 val_accuracy
229 );
230 }
231
232 // Update visualization callback
233 let mut context = CallbackContext {
234 epoch,
235 total_epochs: epochs,
236 batch: 0,
237 total_batches: 1,
238 batch_loss: None,
239 epoch_loss: Some(train_loss),
240 val_loss: None,
241 metrics: vec![val_accuracy],
242 history: &epoch_history,
243 stop_training: false,
244 model: None,
245 };
246
247 // Visualize progress with metrics chart
248 if epoch % 10 == 0 || epoch == epochs - 1 {
249 visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
250 }
251
252 // Calculate and show confusion matrix during training
253 if epoch % 20 == 0 || epoch == epochs - 1 {
254 // Create confusion matrix
255 let cm = ConfusionMatrix::<f32>::new(
256 &y_test.view(),
257 &predicted_classes.view(),
258 Some(n_classes),
259 Some(class_labels.clone()),
260 )?;
261
262 // Show heatmap visualization
263 println!("\nConfusion Matrix at Epoch {}:", epoch + 1);
264 println!(
265 "{}",
266 cm.to_heatmap(
267 Some(&format!("Confusion Matrix - Epoch {}", epoch + 1)),
268 true
269 )
270 );
271 }
272 }
273
274 // Final evaluation
275 println!("\nFinal model evaluation:");
276
277 // Make predictions on test set
278 let x_test_dyn = x_test.clone().into_dyn();
279 let predictions = model.forward(&x_test_dyn)?;
280 let predicted_classes = predictions_to_classes(&predictions);
281
282 // Create confusion matrix
283 let cm = ConfusionMatrix::<f32>::new(
284 &y_test.view(),
285 &predicted_classes.view(),
286 Some(n_classes),
287 Some(class_labels.clone()),
288 )?;
289
290 // Calculate and show metrics
291 let accuracy = cm.accuracy();
292 let precision = cm.precision();
293 let recall = cm.recall();
294 let f1 = cm.f1_score();
295
296 println!("\nFinal Classification Metrics:");
297 println!("Overall Accuracy: {:.4}", accuracy);
298
299 println!("\nPer-Class Metrics:");
300 println!("Class | Precision | Recall | F1-Score");
301 println!("-----------------------------------");
302
303 for i in 0..n_classes {
304 println!(
305 "{} | {:.4} | {:.4} | {:.4}",
306 class_labels[i], precision[i], recall[i], f1[i]
307 );
308 }
309
310 println!("\nMacro F1-Score: {:.4}", cm.macro_f1());
311
312 // Show different confusion matrix visualizations
313 println!("\nFinal Confusion Matrix Visualizations:");
314
315 // 1. Standard confusion matrix
316 println!("\n1. Standard Confusion Matrix:");
317 println!("{}", cm.to_ascii(Some("Final Confusion Matrix"), false));
318
319 // 2. Normalized confusion matrix
320 println!("\n2. Normalized Confusion Matrix:");
321 println!(
322 "{}",
323 cm.to_ascii(Some("Final Normalized Confusion Matrix"), true)
324 );
325
326 // 3. Confusion matrix heatmap
327 println!("\n3. Confusion Matrix Heatmap:");
328 println!(
329 "{}",
330 cm.to_heatmap(Some("Final Confusion Matrix Heatmap"), true)
331 );
332
333 // 4. Error pattern analysis
334 println!("\n4. Error Pattern Analysis:");
335 println!("{}", cm.error_heatmap(Some("Final Error Pattern Analysis")));
336
337 println!("\nNeural Network Confusion Matrix Visualization Complete!");
338 Ok(())
339}
Sourcepub fn to_heatmap_with_options(
&self,
title: Option<&str>,
normalized: bool,
color_options: &ColorOptions,
) -> String
pub fn to_heatmap_with_options( &self, title: Option<&str>, normalized: bool, color_options: &ColorOptions, ) -> String
Convert the confusion matrix to a heatmap visualization with customizable options
§Arguments
title
- Optional title for the heatmapnormalized
- Whether to normalize the matrix (row values sum to 1)color_options
- Color options for visualization
§Returns
String
- ASCII heatmap representation with colors
Examples found in repository?
6fn main() {
7 // Create a reproducible random number generator
8 let mut rng = SmallRng::seed_from_u64(42);
9
10 // Generate synthetic multiclass classification data with specific error patterns
11 let num_classes = 5;
12
13 // Create confusion matrix with controlled error patterns
14 let mut matrix = vec![vec![0; num_classes]; num_classes];
15
16 // Set diagonal elements (correct classifications) with high values
17 for i in 0..num_classes {
18 matrix[i][i] = 70 + rng.random_range(0..15); // 70-85 correct per class
19 }
20
21 // Create specific error patterns:
22 // - Classes 0 and 1 often confused
23 matrix[0][1] = 25;
24 matrix[1][0] = 15;
25
26 // - Class 2 sometimes confused with Class 3
27 matrix[2][3] = 18;
28
29 // - Class 4 has some misclassifications to all other classes
30 matrix[4][0] = 8;
31 matrix[4][1] = 5;
32 matrix[4][2] = 10;
33 matrix[4][3] = 12;
34
35 // - Some minor errors scattered about
36 for i in 0..num_classes {
37 for j in 0..num_classes {
38 if i != j && matrix[i][j] == 0 {
39 matrix[i][j] = rng.random_range(0..5);
40 }
41 }
42 }
43
44 // Convert to ndarray
45 let flat_matrix: Vec<f64> = matrix.iter().flatten().map(|&x| x as f64).collect();
46 let ndarray_matrix =
47 ndarray::Array::from_shape_vec((num_classes, num_classes), flat_matrix).unwrap();
48
49 // Create class labels
50 let class_labels = vec![
51 "Class A".to_string(),
52 "Class B".to_string(),
53 "Class C".to_string(),
54 "Class D".to_string(),
55 "Class E".to_string(),
56 ];
57
58 // Create confusion matrix
59 let cm = ConfusionMatrix::from_matrix(ndarray_matrix, Some(class_labels)).unwrap();
60
61 // Example 1: Standard confusion matrix
62 println!("Example 1: Standard Confusion Matrix\n");
63 let regular_output = cm.to_ascii(Some("Classification Results"), false);
64 println!("{}", regular_output);
65
66 // Example 2: Normal heatmap
67 println!("\n\nExample 2: Standard Heatmap Visualization\n");
68 let color_options = ColorOptions {
69 enabled: true,
70 use_bright: true,
71 use_background: false,
72 };
73 let heatmap_output = cm.to_heatmap_with_options(
74 Some("Classification Heatmap"),
75 true, // normalized
76 &color_options,
77 );
78 println!("{}", heatmap_output);
79
80 // Example 3: Error pattern heatmap
81 println!("\n\nExample 3: Error Pattern Heatmap (highlighting misclassifications)\n");
82 let error_heatmap = cm.error_heatmap(Some("Misclassification Analysis"));
83 println!("{}", error_heatmap);
84}
More examples
7fn main() {
8 // Create a reproducible random number generator
9 let mut rng = SmallRng::seed_from_u64(42);
10
11 // Generate synthetic multiclass classification data
12 let num_classes = 5;
13 let n_samples = 500;
14
15 // Generate true labels (0 to num_classes-1)
16 let mut y_true = Vec::with_capacity(n_samples);
17 for _ in 0..n_samples {
18 y_true.push(rng.random_range(0..num_classes));
19 }
20
21 // Generate predicted labels with controlled accuracy
22 let mut y_pred = Vec::with_capacity(n_samples);
23 for &true_label in &y_true {
24 // 80% chance to predict correctly, 20% chance of error
25 if rng.random::<f64>() < 0.8 {
26 y_pred.push(true_label);
27 } else {
28 // When wrong, tend to predict adjacent classes more often
29 let mut pred = true_label;
30 while pred == true_label {
31 // Generate error that's more likely to be close to true label
32 let error_margin = (rng.random::<f64>() * 2.0).round() as usize; // 0, 1, or 2
33 if rng.random::<bool>() {
34 pred = (true_label + error_margin) % num_classes;
35 } else {
36 pred = (true_label + num_classes - error_margin) % num_classes;
37 }
38 }
39 y_pred.push(pred);
40 }
41 }
42
43 // Convert to ndarray arrays
44 let y_true_array = Array1::from(y_true);
45 let y_pred_array = Array1::from(y_pred);
46
47 // Create class labels
48 let class_labels = vec![
49 "Cat".to_string(),
50 "Dog".to_string(),
51 "Bird".to_string(),
52 "Fish".to_string(),
53 "Rabbit".to_string(),
54 ];
55
56 // Create confusion matrix
57 let cm = ConfusionMatrix::<f64>::new(
58 &y_true_array.view(),
59 &y_pred_array.view(),
60 Some(num_classes),
61 Some(class_labels),
62 )
63 .unwrap();
64
65 // Example 1: Standard confusion matrix
66 println!("Example 1: Standard Confusion Matrix\n");
67 let regular_output = cm.to_ascii(Some("Animal Classification Results"), false);
68 println!("{}", regular_output);
69
70 // Example 2: Confusion matrix with color
71 println!("\n\nExample 2: Colored Confusion Matrix\n");
72 let color_options = ColorOptions {
73 enabled: true,
74 use_bright: true,
75 use_background: false,
76 };
77 let colored_output = cm.to_ascii_with_options(
78 Some("Animal Classification Results (with color)"),
79 false,
80 &color_options,
81 );
82 println!("{}", colored_output);
83
84 // Example 3: Normalized confusion matrix heatmap
85 println!("\n\nExample 3: Normalized Confusion Matrix Heatmap\n");
86 let heatmap_output = cm.to_heatmap_with_options(
87 Some("Animal Classification Heatmap (normalized)"),
88 true, // normalized
89 &color_options,
90 );
91 println!("{}", heatmap_output);
92
93 // Example 4: Raw counts heatmap
94 println!("\n\nExample 4: Raw Counts Confusion Matrix Heatmap\n");
95 let raw_heatmap = cm.to_heatmap_with_options(
96 Some("Animal Classification Heatmap (raw counts)"),
97 false, // not normalized
98 &color_options,
99 );
100 println!("{}", raw_heatmap);
101}
Trait Implementations§
Auto Trait Implementations§
impl<F> Freeze for ConfusionMatrix<F>
impl<F> RefUnwindSafe for ConfusionMatrix<F>where
F: RefUnwindSafe,
impl<F> Send for ConfusionMatrix<F>where
F: Send,
impl<F> Sync for ConfusionMatrix<F>where
F: Sync,
impl<F> Unpin for ConfusionMatrix<F>
impl<F> UnwindSafe for ConfusionMatrix<F>where
F: RefUnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more