error_pattern_heatmap/
error_pattern_heatmap.rs

1use rand::prelude::*;
2use rand::rngs::SmallRng;
3use scirs2_neural::utils::colors::ColorOptions;
4use scirs2_neural::utils::evaluation::ConfusionMatrix;
5
6#[allow(dead_code)]
7fn main() {
8    // Create a reproducible random number generator
9    let mut rng = SmallRng::from_seed([42; 32]);
10    // Generate synthetic multiclass classification data with specific error patterns
11    let num_classes = 5;
12    // Create confusion matrix with controlled error patterns
13    let mut matrix = vec![vec![0; num_classes]; num_classes];
14    // Set diagonal elements (correct classifications) with high values
15    #[allow(clippy::needless_range_loop)]
16    for i in 0..num_classes {
17        matrix[i][i] = 70 + rng.random_range(0..15); // 70-85 correct per class
18    }
19    // Create specific error patterns:
20    // - Classes 0 and 1 often confused
21    matrix[0][1] = 25;
22    matrix[1][0] = 15;
23    // - Class 2 sometimes confused with Class 3
24    matrix[2][3] = 18;
25    // - Class 4 has some misclassifications to all other classes
26    matrix[4][0] = 8;
27    matrix[4][1] = 5;
28    matrix[4][2] = 10;
29    matrix[4][3] = 12;
30    // - Some minor errors scattered about
31    #[allow(clippy::needless_range_loop)]
32    for i in 0..num_classes {
33        for j in 0..num_classes {
34            if i != j && matrix[i][j] == 0 {
35                matrix[i][j] = rng.random_range(0..5);
36            }
37        }
38    }
39    // Convert to ndarray
40    let flat_matrix: Vec<f64> = matrix.iter().flatten().map(|&x| x as f64).collect();
41    let ndarray_matrix =
42        ndarray::Array::from_shape_vec((num_classes, num_classes), flat_matrix).unwrap();
43    // Create class labels
44    let class_labels = vec![
45        "Class A".to_string(),
46        "Class B".to_string(),
47        "Class C".to_string(),
48        "Class D".to_string(),
49        "Class E".to_string(),
50    ];
51    // Create confusion matrix
52    let cm = ConfusionMatrix::from_matrix(ndarray_matrix, Some(class_labels)).unwrap();
53    // Example 1: Standard confusion matrix
54    println!("Example 1: Standard Confusion Matrix\n");
55    let regular_output = cm.to_ascii(Some("Classification Results"), false);
56    println!("{regular_output}");
57    // Example 2: Normal heatmap
58    println!("\n\nExample 2: Standard Heatmap Visualization\n");
59    let color_options = ColorOptions {
60        enabled: true,
61        use_bright: true,
62        use_background: false,
63    };
64    let heatmap_output = cm.to_heatmap_with_options(
65        Some("Classification Heatmap"),
66        true, // normalized
67        &color_options,
68    );
69    println!("{heatmap_output}");
70    // Example 3: Error pattern heatmap
71    println!("\n\nExample 3: Error Pattern Heatmap (highlighting misclassifications)\n");
72    let error_heatmap = cm.error_heatmap(Some("Misclassification Analysis"));
73    println!("{error_heatmap}");
74}