error_pattern_heatmap/
error_pattern_heatmap.rs

1use scirs2_core::random::prelude::*;
2use scirs2_core::random::rngs::SmallRng;
3use scirs2_neural::utils::colors::ColorOptions;
4use scirs2_neural::utils::evaluation::ConfusionMatrix;
5
6#[allow(dead_code)]
7fn main() {
8    // Create a reproducible random number generator
9    let mut rng = SmallRng::from_seed([42; 32]);
10    // Generate synthetic multiclass classification data with specific error patterns
11    let num_classes = 5;
12    // Create confusion matrix with controlled error patterns
13    let mut matrix = vec![vec![0; num_classes]; num_classes];
14    // Set diagonal elements (correct classifications) with high values
15    #[allow(clippy::needless_range_loop)]
16    for i in 0..num_classes {
17        matrix[i][i] = 70 + rng.random_range(0..15); // 70-85 correct per class
18    }
19    // Create specific error patterns:
20    // - Classes 0 and 1 often confused
21    matrix[0][1] = 25;
22    matrix[1][0] = 15;
23    // - Class 2 sometimes confused with Class 3
24    matrix[2][3] = 18;
25    // - Class 4 has some misclassifications to all other classes
26    matrix[4][0] = 8;
27    matrix[4][1] = 5;
28    matrix[4][2] = 10;
29    matrix[4][3] = 12;
30    // - Some minor errors scattered about
31    #[allow(clippy::needless_range_loop)]
32    for i in 0..num_classes {
33        for j in 0..num_classes {
34            if i != j && matrix[i][j] == 0 {
35                matrix[i][j] = rng.random_range(0..5);
36            }
37        }
38    }
39    // Convert to ndarray
40    let flat_matrix: Vec<f64> = matrix.iter().flatten().map(|&x| x as f64).collect();
41    let ndarray_matrix =
42        scirs2_core::ndarray::Array::from_shape_vec((num_classes, num_classes), flat_matrix)
43            .unwrap();
44    // Create class labels
45    let class_labels = vec![
46        "Class A".to_string(),
47        "Class B".to_string(),
48        "Class C".to_string(),
49        "Class D".to_string(),
50        "Class E".to_string(),
51    ];
52    // Create confusion matrix
53    let cm = ConfusionMatrix::from_matrix(ndarray_matrix, Some(class_labels)).unwrap();
54    // Example 1: Standard confusion matrix
55    println!("Example 1: Standard Confusion Matrix\n");
56    let regular_output = cm.to_ascii(Some("Classification Results"), false);
57    println!("{regular_output}");
58    // Example 2: Normal heatmap
59    println!("\n\nExample 2: Standard Heatmap Visualization\n");
60    let color_options = ColorOptions {
61        enabled: true,
62        use_bright: true,
63        use_background: false,
64    };
65    let heatmap_output = cm.to_heatmap_with_options(
66        Some("Classification Heatmap"),
67        true, // normalized
68        &color_options,
69    );
70    println!("{heatmap_output}");
71    // Example 3: Error pattern heatmap
72    println!("\n\nExample 3: Error Pattern Heatmap (highlighting misclassifications)\n");
73    let error_heatmap = cm.error_heatmap(Some("Misclassification Analysis"));
74    println!("{error_heatmap}");
75}