Struct ConfusionMatrix

Source
pub struct ConfusionMatrix<F: Float + Debug + Display> {
    pub matrix: Array2<F>,
    pub labels: Option<Vec<String>>,
    pub num_classes: usize,
}
Expand description

Confusion matrix for classification problems

Fields§

§matrix: Array2<F>

The raw confusion matrix data

§labels: Option<Vec<String>>

Class labels (optional)

§num_classes: usize

Number of classes

Implementations§

Source§

impl<F: Float + Debug + Display> ConfusionMatrix<F>

Source

pub fn new( y_true: &ArrayView1<'_, usize>, y_pred: &ArrayView1<'_, usize>, num_classes: Option<usize>, labels: Option<Vec<String>>, ) -> Result<Self>

Create a new confusion matrix from predictions and true labels

§Arguments
  • y_true - True class labels as integers
  • y_pred - Predicted class labels as integers
  • num_classes - Number of classes (if None, determined from data)
  • labels - Optional class labels as strings
§Returns
  • Result<ConfusionMatrix<F>> - The confusion matrix
§Example
use scirs2_neural::utils::evaluation::ConfusionMatrix;
use ndarray::Array1;

let y_true = Array1::from_vec(vec![0, 1, 2, 0, 1, 2, 0]);
let y_pred = Array1::from_vec(vec![0, 1, 1, 0, 1, 2, 0]);
let cm = ConfusionMatrix::<f32>::new(&y_true.view(), &y_pred.view(), None, None).unwrap();
Examples found in repository?
examples/confusion_matrix_heatmap.rs (lines 57-62)
7fn main() {
8    // Create a reproducible random number generator
9    let mut rng = SmallRng::seed_from_u64(42);
10
11    // Generate synthetic multiclass classification data
12    let num_classes = 5;
13    let n_samples = 500;
14
15    // Generate true labels (0 to num_classes-1)
16    let mut y_true = Vec::with_capacity(n_samples);
17    for _ in 0..n_samples {
18        y_true.push(rng.random_range(0..num_classes));
19    }
20
21    // Generate predicted labels with controlled accuracy
22    let mut y_pred = Vec::with_capacity(n_samples);
23    for &true_label in &y_true {
24        // 80% chance to predict correctly, 20% chance of error
25        if rng.random::<f64>() < 0.8 {
26            y_pred.push(true_label);
27        } else {
28            // When wrong, tend to predict adjacent classes more often
29            let mut pred = true_label;
30            while pred == true_label {
31                // Generate error that's more likely to be close to true label
32                let error_margin = (rng.random::<f64>() * 2.0).round() as usize; // 0, 1, or 2
33                if rng.random::<bool>() {
34                    pred = (true_label + error_margin) % num_classes;
35                } else {
36                    pred = (true_label + num_classes - error_margin) % num_classes;
37                }
38            }
39            y_pred.push(pred);
40        }
41    }
42
43    // Convert to ndarray arrays
44    let y_true_array = Array1::from(y_true);
45    let y_pred_array = Array1::from(y_pred);
46
47    // Create class labels
48    let class_labels = vec![
49        "Cat".to_string(),
50        "Dog".to_string(),
51        "Bird".to_string(),
52        "Fish".to_string(),
53        "Rabbit".to_string(),
54    ];
55
56    // Create confusion matrix
57    let cm = ConfusionMatrix::<f64>::new(
58        &y_true_array.view(),
59        &y_pred_array.view(),
60        Some(num_classes),
61        Some(class_labels),
62    )
63    .unwrap();
64
65    // Example 1: Standard confusion matrix
66    println!("Example 1: Standard Confusion Matrix\n");
67    let regular_output = cm.to_ascii(Some("Animal Classification Results"), false);
68    println!("{}", regular_output);
69
70    // Example 2: Confusion matrix with color
71    println!("\n\nExample 2: Colored Confusion Matrix\n");
72    let color_options = ColorOptions {
73        enabled: true,
74        use_bright: true,
75        use_background: false,
76    };
77    let colored_output = cm.to_ascii_with_options(
78        Some("Animal Classification Results (with color)"),
79        false,
80        &color_options,
81    );
82    println!("{}", colored_output);
83
84    // Example 3: Normalized confusion matrix heatmap
85    println!("\n\nExample 3: Normalized Confusion Matrix Heatmap\n");
86    let heatmap_output = cm.to_heatmap_with_options(
87        Some("Animal Classification Heatmap (normalized)"),
88        true, // normalized
89        &color_options,
90    );
91    println!("{}", heatmap_output);
92
93    // Example 4: Raw counts heatmap
94    println!("\n\nExample 4: Raw Counts Confusion Matrix Heatmap\n");
95    let raw_heatmap = cm.to_heatmap_with_options(
96        Some("Animal Classification Heatmap (raw counts)"),
97        false, // not normalized
98        &color_options,
99    );
100    println!("{}", raw_heatmap);
101}
More examples
Hide additional examples
examples/model_visualization_example.rs (lines 48-53)
7fn main() -> Result<()> {
8    println!("Neural Network Model Evaluation Visualization Example\n");
9
10    // Generate some example data
11    let n_samples = 500;
12    let n_features = 10;
13    let n_classes = 4;
14
15    println!(
16        "Generating {} samples with {} features for {} classes",
17        n_samples, n_features, n_classes
18    );
19
20    // 1. Confusion Matrix Example
21    println!("\n--- Confusion Matrix Visualization ---\n");
22
23    // Create a deterministic RNG for reproducibility
24    let mut rng = SmallRng::seed_from_u64(42);
25
26    // Generate random predictions and true labels
27    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
28
29    // Create slightly correlated predictions (not completely random)
30    let y_pred = Array::from_shape_fn(n_samples, |i| {
31        if rng.random::<f32>() < 0.7 {
32            // 70% chance of correct prediction
33            y_true[i]
34        } else {
35            // 30% chance of random class
36            rng.random_range(0..n_classes)
37        }
38    });
39
40    // Create confusion matrix
41    let class_labels = vec![
42        "Class A".to_string(),
43        "Class B".to_string(),
44        "Class C".to_string(),
45        "Class D".to_string(),
46    ];
47
48    let cm = ConfusionMatrix::<f32>::new(
49        &y_true.view(),
50        &y_pred.view(),
51        Some(n_classes),
52        Some(class_labels),
53    )?;
54
55    // Print raw and normalized confusion matrices
56    println!("Raw Confusion Matrix:\n");
57    println!("{}", cm.to_ascii(Some("Confusion Matrix"), false));
58
59    println!("\nNormalized Confusion Matrix:\n");
60    println!("{}", cm.to_ascii(Some("Normalized Confusion Matrix"), true));
61
62    // Print metrics
63    println!("\nAccuracy: {:.3}", cm.accuracy());
64
65    let precision = cm.precision();
66    let recall = cm.recall();
67    let f1 = cm.f1_score();
68
69    println!("Per-class metrics:");
70    for i in 0..n_classes {
71        println!(
72            "  Class {}: Precision={:.3}, Recall={:.3}, F1={:.3}",
73            i, precision[i], recall[i], f1[i]
74        );
75    }
76
77    println!("Macro F1 Score: {:.3}", cm.macro_f1());
78
79    // 2. Feature Importance Visualization
80    println!("\n--- Feature Importance Visualization ---\n");
81
82    // Generate random feature importance scores
83    let feature_names = (0..n_features)
84        .map(|i| format!("Feature_{}", i))
85        .collect::<Vec<String>>();
86
87    let importance = Array1::from_shape_fn(n_features, |i| {
88        // Make some features more important than others
89        let base = (n_features - i) as f32 / n_features as f32;
90        base + 0.2 * rng.random::<f32>()
91    });
92
93    let fi = FeatureImportance::new(feature_names, importance)?;
94
95    // Print full feature importance
96    println!("{}", fi.to_ascii(Some("Feature Importance"), 60, None));
97
98    // Print top-5 features
99    println!("\nTop 5 Most Important Features:\n");
100    println!("{}", fi.to_ascii(Some("Top 5 Features"), 60, Some(5)));
101
102    // 3. ROC Curve for Binary Classification
103    println!("\n--- ROC Curve Visualization ---\n");
104
105    // Generate binary classification data
106    let n_binary = 200;
107    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
108
109    // Generate scores with some predictive power
110    let y_scores = Array1::from_shape_fn(n_binary, |i| {
111        if y_true_binary[i] == 1 {
112            // Higher scores for positive class
113            0.6 + 0.4 * rng.random::<f32>()
114        } else {
115            // Lower scores for negative class
116            0.4 * rng.random::<f32>()
117        }
118    });
119
120    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
121
122    println!("ROC AUC: {:.3}", roc.auc);
123    println!("\n{}", roc.to_ascii(None, 50, 20));
124
125    // 4. Learning Curve Visualization
126    println!("\n--- Learning Curve Visualization ---\n");
127
128    // Generate learning curve data
129    let n_points = 10;
130    let n_cv = 5;
131
132    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
133
134    // Generate training scores (decreasing with size due to overfitting)
135    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
136        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
137    });
138
139    // Generate validation scores (increasing with size)
140    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
141        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
142    });
143
144    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
145
146    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
147
148    // Print final message
149    println!("\nModel evaluation visualizations completed successfully!");
150
151    Ok(())
152}
examples/colored_eval_visualization.rs (lines 70-75)
10fn main() -> Result<()> {
11    println!(
12        "{}",
13        stylize("Neural Network Model Evaluation with Color", Style::Bold)
14    );
15    println!("{}", "-".repeat(50));
16
17    // Set up color options
18    let color_options = ColorOptions {
19        enabled: true,
20        use_background: false,
21        use_bright: true,
22    };
23
24    // Generate some example data
25    let n_samples = 500;
26    let n_features = 10;
27    let n_classes = 4;
28
29    println!(
30        "\n{} {} {} {} {} {}",
31        colorize("Generating", Color::BrightGreen),
32        colorize(n_samples.to_string(), Color::BrightYellow),
33        colorize("samples with", Color::BrightGreen),
34        colorize(n_features.to_string(), Color::BrightYellow),
35        colorize("features for", Color::BrightGreen),
36        colorize(n_classes.to_string(), Color::BrightYellow),
37    );
38
39    // Create a deterministic RNG for reproducibility
40    let mut rng = SmallRng::seed_from_u64(42);
41
42    // 1. Confusion Matrix Example
43    println!(
44        "\n{}",
45        stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
46    );
47
48    // Generate random predictions and true labels
49    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
50
51    // Create slightly correlated predictions (not completely random)
52    let y_pred = Array::from_shape_fn(n_samples, |i| {
53        if rng.random::<f32>() < 0.7 {
54            // 70% chance of correct prediction
55            y_true[i]
56        } else {
57            // 30% chance of random class
58            rng.random_range(0..n_classes)
59        }
60    });
61
62    // Create confusion matrix
63    let class_labels = vec![
64        "Class A".to_string(),
65        "Class B".to_string(),
66        "Class C".to_string(),
67        "Class D".to_string(),
68    ];
69
70    let cm = ConfusionMatrix::<f32>::new(
71        &y_true.view(),
72        &y_pred.view(),
73        Some(n_classes),
74        Some(class_labels),
75    )?;
76
77    // Print raw and normalized confusion matrices with color
78    println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
79    println!(
80        "{}",
81        cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
82    );
83
84    println!(
85        "\n{}",
86        colorize("Normalized Confusion Matrix:", Color::BrightCyan)
87    );
88    println!(
89        "{}",
90        cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
91    );
92
93    // Print metrics
94    println!(
95        "\n{} {:.3}",
96        colorize("Overall Accuracy:", Color::BrightMagenta),
97        cm.accuracy()
98    );
99
100    let precision = cm.precision();
101    let recall = cm.recall();
102    let f1 = cm.f1_score();
103
104    println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
105    for i in 0..n_classes {
106        println!(
107            "  {}: {}={:.3}, {}={:.3}, {}={:.3}",
108            colorize(format!("Class {}", i), Color::BrightYellow),
109            colorize("Precision", Color::BrightCyan),
110            precision[i],
111            colorize("Recall", Color::BrightGreen),
112            recall[i],
113            colorize("F1", Color::BrightBlue),
114            f1[i]
115        );
116    }
117
118    println!(
119        "{} {:.3}",
120        colorize("Macro F1 Score:", Color::BrightMagenta),
121        cm.macro_f1()
122    );
123
124    // 2. Feature Importance Visualization
125    println!(
126        "\n{}",
127        stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
128    );
129
130    // Generate random feature importance scores
131    let feature_names = (0..n_features)
132        .map(|i| format!("Feature_{}", i))
133        .collect::<Vec<String>>();
134
135    let importance = Array1::from_shape_fn(n_features, |i| {
136        // Make some features more important than others
137        let base = (n_features - i) as f32 / n_features as f32;
138        base + 0.2 * rng.random::<f32>()
139    });
140
141    let fi = FeatureImportance::new(feature_names, importance)?;
142
143    // Print full feature importance with color
144    println!(
145        "{}",
146        fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
147    );
148
149    // Print top-5 features with color
150    println!(
151        "\n{}",
152        colorize("Top 5 Most Important Features:", Color::BrightCyan)
153    );
154    println!(
155        "{}",
156        fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
157    );
158
159    // 3. ROC Curve for Binary Classification
160    println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
161
162    // Generate binary classification data
163    let n_binary = 200;
164    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
165
166    // Generate scores with some predictive power
167    let y_scores = Array1::from_shape_fn(n_binary, |i| {
168        if y_true_binary[i] == 1 {
169            // Higher scores for positive class
170            0.6 + 0.4 * rng.random::<f32>()
171        } else {
172            // Lower scores for negative class
173            0.4 * rng.random::<f32>()
174        }
175    });
176
177    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
178
179    println!(
180        "{} {:.3}",
181        colorize("ROC AUC:", Color::BrightMagenta),
182        roc.auc
183    );
184
185    println!("\n{}", roc.to_ascii(None, 50, 20));
186
187    // 4. Learning Curve Visualization
188    println!(
189        "\n{}",
190        stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
191    );
192
193    // Generate learning curve data
194    let n_points = 10;
195    let n_cv = 5;
196
197    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
198
199    // Generate training scores (decreasing with size due to overfitting)
200    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
201        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
202    });
203
204    // Generate validation scores (increasing with size)
205    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
206        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
207    });
208
209    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
210
211    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
212
213    // Print final message with color
214    println!(
215        "\n{}",
216        colorize(
217            "Model evaluation visualizations completed successfully!",
218            Color::BrightGreen
219        )
220    );
221
222    Ok(())
223}
examples/neural_confusion_matrix.rs (lines 255-260)
118fn main() -> Result<()> {
119    println!("Neural Network Confusion Matrix Visualization");
120    println!("==============================================\n");
121
122    // Initialize RNG with a fixed seed for reproducibility
123    let mut rng = SmallRng::seed_from_u64(42);
124
125    // Generate spiral dataset for 3-class classification
126    let n_classes = 3;
127    let n_samples_per_class = 100;
128    let noise = 0.15;
129
130    let (x, y) = generate_spiral_dataset(n_samples_per_class, n_classes, noise, &mut rng);
131    println!(
132        "Generated spiral dataset with {} classes, {} samples per class",
133        n_classes, n_samples_per_class
134    );
135
136    // Split data into training and test sets (80/20 split)
137    let n_samples = x.shape()[0];
138    let n_train = (n_samples as f32 * 0.8) as usize;
139    let n_test = n_samples - n_train;
140
141    let x_train = x.slice(ndarray::s![0..n_train, ..]).to_owned();
142    let y_train = y.slice(ndarray::s![0..n_train]).to_owned();
143    let x_test = x.slice(ndarray::s![n_train.., ..]).to_owned();
144    let y_test = y.slice(ndarray::s![n_train..]).to_owned();
145
146    println!(
147        "Split data into {} training and {} test samples",
148        n_train, n_test
149    );
150
151    // Create a classification model
152    let input_dim = 2; // 2D input (x, y coordinates)
153    let hidden_dim = 32; // Hidden layer size
154    let output_dim = n_classes; // One output per class
155
156    let mut model = create_classification_model(input_dim, hidden_dim, output_dim, &mut rng)?;
157    println!("Created model with {} layers", model.num_layers());
158
159    // Setup loss function and optimizer
160    let loss_fn = MeanSquaredError::new();
161    let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
162
163    // Train the model
164    let epochs = 100;
165    let x_train_dyn = x_train.clone().into_dyn();
166    let y_train_onehot = one_hot_encode(&y_train, n_classes);
167    let y_train_onehot_dyn = y_train_onehot.into_dyn();
168
169    // Create visualization callback for training metrics
170    let mut visualization_cb = VisualizationCallback::new(10) // Show every 10 epochs
171        .with_tracked_metrics(vec![
172            "train_loss".to_string(),
173            "val_accuracy".to_string(),
174        ]);
175
176    // Define class labels for confusion matrix
177    let class_labels = vec![
178        "Class A".to_string(),
179        "Class B".to_string(),
180        "Class C".to_string(),
181    ];
182
183    // Train the model (simple manual training loop)
184    println!("\nTraining model...");
185
186    // Initialize history for tracking metrics
187    let mut epoch_history = HashMap::new();
188    epoch_history.insert("train_loss".to_string(), Vec::new());
189    epoch_history.insert("val_accuracy".to_string(), Vec::new());
190
191    // Training loop
192    for epoch in 0..epochs {
193        // Train for one epoch
194        let train_loss =
195            model.train_batch(&x_train_dyn, &y_train_onehot_dyn, &loss_fn, &mut optimizer)?;
196
197        // Compute validation accuracy
198        let x_test_dyn = x_test.clone().into_dyn();
199        let predictions = model.forward(&x_test_dyn)?;
200        let predicted_classes = predictions_to_classes(&predictions);
201
202        // Calculate validation accuracy
203        let mut correct = 0;
204        for i in 0..n_test {
205            if predicted_classes[i] == y_test[i] {
206                correct += 1;
207            }
208        }
209        let val_accuracy = correct as f32 / n_test as f32;
210
211        // Store metrics
212        epoch_history
213            .get_mut("train_loss")
214            .unwrap()
215            .push(train_loss);
216        epoch_history
217            .get_mut("val_accuracy")
218            .unwrap()
219            .push(val_accuracy);
220
221        // Print progress
222        if (epoch + 1) % 10 == 0 || epoch == 0 {
223            println!(
224                "Epoch {}/{}: loss = {:.6}, val_accuracy = {:.4}",
225                epoch + 1,
226                epochs,
227                train_loss,
228                val_accuracy
229            );
230        }
231
232        // Update visualization callback
233        let mut context = CallbackContext {
234            epoch,
235            total_epochs: epochs,
236            batch: 0,
237            total_batches: 1,
238            batch_loss: None,
239            epoch_loss: Some(train_loss),
240            val_loss: None,
241            metrics: vec![val_accuracy],
242            history: &epoch_history,
243            stop_training: false,
244            model: None,
245        };
246
247        // Visualize progress with metrics chart
248        if epoch % 10 == 0 || epoch == epochs - 1 {
249            visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
250        }
251
252        // Calculate and show confusion matrix during training
253        if epoch % 20 == 0 || epoch == epochs - 1 {
254            // Create confusion matrix
255            let cm = ConfusionMatrix::<f32>::new(
256                &y_test.view(),
257                &predicted_classes.view(),
258                Some(n_classes),
259                Some(class_labels.clone()),
260            )?;
261
262            // Show heatmap visualization
263            println!("\nConfusion Matrix at Epoch {}:", epoch + 1);
264            println!(
265                "{}",
266                cm.to_heatmap(
267                    Some(&format!("Confusion Matrix - Epoch {}", epoch + 1)),
268                    true
269                )
270            );
271        }
272    }
273
274    // Final evaluation
275    println!("\nFinal model evaluation:");
276
277    // Make predictions on test set
278    let x_test_dyn = x_test.clone().into_dyn();
279    let predictions = model.forward(&x_test_dyn)?;
280    let predicted_classes = predictions_to_classes(&predictions);
281
282    // Create confusion matrix
283    let cm = ConfusionMatrix::<f32>::new(
284        &y_test.view(),
285        &predicted_classes.view(),
286        Some(n_classes),
287        Some(class_labels.clone()),
288    )?;
289
290    // Calculate and show metrics
291    let accuracy = cm.accuracy();
292    let precision = cm.precision();
293    let recall = cm.recall();
294    let f1 = cm.f1_score();
295
296    println!("\nFinal Classification Metrics:");
297    println!("Overall Accuracy: {:.4}", accuracy);
298
299    println!("\nPer-Class Metrics:");
300    println!("Class | Precision | Recall | F1-Score");
301    println!("-----------------------------------");
302
303    for i in 0..n_classes {
304        println!(
305            "{}    | {:.4}     | {:.4}  | {:.4}",
306            class_labels[i], precision[i], recall[i], f1[i]
307        );
308    }
309
310    println!("\nMacro F1-Score: {:.4}", cm.macro_f1());
311
312    // Show different confusion matrix visualizations
313    println!("\nFinal Confusion Matrix Visualizations:");
314
315    // 1. Standard confusion matrix
316    println!("\n1. Standard Confusion Matrix:");
317    println!("{}", cm.to_ascii(Some("Final Confusion Matrix"), false));
318
319    // 2. Normalized confusion matrix
320    println!("\n2. Normalized Confusion Matrix:");
321    println!(
322        "{}",
323        cm.to_ascii(Some("Final Normalized Confusion Matrix"), true)
324    );
325
326    // 3. Confusion matrix heatmap
327    println!("\n3. Confusion Matrix Heatmap:");
328    println!(
329        "{}",
330        cm.to_heatmap(Some("Final Confusion Matrix Heatmap"), true)
331    );
332
333    // 4. Error pattern analysis
334    println!("\n4. Error Pattern Analysis:");
335    println!("{}", cm.error_heatmap(Some("Final Error Pattern Analysis")));
336
337    println!("\nNeural Network Confusion Matrix Visualization Complete!");
338    Ok(())
339}
Source

pub fn from_matrix( matrix: Array2<F>, labels: Option<Vec<String>>, ) -> Result<Self>

Create a confusion matrix from raw matrix data

§Arguments
  • matrix - Raw confusion matrix data
  • labels - Optional class labels
§Returns
  • Result<ConfusionMatrix<F>> - The confusion matrix
Examples found in repository?
examples/error_pattern_heatmap.rs (line 59)
6fn main() {
7    // Create a reproducible random number generator
8    let mut rng = SmallRng::seed_from_u64(42);
9
10    // Generate synthetic multiclass classification data with specific error patterns
11    let num_classes = 5;
12
13    // Create confusion matrix with controlled error patterns
14    let mut matrix = vec![vec![0; num_classes]; num_classes];
15
16    // Set diagonal elements (correct classifications) with high values
17    for i in 0..num_classes {
18        matrix[i][i] = 70 + rng.random_range(0..15); // 70-85 correct per class
19    }
20
21    // Create specific error patterns:
22    // - Classes 0 and 1 often confused
23    matrix[0][1] = 25;
24    matrix[1][0] = 15;
25
26    // - Class 2 sometimes confused with Class 3
27    matrix[2][3] = 18;
28
29    // - Class 4 has some misclassifications to all other classes
30    matrix[4][0] = 8;
31    matrix[4][1] = 5;
32    matrix[4][2] = 10;
33    matrix[4][3] = 12;
34
35    // - Some minor errors scattered about
36    for i in 0..num_classes {
37        for j in 0..num_classes {
38            if i != j && matrix[i][j] == 0 {
39                matrix[i][j] = rng.random_range(0..5);
40            }
41        }
42    }
43
44    // Convert to ndarray
45    let flat_matrix: Vec<f64> = matrix.iter().flatten().map(|&x| x as f64).collect();
46    let ndarray_matrix =
47        ndarray::Array::from_shape_vec((num_classes, num_classes), flat_matrix).unwrap();
48
49    // Create class labels
50    let class_labels = vec![
51        "Class A".to_string(),
52        "Class B".to_string(),
53        "Class C".to_string(),
54        "Class D".to_string(),
55        "Class E".to_string(),
56    ];
57
58    // Create confusion matrix
59    let cm = ConfusionMatrix::from_matrix(ndarray_matrix, Some(class_labels)).unwrap();
60
61    // Example 1: Standard confusion matrix
62    println!("Example 1: Standard Confusion Matrix\n");
63    let regular_output = cm.to_ascii(Some("Classification Results"), false);
64    println!("{}", regular_output);
65
66    // Example 2: Normal heatmap
67    println!("\n\nExample 2: Standard Heatmap Visualization\n");
68    let color_options = ColorOptions {
69        enabled: true,
70        use_bright: true,
71        use_background: false,
72    };
73    let heatmap_output = cm.to_heatmap_with_options(
74        Some("Classification Heatmap"),
75        true, // normalized
76        &color_options,
77    );
78    println!("{}", heatmap_output);
79
80    // Example 3: Error pattern heatmap
81    println!("\n\nExample 3: Error Pattern Heatmap (highlighting misclassifications)\n");
82    let error_heatmap = cm.error_heatmap(Some("Misclassification Analysis"));
83    println!("{}", error_heatmap);
84}
Source

pub fn normalized(&self) -> Array2<F>

Get the normalized confusion matrix (rows sum to 1)

Source

pub fn accuracy(&self) -> F

Calculate accuracy from the confusion matrix

Examples found in repository?
examples/model_visualization_example.rs (line 63)
7fn main() -> Result<()> {
8    println!("Neural Network Model Evaluation Visualization Example\n");
9
10    // Generate some example data
11    let n_samples = 500;
12    let n_features = 10;
13    let n_classes = 4;
14
15    println!(
16        "Generating {} samples with {} features for {} classes",
17        n_samples, n_features, n_classes
18    );
19
20    // 1. Confusion Matrix Example
21    println!("\n--- Confusion Matrix Visualization ---\n");
22
23    // Create a deterministic RNG for reproducibility
24    let mut rng = SmallRng::seed_from_u64(42);
25
26    // Generate random predictions and true labels
27    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
28
29    // Create slightly correlated predictions (not completely random)
30    let y_pred = Array::from_shape_fn(n_samples, |i| {
31        if rng.random::<f32>() < 0.7 {
32            // 70% chance of correct prediction
33            y_true[i]
34        } else {
35            // 30% chance of random class
36            rng.random_range(0..n_classes)
37        }
38    });
39
40    // Create confusion matrix
41    let class_labels = vec![
42        "Class A".to_string(),
43        "Class B".to_string(),
44        "Class C".to_string(),
45        "Class D".to_string(),
46    ];
47
48    let cm = ConfusionMatrix::<f32>::new(
49        &y_true.view(),
50        &y_pred.view(),
51        Some(n_classes),
52        Some(class_labels),
53    )?;
54
55    // Print raw and normalized confusion matrices
56    println!("Raw Confusion Matrix:\n");
57    println!("{}", cm.to_ascii(Some("Confusion Matrix"), false));
58
59    println!("\nNormalized Confusion Matrix:\n");
60    println!("{}", cm.to_ascii(Some("Normalized Confusion Matrix"), true));
61
62    // Print metrics
63    println!("\nAccuracy: {:.3}", cm.accuracy());
64
65    let precision = cm.precision();
66    let recall = cm.recall();
67    let f1 = cm.f1_score();
68
69    println!("Per-class metrics:");
70    for i in 0..n_classes {
71        println!(
72            "  Class {}: Precision={:.3}, Recall={:.3}, F1={:.3}",
73            i, precision[i], recall[i], f1[i]
74        );
75    }
76
77    println!("Macro F1 Score: {:.3}", cm.macro_f1());
78
79    // 2. Feature Importance Visualization
80    println!("\n--- Feature Importance Visualization ---\n");
81
82    // Generate random feature importance scores
83    let feature_names = (0..n_features)
84        .map(|i| format!("Feature_{}", i))
85        .collect::<Vec<String>>();
86
87    let importance = Array1::from_shape_fn(n_features, |i| {
88        // Make some features more important than others
89        let base = (n_features - i) as f32 / n_features as f32;
90        base + 0.2 * rng.random::<f32>()
91    });
92
93    let fi = FeatureImportance::new(feature_names, importance)?;
94
95    // Print full feature importance
96    println!("{}", fi.to_ascii(Some("Feature Importance"), 60, None));
97
98    // Print top-5 features
99    println!("\nTop 5 Most Important Features:\n");
100    println!("{}", fi.to_ascii(Some("Top 5 Features"), 60, Some(5)));
101
102    // 3. ROC Curve for Binary Classification
103    println!("\n--- ROC Curve Visualization ---\n");
104
105    // Generate binary classification data
106    let n_binary = 200;
107    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
108
109    // Generate scores with some predictive power
110    let y_scores = Array1::from_shape_fn(n_binary, |i| {
111        if y_true_binary[i] == 1 {
112            // Higher scores for positive class
113            0.6 + 0.4 * rng.random::<f32>()
114        } else {
115            // Lower scores for negative class
116            0.4 * rng.random::<f32>()
117        }
118    });
119
120    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
121
122    println!("ROC AUC: {:.3}", roc.auc);
123    println!("\n{}", roc.to_ascii(None, 50, 20));
124
125    // 4. Learning Curve Visualization
126    println!("\n--- Learning Curve Visualization ---\n");
127
128    // Generate learning curve data
129    let n_points = 10;
130    let n_cv = 5;
131
132    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
133
134    // Generate training scores (decreasing with size due to overfitting)
135    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
136        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
137    });
138
139    // Generate validation scores (increasing with size)
140    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
141        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
142    });
143
144    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
145
146    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
147
148    // Print final message
149    println!("\nModel evaluation visualizations completed successfully!");
150
151    Ok(())
152}
More examples
Hide additional examples
examples/colored_eval_visualization.rs (line 97)
10fn main() -> Result<()> {
11    println!(
12        "{}",
13        stylize("Neural Network Model Evaluation with Color", Style::Bold)
14    );
15    println!("{}", "-".repeat(50));
16
17    // Set up color options
18    let color_options = ColorOptions {
19        enabled: true,
20        use_background: false,
21        use_bright: true,
22    };
23
24    // Generate some example data
25    let n_samples = 500;
26    let n_features = 10;
27    let n_classes = 4;
28
29    println!(
30        "\n{} {} {} {} {} {}",
31        colorize("Generating", Color::BrightGreen),
32        colorize(n_samples.to_string(), Color::BrightYellow),
33        colorize("samples with", Color::BrightGreen),
34        colorize(n_features.to_string(), Color::BrightYellow),
35        colorize("features for", Color::BrightGreen),
36        colorize(n_classes.to_string(), Color::BrightYellow),
37    );
38
39    // Create a deterministic RNG for reproducibility
40    let mut rng = SmallRng::seed_from_u64(42);
41
42    // 1. Confusion Matrix Example
43    println!(
44        "\n{}",
45        stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
46    );
47
48    // Generate random predictions and true labels
49    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
50
51    // Create slightly correlated predictions (not completely random)
52    let y_pred = Array::from_shape_fn(n_samples, |i| {
53        if rng.random::<f32>() < 0.7 {
54            // 70% chance of correct prediction
55            y_true[i]
56        } else {
57            // 30% chance of random class
58            rng.random_range(0..n_classes)
59        }
60    });
61
62    // Create confusion matrix
63    let class_labels = vec![
64        "Class A".to_string(),
65        "Class B".to_string(),
66        "Class C".to_string(),
67        "Class D".to_string(),
68    ];
69
70    let cm = ConfusionMatrix::<f32>::new(
71        &y_true.view(),
72        &y_pred.view(),
73        Some(n_classes),
74        Some(class_labels),
75    )?;
76
77    // Print raw and normalized confusion matrices with color
78    println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
79    println!(
80        "{}",
81        cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
82    );
83
84    println!(
85        "\n{}",
86        colorize("Normalized Confusion Matrix:", Color::BrightCyan)
87    );
88    println!(
89        "{}",
90        cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
91    );
92
93    // Print metrics
94    println!(
95        "\n{} {:.3}",
96        colorize("Overall Accuracy:", Color::BrightMagenta),
97        cm.accuracy()
98    );
99
100    let precision = cm.precision();
101    let recall = cm.recall();
102    let f1 = cm.f1_score();
103
104    println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
105    for i in 0..n_classes {
106        println!(
107            "  {}: {}={:.3}, {}={:.3}, {}={:.3}",
108            colorize(format!("Class {}", i), Color::BrightYellow),
109            colorize("Precision", Color::BrightCyan),
110            precision[i],
111            colorize("Recall", Color::BrightGreen),
112            recall[i],
113            colorize("F1", Color::BrightBlue),
114            f1[i]
115        );
116    }
117
118    println!(
119        "{} {:.3}",
120        colorize("Macro F1 Score:", Color::BrightMagenta),
121        cm.macro_f1()
122    );
123
124    // 2. Feature Importance Visualization
125    println!(
126        "\n{}",
127        stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
128    );
129
130    // Generate random feature importance scores
131    let feature_names = (0..n_features)
132        .map(|i| format!("Feature_{}", i))
133        .collect::<Vec<String>>();
134
135    let importance = Array1::from_shape_fn(n_features, |i| {
136        // Make some features more important than others
137        let base = (n_features - i) as f32 / n_features as f32;
138        base + 0.2 * rng.random::<f32>()
139    });
140
141    let fi = FeatureImportance::new(feature_names, importance)?;
142
143    // Print full feature importance with color
144    println!(
145        "{}",
146        fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
147    );
148
149    // Print top-5 features with color
150    println!(
151        "\n{}",
152        colorize("Top 5 Most Important Features:", Color::BrightCyan)
153    );
154    println!(
155        "{}",
156        fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
157    );
158
159    // 3. ROC Curve for Binary Classification
160    println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
161
162    // Generate binary classification data
163    let n_binary = 200;
164    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
165
166    // Generate scores with some predictive power
167    let y_scores = Array1::from_shape_fn(n_binary, |i| {
168        if y_true_binary[i] == 1 {
169            // Higher scores for positive class
170            0.6 + 0.4 * rng.random::<f32>()
171        } else {
172            // Lower scores for negative class
173            0.4 * rng.random::<f32>()
174        }
175    });
176
177    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
178
179    println!(
180        "{} {:.3}",
181        colorize("ROC AUC:", Color::BrightMagenta),
182        roc.auc
183    );
184
185    println!("\n{}", roc.to_ascii(None, 50, 20));
186
187    // 4. Learning Curve Visualization
188    println!(
189        "\n{}",
190        stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
191    );
192
193    // Generate learning curve data
194    let n_points = 10;
195    let n_cv = 5;
196
197    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
198
199    // Generate training scores (decreasing with size due to overfitting)
200    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
201        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
202    });
203
204    // Generate validation scores (increasing with size)
205    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
206        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
207    });
208
209    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
210
211    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
212
213    // Print final message with color
214    println!(
215        "\n{}",
216        colorize(
217            "Model evaluation visualizations completed successfully!",
218            Color::BrightGreen
219        )
220    );
221
222    Ok(())
223}
examples/neural_confusion_matrix.rs (line 291)
118fn main() -> Result<()> {
119    println!("Neural Network Confusion Matrix Visualization");
120    println!("==============================================\n");
121
122    // Initialize RNG with a fixed seed for reproducibility
123    let mut rng = SmallRng::seed_from_u64(42);
124
125    // Generate spiral dataset for 3-class classification
126    let n_classes = 3;
127    let n_samples_per_class = 100;
128    let noise = 0.15;
129
130    let (x, y) = generate_spiral_dataset(n_samples_per_class, n_classes, noise, &mut rng);
131    println!(
132        "Generated spiral dataset with {} classes, {} samples per class",
133        n_classes, n_samples_per_class
134    );
135
136    // Split data into training and test sets (80/20 split)
137    let n_samples = x.shape()[0];
138    let n_train = (n_samples as f32 * 0.8) as usize;
139    let n_test = n_samples - n_train;
140
141    let x_train = x.slice(ndarray::s![0..n_train, ..]).to_owned();
142    let y_train = y.slice(ndarray::s![0..n_train]).to_owned();
143    let x_test = x.slice(ndarray::s![n_train.., ..]).to_owned();
144    let y_test = y.slice(ndarray::s![n_train..]).to_owned();
145
146    println!(
147        "Split data into {} training and {} test samples",
148        n_train, n_test
149    );
150
151    // Create a classification model
152    let input_dim = 2; // 2D input (x, y coordinates)
153    let hidden_dim = 32; // Hidden layer size
154    let output_dim = n_classes; // One output per class
155
156    let mut model = create_classification_model(input_dim, hidden_dim, output_dim, &mut rng)?;
157    println!("Created model with {} layers", model.num_layers());
158
159    // Setup loss function and optimizer
160    let loss_fn = MeanSquaredError::new();
161    let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
162
163    // Train the model
164    let epochs = 100;
165    let x_train_dyn = x_train.clone().into_dyn();
166    let y_train_onehot = one_hot_encode(&y_train, n_classes);
167    let y_train_onehot_dyn = y_train_onehot.into_dyn();
168
169    // Create visualization callback for training metrics
170    let mut visualization_cb = VisualizationCallback::new(10) // Show every 10 epochs
171        .with_tracked_metrics(vec![
172            "train_loss".to_string(),
173            "val_accuracy".to_string(),
174        ]);
175
176    // Define class labels for confusion matrix
177    let class_labels = vec![
178        "Class A".to_string(),
179        "Class B".to_string(),
180        "Class C".to_string(),
181    ];
182
183    // Train the model (simple manual training loop)
184    println!("\nTraining model...");
185
186    // Initialize history for tracking metrics
187    let mut epoch_history = HashMap::new();
188    epoch_history.insert("train_loss".to_string(), Vec::new());
189    epoch_history.insert("val_accuracy".to_string(), Vec::new());
190
191    // Training loop
192    for epoch in 0..epochs {
193        // Train for one epoch
194        let train_loss =
195            model.train_batch(&x_train_dyn, &y_train_onehot_dyn, &loss_fn, &mut optimizer)?;
196
197        // Compute validation accuracy
198        let x_test_dyn = x_test.clone().into_dyn();
199        let predictions = model.forward(&x_test_dyn)?;
200        let predicted_classes = predictions_to_classes(&predictions);
201
202        // Calculate validation accuracy
203        let mut correct = 0;
204        for i in 0..n_test {
205            if predicted_classes[i] == y_test[i] {
206                correct += 1;
207            }
208        }
209        let val_accuracy = correct as f32 / n_test as f32;
210
211        // Store metrics
212        epoch_history
213            .get_mut("train_loss")
214            .unwrap()
215            .push(train_loss);
216        epoch_history
217            .get_mut("val_accuracy")
218            .unwrap()
219            .push(val_accuracy);
220
221        // Print progress
222        if (epoch + 1) % 10 == 0 || epoch == 0 {
223            println!(
224                "Epoch {}/{}: loss = {:.6}, val_accuracy = {:.4}",
225                epoch + 1,
226                epochs,
227                train_loss,
228                val_accuracy
229            );
230        }
231
232        // Update visualization callback
233        let mut context = CallbackContext {
234            epoch,
235            total_epochs: epochs,
236            batch: 0,
237            total_batches: 1,
238            batch_loss: None,
239            epoch_loss: Some(train_loss),
240            val_loss: None,
241            metrics: vec![val_accuracy],
242            history: &epoch_history,
243            stop_training: false,
244            model: None,
245        };
246
247        // Visualize progress with metrics chart
248        if epoch % 10 == 0 || epoch == epochs - 1 {
249            visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
250        }
251
252        // Calculate and show confusion matrix during training
253        if epoch % 20 == 0 || epoch == epochs - 1 {
254            // Create confusion matrix
255            let cm = ConfusionMatrix::<f32>::new(
256                &y_test.view(),
257                &predicted_classes.view(),
258                Some(n_classes),
259                Some(class_labels.clone()),
260            )?;
261
262            // Show heatmap visualization
263            println!("\nConfusion Matrix at Epoch {}:", epoch + 1);
264            println!(
265                "{}",
266                cm.to_heatmap(
267                    Some(&format!("Confusion Matrix - Epoch {}", epoch + 1)),
268                    true
269                )
270            );
271        }
272    }
273
274    // Final evaluation
275    println!("\nFinal model evaluation:");
276
277    // Make predictions on test set
278    let x_test_dyn = x_test.clone().into_dyn();
279    let predictions = model.forward(&x_test_dyn)?;
280    let predicted_classes = predictions_to_classes(&predictions);
281
282    // Create confusion matrix
283    let cm = ConfusionMatrix::<f32>::new(
284        &y_test.view(),
285        &predicted_classes.view(),
286        Some(n_classes),
287        Some(class_labels.clone()),
288    )?;
289
290    // Calculate and show metrics
291    let accuracy = cm.accuracy();
292    let precision = cm.precision();
293    let recall = cm.recall();
294    let f1 = cm.f1_score();
295
296    println!("\nFinal Classification Metrics:");
297    println!("Overall Accuracy: {:.4}", accuracy);
298
299    println!("\nPer-Class Metrics:");
300    println!("Class | Precision | Recall | F1-Score");
301    println!("-----------------------------------");
302
303    for i in 0..n_classes {
304        println!(
305            "{}    | {:.4}     | {:.4}  | {:.4}",
306            class_labels[i], precision[i], recall[i], f1[i]
307        );
308    }
309
310    println!("\nMacro F1-Score: {:.4}", cm.macro_f1());
311
312    // Show different confusion matrix visualizations
313    println!("\nFinal Confusion Matrix Visualizations:");
314
315    // 1. Standard confusion matrix
316    println!("\n1. Standard Confusion Matrix:");
317    println!("{}", cm.to_ascii(Some("Final Confusion Matrix"), false));
318
319    // 2. Normalized confusion matrix
320    println!("\n2. Normalized Confusion Matrix:");
321    println!(
322        "{}",
323        cm.to_ascii(Some("Final Normalized Confusion Matrix"), true)
324    );
325
326    // 3. Confusion matrix heatmap
327    println!("\n3. Confusion Matrix Heatmap:");
328    println!(
329        "{}",
330        cm.to_heatmap(Some("Final Confusion Matrix Heatmap"), true)
331    );
332
333    // 4. Error pattern analysis
334    println!("\n4. Error Pattern Analysis:");
335    println!("{}", cm.error_heatmap(Some("Final Error Pattern Analysis")));
336
337    println!("\nNeural Network Confusion Matrix Visualization Complete!");
338    Ok(())
339}
Source

pub fn precision(&self) -> Array1<F>

Calculate precision for each class

Examples found in repository?
examples/model_visualization_example.rs (line 65)
7fn main() -> Result<()> {
8    println!("Neural Network Model Evaluation Visualization Example\n");
9
10    // Generate some example data
11    let n_samples = 500;
12    let n_features = 10;
13    let n_classes = 4;
14
15    println!(
16        "Generating {} samples with {} features for {} classes",
17        n_samples, n_features, n_classes
18    );
19
20    // 1. Confusion Matrix Example
21    println!("\n--- Confusion Matrix Visualization ---\n");
22
23    // Create a deterministic RNG for reproducibility
24    let mut rng = SmallRng::seed_from_u64(42);
25
26    // Generate random predictions and true labels
27    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
28
29    // Create slightly correlated predictions (not completely random)
30    let y_pred = Array::from_shape_fn(n_samples, |i| {
31        if rng.random::<f32>() < 0.7 {
32            // 70% chance of correct prediction
33            y_true[i]
34        } else {
35            // 30% chance of random class
36            rng.random_range(0..n_classes)
37        }
38    });
39
40    // Create confusion matrix
41    let class_labels = vec![
42        "Class A".to_string(),
43        "Class B".to_string(),
44        "Class C".to_string(),
45        "Class D".to_string(),
46    ];
47
48    let cm = ConfusionMatrix::<f32>::new(
49        &y_true.view(),
50        &y_pred.view(),
51        Some(n_classes),
52        Some(class_labels),
53    )?;
54
55    // Print raw and normalized confusion matrices
56    println!("Raw Confusion Matrix:\n");
57    println!("{}", cm.to_ascii(Some("Confusion Matrix"), false));
58
59    println!("\nNormalized Confusion Matrix:\n");
60    println!("{}", cm.to_ascii(Some("Normalized Confusion Matrix"), true));
61
62    // Print metrics
63    println!("\nAccuracy: {:.3}", cm.accuracy());
64
65    let precision = cm.precision();
66    let recall = cm.recall();
67    let f1 = cm.f1_score();
68
69    println!("Per-class metrics:");
70    for i in 0..n_classes {
71        println!(
72            "  Class {}: Precision={:.3}, Recall={:.3}, F1={:.3}",
73            i, precision[i], recall[i], f1[i]
74        );
75    }
76
77    println!("Macro F1 Score: {:.3}", cm.macro_f1());
78
79    // 2. Feature Importance Visualization
80    println!("\n--- Feature Importance Visualization ---\n");
81
82    // Generate random feature importance scores
83    let feature_names = (0..n_features)
84        .map(|i| format!("Feature_{}", i))
85        .collect::<Vec<String>>();
86
87    let importance = Array1::from_shape_fn(n_features, |i| {
88        // Make some features more important than others
89        let base = (n_features - i) as f32 / n_features as f32;
90        base + 0.2 * rng.random::<f32>()
91    });
92
93    let fi = FeatureImportance::new(feature_names, importance)?;
94
95    // Print full feature importance
96    println!("{}", fi.to_ascii(Some("Feature Importance"), 60, None));
97
98    // Print top-5 features
99    println!("\nTop 5 Most Important Features:\n");
100    println!("{}", fi.to_ascii(Some("Top 5 Features"), 60, Some(5)));
101
102    // 3. ROC Curve for Binary Classification
103    println!("\n--- ROC Curve Visualization ---\n");
104
105    // Generate binary classification data
106    let n_binary = 200;
107    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
108
109    // Generate scores with some predictive power
110    let y_scores = Array1::from_shape_fn(n_binary, |i| {
111        if y_true_binary[i] == 1 {
112            // Higher scores for positive class
113            0.6 + 0.4 * rng.random::<f32>()
114        } else {
115            // Lower scores for negative class
116            0.4 * rng.random::<f32>()
117        }
118    });
119
120    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
121
122    println!("ROC AUC: {:.3}", roc.auc);
123    println!("\n{}", roc.to_ascii(None, 50, 20));
124
125    // 4. Learning Curve Visualization
126    println!("\n--- Learning Curve Visualization ---\n");
127
128    // Generate learning curve data
129    let n_points = 10;
130    let n_cv = 5;
131
132    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
133
134    // Generate training scores (decreasing with size due to overfitting)
135    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
136        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
137    });
138
139    // Generate validation scores (increasing with size)
140    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
141        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
142    });
143
144    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
145
146    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
147
148    // Print final message
149    println!("\nModel evaluation visualizations completed successfully!");
150
151    Ok(())
152}
More examples
Hide additional examples
examples/colored_eval_visualization.rs (line 100)
10fn main() -> Result<()> {
11    println!(
12        "{}",
13        stylize("Neural Network Model Evaluation with Color", Style::Bold)
14    );
15    println!("{}", "-".repeat(50));
16
17    // Set up color options
18    let color_options = ColorOptions {
19        enabled: true,
20        use_background: false,
21        use_bright: true,
22    };
23
24    // Generate some example data
25    let n_samples = 500;
26    let n_features = 10;
27    let n_classes = 4;
28
29    println!(
30        "\n{} {} {} {} {} {}",
31        colorize("Generating", Color::BrightGreen),
32        colorize(n_samples.to_string(), Color::BrightYellow),
33        colorize("samples with", Color::BrightGreen),
34        colorize(n_features.to_string(), Color::BrightYellow),
35        colorize("features for", Color::BrightGreen),
36        colorize(n_classes.to_string(), Color::BrightYellow),
37    );
38
39    // Create a deterministic RNG for reproducibility
40    let mut rng = SmallRng::seed_from_u64(42);
41
42    // 1. Confusion Matrix Example
43    println!(
44        "\n{}",
45        stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
46    );
47
48    // Generate random predictions and true labels
49    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
50
51    // Create slightly correlated predictions (not completely random)
52    let y_pred = Array::from_shape_fn(n_samples, |i| {
53        if rng.random::<f32>() < 0.7 {
54            // 70% chance of correct prediction
55            y_true[i]
56        } else {
57            // 30% chance of random class
58            rng.random_range(0..n_classes)
59        }
60    });
61
62    // Create confusion matrix
63    let class_labels = vec![
64        "Class A".to_string(),
65        "Class B".to_string(),
66        "Class C".to_string(),
67        "Class D".to_string(),
68    ];
69
70    let cm = ConfusionMatrix::<f32>::new(
71        &y_true.view(),
72        &y_pred.view(),
73        Some(n_classes),
74        Some(class_labels),
75    )?;
76
77    // Print raw and normalized confusion matrices with color
78    println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
79    println!(
80        "{}",
81        cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
82    );
83
84    println!(
85        "\n{}",
86        colorize("Normalized Confusion Matrix:", Color::BrightCyan)
87    );
88    println!(
89        "{}",
90        cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
91    );
92
93    // Print metrics
94    println!(
95        "\n{} {:.3}",
96        colorize("Overall Accuracy:", Color::BrightMagenta),
97        cm.accuracy()
98    );
99
100    let precision = cm.precision();
101    let recall = cm.recall();
102    let f1 = cm.f1_score();
103
104    println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
105    for i in 0..n_classes {
106        println!(
107            "  {}: {}={:.3}, {}={:.3}, {}={:.3}",
108            colorize(format!("Class {}", i), Color::BrightYellow),
109            colorize("Precision", Color::BrightCyan),
110            precision[i],
111            colorize("Recall", Color::BrightGreen),
112            recall[i],
113            colorize("F1", Color::BrightBlue),
114            f1[i]
115        );
116    }
117
118    println!(
119        "{} {:.3}",
120        colorize("Macro F1 Score:", Color::BrightMagenta),
121        cm.macro_f1()
122    );
123
124    // 2. Feature Importance Visualization
125    println!(
126        "\n{}",
127        stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
128    );
129
130    // Generate random feature importance scores
131    let feature_names = (0..n_features)
132        .map(|i| format!("Feature_{}", i))
133        .collect::<Vec<String>>();
134
135    let importance = Array1::from_shape_fn(n_features, |i| {
136        // Make some features more important than others
137        let base = (n_features - i) as f32 / n_features as f32;
138        base + 0.2 * rng.random::<f32>()
139    });
140
141    let fi = FeatureImportance::new(feature_names, importance)?;
142
143    // Print full feature importance with color
144    println!(
145        "{}",
146        fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
147    );
148
149    // Print top-5 features with color
150    println!(
151        "\n{}",
152        colorize("Top 5 Most Important Features:", Color::BrightCyan)
153    );
154    println!(
155        "{}",
156        fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
157    );
158
159    // 3. ROC Curve for Binary Classification
160    println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
161
162    // Generate binary classification data
163    let n_binary = 200;
164    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
165
166    // Generate scores with some predictive power
167    let y_scores = Array1::from_shape_fn(n_binary, |i| {
168        if y_true_binary[i] == 1 {
169            // Higher scores for positive class
170            0.6 + 0.4 * rng.random::<f32>()
171        } else {
172            // Lower scores for negative class
173            0.4 * rng.random::<f32>()
174        }
175    });
176
177    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
178
179    println!(
180        "{} {:.3}",
181        colorize("ROC AUC:", Color::BrightMagenta),
182        roc.auc
183    );
184
185    println!("\n{}", roc.to_ascii(None, 50, 20));
186
187    // 4. Learning Curve Visualization
188    println!(
189        "\n{}",
190        stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
191    );
192
193    // Generate learning curve data
194    let n_points = 10;
195    let n_cv = 5;
196
197    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
198
199    // Generate training scores (decreasing with size due to overfitting)
200    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
201        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
202    });
203
204    // Generate validation scores (increasing with size)
205    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
206        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
207    });
208
209    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
210
211    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
212
213    // Print final message with color
214    println!(
215        "\n{}",
216        colorize(
217            "Model evaluation visualizations completed successfully!",
218            Color::BrightGreen
219        )
220    );
221
222    Ok(())
223}
examples/neural_confusion_matrix.rs (line 292)
118fn main() -> Result<()> {
119    println!("Neural Network Confusion Matrix Visualization");
120    println!("==============================================\n");
121
122    // Initialize RNG with a fixed seed for reproducibility
123    let mut rng = SmallRng::seed_from_u64(42);
124
125    // Generate spiral dataset for 3-class classification
126    let n_classes = 3;
127    let n_samples_per_class = 100;
128    let noise = 0.15;
129
130    let (x, y) = generate_spiral_dataset(n_samples_per_class, n_classes, noise, &mut rng);
131    println!(
132        "Generated spiral dataset with {} classes, {} samples per class",
133        n_classes, n_samples_per_class
134    );
135
136    // Split data into training and test sets (80/20 split)
137    let n_samples = x.shape()[0];
138    let n_train = (n_samples as f32 * 0.8) as usize;
139    let n_test = n_samples - n_train;
140
141    let x_train = x.slice(ndarray::s![0..n_train, ..]).to_owned();
142    let y_train = y.slice(ndarray::s![0..n_train]).to_owned();
143    let x_test = x.slice(ndarray::s![n_train.., ..]).to_owned();
144    let y_test = y.slice(ndarray::s![n_train..]).to_owned();
145
146    println!(
147        "Split data into {} training and {} test samples",
148        n_train, n_test
149    );
150
151    // Create a classification model
152    let input_dim = 2; // 2D input (x, y coordinates)
153    let hidden_dim = 32; // Hidden layer size
154    let output_dim = n_classes; // One output per class
155
156    let mut model = create_classification_model(input_dim, hidden_dim, output_dim, &mut rng)?;
157    println!("Created model with {} layers", model.num_layers());
158
159    // Setup loss function and optimizer
160    let loss_fn = MeanSquaredError::new();
161    let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
162
163    // Train the model
164    let epochs = 100;
165    let x_train_dyn = x_train.clone().into_dyn();
166    let y_train_onehot = one_hot_encode(&y_train, n_classes);
167    let y_train_onehot_dyn = y_train_onehot.into_dyn();
168
169    // Create visualization callback for training metrics
170    let mut visualization_cb = VisualizationCallback::new(10) // Show every 10 epochs
171        .with_tracked_metrics(vec![
172            "train_loss".to_string(),
173            "val_accuracy".to_string(),
174        ]);
175
176    // Define class labels for confusion matrix
177    let class_labels = vec![
178        "Class A".to_string(),
179        "Class B".to_string(),
180        "Class C".to_string(),
181    ];
182
183    // Train the model (simple manual training loop)
184    println!("\nTraining model...");
185
186    // Initialize history for tracking metrics
187    let mut epoch_history = HashMap::new();
188    epoch_history.insert("train_loss".to_string(), Vec::new());
189    epoch_history.insert("val_accuracy".to_string(), Vec::new());
190
191    // Training loop
192    for epoch in 0..epochs {
193        // Train for one epoch
194        let train_loss =
195            model.train_batch(&x_train_dyn, &y_train_onehot_dyn, &loss_fn, &mut optimizer)?;
196
197        // Compute validation accuracy
198        let x_test_dyn = x_test.clone().into_dyn();
199        let predictions = model.forward(&x_test_dyn)?;
200        let predicted_classes = predictions_to_classes(&predictions);
201
202        // Calculate validation accuracy
203        let mut correct = 0;
204        for i in 0..n_test {
205            if predicted_classes[i] == y_test[i] {
206                correct += 1;
207            }
208        }
209        let val_accuracy = correct as f32 / n_test as f32;
210
211        // Store metrics
212        epoch_history
213            .get_mut("train_loss")
214            .unwrap()
215            .push(train_loss);
216        epoch_history
217            .get_mut("val_accuracy")
218            .unwrap()
219            .push(val_accuracy);
220
221        // Print progress
222        if (epoch + 1) % 10 == 0 || epoch == 0 {
223            println!(
224                "Epoch {}/{}: loss = {:.6}, val_accuracy = {:.4}",
225                epoch + 1,
226                epochs,
227                train_loss,
228                val_accuracy
229            );
230        }
231
232        // Update visualization callback
233        let mut context = CallbackContext {
234            epoch,
235            total_epochs: epochs,
236            batch: 0,
237            total_batches: 1,
238            batch_loss: None,
239            epoch_loss: Some(train_loss),
240            val_loss: None,
241            metrics: vec![val_accuracy],
242            history: &epoch_history,
243            stop_training: false,
244            model: None,
245        };
246
247        // Visualize progress with metrics chart
248        if epoch % 10 == 0 || epoch == epochs - 1 {
249            visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
250        }
251
252        // Calculate and show confusion matrix during training
253        if epoch % 20 == 0 || epoch == epochs - 1 {
254            // Create confusion matrix
255            let cm = ConfusionMatrix::<f32>::new(
256                &y_test.view(),
257                &predicted_classes.view(),
258                Some(n_classes),
259                Some(class_labels.clone()),
260            )?;
261
262            // Show heatmap visualization
263            println!("\nConfusion Matrix at Epoch {}:", epoch + 1);
264            println!(
265                "{}",
266                cm.to_heatmap(
267                    Some(&format!("Confusion Matrix - Epoch {}", epoch + 1)),
268                    true
269                )
270            );
271        }
272    }
273
274    // Final evaluation
275    println!("\nFinal model evaluation:");
276
277    // Make predictions on test set
278    let x_test_dyn = x_test.clone().into_dyn();
279    let predictions = model.forward(&x_test_dyn)?;
280    let predicted_classes = predictions_to_classes(&predictions);
281
282    // Create confusion matrix
283    let cm = ConfusionMatrix::<f32>::new(
284        &y_test.view(),
285        &predicted_classes.view(),
286        Some(n_classes),
287        Some(class_labels.clone()),
288    )?;
289
290    // Calculate and show metrics
291    let accuracy = cm.accuracy();
292    let precision = cm.precision();
293    let recall = cm.recall();
294    let f1 = cm.f1_score();
295
296    println!("\nFinal Classification Metrics:");
297    println!("Overall Accuracy: {:.4}", accuracy);
298
299    println!("\nPer-Class Metrics:");
300    println!("Class | Precision | Recall | F1-Score");
301    println!("-----------------------------------");
302
303    for i in 0..n_classes {
304        println!(
305            "{}    | {:.4}     | {:.4}  | {:.4}",
306            class_labels[i], precision[i], recall[i], f1[i]
307        );
308    }
309
310    println!("\nMacro F1-Score: {:.4}", cm.macro_f1());
311
312    // Show different confusion matrix visualizations
313    println!("\nFinal Confusion Matrix Visualizations:");
314
315    // 1. Standard confusion matrix
316    println!("\n1. Standard Confusion Matrix:");
317    println!("{}", cm.to_ascii(Some("Final Confusion Matrix"), false));
318
319    // 2. Normalized confusion matrix
320    println!("\n2. Normalized Confusion Matrix:");
321    println!(
322        "{}",
323        cm.to_ascii(Some("Final Normalized Confusion Matrix"), true)
324    );
325
326    // 3. Confusion matrix heatmap
327    println!("\n3. Confusion Matrix Heatmap:");
328    println!(
329        "{}",
330        cm.to_heatmap(Some("Final Confusion Matrix Heatmap"), true)
331    );
332
333    // 4. Error pattern analysis
334    println!("\n4. Error Pattern Analysis:");
335    println!("{}", cm.error_heatmap(Some("Final Error Pattern Analysis")));
336
337    println!("\nNeural Network Confusion Matrix Visualization Complete!");
338    Ok(())
339}
Source

pub fn recall(&self) -> Array1<F>

Calculate recall for each class

Examples found in repository?
examples/model_visualization_example.rs (line 66)
7fn main() -> Result<()> {
8    println!("Neural Network Model Evaluation Visualization Example\n");
9
10    // Generate some example data
11    let n_samples = 500;
12    let n_features = 10;
13    let n_classes = 4;
14
15    println!(
16        "Generating {} samples with {} features for {} classes",
17        n_samples, n_features, n_classes
18    );
19
20    // 1. Confusion Matrix Example
21    println!("\n--- Confusion Matrix Visualization ---\n");
22
23    // Create a deterministic RNG for reproducibility
24    let mut rng = SmallRng::seed_from_u64(42);
25
26    // Generate random predictions and true labels
27    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
28
29    // Create slightly correlated predictions (not completely random)
30    let y_pred = Array::from_shape_fn(n_samples, |i| {
31        if rng.random::<f32>() < 0.7 {
32            // 70% chance of correct prediction
33            y_true[i]
34        } else {
35            // 30% chance of random class
36            rng.random_range(0..n_classes)
37        }
38    });
39
40    // Create confusion matrix
41    let class_labels = vec![
42        "Class A".to_string(),
43        "Class B".to_string(),
44        "Class C".to_string(),
45        "Class D".to_string(),
46    ];
47
48    let cm = ConfusionMatrix::<f32>::new(
49        &y_true.view(),
50        &y_pred.view(),
51        Some(n_classes),
52        Some(class_labels),
53    )?;
54
55    // Print raw and normalized confusion matrices
56    println!("Raw Confusion Matrix:\n");
57    println!("{}", cm.to_ascii(Some("Confusion Matrix"), false));
58
59    println!("\nNormalized Confusion Matrix:\n");
60    println!("{}", cm.to_ascii(Some("Normalized Confusion Matrix"), true));
61
62    // Print metrics
63    println!("\nAccuracy: {:.3}", cm.accuracy());
64
65    let precision = cm.precision();
66    let recall = cm.recall();
67    let f1 = cm.f1_score();
68
69    println!("Per-class metrics:");
70    for i in 0..n_classes {
71        println!(
72            "  Class {}: Precision={:.3}, Recall={:.3}, F1={:.3}",
73            i, precision[i], recall[i], f1[i]
74        );
75    }
76
77    println!("Macro F1 Score: {:.3}", cm.macro_f1());
78
79    // 2. Feature Importance Visualization
80    println!("\n--- Feature Importance Visualization ---\n");
81
82    // Generate random feature importance scores
83    let feature_names = (0..n_features)
84        .map(|i| format!("Feature_{}", i))
85        .collect::<Vec<String>>();
86
87    let importance = Array1::from_shape_fn(n_features, |i| {
88        // Make some features more important than others
89        let base = (n_features - i) as f32 / n_features as f32;
90        base + 0.2 * rng.random::<f32>()
91    });
92
93    let fi = FeatureImportance::new(feature_names, importance)?;
94
95    // Print full feature importance
96    println!("{}", fi.to_ascii(Some("Feature Importance"), 60, None));
97
98    // Print top-5 features
99    println!("\nTop 5 Most Important Features:\n");
100    println!("{}", fi.to_ascii(Some("Top 5 Features"), 60, Some(5)));
101
102    // 3. ROC Curve for Binary Classification
103    println!("\n--- ROC Curve Visualization ---\n");
104
105    // Generate binary classification data
106    let n_binary = 200;
107    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
108
109    // Generate scores with some predictive power
110    let y_scores = Array1::from_shape_fn(n_binary, |i| {
111        if y_true_binary[i] == 1 {
112            // Higher scores for positive class
113            0.6 + 0.4 * rng.random::<f32>()
114        } else {
115            // Lower scores for negative class
116            0.4 * rng.random::<f32>()
117        }
118    });
119
120    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
121
122    println!("ROC AUC: {:.3}", roc.auc);
123    println!("\n{}", roc.to_ascii(None, 50, 20));
124
125    // 4. Learning Curve Visualization
126    println!("\n--- Learning Curve Visualization ---\n");
127
128    // Generate learning curve data
129    let n_points = 10;
130    let n_cv = 5;
131
132    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
133
134    // Generate training scores (decreasing with size due to overfitting)
135    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
136        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
137    });
138
139    // Generate validation scores (increasing with size)
140    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
141        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
142    });
143
144    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
145
146    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
147
148    // Print final message
149    println!("\nModel evaluation visualizations completed successfully!");
150
151    Ok(())
152}
More examples
Hide additional examples
examples/colored_eval_visualization.rs (line 101)
10fn main() -> Result<()> {
11    println!(
12        "{}",
13        stylize("Neural Network Model Evaluation with Color", Style::Bold)
14    );
15    println!("{}", "-".repeat(50));
16
17    // Set up color options
18    let color_options = ColorOptions {
19        enabled: true,
20        use_background: false,
21        use_bright: true,
22    };
23
24    // Generate some example data
25    let n_samples = 500;
26    let n_features = 10;
27    let n_classes = 4;
28
29    println!(
30        "\n{} {} {} {} {} {}",
31        colorize("Generating", Color::BrightGreen),
32        colorize(n_samples.to_string(), Color::BrightYellow),
33        colorize("samples with", Color::BrightGreen),
34        colorize(n_features.to_string(), Color::BrightYellow),
35        colorize("features for", Color::BrightGreen),
36        colorize(n_classes.to_string(), Color::BrightYellow),
37    );
38
39    // Create a deterministic RNG for reproducibility
40    let mut rng = SmallRng::seed_from_u64(42);
41
42    // 1. Confusion Matrix Example
43    println!(
44        "\n{}",
45        stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
46    );
47
48    // Generate random predictions and true labels
49    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
50
51    // Create slightly correlated predictions (not completely random)
52    let y_pred = Array::from_shape_fn(n_samples, |i| {
53        if rng.random::<f32>() < 0.7 {
54            // 70% chance of correct prediction
55            y_true[i]
56        } else {
57            // 30% chance of random class
58            rng.random_range(0..n_classes)
59        }
60    });
61
62    // Create confusion matrix
63    let class_labels = vec![
64        "Class A".to_string(),
65        "Class B".to_string(),
66        "Class C".to_string(),
67        "Class D".to_string(),
68    ];
69
70    let cm = ConfusionMatrix::<f32>::new(
71        &y_true.view(),
72        &y_pred.view(),
73        Some(n_classes),
74        Some(class_labels),
75    )?;
76
77    // Print raw and normalized confusion matrices with color
78    println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
79    println!(
80        "{}",
81        cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
82    );
83
84    println!(
85        "\n{}",
86        colorize("Normalized Confusion Matrix:", Color::BrightCyan)
87    );
88    println!(
89        "{}",
90        cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
91    );
92
93    // Print metrics
94    println!(
95        "\n{} {:.3}",
96        colorize("Overall Accuracy:", Color::BrightMagenta),
97        cm.accuracy()
98    );
99
100    let precision = cm.precision();
101    let recall = cm.recall();
102    let f1 = cm.f1_score();
103
104    println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
105    for i in 0..n_classes {
106        println!(
107            "  {}: {}={:.3}, {}={:.3}, {}={:.3}",
108            colorize(format!("Class {}", i), Color::BrightYellow),
109            colorize("Precision", Color::BrightCyan),
110            precision[i],
111            colorize("Recall", Color::BrightGreen),
112            recall[i],
113            colorize("F1", Color::BrightBlue),
114            f1[i]
115        );
116    }
117
118    println!(
119        "{} {:.3}",
120        colorize("Macro F1 Score:", Color::BrightMagenta),
121        cm.macro_f1()
122    );
123
124    // 2. Feature Importance Visualization
125    println!(
126        "\n{}",
127        stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
128    );
129
130    // Generate random feature importance scores
131    let feature_names = (0..n_features)
132        .map(|i| format!("Feature_{}", i))
133        .collect::<Vec<String>>();
134
135    let importance = Array1::from_shape_fn(n_features, |i| {
136        // Make some features more important than others
137        let base = (n_features - i) as f32 / n_features as f32;
138        base + 0.2 * rng.random::<f32>()
139    });
140
141    let fi = FeatureImportance::new(feature_names, importance)?;
142
143    // Print full feature importance with color
144    println!(
145        "{}",
146        fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
147    );
148
149    // Print top-5 features with color
150    println!(
151        "\n{}",
152        colorize("Top 5 Most Important Features:", Color::BrightCyan)
153    );
154    println!(
155        "{}",
156        fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
157    );
158
159    // 3. ROC Curve for Binary Classification
160    println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
161
162    // Generate binary classification data
163    let n_binary = 200;
164    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
165
166    // Generate scores with some predictive power
167    let y_scores = Array1::from_shape_fn(n_binary, |i| {
168        if y_true_binary[i] == 1 {
169            // Higher scores for positive class
170            0.6 + 0.4 * rng.random::<f32>()
171        } else {
172            // Lower scores for negative class
173            0.4 * rng.random::<f32>()
174        }
175    });
176
177    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
178
179    println!(
180        "{} {:.3}",
181        colorize("ROC AUC:", Color::BrightMagenta),
182        roc.auc
183    );
184
185    println!("\n{}", roc.to_ascii(None, 50, 20));
186
187    // 4. Learning Curve Visualization
188    println!(
189        "\n{}",
190        stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
191    );
192
193    // Generate learning curve data
194    let n_points = 10;
195    let n_cv = 5;
196
197    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
198
199    // Generate training scores (decreasing with size due to overfitting)
200    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
201        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
202    });
203
204    // Generate validation scores (increasing with size)
205    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
206        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
207    });
208
209    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
210
211    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
212
213    // Print final message with color
214    println!(
215        "\n{}",
216        colorize(
217            "Model evaluation visualizations completed successfully!",
218            Color::BrightGreen
219        )
220    );
221
222    Ok(())
223}
examples/neural_confusion_matrix.rs (line 293)
118fn main() -> Result<()> {
119    println!("Neural Network Confusion Matrix Visualization");
120    println!("==============================================\n");
121
122    // Initialize RNG with a fixed seed for reproducibility
123    let mut rng = SmallRng::seed_from_u64(42);
124
125    // Generate spiral dataset for 3-class classification
126    let n_classes = 3;
127    let n_samples_per_class = 100;
128    let noise = 0.15;
129
130    let (x, y) = generate_spiral_dataset(n_samples_per_class, n_classes, noise, &mut rng);
131    println!(
132        "Generated spiral dataset with {} classes, {} samples per class",
133        n_classes, n_samples_per_class
134    );
135
136    // Split data into training and test sets (80/20 split)
137    let n_samples = x.shape()[0];
138    let n_train = (n_samples as f32 * 0.8) as usize;
139    let n_test = n_samples - n_train;
140
141    let x_train = x.slice(ndarray::s![0..n_train, ..]).to_owned();
142    let y_train = y.slice(ndarray::s![0..n_train]).to_owned();
143    let x_test = x.slice(ndarray::s![n_train.., ..]).to_owned();
144    let y_test = y.slice(ndarray::s![n_train..]).to_owned();
145
146    println!(
147        "Split data into {} training and {} test samples",
148        n_train, n_test
149    );
150
151    // Create a classification model
152    let input_dim = 2; // 2D input (x, y coordinates)
153    let hidden_dim = 32; // Hidden layer size
154    let output_dim = n_classes; // One output per class
155
156    let mut model = create_classification_model(input_dim, hidden_dim, output_dim, &mut rng)?;
157    println!("Created model with {} layers", model.num_layers());
158
159    // Setup loss function and optimizer
160    let loss_fn = MeanSquaredError::new();
161    let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
162
163    // Train the model
164    let epochs = 100;
165    let x_train_dyn = x_train.clone().into_dyn();
166    let y_train_onehot = one_hot_encode(&y_train, n_classes);
167    let y_train_onehot_dyn = y_train_onehot.into_dyn();
168
169    // Create visualization callback for training metrics
170    let mut visualization_cb = VisualizationCallback::new(10) // Show every 10 epochs
171        .with_tracked_metrics(vec![
172            "train_loss".to_string(),
173            "val_accuracy".to_string(),
174        ]);
175
176    // Define class labels for confusion matrix
177    let class_labels = vec![
178        "Class A".to_string(),
179        "Class B".to_string(),
180        "Class C".to_string(),
181    ];
182
183    // Train the model (simple manual training loop)
184    println!("\nTraining model...");
185
186    // Initialize history for tracking metrics
187    let mut epoch_history = HashMap::new();
188    epoch_history.insert("train_loss".to_string(), Vec::new());
189    epoch_history.insert("val_accuracy".to_string(), Vec::new());
190
191    // Training loop
192    for epoch in 0..epochs {
193        // Train for one epoch
194        let train_loss =
195            model.train_batch(&x_train_dyn, &y_train_onehot_dyn, &loss_fn, &mut optimizer)?;
196
197        // Compute validation accuracy
198        let x_test_dyn = x_test.clone().into_dyn();
199        let predictions = model.forward(&x_test_dyn)?;
200        let predicted_classes = predictions_to_classes(&predictions);
201
202        // Calculate validation accuracy
203        let mut correct = 0;
204        for i in 0..n_test {
205            if predicted_classes[i] == y_test[i] {
206                correct += 1;
207            }
208        }
209        let val_accuracy = correct as f32 / n_test as f32;
210
211        // Store metrics
212        epoch_history
213            .get_mut("train_loss")
214            .unwrap()
215            .push(train_loss);
216        epoch_history
217            .get_mut("val_accuracy")
218            .unwrap()
219            .push(val_accuracy);
220
221        // Print progress
222        if (epoch + 1) % 10 == 0 || epoch == 0 {
223            println!(
224                "Epoch {}/{}: loss = {:.6}, val_accuracy = {:.4}",
225                epoch + 1,
226                epochs,
227                train_loss,
228                val_accuracy
229            );
230        }
231
232        // Update visualization callback
233        let mut context = CallbackContext {
234            epoch,
235            total_epochs: epochs,
236            batch: 0,
237            total_batches: 1,
238            batch_loss: None,
239            epoch_loss: Some(train_loss),
240            val_loss: None,
241            metrics: vec![val_accuracy],
242            history: &epoch_history,
243            stop_training: false,
244            model: None,
245        };
246
247        // Visualize progress with metrics chart
248        if epoch % 10 == 0 || epoch == epochs - 1 {
249            visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
250        }
251
252        // Calculate and show confusion matrix during training
253        if epoch % 20 == 0 || epoch == epochs - 1 {
254            // Create confusion matrix
255            let cm = ConfusionMatrix::<f32>::new(
256                &y_test.view(),
257                &predicted_classes.view(),
258                Some(n_classes),
259                Some(class_labels.clone()),
260            )?;
261
262            // Show heatmap visualization
263            println!("\nConfusion Matrix at Epoch {}:", epoch + 1);
264            println!(
265                "{}",
266                cm.to_heatmap(
267                    Some(&format!("Confusion Matrix - Epoch {}", epoch + 1)),
268                    true
269                )
270            );
271        }
272    }
273
274    // Final evaluation
275    println!("\nFinal model evaluation:");
276
277    // Make predictions on test set
278    let x_test_dyn = x_test.clone().into_dyn();
279    let predictions = model.forward(&x_test_dyn)?;
280    let predicted_classes = predictions_to_classes(&predictions);
281
282    // Create confusion matrix
283    let cm = ConfusionMatrix::<f32>::new(
284        &y_test.view(),
285        &predicted_classes.view(),
286        Some(n_classes),
287        Some(class_labels.clone()),
288    )?;
289
290    // Calculate and show metrics
291    let accuracy = cm.accuracy();
292    let precision = cm.precision();
293    let recall = cm.recall();
294    let f1 = cm.f1_score();
295
296    println!("\nFinal Classification Metrics:");
297    println!("Overall Accuracy: {:.4}", accuracy);
298
299    println!("\nPer-Class Metrics:");
300    println!("Class | Precision | Recall | F1-Score");
301    println!("-----------------------------------");
302
303    for i in 0..n_classes {
304        println!(
305            "{}    | {:.4}     | {:.4}  | {:.4}",
306            class_labels[i], precision[i], recall[i], f1[i]
307        );
308    }
309
310    println!("\nMacro F1-Score: {:.4}", cm.macro_f1());
311
312    // Show different confusion matrix visualizations
313    println!("\nFinal Confusion Matrix Visualizations:");
314
315    // 1. Standard confusion matrix
316    println!("\n1. Standard Confusion Matrix:");
317    println!("{}", cm.to_ascii(Some("Final Confusion Matrix"), false));
318
319    // 2. Normalized confusion matrix
320    println!("\n2. Normalized Confusion Matrix:");
321    println!(
322        "{}",
323        cm.to_ascii(Some("Final Normalized Confusion Matrix"), true)
324    );
325
326    // 3. Confusion matrix heatmap
327    println!("\n3. Confusion Matrix Heatmap:");
328    println!(
329        "{}",
330        cm.to_heatmap(Some("Final Confusion Matrix Heatmap"), true)
331    );
332
333    // 4. Error pattern analysis
334    println!("\n4. Error Pattern Analysis:");
335    println!("{}", cm.error_heatmap(Some("Final Error Pattern Analysis")));
336
337    println!("\nNeural Network Confusion Matrix Visualization Complete!");
338    Ok(())
339}
Source

pub fn f1_score(&self) -> Array1<F>

Calculate F1 score for each class

Examples found in repository?
examples/model_visualization_example.rs (line 67)
7fn main() -> Result<()> {
8    println!("Neural Network Model Evaluation Visualization Example\n");
9
10    // Generate some example data
11    let n_samples = 500;
12    let n_features = 10;
13    let n_classes = 4;
14
15    println!(
16        "Generating {} samples with {} features for {} classes",
17        n_samples, n_features, n_classes
18    );
19
20    // 1. Confusion Matrix Example
21    println!("\n--- Confusion Matrix Visualization ---\n");
22
23    // Create a deterministic RNG for reproducibility
24    let mut rng = SmallRng::seed_from_u64(42);
25
26    // Generate random predictions and true labels
27    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
28
29    // Create slightly correlated predictions (not completely random)
30    let y_pred = Array::from_shape_fn(n_samples, |i| {
31        if rng.random::<f32>() < 0.7 {
32            // 70% chance of correct prediction
33            y_true[i]
34        } else {
35            // 30% chance of random class
36            rng.random_range(0..n_classes)
37        }
38    });
39
40    // Create confusion matrix
41    let class_labels = vec![
42        "Class A".to_string(),
43        "Class B".to_string(),
44        "Class C".to_string(),
45        "Class D".to_string(),
46    ];
47
48    let cm = ConfusionMatrix::<f32>::new(
49        &y_true.view(),
50        &y_pred.view(),
51        Some(n_classes),
52        Some(class_labels),
53    )?;
54
55    // Print raw and normalized confusion matrices
56    println!("Raw Confusion Matrix:\n");
57    println!("{}", cm.to_ascii(Some("Confusion Matrix"), false));
58
59    println!("\nNormalized Confusion Matrix:\n");
60    println!("{}", cm.to_ascii(Some("Normalized Confusion Matrix"), true));
61
62    // Print metrics
63    println!("\nAccuracy: {:.3}", cm.accuracy());
64
65    let precision = cm.precision();
66    let recall = cm.recall();
67    let f1 = cm.f1_score();
68
69    println!("Per-class metrics:");
70    for i in 0..n_classes {
71        println!(
72            "  Class {}: Precision={:.3}, Recall={:.3}, F1={:.3}",
73            i, precision[i], recall[i], f1[i]
74        );
75    }
76
77    println!("Macro F1 Score: {:.3}", cm.macro_f1());
78
79    // 2. Feature Importance Visualization
80    println!("\n--- Feature Importance Visualization ---\n");
81
82    // Generate random feature importance scores
83    let feature_names = (0..n_features)
84        .map(|i| format!("Feature_{}", i))
85        .collect::<Vec<String>>();
86
87    let importance = Array1::from_shape_fn(n_features, |i| {
88        // Make some features more important than others
89        let base = (n_features - i) as f32 / n_features as f32;
90        base + 0.2 * rng.random::<f32>()
91    });
92
93    let fi = FeatureImportance::new(feature_names, importance)?;
94
95    // Print full feature importance
96    println!("{}", fi.to_ascii(Some("Feature Importance"), 60, None));
97
98    // Print top-5 features
99    println!("\nTop 5 Most Important Features:\n");
100    println!("{}", fi.to_ascii(Some("Top 5 Features"), 60, Some(5)));
101
102    // 3. ROC Curve for Binary Classification
103    println!("\n--- ROC Curve Visualization ---\n");
104
105    // Generate binary classification data
106    let n_binary = 200;
107    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
108
109    // Generate scores with some predictive power
110    let y_scores = Array1::from_shape_fn(n_binary, |i| {
111        if y_true_binary[i] == 1 {
112            // Higher scores for positive class
113            0.6 + 0.4 * rng.random::<f32>()
114        } else {
115            // Lower scores for negative class
116            0.4 * rng.random::<f32>()
117        }
118    });
119
120    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
121
122    println!("ROC AUC: {:.3}", roc.auc);
123    println!("\n{}", roc.to_ascii(None, 50, 20));
124
125    // 4. Learning Curve Visualization
126    println!("\n--- Learning Curve Visualization ---\n");
127
128    // Generate learning curve data
129    let n_points = 10;
130    let n_cv = 5;
131
132    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
133
134    // Generate training scores (decreasing with size due to overfitting)
135    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
136        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
137    });
138
139    // Generate validation scores (increasing with size)
140    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
141        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
142    });
143
144    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
145
146    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
147
148    // Print final message
149    println!("\nModel evaluation visualizations completed successfully!");
150
151    Ok(())
152}
More examples
Hide additional examples
examples/colored_eval_visualization.rs (line 102)
10fn main() -> Result<()> {
11    println!(
12        "{}",
13        stylize("Neural Network Model Evaluation with Color", Style::Bold)
14    );
15    println!("{}", "-".repeat(50));
16
17    // Set up color options
18    let color_options = ColorOptions {
19        enabled: true,
20        use_background: false,
21        use_bright: true,
22    };
23
24    // Generate some example data
25    let n_samples = 500;
26    let n_features = 10;
27    let n_classes = 4;
28
29    println!(
30        "\n{} {} {} {} {} {}",
31        colorize("Generating", Color::BrightGreen),
32        colorize(n_samples.to_string(), Color::BrightYellow),
33        colorize("samples with", Color::BrightGreen),
34        colorize(n_features.to_string(), Color::BrightYellow),
35        colorize("features for", Color::BrightGreen),
36        colorize(n_classes.to_string(), Color::BrightYellow),
37    );
38
39    // Create a deterministic RNG for reproducibility
40    let mut rng = SmallRng::seed_from_u64(42);
41
42    // 1. Confusion Matrix Example
43    println!(
44        "\n{}",
45        stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
46    );
47
48    // Generate random predictions and true labels
49    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
50
51    // Create slightly correlated predictions (not completely random)
52    let y_pred = Array::from_shape_fn(n_samples, |i| {
53        if rng.random::<f32>() < 0.7 {
54            // 70% chance of correct prediction
55            y_true[i]
56        } else {
57            // 30% chance of random class
58            rng.random_range(0..n_classes)
59        }
60    });
61
62    // Create confusion matrix
63    let class_labels = vec![
64        "Class A".to_string(),
65        "Class B".to_string(),
66        "Class C".to_string(),
67        "Class D".to_string(),
68    ];
69
70    let cm = ConfusionMatrix::<f32>::new(
71        &y_true.view(),
72        &y_pred.view(),
73        Some(n_classes),
74        Some(class_labels),
75    )?;
76
77    // Print raw and normalized confusion matrices with color
78    println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
79    println!(
80        "{}",
81        cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
82    );
83
84    println!(
85        "\n{}",
86        colorize("Normalized Confusion Matrix:", Color::BrightCyan)
87    );
88    println!(
89        "{}",
90        cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
91    );
92
93    // Print metrics
94    println!(
95        "\n{} {:.3}",
96        colorize("Overall Accuracy:", Color::BrightMagenta),
97        cm.accuracy()
98    );
99
100    let precision = cm.precision();
101    let recall = cm.recall();
102    let f1 = cm.f1_score();
103
104    println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
105    for i in 0..n_classes {
106        println!(
107            "  {}: {}={:.3}, {}={:.3}, {}={:.3}",
108            colorize(format!("Class {}", i), Color::BrightYellow),
109            colorize("Precision", Color::BrightCyan),
110            precision[i],
111            colorize("Recall", Color::BrightGreen),
112            recall[i],
113            colorize("F1", Color::BrightBlue),
114            f1[i]
115        );
116    }
117
118    println!(
119        "{} {:.3}",
120        colorize("Macro F1 Score:", Color::BrightMagenta),
121        cm.macro_f1()
122    );
123
124    // 2. Feature Importance Visualization
125    println!(
126        "\n{}",
127        stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
128    );
129
130    // Generate random feature importance scores
131    let feature_names = (0..n_features)
132        .map(|i| format!("Feature_{}", i))
133        .collect::<Vec<String>>();
134
135    let importance = Array1::from_shape_fn(n_features, |i| {
136        // Make some features more important than others
137        let base = (n_features - i) as f32 / n_features as f32;
138        base + 0.2 * rng.random::<f32>()
139    });
140
141    let fi = FeatureImportance::new(feature_names, importance)?;
142
143    // Print full feature importance with color
144    println!(
145        "{}",
146        fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
147    );
148
149    // Print top-5 features with color
150    println!(
151        "\n{}",
152        colorize("Top 5 Most Important Features:", Color::BrightCyan)
153    );
154    println!(
155        "{}",
156        fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
157    );
158
159    // 3. ROC Curve for Binary Classification
160    println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
161
162    // Generate binary classification data
163    let n_binary = 200;
164    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
165
166    // Generate scores with some predictive power
167    let y_scores = Array1::from_shape_fn(n_binary, |i| {
168        if y_true_binary[i] == 1 {
169            // Higher scores for positive class
170            0.6 + 0.4 * rng.random::<f32>()
171        } else {
172            // Lower scores for negative class
173            0.4 * rng.random::<f32>()
174        }
175    });
176
177    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
178
179    println!(
180        "{} {:.3}",
181        colorize("ROC AUC:", Color::BrightMagenta),
182        roc.auc
183    );
184
185    println!("\n{}", roc.to_ascii(None, 50, 20));
186
187    // 4. Learning Curve Visualization
188    println!(
189        "\n{}",
190        stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
191    );
192
193    // Generate learning curve data
194    let n_points = 10;
195    let n_cv = 5;
196
197    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
198
199    // Generate training scores (decreasing with size due to overfitting)
200    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
201        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
202    });
203
204    // Generate validation scores (increasing with size)
205    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
206        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
207    });
208
209    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
210
211    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
212
213    // Print final message with color
214    println!(
215        "\n{}",
216        colorize(
217            "Model evaluation visualizations completed successfully!",
218            Color::BrightGreen
219        )
220    );
221
222    Ok(())
223}
examples/neural_confusion_matrix.rs (line 294)
118fn main() -> Result<()> {
119    println!("Neural Network Confusion Matrix Visualization");
120    println!("==============================================\n");
121
122    // Initialize RNG with a fixed seed for reproducibility
123    let mut rng = SmallRng::seed_from_u64(42);
124
125    // Generate spiral dataset for 3-class classification
126    let n_classes = 3;
127    let n_samples_per_class = 100;
128    let noise = 0.15;
129
130    let (x, y) = generate_spiral_dataset(n_samples_per_class, n_classes, noise, &mut rng);
131    println!(
132        "Generated spiral dataset with {} classes, {} samples per class",
133        n_classes, n_samples_per_class
134    );
135
136    // Split data into training and test sets (80/20 split)
137    let n_samples = x.shape()[0];
138    let n_train = (n_samples as f32 * 0.8) as usize;
139    let n_test = n_samples - n_train;
140
141    let x_train = x.slice(ndarray::s![0..n_train, ..]).to_owned();
142    let y_train = y.slice(ndarray::s![0..n_train]).to_owned();
143    let x_test = x.slice(ndarray::s![n_train.., ..]).to_owned();
144    let y_test = y.slice(ndarray::s![n_train..]).to_owned();
145
146    println!(
147        "Split data into {} training and {} test samples",
148        n_train, n_test
149    );
150
151    // Create a classification model
152    let input_dim = 2; // 2D input (x, y coordinates)
153    let hidden_dim = 32; // Hidden layer size
154    let output_dim = n_classes; // One output per class
155
156    let mut model = create_classification_model(input_dim, hidden_dim, output_dim, &mut rng)?;
157    println!("Created model with {} layers", model.num_layers());
158
159    // Setup loss function and optimizer
160    let loss_fn = MeanSquaredError::new();
161    let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
162
163    // Train the model
164    let epochs = 100;
165    let x_train_dyn = x_train.clone().into_dyn();
166    let y_train_onehot = one_hot_encode(&y_train, n_classes);
167    let y_train_onehot_dyn = y_train_onehot.into_dyn();
168
169    // Create visualization callback for training metrics
170    let mut visualization_cb = VisualizationCallback::new(10) // Show every 10 epochs
171        .with_tracked_metrics(vec![
172            "train_loss".to_string(),
173            "val_accuracy".to_string(),
174        ]);
175
176    // Define class labels for confusion matrix
177    let class_labels = vec![
178        "Class A".to_string(),
179        "Class B".to_string(),
180        "Class C".to_string(),
181    ];
182
183    // Train the model (simple manual training loop)
184    println!("\nTraining model...");
185
186    // Initialize history for tracking metrics
187    let mut epoch_history = HashMap::new();
188    epoch_history.insert("train_loss".to_string(), Vec::new());
189    epoch_history.insert("val_accuracy".to_string(), Vec::new());
190
191    // Training loop
192    for epoch in 0..epochs {
193        // Train for one epoch
194        let train_loss =
195            model.train_batch(&x_train_dyn, &y_train_onehot_dyn, &loss_fn, &mut optimizer)?;
196
197        // Compute validation accuracy
198        let x_test_dyn = x_test.clone().into_dyn();
199        let predictions = model.forward(&x_test_dyn)?;
200        let predicted_classes = predictions_to_classes(&predictions);
201
202        // Calculate validation accuracy
203        let mut correct = 0;
204        for i in 0..n_test {
205            if predicted_classes[i] == y_test[i] {
206                correct += 1;
207            }
208        }
209        let val_accuracy = correct as f32 / n_test as f32;
210
211        // Store metrics
212        epoch_history
213            .get_mut("train_loss")
214            .unwrap()
215            .push(train_loss);
216        epoch_history
217            .get_mut("val_accuracy")
218            .unwrap()
219            .push(val_accuracy);
220
221        // Print progress
222        if (epoch + 1) % 10 == 0 || epoch == 0 {
223            println!(
224                "Epoch {}/{}: loss = {:.6}, val_accuracy = {:.4}",
225                epoch + 1,
226                epochs,
227                train_loss,
228                val_accuracy
229            );
230        }
231
232        // Update visualization callback
233        let mut context = CallbackContext {
234            epoch,
235            total_epochs: epochs,
236            batch: 0,
237            total_batches: 1,
238            batch_loss: None,
239            epoch_loss: Some(train_loss),
240            val_loss: None,
241            metrics: vec![val_accuracy],
242            history: &epoch_history,
243            stop_training: false,
244            model: None,
245        };
246
247        // Visualize progress with metrics chart
248        if epoch % 10 == 0 || epoch == epochs - 1 {
249            visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
250        }
251
252        // Calculate and show confusion matrix during training
253        if epoch % 20 == 0 || epoch == epochs - 1 {
254            // Create confusion matrix
255            let cm = ConfusionMatrix::<f32>::new(
256                &y_test.view(),
257                &predicted_classes.view(),
258                Some(n_classes),
259                Some(class_labels.clone()),
260            )?;
261
262            // Show heatmap visualization
263            println!("\nConfusion Matrix at Epoch {}:", epoch + 1);
264            println!(
265                "{}",
266                cm.to_heatmap(
267                    Some(&format!("Confusion Matrix - Epoch {}", epoch + 1)),
268                    true
269                )
270            );
271        }
272    }
273
274    // Final evaluation
275    println!("\nFinal model evaluation:");
276
277    // Make predictions on test set
278    let x_test_dyn = x_test.clone().into_dyn();
279    let predictions = model.forward(&x_test_dyn)?;
280    let predicted_classes = predictions_to_classes(&predictions);
281
282    // Create confusion matrix
283    let cm = ConfusionMatrix::<f32>::new(
284        &y_test.view(),
285        &predicted_classes.view(),
286        Some(n_classes),
287        Some(class_labels.clone()),
288    )?;
289
290    // Calculate and show metrics
291    let accuracy = cm.accuracy();
292    let precision = cm.precision();
293    let recall = cm.recall();
294    let f1 = cm.f1_score();
295
296    println!("\nFinal Classification Metrics:");
297    println!("Overall Accuracy: {:.4}", accuracy);
298
299    println!("\nPer-Class Metrics:");
300    println!("Class | Precision | Recall | F1-Score");
301    println!("-----------------------------------");
302
303    for i in 0..n_classes {
304        println!(
305            "{}    | {:.4}     | {:.4}  | {:.4}",
306            class_labels[i], precision[i], recall[i], f1[i]
307        );
308    }
309
310    println!("\nMacro F1-Score: {:.4}", cm.macro_f1());
311
312    // Show different confusion matrix visualizations
313    println!("\nFinal Confusion Matrix Visualizations:");
314
315    // 1. Standard confusion matrix
316    println!("\n1. Standard Confusion Matrix:");
317    println!("{}", cm.to_ascii(Some("Final Confusion Matrix"), false));
318
319    // 2. Normalized confusion matrix
320    println!("\n2. Normalized Confusion Matrix:");
321    println!(
322        "{}",
323        cm.to_ascii(Some("Final Normalized Confusion Matrix"), true)
324    );
325
326    // 3. Confusion matrix heatmap
327    println!("\n3. Confusion Matrix Heatmap:");
328    println!(
329        "{}",
330        cm.to_heatmap(Some("Final Confusion Matrix Heatmap"), true)
331    );
332
333    // 4. Error pattern analysis
334    println!("\n4. Error Pattern Analysis:");
335    println!("{}", cm.error_heatmap(Some("Final Error Pattern Analysis")));
336
337    println!("\nNeural Network Confusion Matrix Visualization Complete!");
338    Ok(())
339}
Source

pub fn macro_f1(&self) -> F

Calculate macro-averaged F1 score

Examples found in repository?
examples/model_visualization_example.rs (line 77)
7fn main() -> Result<()> {
8    println!("Neural Network Model Evaluation Visualization Example\n");
9
10    // Generate some example data
11    let n_samples = 500;
12    let n_features = 10;
13    let n_classes = 4;
14
15    println!(
16        "Generating {} samples with {} features for {} classes",
17        n_samples, n_features, n_classes
18    );
19
20    // 1. Confusion Matrix Example
21    println!("\n--- Confusion Matrix Visualization ---\n");
22
23    // Create a deterministic RNG for reproducibility
24    let mut rng = SmallRng::seed_from_u64(42);
25
26    // Generate random predictions and true labels
27    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
28
29    // Create slightly correlated predictions (not completely random)
30    let y_pred = Array::from_shape_fn(n_samples, |i| {
31        if rng.random::<f32>() < 0.7 {
32            // 70% chance of correct prediction
33            y_true[i]
34        } else {
35            // 30% chance of random class
36            rng.random_range(0..n_classes)
37        }
38    });
39
40    // Create confusion matrix
41    let class_labels = vec![
42        "Class A".to_string(),
43        "Class B".to_string(),
44        "Class C".to_string(),
45        "Class D".to_string(),
46    ];
47
48    let cm = ConfusionMatrix::<f32>::new(
49        &y_true.view(),
50        &y_pred.view(),
51        Some(n_classes),
52        Some(class_labels),
53    )?;
54
55    // Print raw and normalized confusion matrices
56    println!("Raw Confusion Matrix:\n");
57    println!("{}", cm.to_ascii(Some("Confusion Matrix"), false));
58
59    println!("\nNormalized Confusion Matrix:\n");
60    println!("{}", cm.to_ascii(Some("Normalized Confusion Matrix"), true));
61
62    // Print metrics
63    println!("\nAccuracy: {:.3}", cm.accuracy());
64
65    let precision = cm.precision();
66    let recall = cm.recall();
67    let f1 = cm.f1_score();
68
69    println!("Per-class metrics:");
70    for i in 0..n_classes {
71        println!(
72            "  Class {}: Precision={:.3}, Recall={:.3}, F1={:.3}",
73            i, precision[i], recall[i], f1[i]
74        );
75    }
76
77    println!("Macro F1 Score: {:.3}", cm.macro_f1());
78
79    // 2. Feature Importance Visualization
80    println!("\n--- Feature Importance Visualization ---\n");
81
82    // Generate random feature importance scores
83    let feature_names = (0..n_features)
84        .map(|i| format!("Feature_{}", i))
85        .collect::<Vec<String>>();
86
87    let importance = Array1::from_shape_fn(n_features, |i| {
88        // Make some features more important than others
89        let base = (n_features - i) as f32 / n_features as f32;
90        base + 0.2 * rng.random::<f32>()
91    });
92
93    let fi = FeatureImportance::new(feature_names, importance)?;
94
95    // Print full feature importance
96    println!("{}", fi.to_ascii(Some("Feature Importance"), 60, None));
97
98    // Print top-5 features
99    println!("\nTop 5 Most Important Features:\n");
100    println!("{}", fi.to_ascii(Some("Top 5 Features"), 60, Some(5)));
101
102    // 3. ROC Curve for Binary Classification
103    println!("\n--- ROC Curve Visualization ---\n");
104
105    // Generate binary classification data
106    let n_binary = 200;
107    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
108
109    // Generate scores with some predictive power
110    let y_scores = Array1::from_shape_fn(n_binary, |i| {
111        if y_true_binary[i] == 1 {
112            // Higher scores for positive class
113            0.6 + 0.4 * rng.random::<f32>()
114        } else {
115            // Lower scores for negative class
116            0.4 * rng.random::<f32>()
117        }
118    });
119
120    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
121
122    println!("ROC AUC: {:.3}", roc.auc);
123    println!("\n{}", roc.to_ascii(None, 50, 20));
124
125    // 4. Learning Curve Visualization
126    println!("\n--- Learning Curve Visualization ---\n");
127
128    // Generate learning curve data
129    let n_points = 10;
130    let n_cv = 5;
131
132    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
133
134    // Generate training scores (decreasing with size due to overfitting)
135    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
136        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
137    });
138
139    // Generate validation scores (increasing with size)
140    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
141        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
142    });
143
144    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
145
146    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
147
148    // Print final message
149    println!("\nModel evaluation visualizations completed successfully!");
150
151    Ok(())
152}
More examples
Hide additional examples
examples/colored_eval_visualization.rs (line 121)
10fn main() -> Result<()> {
11    println!(
12        "{}",
13        stylize("Neural Network Model Evaluation with Color", Style::Bold)
14    );
15    println!("{}", "-".repeat(50));
16
17    // Set up color options
18    let color_options = ColorOptions {
19        enabled: true,
20        use_background: false,
21        use_bright: true,
22    };
23
24    // Generate some example data
25    let n_samples = 500;
26    let n_features = 10;
27    let n_classes = 4;
28
29    println!(
30        "\n{} {} {} {} {} {}",
31        colorize("Generating", Color::BrightGreen),
32        colorize(n_samples.to_string(), Color::BrightYellow),
33        colorize("samples with", Color::BrightGreen),
34        colorize(n_features.to_string(), Color::BrightYellow),
35        colorize("features for", Color::BrightGreen),
36        colorize(n_classes.to_string(), Color::BrightYellow),
37    );
38
39    // Create a deterministic RNG for reproducibility
40    let mut rng = SmallRng::seed_from_u64(42);
41
42    // 1. Confusion Matrix Example
43    println!(
44        "\n{}",
45        stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
46    );
47
48    // Generate random predictions and true labels
49    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
50
51    // Create slightly correlated predictions (not completely random)
52    let y_pred = Array::from_shape_fn(n_samples, |i| {
53        if rng.random::<f32>() < 0.7 {
54            // 70% chance of correct prediction
55            y_true[i]
56        } else {
57            // 30% chance of random class
58            rng.random_range(0..n_classes)
59        }
60    });
61
62    // Create confusion matrix
63    let class_labels = vec![
64        "Class A".to_string(),
65        "Class B".to_string(),
66        "Class C".to_string(),
67        "Class D".to_string(),
68    ];
69
70    let cm = ConfusionMatrix::<f32>::new(
71        &y_true.view(),
72        &y_pred.view(),
73        Some(n_classes),
74        Some(class_labels),
75    )?;
76
77    // Print raw and normalized confusion matrices with color
78    println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
79    println!(
80        "{}",
81        cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
82    );
83
84    println!(
85        "\n{}",
86        colorize("Normalized Confusion Matrix:", Color::BrightCyan)
87    );
88    println!(
89        "{}",
90        cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
91    );
92
93    // Print metrics
94    println!(
95        "\n{} {:.3}",
96        colorize("Overall Accuracy:", Color::BrightMagenta),
97        cm.accuracy()
98    );
99
100    let precision = cm.precision();
101    let recall = cm.recall();
102    let f1 = cm.f1_score();
103
104    println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
105    for i in 0..n_classes {
106        println!(
107            "  {}: {}={:.3}, {}={:.3}, {}={:.3}",
108            colorize(format!("Class {}", i), Color::BrightYellow),
109            colorize("Precision", Color::BrightCyan),
110            precision[i],
111            colorize("Recall", Color::BrightGreen),
112            recall[i],
113            colorize("F1", Color::BrightBlue),
114            f1[i]
115        );
116    }
117
118    println!(
119        "{} {:.3}",
120        colorize("Macro F1 Score:", Color::BrightMagenta),
121        cm.macro_f1()
122    );
123
124    // 2. Feature Importance Visualization
125    println!(
126        "\n{}",
127        stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
128    );
129
130    // Generate random feature importance scores
131    let feature_names = (0..n_features)
132        .map(|i| format!("Feature_{}", i))
133        .collect::<Vec<String>>();
134
135    let importance = Array1::from_shape_fn(n_features, |i| {
136        // Make some features more important than others
137        let base = (n_features - i) as f32 / n_features as f32;
138        base + 0.2 * rng.random::<f32>()
139    });
140
141    let fi = FeatureImportance::new(feature_names, importance)?;
142
143    // Print full feature importance with color
144    println!(
145        "{}",
146        fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
147    );
148
149    // Print top-5 features with color
150    println!(
151        "\n{}",
152        colorize("Top 5 Most Important Features:", Color::BrightCyan)
153    );
154    println!(
155        "{}",
156        fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
157    );
158
159    // 3. ROC Curve for Binary Classification
160    println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
161
162    // Generate binary classification data
163    let n_binary = 200;
164    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
165
166    // Generate scores with some predictive power
167    let y_scores = Array1::from_shape_fn(n_binary, |i| {
168        if y_true_binary[i] == 1 {
169            // Higher scores for positive class
170            0.6 + 0.4 * rng.random::<f32>()
171        } else {
172            // Lower scores for negative class
173            0.4 * rng.random::<f32>()
174        }
175    });
176
177    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
178
179    println!(
180        "{} {:.3}",
181        colorize("ROC AUC:", Color::BrightMagenta),
182        roc.auc
183    );
184
185    println!("\n{}", roc.to_ascii(None, 50, 20));
186
187    // 4. Learning Curve Visualization
188    println!(
189        "\n{}",
190        stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
191    );
192
193    // Generate learning curve data
194    let n_points = 10;
195    let n_cv = 5;
196
197    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
198
199    // Generate training scores (decreasing with size due to overfitting)
200    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
201        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
202    });
203
204    // Generate validation scores (increasing with size)
205    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
206        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
207    });
208
209    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
210
211    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
212
213    // Print final message with color
214    println!(
215        "\n{}",
216        colorize(
217            "Model evaluation visualizations completed successfully!",
218            Color::BrightGreen
219        )
220    );
221
222    Ok(())
223}
examples/neural_confusion_matrix.rs (line 310)
118fn main() -> Result<()> {
119    println!("Neural Network Confusion Matrix Visualization");
120    println!("==============================================\n");
121
122    // Initialize RNG with a fixed seed for reproducibility
123    let mut rng = SmallRng::seed_from_u64(42);
124
125    // Generate spiral dataset for 3-class classification
126    let n_classes = 3;
127    let n_samples_per_class = 100;
128    let noise = 0.15;
129
130    let (x, y) = generate_spiral_dataset(n_samples_per_class, n_classes, noise, &mut rng);
131    println!(
132        "Generated spiral dataset with {} classes, {} samples per class",
133        n_classes, n_samples_per_class
134    );
135
136    // Split data into training and test sets (80/20 split)
137    let n_samples = x.shape()[0];
138    let n_train = (n_samples as f32 * 0.8) as usize;
139    let n_test = n_samples - n_train;
140
141    let x_train = x.slice(ndarray::s![0..n_train, ..]).to_owned();
142    let y_train = y.slice(ndarray::s![0..n_train]).to_owned();
143    let x_test = x.slice(ndarray::s![n_train.., ..]).to_owned();
144    let y_test = y.slice(ndarray::s![n_train..]).to_owned();
145
146    println!(
147        "Split data into {} training and {} test samples",
148        n_train, n_test
149    );
150
151    // Create a classification model
152    let input_dim = 2; // 2D input (x, y coordinates)
153    let hidden_dim = 32; // Hidden layer size
154    let output_dim = n_classes; // One output per class
155
156    let mut model = create_classification_model(input_dim, hidden_dim, output_dim, &mut rng)?;
157    println!("Created model with {} layers", model.num_layers());
158
159    // Setup loss function and optimizer
160    let loss_fn = MeanSquaredError::new();
161    let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
162
163    // Train the model
164    let epochs = 100;
165    let x_train_dyn = x_train.clone().into_dyn();
166    let y_train_onehot = one_hot_encode(&y_train, n_classes);
167    let y_train_onehot_dyn = y_train_onehot.into_dyn();
168
169    // Create visualization callback for training metrics
170    let mut visualization_cb = VisualizationCallback::new(10) // Show every 10 epochs
171        .with_tracked_metrics(vec![
172            "train_loss".to_string(),
173            "val_accuracy".to_string(),
174        ]);
175
176    // Define class labels for confusion matrix
177    let class_labels = vec![
178        "Class A".to_string(),
179        "Class B".to_string(),
180        "Class C".to_string(),
181    ];
182
183    // Train the model (simple manual training loop)
184    println!("\nTraining model...");
185
186    // Initialize history for tracking metrics
187    let mut epoch_history = HashMap::new();
188    epoch_history.insert("train_loss".to_string(), Vec::new());
189    epoch_history.insert("val_accuracy".to_string(), Vec::new());
190
191    // Training loop
192    for epoch in 0..epochs {
193        // Train for one epoch
194        let train_loss =
195            model.train_batch(&x_train_dyn, &y_train_onehot_dyn, &loss_fn, &mut optimizer)?;
196
197        // Compute validation accuracy
198        let x_test_dyn = x_test.clone().into_dyn();
199        let predictions = model.forward(&x_test_dyn)?;
200        let predicted_classes = predictions_to_classes(&predictions);
201
202        // Calculate validation accuracy
203        let mut correct = 0;
204        for i in 0..n_test {
205            if predicted_classes[i] == y_test[i] {
206                correct += 1;
207            }
208        }
209        let val_accuracy = correct as f32 / n_test as f32;
210
211        // Store metrics
212        epoch_history
213            .get_mut("train_loss")
214            .unwrap()
215            .push(train_loss);
216        epoch_history
217            .get_mut("val_accuracy")
218            .unwrap()
219            .push(val_accuracy);
220
221        // Print progress
222        if (epoch + 1) % 10 == 0 || epoch == 0 {
223            println!(
224                "Epoch {}/{}: loss = {:.6}, val_accuracy = {:.4}",
225                epoch + 1,
226                epochs,
227                train_loss,
228                val_accuracy
229            );
230        }
231
232        // Update visualization callback
233        let mut context = CallbackContext {
234            epoch,
235            total_epochs: epochs,
236            batch: 0,
237            total_batches: 1,
238            batch_loss: None,
239            epoch_loss: Some(train_loss),
240            val_loss: None,
241            metrics: vec![val_accuracy],
242            history: &epoch_history,
243            stop_training: false,
244            model: None,
245        };
246
247        // Visualize progress with metrics chart
248        if epoch % 10 == 0 || epoch == epochs - 1 {
249            visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
250        }
251
252        // Calculate and show confusion matrix during training
253        if epoch % 20 == 0 || epoch == epochs - 1 {
254            // Create confusion matrix
255            let cm = ConfusionMatrix::<f32>::new(
256                &y_test.view(),
257                &predicted_classes.view(),
258                Some(n_classes),
259                Some(class_labels.clone()),
260            )?;
261
262            // Show heatmap visualization
263            println!("\nConfusion Matrix at Epoch {}:", epoch + 1);
264            println!(
265                "{}",
266                cm.to_heatmap(
267                    Some(&format!("Confusion Matrix - Epoch {}", epoch + 1)),
268                    true
269                )
270            );
271        }
272    }
273
274    // Final evaluation
275    println!("\nFinal model evaluation:");
276
277    // Make predictions on test set
278    let x_test_dyn = x_test.clone().into_dyn();
279    let predictions = model.forward(&x_test_dyn)?;
280    let predicted_classes = predictions_to_classes(&predictions);
281
282    // Create confusion matrix
283    let cm = ConfusionMatrix::<f32>::new(
284        &y_test.view(),
285        &predicted_classes.view(),
286        Some(n_classes),
287        Some(class_labels.clone()),
288    )?;
289
290    // Calculate and show metrics
291    let accuracy = cm.accuracy();
292    let precision = cm.precision();
293    let recall = cm.recall();
294    let f1 = cm.f1_score();
295
296    println!("\nFinal Classification Metrics:");
297    println!("Overall Accuracy: {:.4}", accuracy);
298
299    println!("\nPer-Class Metrics:");
300    println!("Class | Precision | Recall | F1-Score");
301    println!("-----------------------------------");
302
303    for i in 0..n_classes {
304        println!(
305            "{}    | {:.4}     | {:.4}  | {:.4}",
306            class_labels[i], precision[i], recall[i], f1[i]
307        );
308    }
309
310    println!("\nMacro F1-Score: {:.4}", cm.macro_f1());
311
312    // Show different confusion matrix visualizations
313    println!("\nFinal Confusion Matrix Visualizations:");
314
315    // 1. Standard confusion matrix
316    println!("\n1. Standard Confusion Matrix:");
317    println!("{}", cm.to_ascii(Some("Final Confusion Matrix"), false));
318
319    // 2. Normalized confusion matrix
320    println!("\n2. Normalized Confusion Matrix:");
321    println!(
322        "{}",
323        cm.to_ascii(Some("Final Normalized Confusion Matrix"), true)
324    );
325
326    // 3. Confusion matrix heatmap
327    println!("\n3. Confusion Matrix Heatmap:");
328    println!(
329        "{}",
330        cm.to_heatmap(Some("Final Confusion Matrix Heatmap"), true)
331    );
332
333    // 4. Error pattern analysis
334    println!("\n4. Error Pattern Analysis:");
335    println!("{}", cm.error_heatmap(Some("Final Error Pattern Analysis")));
336
337    println!("\nNeural Network Confusion Matrix Visualization Complete!");
338    Ok(())
339}
Source

pub fn class_metrics(&self) -> HashMap<String, Vec<F>>

Get class-wise metrics as a HashMap

Source

pub fn to_ascii(&self, title: Option<&str>, normalized: bool) -> String

Convert the confusion matrix to an ASCII representation

Examples found in repository?
examples/error_pattern_heatmap.rs (line 63)
6fn main() {
7    // Create a reproducible random number generator
8    let mut rng = SmallRng::seed_from_u64(42);
9
10    // Generate synthetic multiclass classification data with specific error patterns
11    let num_classes = 5;
12
13    // Create confusion matrix with controlled error patterns
14    let mut matrix = vec![vec![0; num_classes]; num_classes];
15
16    // Set diagonal elements (correct classifications) with high values
17    for i in 0..num_classes {
18        matrix[i][i] = 70 + rng.random_range(0..15); // 70-85 correct per class
19    }
20
21    // Create specific error patterns:
22    // - Classes 0 and 1 often confused
23    matrix[0][1] = 25;
24    matrix[1][0] = 15;
25
26    // - Class 2 sometimes confused with Class 3
27    matrix[2][3] = 18;
28
29    // - Class 4 has some misclassifications to all other classes
30    matrix[4][0] = 8;
31    matrix[4][1] = 5;
32    matrix[4][2] = 10;
33    matrix[4][3] = 12;
34
35    // - Some minor errors scattered about
36    for i in 0..num_classes {
37        for j in 0..num_classes {
38            if i != j && matrix[i][j] == 0 {
39                matrix[i][j] = rng.random_range(0..5);
40            }
41        }
42    }
43
44    // Convert to ndarray
45    let flat_matrix: Vec<f64> = matrix.iter().flatten().map(|&x| x as f64).collect();
46    let ndarray_matrix =
47        ndarray::Array::from_shape_vec((num_classes, num_classes), flat_matrix).unwrap();
48
49    // Create class labels
50    let class_labels = vec![
51        "Class A".to_string(),
52        "Class B".to_string(),
53        "Class C".to_string(),
54        "Class D".to_string(),
55        "Class E".to_string(),
56    ];
57
58    // Create confusion matrix
59    let cm = ConfusionMatrix::from_matrix(ndarray_matrix, Some(class_labels)).unwrap();
60
61    // Example 1: Standard confusion matrix
62    println!("Example 1: Standard Confusion Matrix\n");
63    let regular_output = cm.to_ascii(Some("Classification Results"), false);
64    println!("{}", regular_output);
65
66    // Example 2: Normal heatmap
67    println!("\n\nExample 2: Standard Heatmap Visualization\n");
68    let color_options = ColorOptions {
69        enabled: true,
70        use_bright: true,
71        use_background: false,
72    };
73    let heatmap_output = cm.to_heatmap_with_options(
74        Some("Classification Heatmap"),
75        true, // normalized
76        &color_options,
77    );
78    println!("{}", heatmap_output);
79
80    // Example 3: Error pattern heatmap
81    println!("\n\nExample 3: Error Pattern Heatmap (highlighting misclassifications)\n");
82    let error_heatmap = cm.error_heatmap(Some("Misclassification Analysis"));
83    println!("{}", error_heatmap);
84}
More examples
Hide additional examples
examples/confusion_matrix_heatmap.rs (line 67)
7fn main() {
8    // Create a reproducible random number generator
9    let mut rng = SmallRng::seed_from_u64(42);
10
11    // Generate synthetic multiclass classification data
12    let num_classes = 5;
13    let n_samples = 500;
14
15    // Generate true labels (0 to num_classes-1)
16    let mut y_true = Vec::with_capacity(n_samples);
17    for _ in 0..n_samples {
18        y_true.push(rng.random_range(0..num_classes));
19    }
20
21    // Generate predicted labels with controlled accuracy
22    let mut y_pred = Vec::with_capacity(n_samples);
23    for &true_label in &y_true {
24        // 80% chance to predict correctly, 20% chance of error
25        if rng.random::<f64>() < 0.8 {
26            y_pred.push(true_label);
27        } else {
28            // When wrong, tend to predict adjacent classes more often
29            let mut pred = true_label;
30            while pred == true_label {
31                // Generate error that's more likely to be close to true label
32                let error_margin = (rng.random::<f64>() * 2.0).round() as usize; // 0, 1, or 2
33                if rng.random::<bool>() {
34                    pred = (true_label + error_margin) % num_classes;
35                } else {
36                    pred = (true_label + num_classes - error_margin) % num_classes;
37                }
38            }
39            y_pred.push(pred);
40        }
41    }
42
43    // Convert to ndarray arrays
44    let y_true_array = Array1::from(y_true);
45    let y_pred_array = Array1::from(y_pred);
46
47    // Create class labels
48    let class_labels = vec![
49        "Cat".to_string(),
50        "Dog".to_string(),
51        "Bird".to_string(),
52        "Fish".to_string(),
53        "Rabbit".to_string(),
54    ];
55
56    // Create confusion matrix
57    let cm = ConfusionMatrix::<f64>::new(
58        &y_true_array.view(),
59        &y_pred_array.view(),
60        Some(num_classes),
61        Some(class_labels),
62    )
63    .unwrap();
64
65    // Example 1: Standard confusion matrix
66    println!("Example 1: Standard Confusion Matrix\n");
67    let regular_output = cm.to_ascii(Some("Animal Classification Results"), false);
68    println!("{}", regular_output);
69
70    // Example 2: Confusion matrix with color
71    println!("\n\nExample 2: Colored Confusion Matrix\n");
72    let color_options = ColorOptions {
73        enabled: true,
74        use_bright: true,
75        use_background: false,
76    };
77    let colored_output = cm.to_ascii_with_options(
78        Some("Animal Classification Results (with color)"),
79        false,
80        &color_options,
81    );
82    println!("{}", colored_output);
83
84    // Example 3: Normalized confusion matrix heatmap
85    println!("\n\nExample 3: Normalized Confusion Matrix Heatmap\n");
86    let heatmap_output = cm.to_heatmap_with_options(
87        Some("Animal Classification Heatmap (normalized)"),
88        true, // normalized
89        &color_options,
90    );
91    println!("{}", heatmap_output);
92
93    // Example 4: Raw counts heatmap
94    println!("\n\nExample 4: Raw Counts Confusion Matrix Heatmap\n");
95    let raw_heatmap = cm.to_heatmap_with_options(
96        Some("Animal Classification Heatmap (raw counts)"),
97        false, // not normalized
98        &color_options,
99    );
100    println!("{}", raw_heatmap);
101}
examples/model_visualization_example.rs (line 57)
7fn main() -> Result<()> {
8    println!("Neural Network Model Evaluation Visualization Example\n");
9
10    // Generate some example data
11    let n_samples = 500;
12    let n_features = 10;
13    let n_classes = 4;
14
15    println!(
16        "Generating {} samples with {} features for {} classes",
17        n_samples, n_features, n_classes
18    );
19
20    // 1. Confusion Matrix Example
21    println!("\n--- Confusion Matrix Visualization ---\n");
22
23    // Create a deterministic RNG for reproducibility
24    let mut rng = SmallRng::seed_from_u64(42);
25
26    // Generate random predictions and true labels
27    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
28
29    // Create slightly correlated predictions (not completely random)
30    let y_pred = Array::from_shape_fn(n_samples, |i| {
31        if rng.random::<f32>() < 0.7 {
32            // 70% chance of correct prediction
33            y_true[i]
34        } else {
35            // 30% chance of random class
36            rng.random_range(0..n_classes)
37        }
38    });
39
40    // Create confusion matrix
41    let class_labels = vec![
42        "Class A".to_string(),
43        "Class B".to_string(),
44        "Class C".to_string(),
45        "Class D".to_string(),
46    ];
47
48    let cm = ConfusionMatrix::<f32>::new(
49        &y_true.view(),
50        &y_pred.view(),
51        Some(n_classes),
52        Some(class_labels),
53    )?;
54
55    // Print raw and normalized confusion matrices
56    println!("Raw Confusion Matrix:\n");
57    println!("{}", cm.to_ascii(Some("Confusion Matrix"), false));
58
59    println!("\nNormalized Confusion Matrix:\n");
60    println!("{}", cm.to_ascii(Some("Normalized Confusion Matrix"), true));
61
62    // Print metrics
63    println!("\nAccuracy: {:.3}", cm.accuracy());
64
65    let precision = cm.precision();
66    let recall = cm.recall();
67    let f1 = cm.f1_score();
68
69    println!("Per-class metrics:");
70    for i in 0..n_classes {
71        println!(
72            "  Class {}: Precision={:.3}, Recall={:.3}, F1={:.3}",
73            i, precision[i], recall[i], f1[i]
74        );
75    }
76
77    println!("Macro F1 Score: {:.3}", cm.macro_f1());
78
79    // 2. Feature Importance Visualization
80    println!("\n--- Feature Importance Visualization ---\n");
81
82    // Generate random feature importance scores
83    let feature_names = (0..n_features)
84        .map(|i| format!("Feature_{}", i))
85        .collect::<Vec<String>>();
86
87    let importance = Array1::from_shape_fn(n_features, |i| {
88        // Make some features more important than others
89        let base = (n_features - i) as f32 / n_features as f32;
90        base + 0.2 * rng.random::<f32>()
91    });
92
93    let fi = FeatureImportance::new(feature_names, importance)?;
94
95    // Print full feature importance
96    println!("{}", fi.to_ascii(Some("Feature Importance"), 60, None));
97
98    // Print top-5 features
99    println!("\nTop 5 Most Important Features:\n");
100    println!("{}", fi.to_ascii(Some("Top 5 Features"), 60, Some(5)));
101
102    // 3. ROC Curve for Binary Classification
103    println!("\n--- ROC Curve Visualization ---\n");
104
105    // Generate binary classification data
106    let n_binary = 200;
107    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
108
109    // Generate scores with some predictive power
110    let y_scores = Array1::from_shape_fn(n_binary, |i| {
111        if y_true_binary[i] == 1 {
112            // Higher scores for positive class
113            0.6 + 0.4 * rng.random::<f32>()
114        } else {
115            // Lower scores for negative class
116            0.4 * rng.random::<f32>()
117        }
118    });
119
120    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
121
122    println!("ROC AUC: {:.3}", roc.auc);
123    println!("\n{}", roc.to_ascii(None, 50, 20));
124
125    // 4. Learning Curve Visualization
126    println!("\n--- Learning Curve Visualization ---\n");
127
128    // Generate learning curve data
129    let n_points = 10;
130    let n_cv = 5;
131
132    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
133
134    // Generate training scores (decreasing with size due to overfitting)
135    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
136        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
137    });
138
139    // Generate validation scores (increasing with size)
140    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
141        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
142    });
143
144    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
145
146    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
147
148    // Print final message
149    println!("\nModel evaluation visualizations completed successfully!");
150
151    Ok(())
152}
examples/neural_confusion_matrix.rs (line 317)
118fn main() -> Result<()> {
119    println!("Neural Network Confusion Matrix Visualization");
120    println!("==============================================\n");
121
122    // Initialize RNG with a fixed seed for reproducibility
123    let mut rng = SmallRng::seed_from_u64(42);
124
125    // Generate spiral dataset for 3-class classification
126    let n_classes = 3;
127    let n_samples_per_class = 100;
128    let noise = 0.15;
129
130    let (x, y) = generate_spiral_dataset(n_samples_per_class, n_classes, noise, &mut rng);
131    println!(
132        "Generated spiral dataset with {} classes, {} samples per class",
133        n_classes, n_samples_per_class
134    );
135
136    // Split data into training and test sets (80/20 split)
137    let n_samples = x.shape()[0];
138    let n_train = (n_samples as f32 * 0.8) as usize;
139    let n_test = n_samples - n_train;
140
141    let x_train = x.slice(ndarray::s![0..n_train, ..]).to_owned();
142    let y_train = y.slice(ndarray::s![0..n_train]).to_owned();
143    let x_test = x.slice(ndarray::s![n_train.., ..]).to_owned();
144    let y_test = y.slice(ndarray::s![n_train..]).to_owned();
145
146    println!(
147        "Split data into {} training and {} test samples",
148        n_train, n_test
149    );
150
151    // Create a classification model
152    let input_dim = 2; // 2D input (x, y coordinates)
153    let hidden_dim = 32; // Hidden layer size
154    let output_dim = n_classes; // One output per class
155
156    let mut model = create_classification_model(input_dim, hidden_dim, output_dim, &mut rng)?;
157    println!("Created model with {} layers", model.num_layers());
158
159    // Setup loss function and optimizer
160    let loss_fn = MeanSquaredError::new();
161    let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
162
163    // Train the model
164    let epochs = 100;
165    let x_train_dyn = x_train.clone().into_dyn();
166    let y_train_onehot = one_hot_encode(&y_train, n_classes);
167    let y_train_onehot_dyn = y_train_onehot.into_dyn();
168
169    // Create visualization callback for training metrics
170    let mut visualization_cb = VisualizationCallback::new(10) // Show every 10 epochs
171        .with_tracked_metrics(vec![
172            "train_loss".to_string(),
173            "val_accuracy".to_string(),
174        ]);
175
176    // Define class labels for confusion matrix
177    let class_labels = vec![
178        "Class A".to_string(),
179        "Class B".to_string(),
180        "Class C".to_string(),
181    ];
182
183    // Train the model (simple manual training loop)
184    println!("\nTraining model...");
185
186    // Initialize history for tracking metrics
187    let mut epoch_history = HashMap::new();
188    epoch_history.insert("train_loss".to_string(), Vec::new());
189    epoch_history.insert("val_accuracy".to_string(), Vec::new());
190
191    // Training loop
192    for epoch in 0..epochs {
193        // Train for one epoch
194        let train_loss =
195            model.train_batch(&x_train_dyn, &y_train_onehot_dyn, &loss_fn, &mut optimizer)?;
196
197        // Compute validation accuracy
198        let x_test_dyn = x_test.clone().into_dyn();
199        let predictions = model.forward(&x_test_dyn)?;
200        let predicted_classes = predictions_to_classes(&predictions);
201
202        // Calculate validation accuracy
203        let mut correct = 0;
204        for i in 0..n_test {
205            if predicted_classes[i] == y_test[i] {
206                correct += 1;
207            }
208        }
209        let val_accuracy = correct as f32 / n_test as f32;
210
211        // Store metrics
212        epoch_history
213            .get_mut("train_loss")
214            .unwrap()
215            .push(train_loss);
216        epoch_history
217            .get_mut("val_accuracy")
218            .unwrap()
219            .push(val_accuracy);
220
221        // Print progress
222        if (epoch + 1) % 10 == 0 || epoch == 0 {
223            println!(
224                "Epoch {}/{}: loss = {:.6}, val_accuracy = {:.4}",
225                epoch + 1,
226                epochs,
227                train_loss,
228                val_accuracy
229            );
230        }
231
232        // Update visualization callback
233        let mut context = CallbackContext {
234            epoch,
235            total_epochs: epochs,
236            batch: 0,
237            total_batches: 1,
238            batch_loss: None,
239            epoch_loss: Some(train_loss),
240            val_loss: None,
241            metrics: vec![val_accuracy],
242            history: &epoch_history,
243            stop_training: false,
244            model: None,
245        };
246
247        // Visualize progress with metrics chart
248        if epoch % 10 == 0 || epoch == epochs - 1 {
249            visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
250        }
251
252        // Calculate and show confusion matrix during training
253        if epoch % 20 == 0 || epoch == epochs - 1 {
254            // Create confusion matrix
255            let cm = ConfusionMatrix::<f32>::new(
256                &y_test.view(),
257                &predicted_classes.view(),
258                Some(n_classes),
259                Some(class_labels.clone()),
260            )?;
261
262            // Show heatmap visualization
263            println!("\nConfusion Matrix at Epoch {}:", epoch + 1);
264            println!(
265                "{}",
266                cm.to_heatmap(
267                    Some(&format!("Confusion Matrix - Epoch {}", epoch + 1)),
268                    true
269                )
270            );
271        }
272    }
273
274    // Final evaluation
275    println!("\nFinal model evaluation:");
276
277    // Make predictions on test set
278    let x_test_dyn = x_test.clone().into_dyn();
279    let predictions = model.forward(&x_test_dyn)?;
280    let predicted_classes = predictions_to_classes(&predictions);
281
282    // Create confusion matrix
283    let cm = ConfusionMatrix::<f32>::new(
284        &y_test.view(),
285        &predicted_classes.view(),
286        Some(n_classes),
287        Some(class_labels.clone()),
288    )?;
289
290    // Calculate and show metrics
291    let accuracy = cm.accuracy();
292    let precision = cm.precision();
293    let recall = cm.recall();
294    let f1 = cm.f1_score();
295
296    println!("\nFinal Classification Metrics:");
297    println!("Overall Accuracy: {:.4}", accuracy);
298
299    println!("\nPer-Class Metrics:");
300    println!("Class | Precision | Recall | F1-Score");
301    println!("-----------------------------------");
302
303    for i in 0..n_classes {
304        println!(
305            "{}    | {:.4}     | {:.4}  | {:.4}",
306            class_labels[i], precision[i], recall[i], f1[i]
307        );
308    }
309
310    println!("\nMacro F1-Score: {:.4}", cm.macro_f1());
311
312    // Show different confusion matrix visualizations
313    println!("\nFinal Confusion Matrix Visualizations:");
314
315    // 1. Standard confusion matrix
316    println!("\n1. Standard Confusion Matrix:");
317    println!("{}", cm.to_ascii(Some("Final Confusion Matrix"), false));
318
319    // 2. Normalized confusion matrix
320    println!("\n2. Normalized Confusion Matrix:");
321    println!(
322        "{}",
323        cm.to_ascii(Some("Final Normalized Confusion Matrix"), true)
324    );
325
326    // 3. Confusion matrix heatmap
327    println!("\n3. Confusion Matrix Heatmap:");
328    println!(
329        "{}",
330        cm.to_heatmap(Some("Final Confusion Matrix Heatmap"), true)
331    );
332
333    // 4. Error pattern analysis
334    println!("\n4. Error Pattern Analysis:");
335    println!("{}", cm.error_heatmap(Some("Final Error Pattern Analysis")));
336
337    println!("\nNeural Network Confusion Matrix Visualization Complete!");
338    Ok(())
339}
Source

pub fn to_ascii_with_options( &self, title: Option<&str>, normalized: bool, color_options: &ColorOptions, ) -> String

Convert the confusion matrix to an ASCII representation with color options

Examples found in repository?
examples/confusion_matrix_heatmap.rs (lines 77-81)
7fn main() {
8    // Create a reproducible random number generator
9    let mut rng = SmallRng::seed_from_u64(42);
10
11    // Generate synthetic multiclass classification data
12    let num_classes = 5;
13    let n_samples = 500;
14
15    // Generate true labels (0 to num_classes-1)
16    let mut y_true = Vec::with_capacity(n_samples);
17    for _ in 0..n_samples {
18        y_true.push(rng.random_range(0..num_classes));
19    }
20
21    // Generate predicted labels with controlled accuracy
22    let mut y_pred = Vec::with_capacity(n_samples);
23    for &true_label in &y_true {
24        // 80% chance to predict correctly, 20% chance of error
25        if rng.random::<f64>() < 0.8 {
26            y_pred.push(true_label);
27        } else {
28            // When wrong, tend to predict adjacent classes more often
29            let mut pred = true_label;
30            while pred == true_label {
31                // Generate error that's more likely to be close to true label
32                let error_margin = (rng.random::<f64>() * 2.0).round() as usize; // 0, 1, or 2
33                if rng.random::<bool>() {
34                    pred = (true_label + error_margin) % num_classes;
35                } else {
36                    pred = (true_label + num_classes - error_margin) % num_classes;
37                }
38            }
39            y_pred.push(pred);
40        }
41    }
42
43    // Convert to ndarray arrays
44    let y_true_array = Array1::from(y_true);
45    let y_pred_array = Array1::from(y_pred);
46
47    // Create class labels
48    let class_labels = vec![
49        "Cat".to_string(),
50        "Dog".to_string(),
51        "Bird".to_string(),
52        "Fish".to_string(),
53        "Rabbit".to_string(),
54    ];
55
56    // Create confusion matrix
57    let cm = ConfusionMatrix::<f64>::new(
58        &y_true_array.view(),
59        &y_pred_array.view(),
60        Some(num_classes),
61        Some(class_labels),
62    )
63    .unwrap();
64
65    // Example 1: Standard confusion matrix
66    println!("Example 1: Standard Confusion Matrix\n");
67    let regular_output = cm.to_ascii(Some("Animal Classification Results"), false);
68    println!("{}", regular_output);
69
70    // Example 2: Confusion matrix with color
71    println!("\n\nExample 2: Colored Confusion Matrix\n");
72    let color_options = ColorOptions {
73        enabled: true,
74        use_bright: true,
75        use_background: false,
76    };
77    let colored_output = cm.to_ascii_with_options(
78        Some("Animal Classification Results (with color)"),
79        false,
80        &color_options,
81    );
82    println!("{}", colored_output);
83
84    // Example 3: Normalized confusion matrix heatmap
85    println!("\n\nExample 3: Normalized Confusion Matrix Heatmap\n");
86    let heatmap_output = cm.to_heatmap_with_options(
87        Some("Animal Classification Heatmap (normalized)"),
88        true, // normalized
89        &color_options,
90    );
91    println!("{}", heatmap_output);
92
93    // Example 4: Raw counts heatmap
94    println!("\n\nExample 4: Raw Counts Confusion Matrix Heatmap\n");
95    let raw_heatmap = cm.to_heatmap_with_options(
96        Some("Animal Classification Heatmap (raw counts)"),
97        false, // not normalized
98        &color_options,
99    );
100    println!("{}", raw_heatmap);
101}
More examples
Hide additional examples
examples/colored_eval_visualization.rs (line 81)
10fn main() -> Result<()> {
11    println!(
12        "{}",
13        stylize("Neural Network Model Evaluation with Color", Style::Bold)
14    );
15    println!("{}", "-".repeat(50));
16
17    // Set up color options
18    let color_options = ColorOptions {
19        enabled: true,
20        use_background: false,
21        use_bright: true,
22    };
23
24    // Generate some example data
25    let n_samples = 500;
26    let n_features = 10;
27    let n_classes = 4;
28
29    println!(
30        "\n{} {} {} {} {} {}",
31        colorize("Generating", Color::BrightGreen),
32        colorize(n_samples.to_string(), Color::BrightYellow),
33        colorize("samples with", Color::BrightGreen),
34        colorize(n_features.to_string(), Color::BrightYellow),
35        colorize("features for", Color::BrightGreen),
36        colorize(n_classes.to_string(), Color::BrightYellow),
37    );
38
39    // Create a deterministic RNG for reproducibility
40    let mut rng = SmallRng::seed_from_u64(42);
41
42    // 1. Confusion Matrix Example
43    println!(
44        "\n{}",
45        stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
46    );
47
48    // Generate random predictions and true labels
49    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
50
51    // Create slightly correlated predictions (not completely random)
52    let y_pred = Array::from_shape_fn(n_samples, |i| {
53        if rng.random::<f32>() < 0.7 {
54            // 70% chance of correct prediction
55            y_true[i]
56        } else {
57            // 30% chance of random class
58            rng.random_range(0..n_classes)
59        }
60    });
61
62    // Create confusion matrix
63    let class_labels = vec![
64        "Class A".to_string(),
65        "Class B".to_string(),
66        "Class C".to_string(),
67        "Class D".to_string(),
68    ];
69
70    let cm = ConfusionMatrix::<f32>::new(
71        &y_true.view(),
72        &y_pred.view(),
73        Some(n_classes),
74        Some(class_labels),
75    )?;
76
77    // Print raw and normalized confusion matrices with color
78    println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
79    println!(
80        "{}",
81        cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
82    );
83
84    println!(
85        "\n{}",
86        colorize("Normalized Confusion Matrix:", Color::BrightCyan)
87    );
88    println!(
89        "{}",
90        cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
91    );
92
93    // Print metrics
94    println!(
95        "\n{} {:.3}",
96        colorize("Overall Accuracy:", Color::BrightMagenta),
97        cm.accuracy()
98    );
99
100    let precision = cm.precision();
101    let recall = cm.recall();
102    let f1 = cm.f1_score();
103
104    println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
105    for i in 0..n_classes {
106        println!(
107            "  {}: {}={:.3}, {}={:.3}, {}={:.3}",
108            colorize(format!("Class {}", i), Color::BrightYellow),
109            colorize("Precision", Color::BrightCyan),
110            precision[i],
111            colorize("Recall", Color::BrightGreen),
112            recall[i],
113            colorize("F1", Color::BrightBlue),
114            f1[i]
115        );
116    }
117
118    println!(
119        "{} {:.3}",
120        colorize("Macro F1 Score:", Color::BrightMagenta),
121        cm.macro_f1()
122    );
123
124    // 2. Feature Importance Visualization
125    println!(
126        "\n{}",
127        stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
128    );
129
130    // Generate random feature importance scores
131    let feature_names = (0..n_features)
132        .map(|i| format!("Feature_{}", i))
133        .collect::<Vec<String>>();
134
135    let importance = Array1::from_shape_fn(n_features, |i| {
136        // Make some features more important than others
137        let base = (n_features - i) as f32 / n_features as f32;
138        base + 0.2 * rng.random::<f32>()
139    });
140
141    let fi = FeatureImportance::new(feature_names, importance)?;
142
143    // Print full feature importance with color
144    println!(
145        "{}",
146        fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
147    );
148
149    // Print top-5 features with color
150    println!(
151        "\n{}",
152        colorize("Top 5 Most Important Features:", Color::BrightCyan)
153    );
154    println!(
155        "{}",
156        fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
157    );
158
159    // 3. ROC Curve for Binary Classification
160    println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
161
162    // Generate binary classification data
163    let n_binary = 200;
164    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
165
166    // Generate scores with some predictive power
167    let y_scores = Array1::from_shape_fn(n_binary, |i| {
168        if y_true_binary[i] == 1 {
169            // Higher scores for positive class
170            0.6 + 0.4 * rng.random::<f32>()
171        } else {
172            // Lower scores for negative class
173            0.4 * rng.random::<f32>()
174        }
175    });
176
177    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
178
179    println!(
180        "{} {:.3}",
181        colorize("ROC AUC:", Color::BrightMagenta),
182        roc.auc
183    );
184
185    println!("\n{}", roc.to_ascii(None, 50, 20));
186
187    // 4. Learning Curve Visualization
188    println!(
189        "\n{}",
190        stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
191    );
192
193    // Generate learning curve data
194    let n_points = 10;
195    let n_cv = 5;
196
197    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
198
199    // Generate training scores (decreasing with size due to overfitting)
200    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
201        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
202    });
203
204    // Generate validation scores (increasing with size)
205    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
206        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
207    });
208
209    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
210
211    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
212
213    // Print final message with color
214    println!(
215        "\n{}",
216        colorize(
217            "Model evaluation visualizations completed successfully!",
218            Color::BrightGreen
219        )
220    );
221
222    Ok(())
223}
Source

pub fn to_heatmap(&self, title: Option<&str>, normalized: bool) -> String

Convert the confusion matrix to a heatmap visualization

This creates a colorful heatmap visualization of the confusion matrix where cell colors represent the intensity of values using a detailed color gradient.

§Arguments
  • title - Optional title for the heatmap
  • normalized - Whether to normalize the matrix (row values sum to 1)
§Returns
  • String - ASCII heatmap representation
Examples found in repository?
examples/neural_confusion_matrix.rs (lines 266-269)
118fn main() -> Result<()> {
119    println!("Neural Network Confusion Matrix Visualization");
120    println!("==============================================\n");
121
122    // Initialize RNG with a fixed seed for reproducibility
123    let mut rng = SmallRng::seed_from_u64(42);
124
125    // Generate spiral dataset for 3-class classification
126    let n_classes = 3;
127    let n_samples_per_class = 100;
128    let noise = 0.15;
129
130    let (x, y) = generate_spiral_dataset(n_samples_per_class, n_classes, noise, &mut rng);
131    println!(
132        "Generated spiral dataset with {} classes, {} samples per class",
133        n_classes, n_samples_per_class
134    );
135
136    // Split data into training and test sets (80/20 split)
137    let n_samples = x.shape()[0];
138    let n_train = (n_samples as f32 * 0.8) as usize;
139    let n_test = n_samples - n_train;
140
141    let x_train = x.slice(ndarray::s![0..n_train, ..]).to_owned();
142    let y_train = y.slice(ndarray::s![0..n_train]).to_owned();
143    let x_test = x.slice(ndarray::s![n_train.., ..]).to_owned();
144    let y_test = y.slice(ndarray::s![n_train..]).to_owned();
145
146    println!(
147        "Split data into {} training and {} test samples",
148        n_train, n_test
149    );
150
151    // Create a classification model
152    let input_dim = 2; // 2D input (x, y coordinates)
153    let hidden_dim = 32; // Hidden layer size
154    let output_dim = n_classes; // One output per class
155
156    let mut model = create_classification_model(input_dim, hidden_dim, output_dim, &mut rng)?;
157    println!("Created model with {} layers", model.num_layers());
158
159    // Setup loss function and optimizer
160    let loss_fn = MeanSquaredError::new();
161    let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
162
163    // Train the model
164    let epochs = 100;
165    let x_train_dyn = x_train.clone().into_dyn();
166    let y_train_onehot = one_hot_encode(&y_train, n_classes);
167    let y_train_onehot_dyn = y_train_onehot.into_dyn();
168
169    // Create visualization callback for training metrics
170    let mut visualization_cb = VisualizationCallback::new(10) // Show every 10 epochs
171        .with_tracked_metrics(vec![
172            "train_loss".to_string(),
173            "val_accuracy".to_string(),
174        ]);
175
176    // Define class labels for confusion matrix
177    let class_labels = vec![
178        "Class A".to_string(),
179        "Class B".to_string(),
180        "Class C".to_string(),
181    ];
182
183    // Train the model (simple manual training loop)
184    println!("\nTraining model...");
185
186    // Initialize history for tracking metrics
187    let mut epoch_history = HashMap::new();
188    epoch_history.insert("train_loss".to_string(), Vec::new());
189    epoch_history.insert("val_accuracy".to_string(), Vec::new());
190
191    // Training loop
192    for epoch in 0..epochs {
193        // Train for one epoch
194        let train_loss =
195            model.train_batch(&x_train_dyn, &y_train_onehot_dyn, &loss_fn, &mut optimizer)?;
196
197        // Compute validation accuracy
198        let x_test_dyn = x_test.clone().into_dyn();
199        let predictions = model.forward(&x_test_dyn)?;
200        let predicted_classes = predictions_to_classes(&predictions);
201
202        // Calculate validation accuracy
203        let mut correct = 0;
204        for i in 0..n_test {
205            if predicted_classes[i] == y_test[i] {
206                correct += 1;
207            }
208        }
209        let val_accuracy = correct as f32 / n_test as f32;
210
211        // Store metrics
212        epoch_history
213            .get_mut("train_loss")
214            .unwrap()
215            .push(train_loss);
216        epoch_history
217            .get_mut("val_accuracy")
218            .unwrap()
219            .push(val_accuracy);
220
221        // Print progress
222        if (epoch + 1) % 10 == 0 || epoch == 0 {
223            println!(
224                "Epoch {}/{}: loss = {:.6}, val_accuracy = {:.4}",
225                epoch + 1,
226                epochs,
227                train_loss,
228                val_accuracy
229            );
230        }
231
232        // Update visualization callback
233        let mut context = CallbackContext {
234            epoch,
235            total_epochs: epochs,
236            batch: 0,
237            total_batches: 1,
238            batch_loss: None,
239            epoch_loss: Some(train_loss),
240            val_loss: None,
241            metrics: vec![val_accuracy],
242            history: &epoch_history,
243            stop_training: false,
244            model: None,
245        };
246
247        // Visualize progress with metrics chart
248        if epoch % 10 == 0 || epoch == epochs - 1 {
249            visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
250        }
251
252        // Calculate and show confusion matrix during training
253        if epoch % 20 == 0 || epoch == epochs - 1 {
254            // Create confusion matrix
255            let cm = ConfusionMatrix::<f32>::new(
256                &y_test.view(),
257                &predicted_classes.view(),
258                Some(n_classes),
259                Some(class_labels.clone()),
260            )?;
261
262            // Show heatmap visualization
263            println!("\nConfusion Matrix at Epoch {}:", epoch + 1);
264            println!(
265                "{}",
266                cm.to_heatmap(
267                    Some(&format!("Confusion Matrix - Epoch {}", epoch + 1)),
268                    true
269                )
270            );
271        }
272    }
273
274    // Final evaluation
275    println!("\nFinal model evaluation:");
276
277    // Make predictions on test set
278    let x_test_dyn = x_test.clone().into_dyn();
279    let predictions = model.forward(&x_test_dyn)?;
280    let predicted_classes = predictions_to_classes(&predictions);
281
282    // Create confusion matrix
283    let cm = ConfusionMatrix::<f32>::new(
284        &y_test.view(),
285        &predicted_classes.view(),
286        Some(n_classes),
287        Some(class_labels.clone()),
288    )?;
289
290    // Calculate and show metrics
291    let accuracy = cm.accuracy();
292    let precision = cm.precision();
293    let recall = cm.recall();
294    let f1 = cm.f1_score();
295
296    println!("\nFinal Classification Metrics:");
297    println!("Overall Accuracy: {:.4}", accuracy);
298
299    println!("\nPer-Class Metrics:");
300    println!("Class | Precision | Recall | F1-Score");
301    println!("-----------------------------------");
302
303    for i in 0..n_classes {
304        println!(
305            "{}    | {:.4}     | {:.4}  | {:.4}",
306            class_labels[i], precision[i], recall[i], f1[i]
307        );
308    }
309
310    println!("\nMacro F1-Score: {:.4}", cm.macro_f1());
311
312    // Show different confusion matrix visualizations
313    println!("\nFinal Confusion Matrix Visualizations:");
314
315    // 1. Standard confusion matrix
316    println!("\n1. Standard Confusion Matrix:");
317    println!("{}", cm.to_ascii(Some("Final Confusion Matrix"), false));
318
319    // 2. Normalized confusion matrix
320    println!("\n2. Normalized Confusion Matrix:");
321    println!(
322        "{}",
323        cm.to_ascii(Some("Final Normalized Confusion Matrix"), true)
324    );
325
326    // 3. Confusion matrix heatmap
327    println!("\n3. Confusion Matrix Heatmap:");
328    println!(
329        "{}",
330        cm.to_heatmap(Some("Final Confusion Matrix Heatmap"), true)
331    );
332
333    // 4. Error pattern analysis
334    println!("\n4. Error Pattern Analysis:");
335    println!("{}", cm.error_heatmap(Some("Final Error Pattern Analysis")));
336
337    println!("\nNeural Network Confusion Matrix Visualization Complete!");
338    Ok(())
339}
Source

pub fn error_heatmap(&self, title: Option<&str>) -> String

Create a confusion matrix heatmap that focuses on misclassification patterns

This visualization is specialized to highlight where the model makes mistakes, with emphasis on the off-diagonal elements to help identify error patterns. The key features of this visualization are:

  • Diagonal elements (correct classifications) are de-emphasized with dim styling
  • Off-diagonal elements (errors) are highlighted with a color gradient
  • Colors are normalized relative to the maximum off-diagonal value
  • A specialized legend explains error intensity levels
§Arguments
  • title - Optional title for the heatmap
§Returns
  • String - ASCII error pattern heatmap
§Example
use scirs2_neural::utils::evaluation::ConfusionMatrix;
use ndarray::Array1;

// Create some example data
let y_true = Array1::from_vec(vec![0, 1, 2, 0, 1, 2, 0, 1, 2, 0]);
let y_pred = Array1::from_vec(vec![0, 1, 1, 0, 1, 2, 1, 1, 0, 0]);

let class_labels = vec!["Class A".to_string(), "Class B".to_string(), "Class C".to_string()];
let cm = ConfusionMatrix::<f32>::new(&y_true.view(), &y_pred.view(), None, Some(class_labels)).unwrap();

// Generate the error pattern heatmap
let error_viz = cm.error_heatmap(Some("Misclassification Analysis"));
println!("{}", error_viz);
Examples found in repository?
examples/error_pattern_heatmap.rs (line 82)
6fn main() {
7    // Create a reproducible random number generator
8    let mut rng = SmallRng::seed_from_u64(42);
9
10    // Generate synthetic multiclass classification data with specific error patterns
11    let num_classes = 5;
12
13    // Create confusion matrix with controlled error patterns
14    let mut matrix = vec![vec![0; num_classes]; num_classes];
15
16    // Set diagonal elements (correct classifications) with high values
17    for i in 0..num_classes {
18        matrix[i][i] = 70 + rng.random_range(0..15); // 70-85 correct per class
19    }
20
21    // Create specific error patterns:
22    // - Classes 0 and 1 often confused
23    matrix[0][1] = 25;
24    matrix[1][0] = 15;
25
26    // - Class 2 sometimes confused with Class 3
27    matrix[2][3] = 18;
28
29    // - Class 4 has some misclassifications to all other classes
30    matrix[4][0] = 8;
31    matrix[4][1] = 5;
32    matrix[4][2] = 10;
33    matrix[4][3] = 12;
34
35    // - Some minor errors scattered about
36    for i in 0..num_classes {
37        for j in 0..num_classes {
38            if i != j && matrix[i][j] == 0 {
39                matrix[i][j] = rng.random_range(0..5);
40            }
41        }
42    }
43
44    // Convert to ndarray
45    let flat_matrix: Vec<f64> = matrix.iter().flatten().map(|&x| x as f64).collect();
46    let ndarray_matrix =
47        ndarray::Array::from_shape_vec((num_classes, num_classes), flat_matrix).unwrap();
48
49    // Create class labels
50    let class_labels = vec![
51        "Class A".to_string(),
52        "Class B".to_string(),
53        "Class C".to_string(),
54        "Class D".to_string(),
55        "Class E".to_string(),
56    ];
57
58    // Create confusion matrix
59    let cm = ConfusionMatrix::from_matrix(ndarray_matrix, Some(class_labels)).unwrap();
60
61    // Example 1: Standard confusion matrix
62    println!("Example 1: Standard Confusion Matrix\n");
63    let regular_output = cm.to_ascii(Some("Classification Results"), false);
64    println!("{}", regular_output);
65
66    // Example 2: Normal heatmap
67    println!("\n\nExample 2: Standard Heatmap Visualization\n");
68    let color_options = ColorOptions {
69        enabled: true,
70        use_bright: true,
71        use_background: false,
72    };
73    let heatmap_output = cm.to_heatmap_with_options(
74        Some("Classification Heatmap"),
75        true, // normalized
76        &color_options,
77    );
78    println!("{}", heatmap_output);
79
80    // Example 3: Error pattern heatmap
81    println!("\n\nExample 3: Error Pattern Heatmap (highlighting misclassifications)\n");
82    let error_heatmap = cm.error_heatmap(Some("Misclassification Analysis"));
83    println!("{}", error_heatmap);
84}
More examples
Hide additional examples
examples/neural_confusion_matrix.rs (line 335)
118fn main() -> Result<()> {
119    println!("Neural Network Confusion Matrix Visualization");
120    println!("==============================================\n");
121
122    // Initialize RNG with a fixed seed for reproducibility
123    let mut rng = SmallRng::seed_from_u64(42);
124
125    // Generate spiral dataset for 3-class classification
126    let n_classes = 3;
127    let n_samples_per_class = 100;
128    let noise = 0.15;
129
130    let (x, y) = generate_spiral_dataset(n_samples_per_class, n_classes, noise, &mut rng);
131    println!(
132        "Generated spiral dataset with {} classes, {} samples per class",
133        n_classes, n_samples_per_class
134    );
135
136    // Split data into training and test sets (80/20 split)
137    let n_samples = x.shape()[0];
138    let n_train = (n_samples as f32 * 0.8) as usize;
139    let n_test = n_samples - n_train;
140
141    let x_train = x.slice(ndarray::s![0..n_train, ..]).to_owned();
142    let y_train = y.slice(ndarray::s![0..n_train]).to_owned();
143    let x_test = x.slice(ndarray::s![n_train.., ..]).to_owned();
144    let y_test = y.slice(ndarray::s![n_train..]).to_owned();
145
146    println!(
147        "Split data into {} training and {} test samples",
148        n_train, n_test
149    );
150
151    // Create a classification model
152    let input_dim = 2; // 2D input (x, y coordinates)
153    let hidden_dim = 32; // Hidden layer size
154    let output_dim = n_classes; // One output per class
155
156    let mut model = create_classification_model(input_dim, hidden_dim, output_dim, &mut rng)?;
157    println!("Created model with {} layers", model.num_layers());
158
159    // Setup loss function and optimizer
160    let loss_fn = MeanSquaredError::new();
161    let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
162
163    // Train the model
164    let epochs = 100;
165    let x_train_dyn = x_train.clone().into_dyn();
166    let y_train_onehot = one_hot_encode(&y_train, n_classes);
167    let y_train_onehot_dyn = y_train_onehot.into_dyn();
168
169    // Create visualization callback for training metrics
170    let mut visualization_cb = VisualizationCallback::new(10) // Show every 10 epochs
171        .with_tracked_metrics(vec![
172            "train_loss".to_string(),
173            "val_accuracy".to_string(),
174        ]);
175
176    // Define class labels for confusion matrix
177    let class_labels = vec![
178        "Class A".to_string(),
179        "Class B".to_string(),
180        "Class C".to_string(),
181    ];
182
183    // Train the model (simple manual training loop)
184    println!("\nTraining model...");
185
186    // Initialize history for tracking metrics
187    let mut epoch_history = HashMap::new();
188    epoch_history.insert("train_loss".to_string(), Vec::new());
189    epoch_history.insert("val_accuracy".to_string(), Vec::new());
190
191    // Training loop
192    for epoch in 0..epochs {
193        // Train for one epoch
194        let train_loss =
195            model.train_batch(&x_train_dyn, &y_train_onehot_dyn, &loss_fn, &mut optimizer)?;
196
197        // Compute validation accuracy
198        let x_test_dyn = x_test.clone().into_dyn();
199        let predictions = model.forward(&x_test_dyn)?;
200        let predicted_classes = predictions_to_classes(&predictions);
201
202        // Calculate validation accuracy
203        let mut correct = 0;
204        for i in 0..n_test {
205            if predicted_classes[i] == y_test[i] {
206                correct += 1;
207            }
208        }
209        let val_accuracy = correct as f32 / n_test as f32;
210
211        // Store metrics
212        epoch_history
213            .get_mut("train_loss")
214            .unwrap()
215            .push(train_loss);
216        epoch_history
217            .get_mut("val_accuracy")
218            .unwrap()
219            .push(val_accuracy);
220
221        // Print progress
222        if (epoch + 1) % 10 == 0 || epoch == 0 {
223            println!(
224                "Epoch {}/{}: loss = {:.6}, val_accuracy = {:.4}",
225                epoch + 1,
226                epochs,
227                train_loss,
228                val_accuracy
229            );
230        }
231
232        // Update visualization callback
233        let mut context = CallbackContext {
234            epoch,
235            total_epochs: epochs,
236            batch: 0,
237            total_batches: 1,
238            batch_loss: None,
239            epoch_loss: Some(train_loss),
240            val_loss: None,
241            metrics: vec![val_accuracy],
242            history: &epoch_history,
243            stop_training: false,
244            model: None,
245        };
246
247        // Visualize progress with metrics chart
248        if epoch % 10 == 0 || epoch == epochs - 1 {
249            visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
250        }
251
252        // Calculate and show confusion matrix during training
253        if epoch % 20 == 0 || epoch == epochs - 1 {
254            // Create confusion matrix
255            let cm = ConfusionMatrix::<f32>::new(
256                &y_test.view(),
257                &predicted_classes.view(),
258                Some(n_classes),
259                Some(class_labels.clone()),
260            )?;
261
262            // Show heatmap visualization
263            println!("\nConfusion Matrix at Epoch {}:", epoch + 1);
264            println!(
265                "{}",
266                cm.to_heatmap(
267                    Some(&format!("Confusion Matrix - Epoch {}", epoch + 1)),
268                    true
269                )
270            );
271        }
272    }
273
274    // Final evaluation
275    println!("\nFinal model evaluation:");
276
277    // Make predictions on test set
278    let x_test_dyn = x_test.clone().into_dyn();
279    let predictions = model.forward(&x_test_dyn)?;
280    let predicted_classes = predictions_to_classes(&predictions);
281
282    // Create confusion matrix
283    let cm = ConfusionMatrix::<f32>::new(
284        &y_test.view(),
285        &predicted_classes.view(),
286        Some(n_classes),
287        Some(class_labels.clone()),
288    )?;
289
290    // Calculate and show metrics
291    let accuracy = cm.accuracy();
292    let precision = cm.precision();
293    let recall = cm.recall();
294    let f1 = cm.f1_score();
295
296    println!("\nFinal Classification Metrics:");
297    println!("Overall Accuracy: {:.4}", accuracy);
298
299    println!("\nPer-Class Metrics:");
300    println!("Class | Precision | Recall | F1-Score");
301    println!("-----------------------------------");
302
303    for i in 0..n_classes {
304        println!(
305            "{}    | {:.4}     | {:.4}  | {:.4}",
306            class_labels[i], precision[i], recall[i], f1[i]
307        );
308    }
309
310    println!("\nMacro F1-Score: {:.4}", cm.macro_f1());
311
312    // Show different confusion matrix visualizations
313    println!("\nFinal Confusion Matrix Visualizations:");
314
315    // 1. Standard confusion matrix
316    println!("\n1. Standard Confusion Matrix:");
317    println!("{}", cm.to_ascii(Some("Final Confusion Matrix"), false));
318
319    // 2. Normalized confusion matrix
320    println!("\n2. Normalized Confusion Matrix:");
321    println!(
322        "{}",
323        cm.to_ascii(Some("Final Normalized Confusion Matrix"), true)
324    );
325
326    // 3. Confusion matrix heatmap
327    println!("\n3. Confusion Matrix Heatmap:");
328    println!(
329        "{}",
330        cm.to_heatmap(Some("Final Confusion Matrix Heatmap"), true)
331    );
332
333    // 4. Error pattern analysis
334    println!("\n4. Error Pattern Analysis:");
335    println!("{}", cm.error_heatmap(Some("Final Error Pattern Analysis")));
336
337    println!("\nNeural Network Confusion Matrix Visualization Complete!");
338    Ok(())
339}
Source

pub fn to_heatmap_with_options( &self, title: Option<&str>, normalized: bool, color_options: &ColorOptions, ) -> String

Convert the confusion matrix to a heatmap visualization with customizable options

§Arguments
  • title - Optional title for the heatmap
  • normalized - Whether to normalize the matrix (row values sum to 1)
  • color_options - Color options for visualization
§Returns
  • String - ASCII heatmap representation with colors
Examples found in repository?
examples/error_pattern_heatmap.rs (lines 73-77)
6fn main() {
7    // Create a reproducible random number generator
8    let mut rng = SmallRng::seed_from_u64(42);
9
10    // Generate synthetic multiclass classification data with specific error patterns
11    let num_classes = 5;
12
13    // Create confusion matrix with controlled error patterns
14    let mut matrix = vec![vec![0; num_classes]; num_classes];
15
16    // Set diagonal elements (correct classifications) with high values
17    for i in 0..num_classes {
18        matrix[i][i] = 70 + rng.random_range(0..15); // 70-85 correct per class
19    }
20
21    // Create specific error patterns:
22    // - Classes 0 and 1 often confused
23    matrix[0][1] = 25;
24    matrix[1][0] = 15;
25
26    // - Class 2 sometimes confused with Class 3
27    matrix[2][3] = 18;
28
29    // - Class 4 has some misclassifications to all other classes
30    matrix[4][0] = 8;
31    matrix[4][1] = 5;
32    matrix[4][2] = 10;
33    matrix[4][3] = 12;
34
35    // - Some minor errors scattered about
36    for i in 0..num_classes {
37        for j in 0..num_classes {
38            if i != j && matrix[i][j] == 0 {
39                matrix[i][j] = rng.random_range(0..5);
40            }
41        }
42    }
43
44    // Convert to ndarray
45    let flat_matrix: Vec<f64> = matrix.iter().flatten().map(|&x| x as f64).collect();
46    let ndarray_matrix =
47        ndarray::Array::from_shape_vec((num_classes, num_classes), flat_matrix).unwrap();
48
49    // Create class labels
50    let class_labels = vec![
51        "Class A".to_string(),
52        "Class B".to_string(),
53        "Class C".to_string(),
54        "Class D".to_string(),
55        "Class E".to_string(),
56    ];
57
58    // Create confusion matrix
59    let cm = ConfusionMatrix::from_matrix(ndarray_matrix, Some(class_labels)).unwrap();
60
61    // Example 1: Standard confusion matrix
62    println!("Example 1: Standard Confusion Matrix\n");
63    let regular_output = cm.to_ascii(Some("Classification Results"), false);
64    println!("{}", regular_output);
65
66    // Example 2: Normal heatmap
67    println!("\n\nExample 2: Standard Heatmap Visualization\n");
68    let color_options = ColorOptions {
69        enabled: true,
70        use_bright: true,
71        use_background: false,
72    };
73    let heatmap_output = cm.to_heatmap_with_options(
74        Some("Classification Heatmap"),
75        true, // normalized
76        &color_options,
77    );
78    println!("{}", heatmap_output);
79
80    // Example 3: Error pattern heatmap
81    println!("\n\nExample 3: Error Pattern Heatmap (highlighting misclassifications)\n");
82    let error_heatmap = cm.error_heatmap(Some("Misclassification Analysis"));
83    println!("{}", error_heatmap);
84}
More examples
Hide additional examples
examples/confusion_matrix_heatmap.rs (lines 86-90)
7fn main() {
8    // Create a reproducible random number generator
9    let mut rng = SmallRng::seed_from_u64(42);
10
11    // Generate synthetic multiclass classification data
12    let num_classes = 5;
13    let n_samples = 500;
14
15    // Generate true labels (0 to num_classes-1)
16    let mut y_true = Vec::with_capacity(n_samples);
17    for _ in 0..n_samples {
18        y_true.push(rng.random_range(0..num_classes));
19    }
20
21    // Generate predicted labels with controlled accuracy
22    let mut y_pred = Vec::with_capacity(n_samples);
23    for &true_label in &y_true {
24        // 80% chance to predict correctly, 20% chance of error
25        if rng.random::<f64>() < 0.8 {
26            y_pred.push(true_label);
27        } else {
28            // When wrong, tend to predict adjacent classes more often
29            let mut pred = true_label;
30            while pred == true_label {
31                // Generate error that's more likely to be close to true label
32                let error_margin = (rng.random::<f64>() * 2.0).round() as usize; // 0, 1, or 2
33                if rng.random::<bool>() {
34                    pred = (true_label + error_margin) % num_classes;
35                } else {
36                    pred = (true_label + num_classes - error_margin) % num_classes;
37                }
38            }
39            y_pred.push(pred);
40        }
41    }
42
43    // Convert to ndarray arrays
44    let y_true_array = Array1::from(y_true);
45    let y_pred_array = Array1::from(y_pred);
46
47    // Create class labels
48    let class_labels = vec![
49        "Cat".to_string(),
50        "Dog".to_string(),
51        "Bird".to_string(),
52        "Fish".to_string(),
53        "Rabbit".to_string(),
54    ];
55
56    // Create confusion matrix
57    let cm = ConfusionMatrix::<f64>::new(
58        &y_true_array.view(),
59        &y_pred_array.view(),
60        Some(num_classes),
61        Some(class_labels),
62    )
63    .unwrap();
64
65    // Example 1: Standard confusion matrix
66    println!("Example 1: Standard Confusion Matrix\n");
67    let regular_output = cm.to_ascii(Some("Animal Classification Results"), false);
68    println!("{}", regular_output);
69
70    // Example 2: Confusion matrix with color
71    println!("\n\nExample 2: Colored Confusion Matrix\n");
72    let color_options = ColorOptions {
73        enabled: true,
74        use_bright: true,
75        use_background: false,
76    };
77    let colored_output = cm.to_ascii_with_options(
78        Some("Animal Classification Results (with color)"),
79        false,
80        &color_options,
81    );
82    println!("{}", colored_output);
83
84    // Example 3: Normalized confusion matrix heatmap
85    println!("\n\nExample 3: Normalized Confusion Matrix Heatmap\n");
86    let heatmap_output = cm.to_heatmap_with_options(
87        Some("Animal Classification Heatmap (normalized)"),
88        true, // normalized
89        &color_options,
90    );
91    println!("{}", heatmap_output);
92
93    // Example 4: Raw counts heatmap
94    println!("\n\nExample 4: Raw Counts Confusion Matrix Heatmap\n");
95    let raw_heatmap = cm.to_heatmap_with_options(
96        Some("Animal Classification Heatmap (raw counts)"),
97        false, // not normalized
98        &color_options,
99    );
100    println!("{}", raw_heatmap);
101}

Trait Implementations§

Source§

impl<F: Clone + Float + Debug + Display> Clone for ConfusionMatrix<F>

Source§

fn clone(&self) -> ConfusionMatrix<F>

Returns a duplicate of the value. Read more
1.0.0 · Source§

const fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl<F: Debug + Float + Debug + Display> Debug for ConfusionMatrix<F>

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more

Auto Trait Implementations§

§

impl<F> Freeze for ConfusionMatrix<F>

§

impl<F> RefUnwindSafe for ConfusionMatrix<F>
where F: RefUnwindSafe,

§

impl<F> Send for ConfusionMatrix<F>
where F: Send,

§

impl<F> Sync for ConfusionMatrix<F>
where F: Sync,

§

impl<F> Unpin for ConfusionMatrix<F>

§

impl<F> UnwindSafe for ConfusionMatrix<F>
where F: RefUnwindSafe,

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V