ConfusionMatrix

Struct ConfusionMatrix 

Source
pub struct ConfusionMatrix<F: Float + Debug + Display> {
    pub matrix: Array2<F>,
    pub labels: Option<Vec<String>>,
    pub num_classes: usize,
}
Expand description

Confusion matrix for classification problems

Fields§

§matrix: Array2<F>

The raw confusion matrix data

§labels: Option<Vec<String>>

Class labels (optional)

§num_classes: usize

Number of classes

Implementations§

Source§

impl<F: Float + Debug + Display> ConfusionMatrix<F>

Source

pub fn new( y_true: &ArrayView1<'_, usize>, y_pred: &ArrayView1<'_, usize>, num_classes: Option<usize>, labels: Option<Vec<String>>, ) -> Result<Self>

Create a new confusion matrix from predictions and true labels

§Arguments
  • y_true - True class labels as integers
  • y_pred - Predicted class labels as integers
  • num_classes - Number of classes (if None, determined from data)
  • labels - Optional class labels as strings
§Returns
  • Result<ConfusionMatrix<F>> - The confusion matrix
§Example
use scirs2_neural::utils::evaluation::ConfusionMatrix;
use scirs2_core::ndarray::Array1;
let y_true = Array1::from_vec(vec![0, 1, 2, 0, 1, 2, 0]);
let y_pred = Array1::from_vec(vec![0, 1, 1, 0, 1, 2, 0]);
let cm = ConfusionMatrix::<f32>::new(&y_true.view(), &y_pred.view(), None, None).unwrap();
Examples found in repository?
examples/confusion_matrix_heatmap.rs (lines 52-57)
8fn main() {
9    // Create a reproducible random number generator
10    let mut rng = SmallRng::from_seed([42; 32]);
11    // Generate synthetic multiclass classification data
12    let num_classes = 5;
13    let n_samples = 500;
14    // Generate true labels (0 to num_classes-1)
15    let mut y_true = Vec::with_capacity(n_samples);
16    for _ in 0..n_samples {
17        y_true.push(rng.random_range(0..num_classes));
18    }
19    // Generate predicted labels with controlled accuracy
20    let mut y_pred = Vec::with_capacity(n_samples);
21    for &true_label in &y_true {
22        // 80% chance to predict correctly..20% chance of error
23        if rng.random::<f64>() < 0.8 {
24            y_pred.push(true_label);
25        } else {
26            // When wrong, tend to predict adjacent classes more often
27            let mut pred = true_label;
28            while pred == true_label {
29                // Generate error that's more likely to be close to true label
30                let error_margin = (rng.random::<f64>() * 2.0).round() as usize; // 0, 1, or 2
31                if rng.random::<bool>() {
32                    pred = (true_label + error_margin) % num_classes;
33                } else {
34                    pred = (true_label + num_classes - error_margin) % num_classes;
35                }
36            }
37            y_pred.push(pred);
38        }
39    }
40    // Convert to ndarray arrays
41    let y_true_array = Array1::from(y_true);
42    let y_pred_array = Array1::from(y_pred);
43    // Create class labels
44    let class_labels = vec![
45        "Cat".to_string(),
46        "Dog".to_string(),
47        "Bird".to_string(),
48        "Fish".to_string(),
49        "Rabbit".to_string(),
50    ];
51    // Create confusion matrix
52    let cm = ConfusionMatrix::<f64>::new(
53        &y_true_array.view(),
54        &y_pred_array.view(),
55        Some(num_classes),
56        Some(class_labels),
57    )
58    .unwrap();
59    // Example 1: Standard confusion matrix
60    println!("Example 1: Standard Confusion Matrix\n");
61    let regular_output = cm.to_ascii(Some("Animal Classification Results"), false);
62    println!("{regular_output}");
63    // Example 2: Confusion matrix with color
64    println!("\n\nExample 2: Colored Confusion Matrix\n");
65    let color_options = ColorOptions {
66        enabled: true,
67        use_bright: true,
68        use_background: false,
69    };
70    let colored_output = cm.to_ascii_with_options(
71        Some("Animal Classification Results (with color)"),
72        false,
73        &color_options,
74    );
75    println!("{colored_output}");
76    // Example 3: Normalized confusion matrix heatmap
77    println!("\n\nExample 3: Normalized Confusion Matrix Heatmap\n");
78    let heatmap_output = cm.to_heatmap_with_options(
79        Some("Animal Classification Heatmap (normalized)"),
80        true, // normalized
81        &color_options,
82    );
83    println!("{heatmap_output}");
84
85    // Example 4: Raw counts heatmap
86    println!("\n\nExample 4: Raw Counts Confusion Matrix Heatmap\n");
87    let raw_heatmap = cm.to_heatmap_with_options(
88        Some("Animal Classification Heatmap (raw counts)"),
89        false, // not normalized
90        &color_options,
91    );
92    println!("{raw_heatmap}");
93}
More examples
Hide additional examples
examples/colored_eval_visualization.rs (lines 64-69)
11fn main() -> Result<()> {
12    println!(
13        "{}",
14        stylize("Neural Network Model Evaluation with Color", Style::Bold)
15    );
16    println!("{}", "-".repeat(50));
17    // Set up color options
18    let color_options = ColorOptions {
19        enabled: true,
20        use_background: false,
21        use_bright: true,
22    };
23    // Generate some example data
24    let n_samples = 500;
25    let n_features = 10;
26    let n_classes = 4;
27    println!(
28        "\n{} {} {} {} {} {}",
29        colorize("Generating", Color::BrightGreen),
30        colorize(n_samples.to_string(), Color::BrightYellow),
31        colorize("samples with", Color::BrightGreen),
32        colorize(n_features.to_string(), Color::BrightYellow),
33        colorize("features for", Color::BrightGreen),
34        colorize(n_classes.to_string(), Color::BrightYellow),
35    );
36
37    // Create a deterministic RNG for reproducibility
38    let mut rng = SmallRng::from_seed([42; 32]);
39
40    // 1. Confusion Matrix Example
41    println!(
42        "\n{}",
43        stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
44    );
45    // Generate random predictions and true labels
46    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
47    // Create slightly correlated predictions (not completely random)
48    let y_pred = Array::from_shape_fn(n_samples, |i| {
49        if rng.random::<f32>() < 0.7 {
50            // 70% chance of correct prediction
51            y_true[i]
52        } else {
53            // 30% chance of random class
54            rng.random_range(0..n_classes)
55        }
56    });
57    // Create confusion matrix
58    let class_labels = vec![
59        "Class A".to_string(),
60        "Class B".to_string(),
61        "Class C".to_string(),
62        "Class D".to_string(),
63    ];
64    let cm = ConfusionMatrix::<f32>::new(
65        &y_true.view(),
66        &y_pred.view(),
67        Some(n_classes),
68        Some(class_labels),
69    )?;
70    // Print raw and normalized confusion matrices with color
71    println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
72    println!(
73        "{}",
74        cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
75    );
76    println!(
77        "\n{}",
78        colorize("Normalized Confusion Matrix:", Color::BrightCyan)
79    );
80    println!(
81        "{}",
82        cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
83    );
84    // Print metrics
85    println!(
86        "\n{} {:.3}",
87        colorize("Overall Accuracy:", Color::BrightMagenta),
88        cm.accuracy()
89    );
90    let precision = cm.precision();
91    let recall = cm.recall();
92    let f1 = cm.f1_score();
93    println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
94    for i in 0..n_classes {
95        println!(
96            "  {}: {}={:.3}, {}={:.3}, {}={:.3}",
97            colorize(format!("Class {i}"), Color::BrightYellow),
98            colorize("Precision", Color::BrightCyan),
99            precision[i],
100            colorize("Recall", Color::BrightGreen),
101            recall[i],
102            colorize("F1", Color::BrightBlue),
103            f1[i]
104        );
105    }
106    println!(
107        "{} {:.3}",
108        colorize("Macro F1 Score:", Color::BrightMagenta),
109        cm.macro_f1()
110    );
111    // 2. Feature Importance Visualization
112    println!(
113        "{}",
114        stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
115    );
116    // Generate random feature importance scores
117    let feature_names = (0..n_features)
118        .map(|i| format!("Feature_{i}"))
119        .collect::<Vec<String>>();
120    let importance = Array1::from_shape_fn(n_features, |i| {
121        // Make some features more important than others
122        let base = (n_features - i) as f32 / n_features as f32;
123        base + 0.2 * rng.random::<f32>()
124    });
125
126    let fi = FeatureImportance::new(feature_names, importance)?;
127
128    // Print full feature importance with color
129    println!(
130        "{}",
131        fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
132    );
133
134    // Print top-5 features with color
135    println!(
136        "{}",
137        colorize("Top 5 Most Important Features:", Color::BrightCyan)
138    );
139    println!(
140        "{}",
141        fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
142    );
143    // 3. ROC Curve for Binary Classification
144    println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
145    // Generate binary classification data
146    let n_binary = 200;
147    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
148    // Generate scores with some predictive power
149    let y_scores = Array1::from_shape_fn(n_binary, |i| {
150        if y_true_binary[i] == 1 {
151            // Higher scores for positive class
152            0.6 + 0.4 * rng.random::<f32>()
153        } else {
154            // Lower scores for negative class
155            0.4 * rng.random::<f32>()
156        }
157    });
158
159    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
160    println!(
161        "{}: {:.3}",
162        colorize("ROC AUC:", Color::BrightMagenta),
163        roc.auc
164    );
165    println!("\n{}", roc.to_ascii(None, 50, 20));
166
167    // 4. Learning Curve Visualization
168    println!(
169        "\n{}",
170        stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
171    );
172    // Generate learning curve data
173    let n_points = 10;
174    let n_cv = 5;
175    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
176    // Generate training scores (decreasing with size due to overfitting)
177    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
178        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
179    });
180
181    // Generate validation scores (increasing with size)
182    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
183        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
184    });
185
186    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
187    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
188
189    // Print final message with color
190    println!(
191        "{}",
192        colorize(
193            "Model evaluation visualizations completed successfully!",
194            Color::BrightGreen
195        )
196    );
197    Ok(())
198}
Source

pub fn from_matrix( matrix: Array2<F>, labels: Option<Vec<String>>, ) -> Result<Self>

Create a confusion matrix from raw matrix data

§Arguments
  • matrix - Raw confusion matrix data
  • labels - Optional class labels
Examples found in repository?
examples/error_pattern_heatmap.rs (line 53)
7fn main() {
8    // Create a reproducible random number generator
9    let mut rng = SmallRng::from_seed([42; 32]);
10    // Generate synthetic multiclass classification data with specific error patterns
11    let num_classes = 5;
12    // Create confusion matrix with controlled error patterns
13    let mut matrix = vec![vec![0; num_classes]; num_classes];
14    // Set diagonal elements (correct classifications) with high values
15    #[allow(clippy::needless_range_loop)]
16    for i in 0..num_classes {
17        matrix[i][i] = 70 + rng.random_range(0..15); // 70-85 correct per class
18    }
19    // Create specific error patterns:
20    // - Classes 0 and 1 often confused
21    matrix[0][1] = 25;
22    matrix[1][0] = 15;
23    // - Class 2 sometimes confused with Class 3
24    matrix[2][3] = 18;
25    // - Class 4 has some misclassifications to all other classes
26    matrix[4][0] = 8;
27    matrix[4][1] = 5;
28    matrix[4][2] = 10;
29    matrix[4][3] = 12;
30    // - Some minor errors scattered about
31    #[allow(clippy::needless_range_loop)]
32    for i in 0..num_classes {
33        for j in 0..num_classes {
34            if i != j && matrix[i][j] == 0 {
35                matrix[i][j] = rng.random_range(0..5);
36            }
37        }
38    }
39    // Convert to ndarray
40    let flat_matrix: Vec<f64> = matrix.iter().flatten().map(|&x| x as f64).collect();
41    let ndarray_matrix =
42        scirs2_core::ndarray::Array::from_shape_vec((num_classes, num_classes), flat_matrix)
43            .unwrap();
44    // Create class labels
45    let class_labels = vec![
46        "Class A".to_string(),
47        "Class B".to_string(),
48        "Class C".to_string(),
49        "Class D".to_string(),
50        "Class E".to_string(),
51    ];
52    // Create confusion matrix
53    let cm = ConfusionMatrix::from_matrix(ndarray_matrix, Some(class_labels)).unwrap();
54    // Example 1: Standard confusion matrix
55    println!("Example 1: Standard Confusion Matrix\n");
56    let regular_output = cm.to_ascii(Some("Classification Results"), false);
57    println!("{regular_output}");
58    // Example 2: Normal heatmap
59    println!("\n\nExample 2: Standard Heatmap Visualization\n");
60    let color_options = ColorOptions {
61        enabled: true,
62        use_bright: true,
63        use_background: false,
64    };
65    let heatmap_output = cm.to_heatmap_with_options(
66        Some("Classification Heatmap"),
67        true, // normalized
68        &color_options,
69    );
70    println!("{heatmap_output}");
71    // Example 3: Error pattern heatmap
72    println!("\n\nExample 3: Error Pattern Heatmap (highlighting misclassifications)\n");
73    let error_heatmap = cm.error_heatmap(Some("Misclassification Analysis"));
74    println!("{error_heatmap}");
75}
Source

pub fn normalized(&self) -> Array2<F>

Get the normalized confusion matrix (rows sum to 1)

Source

pub fn accuracy(&self) -> F

Calculate accuracy from the confusion matrix

Examples found in repository?
examples/colored_eval_visualization.rs (line 88)
11fn main() -> Result<()> {
12    println!(
13        "{}",
14        stylize("Neural Network Model Evaluation with Color", Style::Bold)
15    );
16    println!("{}", "-".repeat(50));
17    // Set up color options
18    let color_options = ColorOptions {
19        enabled: true,
20        use_background: false,
21        use_bright: true,
22    };
23    // Generate some example data
24    let n_samples = 500;
25    let n_features = 10;
26    let n_classes = 4;
27    println!(
28        "\n{} {} {} {} {} {}",
29        colorize("Generating", Color::BrightGreen),
30        colorize(n_samples.to_string(), Color::BrightYellow),
31        colorize("samples with", Color::BrightGreen),
32        colorize(n_features.to_string(), Color::BrightYellow),
33        colorize("features for", Color::BrightGreen),
34        colorize(n_classes.to_string(), Color::BrightYellow),
35    );
36
37    // Create a deterministic RNG for reproducibility
38    let mut rng = SmallRng::from_seed([42; 32]);
39
40    // 1. Confusion Matrix Example
41    println!(
42        "\n{}",
43        stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
44    );
45    // Generate random predictions and true labels
46    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
47    // Create slightly correlated predictions (not completely random)
48    let y_pred = Array::from_shape_fn(n_samples, |i| {
49        if rng.random::<f32>() < 0.7 {
50            // 70% chance of correct prediction
51            y_true[i]
52        } else {
53            // 30% chance of random class
54            rng.random_range(0..n_classes)
55        }
56    });
57    // Create confusion matrix
58    let class_labels = vec![
59        "Class A".to_string(),
60        "Class B".to_string(),
61        "Class C".to_string(),
62        "Class D".to_string(),
63    ];
64    let cm = ConfusionMatrix::<f32>::new(
65        &y_true.view(),
66        &y_pred.view(),
67        Some(n_classes),
68        Some(class_labels),
69    )?;
70    // Print raw and normalized confusion matrices with color
71    println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
72    println!(
73        "{}",
74        cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
75    );
76    println!(
77        "\n{}",
78        colorize("Normalized Confusion Matrix:", Color::BrightCyan)
79    );
80    println!(
81        "{}",
82        cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
83    );
84    // Print metrics
85    println!(
86        "\n{} {:.3}",
87        colorize("Overall Accuracy:", Color::BrightMagenta),
88        cm.accuracy()
89    );
90    let precision = cm.precision();
91    let recall = cm.recall();
92    let f1 = cm.f1_score();
93    println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
94    for i in 0..n_classes {
95        println!(
96            "  {}: {}={:.3}, {}={:.3}, {}={:.3}",
97            colorize(format!("Class {i}"), Color::BrightYellow),
98            colorize("Precision", Color::BrightCyan),
99            precision[i],
100            colorize("Recall", Color::BrightGreen),
101            recall[i],
102            colorize("F1", Color::BrightBlue),
103            f1[i]
104        );
105    }
106    println!(
107        "{} {:.3}",
108        colorize("Macro F1 Score:", Color::BrightMagenta),
109        cm.macro_f1()
110    );
111    // 2. Feature Importance Visualization
112    println!(
113        "{}",
114        stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
115    );
116    // Generate random feature importance scores
117    let feature_names = (0..n_features)
118        .map(|i| format!("Feature_{i}"))
119        .collect::<Vec<String>>();
120    let importance = Array1::from_shape_fn(n_features, |i| {
121        // Make some features more important than others
122        let base = (n_features - i) as f32 / n_features as f32;
123        base + 0.2 * rng.random::<f32>()
124    });
125
126    let fi = FeatureImportance::new(feature_names, importance)?;
127
128    // Print full feature importance with color
129    println!(
130        "{}",
131        fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
132    );
133
134    // Print top-5 features with color
135    println!(
136        "{}",
137        colorize("Top 5 Most Important Features:", Color::BrightCyan)
138    );
139    println!(
140        "{}",
141        fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
142    );
143    // 3. ROC Curve for Binary Classification
144    println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
145    // Generate binary classification data
146    let n_binary = 200;
147    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
148    // Generate scores with some predictive power
149    let y_scores = Array1::from_shape_fn(n_binary, |i| {
150        if y_true_binary[i] == 1 {
151            // Higher scores for positive class
152            0.6 + 0.4 * rng.random::<f32>()
153        } else {
154            // Lower scores for negative class
155            0.4 * rng.random::<f32>()
156        }
157    });
158
159    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
160    println!(
161        "{}: {:.3}",
162        colorize("ROC AUC:", Color::BrightMagenta),
163        roc.auc
164    );
165    println!("\n{}", roc.to_ascii(None, 50, 20));
166
167    // 4. Learning Curve Visualization
168    println!(
169        "\n{}",
170        stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
171    );
172    // Generate learning curve data
173    let n_points = 10;
174    let n_cv = 5;
175    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
176    // Generate training scores (decreasing with size due to overfitting)
177    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
178        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
179    });
180
181    // Generate validation scores (increasing with size)
182    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
183        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
184    });
185
186    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
187    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
188
189    // Print final message with color
190    println!(
191        "{}",
192        colorize(
193            "Model evaluation visualizations completed successfully!",
194            Color::BrightGreen
195        )
196    );
197    Ok(())
198}
Source

pub fn precision(&self) -> Array1<F>

Calculate precision for each class

Examples found in repository?
examples/colored_eval_visualization.rs (line 90)
11fn main() -> Result<()> {
12    println!(
13        "{}",
14        stylize("Neural Network Model Evaluation with Color", Style::Bold)
15    );
16    println!("{}", "-".repeat(50));
17    // Set up color options
18    let color_options = ColorOptions {
19        enabled: true,
20        use_background: false,
21        use_bright: true,
22    };
23    // Generate some example data
24    let n_samples = 500;
25    let n_features = 10;
26    let n_classes = 4;
27    println!(
28        "\n{} {} {} {} {} {}",
29        colorize("Generating", Color::BrightGreen),
30        colorize(n_samples.to_string(), Color::BrightYellow),
31        colorize("samples with", Color::BrightGreen),
32        colorize(n_features.to_string(), Color::BrightYellow),
33        colorize("features for", Color::BrightGreen),
34        colorize(n_classes.to_string(), Color::BrightYellow),
35    );
36
37    // Create a deterministic RNG for reproducibility
38    let mut rng = SmallRng::from_seed([42; 32]);
39
40    // 1. Confusion Matrix Example
41    println!(
42        "\n{}",
43        stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
44    );
45    // Generate random predictions and true labels
46    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
47    // Create slightly correlated predictions (not completely random)
48    let y_pred = Array::from_shape_fn(n_samples, |i| {
49        if rng.random::<f32>() < 0.7 {
50            // 70% chance of correct prediction
51            y_true[i]
52        } else {
53            // 30% chance of random class
54            rng.random_range(0..n_classes)
55        }
56    });
57    // Create confusion matrix
58    let class_labels = vec![
59        "Class A".to_string(),
60        "Class B".to_string(),
61        "Class C".to_string(),
62        "Class D".to_string(),
63    ];
64    let cm = ConfusionMatrix::<f32>::new(
65        &y_true.view(),
66        &y_pred.view(),
67        Some(n_classes),
68        Some(class_labels),
69    )?;
70    // Print raw and normalized confusion matrices with color
71    println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
72    println!(
73        "{}",
74        cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
75    );
76    println!(
77        "\n{}",
78        colorize("Normalized Confusion Matrix:", Color::BrightCyan)
79    );
80    println!(
81        "{}",
82        cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
83    );
84    // Print metrics
85    println!(
86        "\n{} {:.3}",
87        colorize("Overall Accuracy:", Color::BrightMagenta),
88        cm.accuracy()
89    );
90    let precision = cm.precision();
91    let recall = cm.recall();
92    let f1 = cm.f1_score();
93    println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
94    for i in 0..n_classes {
95        println!(
96            "  {}: {}={:.3}, {}={:.3}, {}={:.3}",
97            colorize(format!("Class {i}"), Color::BrightYellow),
98            colorize("Precision", Color::BrightCyan),
99            precision[i],
100            colorize("Recall", Color::BrightGreen),
101            recall[i],
102            colorize("F1", Color::BrightBlue),
103            f1[i]
104        );
105    }
106    println!(
107        "{} {:.3}",
108        colorize("Macro F1 Score:", Color::BrightMagenta),
109        cm.macro_f1()
110    );
111    // 2. Feature Importance Visualization
112    println!(
113        "{}",
114        stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
115    );
116    // Generate random feature importance scores
117    let feature_names = (0..n_features)
118        .map(|i| format!("Feature_{i}"))
119        .collect::<Vec<String>>();
120    let importance = Array1::from_shape_fn(n_features, |i| {
121        // Make some features more important than others
122        let base = (n_features - i) as f32 / n_features as f32;
123        base + 0.2 * rng.random::<f32>()
124    });
125
126    let fi = FeatureImportance::new(feature_names, importance)?;
127
128    // Print full feature importance with color
129    println!(
130        "{}",
131        fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
132    );
133
134    // Print top-5 features with color
135    println!(
136        "{}",
137        colorize("Top 5 Most Important Features:", Color::BrightCyan)
138    );
139    println!(
140        "{}",
141        fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
142    );
143    // 3. ROC Curve for Binary Classification
144    println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
145    // Generate binary classification data
146    let n_binary = 200;
147    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
148    // Generate scores with some predictive power
149    let y_scores = Array1::from_shape_fn(n_binary, |i| {
150        if y_true_binary[i] == 1 {
151            // Higher scores for positive class
152            0.6 + 0.4 * rng.random::<f32>()
153        } else {
154            // Lower scores for negative class
155            0.4 * rng.random::<f32>()
156        }
157    });
158
159    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
160    println!(
161        "{}: {:.3}",
162        colorize("ROC AUC:", Color::BrightMagenta),
163        roc.auc
164    );
165    println!("\n{}", roc.to_ascii(None, 50, 20));
166
167    // 4. Learning Curve Visualization
168    println!(
169        "\n{}",
170        stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
171    );
172    // Generate learning curve data
173    let n_points = 10;
174    let n_cv = 5;
175    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
176    // Generate training scores (decreasing with size due to overfitting)
177    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
178        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
179    });
180
181    // Generate validation scores (increasing with size)
182    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
183        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
184    });
185
186    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
187    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
188
189    // Print final message with color
190    println!(
191        "{}",
192        colorize(
193            "Model evaluation visualizations completed successfully!",
194            Color::BrightGreen
195        )
196    );
197    Ok(())
198}
Source

pub fn recall(&self) -> Array1<F>

Calculate recall for each class

Examples found in repository?
examples/colored_eval_visualization.rs (line 91)
11fn main() -> Result<()> {
12    println!(
13        "{}",
14        stylize("Neural Network Model Evaluation with Color", Style::Bold)
15    );
16    println!("{}", "-".repeat(50));
17    // Set up color options
18    let color_options = ColorOptions {
19        enabled: true,
20        use_background: false,
21        use_bright: true,
22    };
23    // Generate some example data
24    let n_samples = 500;
25    let n_features = 10;
26    let n_classes = 4;
27    println!(
28        "\n{} {} {} {} {} {}",
29        colorize("Generating", Color::BrightGreen),
30        colorize(n_samples.to_string(), Color::BrightYellow),
31        colorize("samples with", Color::BrightGreen),
32        colorize(n_features.to_string(), Color::BrightYellow),
33        colorize("features for", Color::BrightGreen),
34        colorize(n_classes.to_string(), Color::BrightYellow),
35    );
36
37    // Create a deterministic RNG for reproducibility
38    let mut rng = SmallRng::from_seed([42; 32]);
39
40    // 1. Confusion Matrix Example
41    println!(
42        "\n{}",
43        stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
44    );
45    // Generate random predictions and true labels
46    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
47    // Create slightly correlated predictions (not completely random)
48    let y_pred = Array::from_shape_fn(n_samples, |i| {
49        if rng.random::<f32>() < 0.7 {
50            // 70% chance of correct prediction
51            y_true[i]
52        } else {
53            // 30% chance of random class
54            rng.random_range(0..n_classes)
55        }
56    });
57    // Create confusion matrix
58    let class_labels = vec![
59        "Class A".to_string(),
60        "Class B".to_string(),
61        "Class C".to_string(),
62        "Class D".to_string(),
63    ];
64    let cm = ConfusionMatrix::<f32>::new(
65        &y_true.view(),
66        &y_pred.view(),
67        Some(n_classes),
68        Some(class_labels),
69    )?;
70    // Print raw and normalized confusion matrices with color
71    println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
72    println!(
73        "{}",
74        cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
75    );
76    println!(
77        "\n{}",
78        colorize("Normalized Confusion Matrix:", Color::BrightCyan)
79    );
80    println!(
81        "{}",
82        cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
83    );
84    // Print metrics
85    println!(
86        "\n{} {:.3}",
87        colorize("Overall Accuracy:", Color::BrightMagenta),
88        cm.accuracy()
89    );
90    let precision = cm.precision();
91    let recall = cm.recall();
92    let f1 = cm.f1_score();
93    println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
94    for i in 0..n_classes {
95        println!(
96            "  {}: {}={:.3}, {}={:.3}, {}={:.3}",
97            colorize(format!("Class {i}"), Color::BrightYellow),
98            colorize("Precision", Color::BrightCyan),
99            precision[i],
100            colorize("Recall", Color::BrightGreen),
101            recall[i],
102            colorize("F1", Color::BrightBlue),
103            f1[i]
104        );
105    }
106    println!(
107        "{} {:.3}",
108        colorize("Macro F1 Score:", Color::BrightMagenta),
109        cm.macro_f1()
110    );
111    // 2. Feature Importance Visualization
112    println!(
113        "{}",
114        stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
115    );
116    // Generate random feature importance scores
117    let feature_names = (0..n_features)
118        .map(|i| format!("Feature_{i}"))
119        .collect::<Vec<String>>();
120    let importance = Array1::from_shape_fn(n_features, |i| {
121        // Make some features more important than others
122        let base = (n_features - i) as f32 / n_features as f32;
123        base + 0.2 * rng.random::<f32>()
124    });
125
126    let fi = FeatureImportance::new(feature_names, importance)?;
127
128    // Print full feature importance with color
129    println!(
130        "{}",
131        fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
132    );
133
134    // Print top-5 features with color
135    println!(
136        "{}",
137        colorize("Top 5 Most Important Features:", Color::BrightCyan)
138    );
139    println!(
140        "{}",
141        fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
142    );
143    // 3. ROC Curve for Binary Classification
144    println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
145    // Generate binary classification data
146    let n_binary = 200;
147    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
148    // Generate scores with some predictive power
149    let y_scores = Array1::from_shape_fn(n_binary, |i| {
150        if y_true_binary[i] == 1 {
151            // Higher scores for positive class
152            0.6 + 0.4 * rng.random::<f32>()
153        } else {
154            // Lower scores for negative class
155            0.4 * rng.random::<f32>()
156        }
157    });
158
159    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
160    println!(
161        "{}: {:.3}",
162        colorize("ROC AUC:", Color::BrightMagenta),
163        roc.auc
164    );
165    println!("\n{}", roc.to_ascii(None, 50, 20));
166
167    // 4. Learning Curve Visualization
168    println!(
169        "\n{}",
170        stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
171    );
172    // Generate learning curve data
173    let n_points = 10;
174    let n_cv = 5;
175    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
176    // Generate training scores (decreasing with size due to overfitting)
177    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
178        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
179    });
180
181    // Generate validation scores (increasing with size)
182    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
183        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
184    });
185
186    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
187    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
188
189    // Print final message with color
190    println!(
191        "{}",
192        colorize(
193            "Model evaluation visualizations completed successfully!",
194            Color::BrightGreen
195        )
196    );
197    Ok(())
198}
Source

pub fn f1_score(&self) -> Array1<F>

Calculate F1 score for each class

Examples found in repository?
examples/colored_eval_visualization.rs (line 92)
11fn main() -> Result<()> {
12    println!(
13        "{}",
14        stylize("Neural Network Model Evaluation with Color", Style::Bold)
15    );
16    println!("{}", "-".repeat(50));
17    // Set up color options
18    let color_options = ColorOptions {
19        enabled: true,
20        use_background: false,
21        use_bright: true,
22    };
23    // Generate some example data
24    let n_samples = 500;
25    let n_features = 10;
26    let n_classes = 4;
27    println!(
28        "\n{} {} {} {} {} {}",
29        colorize("Generating", Color::BrightGreen),
30        colorize(n_samples.to_string(), Color::BrightYellow),
31        colorize("samples with", Color::BrightGreen),
32        colorize(n_features.to_string(), Color::BrightYellow),
33        colorize("features for", Color::BrightGreen),
34        colorize(n_classes.to_string(), Color::BrightYellow),
35    );
36
37    // Create a deterministic RNG for reproducibility
38    let mut rng = SmallRng::from_seed([42; 32]);
39
40    // 1. Confusion Matrix Example
41    println!(
42        "\n{}",
43        stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
44    );
45    // Generate random predictions and true labels
46    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
47    // Create slightly correlated predictions (not completely random)
48    let y_pred = Array::from_shape_fn(n_samples, |i| {
49        if rng.random::<f32>() < 0.7 {
50            // 70% chance of correct prediction
51            y_true[i]
52        } else {
53            // 30% chance of random class
54            rng.random_range(0..n_classes)
55        }
56    });
57    // Create confusion matrix
58    let class_labels = vec![
59        "Class A".to_string(),
60        "Class B".to_string(),
61        "Class C".to_string(),
62        "Class D".to_string(),
63    ];
64    let cm = ConfusionMatrix::<f32>::new(
65        &y_true.view(),
66        &y_pred.view(),
67        Some(n_classes),
68        Some(class_labels),
69    )?;
70    // Print raw and normalized confusion matrices with color
71    println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
72    println!(
73        "{}",
74        cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
75    );
76    println!(
77        "\n{}",
78        colorize("Normalized Confusion Matrix:", Color::BrightCyan)
79    );
80    println!(
81        "{}",
82        cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
83    );
84    // Print metrics
85    println!(
86        "\n{} {:.3}",
87        colorize("Overall Accuracy:", Color::BrightMagenta),
88        cm.accuracy()
89    );
90    let precision = cm.precision();
91    let recall = cm.recall();
92    let f1 = cm.f1_score();
93    println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
94    for i in 0..n_classes {
95        println!(
96            "  {}: {}={:.3}, {}={:.3}, {}={:.3}",
97            colorize(format!("Class {i}"), Color::BrightYellow),
98            colorize("Precision", Color::BrightCyan),
99            precision[i],
100            colorize("Recall", Color::BrightGreen),
101            recall[i],
102            colorize("F1", Color::BrightBlue),
103            f1[i]
104        );
105    }
106    println!(
107        "{} {:.3}",
108        colorize("Macro F1 Score:", Color::BrightMagenta),
109        cm.macro_f1()
110    );
111    // 2. Feature Importance Visualization
112    println!(
113        "{}",
114        stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
115    );
116    // Generate random feature importance scores
117    let feature_names = (0..n_features)
118        .map(|i| format!("Feature_{i}"))
119        .collect::<Vec<String>>();
120    let importance = Array1::from_shape_fn(n_features, |i| {
121        // Make some features more important than others
122        let base = (n_features - i) as f32 / n_features as f32;
123        base + 0.2 * rng.random::<f32>()
124    });
125
126    let fi = FeatureImportance::new(feature_names, importance)?;
127
128    // Print full feature importance with color
129    println!(
130        "{}",
131        fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
132    );
133
134    // Print top-5 features with color
135    println!(
136        "{}",
137        colorize("Top 5 Most Important Features:", Color::BrightCyan)
138    );
139    println!(
140        "{}",
141        fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
142    );
143    // 3. ROC Curve for Binary Classification
144    println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
145    // Generate binary classification data
146    let n_binary = 200;
147    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
148    // Generate scores with some predictive power
149    let y_scores = Array1::from_shape_fn(n_binary, |i| {
150        if y_true_binary[i] == 1 {
151            // Higher scores for positive class
152            0.6 + 0.4 * rng.random::<f32>()
153        } else {
154            // Lower scores for negative class
155            0.4 * rng.random::<f32>()
156        }
157    });
158
159    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
160    println!(
161        "{}: {:.3}",
162        colorize("ROC AUC:", Color::BrightMagenta),
163        roc.auc
164    );
165    println!("\n{}", roc.to_ascii(None, 50, 20));
166
167    // 4. Learning Curve Visualization
168    println!(
169        "\n{}",
170        stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
171    );
172    // Generate learning curve data
173    let n_points = 10;
174    let n_cv = 5;
175    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
176    // Generate training scores (decreasing with size due to overfitting)
177    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
178        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
179    });
180
181    // Generate validation scores (increasing with size)
182    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
183        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
184    });
185
186    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
187    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
188
189    // Print final message with color
190    println!(
191        "{}",
192        colorize(
193            "Model evaluation visualizations completed successfully!",
194            Color::BrightGreen
195        )
196    );
197    Ok(())
198}
Source

pub fn macro_f1(&self) -> F

Calculate macro-averaged F1 score

Examples found in repository?
examples/colored_eval_visualization.rs (line 109)
11fn main() -> Result<()> {
12    println!(
13        "{}",
14        stylize("Neural Network Model Evaluation with Color", Style::Bold)
15    );
16    println!("{}", "-".repeat(50));
17    // Set up color options
18    let color_options = ColorOptions {
19        enabled: true,
20        use_background: false,
21        use_bright: true,
22    };
23    // Generate some example data
24    let n_samples = 500;
25    let n_features = 10;
26    let n_classes = 4;
27    println!(
28        "\n{} {} {} {} {} {}",
29        colorize("Generating", Color::BrightGreen),
30        colorize(n_samples.to_string(), Color::BrightYellow),
31        colorize("samples with", Color::BrightGreen),
32        colorize(n_features.to_string(), Color::BrightYellow),
33        colorize("features for", Color::BrightGreen),
34        colorize(n_classes.to_string(), Color::BrightYellow),
35    );
36
37    // Create a deterministic RNG for reproducibility
38    let mut rng = SmallRng::from_seed([42; 32]);
39
40    // 1. Confusion Matrix Example
41    println!(
42        "\n{}",
43        stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
44    );
45    // Generate random predictions and true labels
46    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
47    // Create slightly correlated predictions (not completely random)
48    let y_pred = Array::from_shape_fn(n_samples, |i| {
49        if rng.random::<f32>() < 0.7 {
50            // 70% chance of correct prediction
51            y_true[i]
52        } else {
53            // 30% chance of random class
54            rng.random_range(0..n_classes)
55        }
56    });
57    // Create confusion matrix
58    let class_labels = vec![
59        "Class A".to_string(),
60        "Class B".to_string(),
61        "Class C".to_string(),
62        "Class D".to_string(),
63    ];
64    let cm = ConfusionMatrix::<f32>::new(
65        &y_true.view(),
66        &y_pred.view(),
67        Some(n_classes),
68        Some(class_labels),
69    )?;
70    // Print raw and normalized confusion matrices with color
71    println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
72    println!(
73        "{}",
74        cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
75    );
76    println!(
77        "\n{}",
78        colorize("Normalized Confusion Matrix:", Color::BrightCyan)
79    );
80    println!(
81        "{}",
82        cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
83    );
84    // Print metrics
85    println!(
86        "\n{} {:.3}",
87        colorize("Overall Accuracy:", Color::BrightMagenta),
88        cm.accuracy()
89    );
90    let precision = cm.precision();
91    let recall = cm.recall();
92    let f1 = cm.f1_score();
93    println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
94    for i in 0..n_classes {
95        println!(
96            "  {}: {}={:.3}, {}={:.3}, {}={:.3}",
97            colorize(format!("Class {i}"), Color::BrightYellow),
98            colorize("Precision", Color::BrightCyan),
99            precision[i],
100            colorize("Recall", Color::BrightGreen),
101            recall[i],
102            colorize("F1", Color::BrightBlue),
103            f1[i]
104        );
105    }
106    println!(
107        "{} {:.3}",
108        colorize("Macro F1 Score:", Color::BrightMagenta),
109        cm.macro_f1()
110    );
111    // 2. Feature Importance Visualization
112    println!(
113        "{}",
114        stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
115    );
116    // Generate random feature importance scores
117    let feature_names = (0..n_features)
118        .map(|i| format!("Feature_{i}"))
119        .collect::<Vec<String>>();
120    let importance = Array1::from_shape_fn(n_features, |i| {
121        // Make some features more important than others
122        let base = (n_features - i) as f32 / n_features as f32;
123        base + 0.2 * rng.random::<f32>()
124    });
125
126    let fi = FeatureImportance::new(feature_names, importance)?;
127
128    // Print full feature importance with color
129    println!(
130        "{}",
131        fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
132    );
133
134    // Print top-5 features with color
135    println!(
136        "{}",
137        colorize("Top 5 Most Important Features:", Color::BrightCyan)
138    );
139    println!(
140        "{}",
141        fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
142    );
143    // 3. ROC Curve for Binary Classification
144    println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
145    // Generate binary classification data
146    let n_binary = 200;
147    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
148    // Generate scores with some predictive power
149    let y_scores = Array1::from_shape_fn(n_binary, |i| {
150        if y_true_binary[i] == 1 {
151            // Higher scores for positive class
152            0.6 + 0.4 * rng.random::<f32>()
153        } else {
154            // Lower scores for negative class
155            0.4 * rng.random::<f32>()
156        }
157    });
158
159    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
160    println!(
161        "{}: {:.3}",
162        colorize("ROC AUC:", Color::BrightMagenta),
163        roc.auc
164    );
165    println!("\n{}", roc.to_ascii(None, 50, 20));
166
167    // 4. Learning Curve Visualization
168    println!(
169        "\n{}",
170        stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
171    );
172    // Generate learning curve data
173    let n_points = 10;
174    let n_cv = 5;
175    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
176    // Generate training scores (decreasing with size due to overfitting)
177    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
178        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
179    });
180
181    // Generate validation scores (increasing with size)
182    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
183        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
184    });
185
186    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
187    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
188
189    // Print final message with color
190    println!(
191        "{}",
192        colorize(
193            "Model evaluation visualizations completed successfully!",
194            Color::BrightGreen
195        )
196    );
197    Ok(())
198}
Source

pub fn class_metrics(&self) -> HashMap<String, Vec<F>>

Get class-wise metrics as a HashMap

Source

pub fn to_ascii(&self, title: Option<&str>, normalized: bool) -> String

Convert the confusion matrix to an ASCII representation

Examples found in repository?
examples/error_pattern_heatmap.rs (line 56)
7fn main() {
8    // Create a reproducible random number generator
9    let mut rng = SmallRng::from_seed([42; 32]);
10    // Generate synthetic multiclass classification data with specific error patterns
11    let num_classes = 5;
12    // Create confusion matrix with controlled error patterns
13    let mut matrix = vec![vec![0; num_classes]; num_classes];
14    // Set diagonal elements (correct classifications) with high values
15    #[allow(clippy::needless_range_loop)]
16    for i in 0..num_classes {
17        matrix[i][i] = 70 + rng.random_range(0..15); // 70-85 correct per class
18    }
19    // Create specific error patterns:
20    // - Classes 0 and 1 often confused
21    matrix[0][1] = 25;
22    matrix[1][0] = 15;
23    // - Class 2 sometimes confused with Class 3
24    matrix[2][3] = 18;
25    // - Class 4 has some misclassifications to all other classes
26    matrix[4][0] = 8;
27    matrix[4][1] = 5;
28    matrix[4][2] = 10;
29    matrix[4][3] = 12;
30    // - Some minor errors scattered about
31    #[allow(clippy::needless_range_loop)]
32    for i in 0..num_classes {
33        for j in 0..num_classes {
34            if i != j && matrix[i][j] == 0 {
35                matrix[i][j] = rng.random_range(0..5);
36            }
37        }
38    }
39    // Convert to ndarray
40    let flat_matrix: Vec<f64> = matrix.iter().flatten().map(|&x| x as f64).collect();
41    let ndarray_matrix =
42        scirs2_core::ndarray::Array::from_shape_vec((num_classes, num_classes), flat_matrix)
43            .unwrap();
44    // Create class labels
45    let class_labels = vec![
46        "Class A".to_string(),
47        "Class B".to_string(),
48        "Class C".to_string(),
49        "Class D".to_string(),
50        "Class E".to_string(),
51    ];
52    // Create confusion matrix
53    let cm = ConfusionMatrix::from_matrix(ndarray_matrix, Some(class_labels)).unwrap();
54    // Example 1: Standard confusion matrix
55    println!("Example 1: Standard Confusion Matrix\n");
56    let regular_output = cm.to_ascii(Some("Classification Results"), false);
57    println!("{regular_output}");
58    // Example 2: Normal heatmap
59    println!("\n\nExample 2: Standard Heatmap Visualization\n");
60    let color_options = ColorOptions {
61        enabled: true,
62        use_bright: true,
63        use_background: false,
64    };
65    let heatmap_output = cm.to_heatmap_with_options(
66        Some("Classification Heatmap"),
67        true, // normalized
68        &color_options,
69    );
70    println!("{heatmap_output}");
71    // Example 3: Error pattern heatmap
72    println!("\n\nExample 3: Error Pattern Heatmap (highlighting misclassifications)\n");
73    let error_heatmap = cm.error_heatmap(Some("Misclassification Analysis"));
74    println!("{error_heatmap}");
75}
More examples
Hide additional examples
examples/confusion_matrix_heatmap.rs (line 61)
8fn main() {
9    // Create a reproducible random number generator
10    let mut rng = SmallRng::from_seed([42; 32]);
11    // Generate synthetic multiclass classification data
12    let num_classes = 5;
13    let n_samples = 500;
14    // Generate true labels (0 to num_classes-1)
15    let mut y_true = Vec::with_capacity(n_samples);
16    for _ in 0..n_samples {
17        y_true.push(rng.random_range(0..num_classes));
18    }
19    // Generate predicted labels with controlled accuracy
20    let mut y_pred = Vec::with_capacity(n_samples);
21    for &true_label in &y_true {
22        // 80% chance to predict correctly..20% chance of error
23        if rng.random::<f64>() < 0.8 {
24            y_pred.push(true_label);
25        } else {
26            // When wrong, tend to predict adjacent classes more often
27            let mut pred = true_label;
28            while pred == true_label {
29                // Generate error that's more likely to be close to true label
30                let error_margin = (rng.random::<f64>() * 2.0).round() as usize; // 0, 1, or 2
31                if rng.random::<bool>() {
32                    pred = (true_label + error_margin) % num_classes;
33                } else {
34                    pred = (true_label + num_classes - error_margin) % num_classes;
35                }
36            }
37            y_pred.push(pred);
38        }
39    }
40    // Convert to ndarray arrays
41    let y_true_array = Array1::from(y_true);
42    let y_pred_array = Array1::from(y_pred);
43    // Create class labels
44    let class_labels = vec![
45        "Cat".to_string(),
46        "Dog".to_string(),
47        "Bird".to_string(),
48        "Fish".to_string(),
49        "Rabbit".to_string(),
50    ];
51    // Create confusion matrix
52    let cm = ConfusionMatrix::<f64>::new(
53        &y_true_array.view(),
54        &y_pred_array.view(),
55        Some(num_classes),
56        Some(class_labels),
57    )
58    .unwrap();
59    // Example 1: Standard confusion matrix
60    println!("Example 1: Standard Confusion Matrix\n");
61    let regular_output = cm.to_ascii(Some("Animal Classification Results"), false);
62    println!("{regular_output}");
63    // Example 2: Confusion matrix with color
64    println!("\n\nExample 2: Colored Confusion Matrix\n");
65    let color_options = ColorOptions {
66        enabled: true,
67        use_bright: true,
68        use_background: false,
69    };
70    let colored_output = cm.to_ascii_with_options(
71        Some("Animal Classification Results (with color)"),
72        false,
73        &color_options,
74    );
75    println!("{colored_output}");
76    // Example 3: Normalized confusion matrix heatmap
77    println!("\n\nExample 3: Normalized Confusion Matrix Heatmap\n");
78    let heatmap_output = cm.to_heatmap_with_options(
79        Some("Animal Classification Heatmap (normalized)"),
80        true, // normalized
81        &color_options,
82    );
83    println!("{heatmap_output}");
84
85    // Example 4: Raw counts heatmap
86    println!("\n\nExample 4: Raw Counts Confusion Matrix Heatmap\n");
87    let raw_heatmap = cm.to_heatmap_with_options(
88        Some("Animal Classification Heatmap (raw counts)"),
89        false, // not normalized
90        &color_options,
91    );
92    println!("{raw_heatmap}");
93}
Source

pub fn to_ascii_with_options( &self, title: Option<&str>, normalized: bool, color_options: &ColorOptions, ) -> String

Convert the confusion matrix to an ASCII representation with color options

Examples found in repository?
examples/confusion_matrix_heatmap.rs (lines 70-74)
8fn main() {
9    // Create a reproducible random number generator
10    let mut rng = SmallRng::from_seed([42; 32]);
11    // Generate synthetic multiclass classification data
12    let num_classes = 5;
13    let n_samples = 500;
14    // Generate true labels (0 to num_classes-1)
15    let mut y_true = Vec::with_capacity(n_samples);
16    for _ in 0..n_samples {
17        y_true.push(rng.random_range(0..num_classes));
18    }
19    // Generate predicted labels with controlled accuracy
20    let mut y_pred = Vec::with_capacity(n_samples);
21    for &true_label in &y_true {
22        // 80% chance to predict correctly..20% chance of error
23        if rng.random::<f64>() < 0.8 {
24            y_pred.push(true_label);
25        } else {
26            // When wrong, tend to predict adjacent classes more often
27            let mut pred = true_label;
28            while pred == true_label {
29                // Generate error that's more likely to be close to true label
30                let error_margin = (rng.random::<f64>() * 2.0).round() as usize; // 0, 1, or 2
31                if rng.random::<bool>() {
32                    pred = (true_label + error_margin) % num_classes;
33                } else {
34                    pred = (true_label + num_classes - error_margin) % num_classes;
35                }
36            }
37            y_pred.push(pred);
38        }
39    }
40    // Convert to ndarray arrays
41    let y_true_array = Array1::from(y_true);
42    let y_pred_array = Array1::from(y_pred);
43    // Create class labels
44    let class_labels = vec![
45        "Cat".to_string(),
46        "Dog".to_string(),
47        "Bird".to_string(),
48        "Fish".to_string(),
49        "Rabbit".to_string(),
50    ];
51    // Create confusion matrix
52    let cm = ConfusionMatrix::<f64>::new(
53        &y_true_array.view(),
54        &y_pred_array.view(),
55        Some(num_classes),
56        Some(class_labels),
57    )
58    .unwrap();
59    // Example 1: Standard confusion matrix
60    println!("Example 1: Standard Confusion Matrix\n");
61    let regular_output = cm.to_ascii(Some("Animal Classification Results"), false);
62    println!("{regular_output}");
63    // Example 2: Confusion matrix with color
64    println!("\n\nExample 2: Colored Confusion Matrix\n");
65    let color_options = ColorOptions {
66        enabled: true,
67        use_bright: true,
68        use_background: false,
69    };
70    let colored_output = cm.to_ascii_with_options(
71        Some("Animal Classification Results (with color)"),
72        false,
73        &color_options,
74    );
75    println!("{colored_output}");
76    // Example 3: Normalized confusion matrix heatmap
77    println!("\n\nExample 3: Normalized Confusion Matrix Heatmap\n");
78    let heatmap_output = cm.to_heatmap_with_options(
79        Some("Animal Classification Heatmap (normalized)"),
80        true, // normalized
81        &color_options,
82    );
83    println!("{heatmap_output}");
84
85    // Example 4: Raw counts heatmap
86    println!("\n\nExample 4: Raw Counts Confusion Matrix Heatmap\n");
87    let raw_heatmap = cm.to_heatmap_with_options(
88        Some("Animal Classification Heatmap (raw counts)"),
89        false, // not normalized
90        &color_options,
91    );
92    println!("{raw_heatmap}");
93}
More examples
Hide additional examples
examples/colored_eval_visualization.rs (line 74)
11fn main() -> Result<()> {
12    println!(
13        "{}",
14        stylize("Neural Network Model Evaluation with Color", Style::Bold)
15    );
16    println!("{}", "-".repeat(50));
17    // Set up color options
18    let color_options = ColorOptions {
19        enabled: true,
20        use_background: false,
21        use_bright: true,
22    };
23    // Generate some example data
24    let n_samples = 500;
25    let n_features = 10;
26    let n_classes = 4;
27    println!(
28        "\n{} {} {} {} {} {}",
29        colorize("Generating", Color::BrightGreen),
30        colorize(n_samples.to_string(), Color::BrightYellow),
31        colorize("samples with", Color::BrightGreen),
32        colorize(n_features.to_string(), Color::BrightYellow),
33        colorize("features for", Color::BrightGreen),
34        colorize(n_classes.to_string(), Color::BrightYellow),
35    );
36
37    // Create a deterministic RNG for reproducibility
38    let mut rng = SmallRng::from_seed([42; 32]);
39
40    // 1. Confusion Matrix Example
41    println!(
42        "\n{}",
43        stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
44    );
45    // Generate random predictions and true labels
46    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
47    // Create slightly correlated predictions (not completely random)
48    let y_pred = Array::from_shape_fn(n_samples, |i| {
49        if rng.random::<f32>() < 0.7 {
50            // 70% chance of correct prediction
51            y_true[i]
52        } else {
53            // 30% chance of random class
54            rng.random_range(0..n_classes)
55        }
56    });
57    // Create confusion matrix
58    let class_labels = vec![
59        "Class A".to_string(),
60        "Class B".to_string(),
61        "Class C".to_string(),
62        "Class D".to_string(),
63    ];
64    let cm = ConfusionMatrix::<f32>::new(
65        &y_true.view(),
66        &y_pred.view(),
67        Some(n_classes),
68        Some(class_labels),
69    )?;
70    // Print raw and normalized confusion matrices with color
71    println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
72    println!(
73        "{}",
74        cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
75    );
76    println!(
77        "\n{}",
78        colorize("Normalized Confusion Matrix:", Color::BrightCyan)
79    );
80    println!(
81        "{}",
82        cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
83    );
84    // Print metrics
85    println!(
86        "\n{} {:.3}",
87        colorize("Overall Accuracy:", Color::BrightMagenta),
88        cm.accuracy()
89    );
90    let precision = cm.precision();
91    let recall = cm.recall();
92    let f1 = cm.f1_score();
93    println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
94    for i in 0..n_classes {
95        println!(
96            "  {}: {}={:.3}, {}={:.3}, {}={:.3}",
97            colorize(format!("Class {i}"), Color::BrightYellow),
98            colorize("Precision", Color::BrightCyan),
99            precision[i],
100            colorize("Recall", Color::BrightGreen),
101            recall[i],
102            colorize("F1", Color::BrightBlue),
103            f1[i]
104        );
105    }
106    println!(
107        "{} {:.3}",
108        colorize("Macro F1 Score:", Color::BrightMagenta),
109        cm.macro_f1()
110    );
111    // 2. Feature Importance Visualization
112    println!(
113        "{}",
114        stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
115    );
116    // Generate random feature importance scores
117    let feature_names = (0..n_features)
118        .map(|i| format!("Feature_{i}"))
119        .collect::<Vec<String>>();
120    let importance = Array1::from_shape_fn(n_features, |i| {
121        // Make some features more important than others
122        let base = (n_features - i) as f32 / n_features as f32;
123        base + 0.2 * rng.random::<f32>()
124    });
125
126    let fi = FeatureImportance::new(feature_names, importance)?;
127
128    // Print full feature importance with color
129    println!(
130        "{}",
131        fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
132    );
133
134    // Print top-5 features with color
135    println!(
136        "{}",
137        colorize("Top 5 Most Important Features:", Color::BrightCyan)
138    );
139    println!(
140        "{}",
141        fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
142    );
143    // 3. ROC Curve for Binary Classification
144    println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
145    // Generate binary classification data
146    let n_binary = 200;
147    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
148    // Generate scores with some predictive power
149    let y_scores = Array1::from_shape_fn(n_binary, |i| {
150        if y_true_binary[i] == 1 {
151            // Higher scores for positive class
152            0.6 + 0.4 * rng.random::<f32>()
153        } else {
154            // Lower scores for negative class
155            0.4 * rng.random::<f32>()
156        }
157    });
158
159    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
160    println!(
161        "{}: {:.3}",
162        colorize("ROC AUC:", Color::BrightMagenta),
163        roc.auc
164    );
165    println!("\n{}", roc.to_ascii(None, 50, 20));
166
167    // 4. Learning Curve Visualization
168    println!(
169        "\n{}",
170        stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
171    );
172    // Generate learning curve data
173    let n_points = 10;
174    let n_cv = 5;
175    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
176    // Generate training scores (decreasing with size due to overfitting)
177    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
178        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
179    });
180
181    // Generate validation scores (increasing with size)
182    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
183        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
184    });
185
186    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
187    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
188
189    // Print final message with color
190    println!(
191        "{}",
192        colorize(
193            "Model evaluation visualizations completed successfully!",
194            Color::BrightGreen
195        )
196    );
197    Ok(())
198}
Source

pub fn to_heatmap(&self, title: Option<&str>, normalized: bool) -> String

Convert the confusion matrix to a heatmap visualization This creates a colorful heatmap visualization of the confusion matrix where cell colors represent the intensity of values using a detailed color gradient.

§Arguments
  • title - Optional title for the heatmap
  • normalized - Whether to normalize the matrix (row values sum to 1)
§Returns
  • String - ASCII heatmap representation
Source

pub fn error_heatmap(&self, title: Option<&str>) -> String

Create a confusion matrix heatmap that focuses on misclassification patterns This visualization is specialized to highlight where the model makes mistakes, with emphasis on the off-diagonal elements to help identify error patterns.

The key features of this visualization are:

  • Diagonal elements (correct classifications) are de-emphasized with dim styling
  • Off-diagonal elements (errors) are highlighted with a color gradient
  • Colors are normalized relative to the maximum off-diagonal value
  • A specialized legend explains error intensity levels
§Arguments
  • title - Optional title for the error heatmap
§Returns
  • String - ASCII error pattern heatmap
§Example
use scirs2_core::ndarray::Array1;
use scirs2_neural::utils::ConfusionMatrix;
// Create some example data
let y_true = Array1::from_vec(vec![0, 1, 2, 0, 1, 2, 0, 1, 2, 0]);
let y_pred = Array1::from_vec(vec![0, 1, 1, 0, 1, 2, 1, 1, 0, 0]);
let class_labels = vec!["Class A".to_string(), "Class B".to_string(), "Class C".to_string()];
let cm = ConfusionMatrix::<f32>::new(&y_true.view(), &y_pred.view(), None, Some(class_labels)).unwrap();
// Generate the error pattern heatmap
let error_viz = cm.error_heatmap(Some("Misclassification Analysis"));
println!("{}", error_viz);
Examples found in repository?
examples/error_pattern_heatmap.rs (line 73)
7fn main() {
8    // Create a reproducible random number generator
9    let mut rng = SmallRng::from_seed([42; 32]);
10    // Generate synthetic multiclass classification data with specific error patterns
11    let num_classes = 5;
12    // Create confusion matrix with controlled error patterns
13    let mut matrix = vec![vec![0; num_classes]; num_classes];
14    // Set diagonal elements (correct classifications) with high values
15    #[allow(clippy::needless_range_loop)]
16    for i in 0..num_classes {
17        matrix[i][i] = 70 + rng.random_range(0..15); // 70-85 correct per class
18    }
19    // Create specific error patterns:
20    // - Classes 0 and 1 often confused
21    matrix[0][1] = 25;
22    matrix[1][0] = 15;
23    // - Class 2 sometimes confused with Class 3
24    matrix[2][3] = 18;
25    // - Class 4 has some misclassifications to all other classes
26    matrix[4][0] = 8;
27    matrix[4][1] = 5;
28    matrix[4][2] = 10;
29    matrix[4][3] = 12;
30    // - Some minor errors scattered about
31    #[allow(clippy::needless_range_loop)]
32    for i in 0..num_classes {
33        for j in 0..num_classes {
34            if i != j && matrix[i][j] == 0 {
35                matrix[i][j] = rng.random_range(0..5);
36            }
37        }
38    }
39    // Convert to ndarray
40    let flat_matrix: Vec<f64> = matrix.iter().flatten().map(|&x| x as f64).collect();
41    let ndarray_matrix =
42        scirs2_core::ndarray::Array::from_shape_vec((num_classes, num_classes), flat_matrix)
43            .unwrap();
44    // Create class labels
45    let class_labels = vec![
46        "Class A".to_string(),
47        "Class B".to_string(),
48        "Class C".to_string(),
49        "Class D".to_string(),
50        "Class E".to_string(),
51    ];
52    // Create confusion matrix
53    let cm = ConfusionMatrix::from_matrix(ndarray_matrix, Some(class_labels)).unwrap();
54    // Example 1: Standard confusion matrix
55    println!("Example 1: Standard Confusion Matrix\n");
56    let regular_output = cm.to_ascii(Some("Classification Results"), false);
57    println!("{regular_output}");
58    // Example 2: Normal heatmap
59    println!("\n\nExample 2: Standard Heatmap Visualization\n");
60    let color_options = ColorOptions {
61        enabled: true,
62        use_bright: true,
63        use_background: false,
64    };
65    let heatmap_output = cm.to_heatmap_with_options(
66        Some("Classification Heatmap"),
67        true, // normalized
68        &color_options,
69    );
70    println!("{heatmap_output}");
71    // Example 3: Error pattern heatmap
72    println!("\n\nExample 3: Error Pattern Heatmap (highlighting misclassifications)\n");
73    let error_heatmap = cm.error_heatmap(Some("Misclassification Analysis"));
74    println!("{error_heatmap}");
75}
Source

pub fn to_heatmap_with_options( &self, title: Option<&str>, normalized: bool, color_options: &ColorOptions, ) -> String

Convert the confusion matrix to a heatmap visualization with customizable options

§Arguments
  • title - Optional title for the heatmap
  • normalized - Whether to normalize the matrix
  • color_options - Color options for visualization
§Returns
  • String - ASCII heatmap representation with colors
Examples found in repository?
examples/error_pattern_heatmap.rs (lines 65-69)
7fn main() {
8    // Create a reproducible random number generator
9    let mut rng = SmallRng::from_seed([42; 32]);
10    // Generate synthetic multiclass classification data with specific error patterns
11    let num_classes = 5;
12    // Create confusion matrix with controlled error patterns
13    let mut matrix = vec![vec![0; num_classes]; num_classes];
14    // Set diagonal elements (correct classifications) with high values
15    #[allow(clippy::needless_range_loop)]
16    for i in 0..num_classes {
17        matrix[i][i] = 70 + rng.random_range(0..15); // 70-85 correct per class
18    }
19    // Create specific error patterns:
20    // - Classes 0 and 1 often confused
21    matrix[0][1] = 25;
22    matrix[1][0] = 15;
23    // - Class 2 sometimes confused with Class 3
24    matrix[2][3] = 18;
25    // - Class 4 has some misclassifications to all other classes
26    matrix[4][0] = 8;
27    matrix[4][1] = 5;
28    matrix[4][2] = 10;
29    matrix[4][3] = 12;
30    // - Some minor errors scattered about
31    #[allow(clippy::needless_range_loop)]
32    for i in 0..num_classes {
33        for j in 0..num_classes {
34            if i != j && matrix[i][j] == 0 {
35                matrix[i][j] = rng.random_range(0..5);
36            }
37        }
38    }
39    // Convert to ndarray
40    let flat_matrix: Vec<f64> = matrix.iter().flatten().map(|&x| x as f64).collect();
41    let ndarray_matrix =
42        scirs2_core::ndarray::Array::from_shape_vec((num_classes, num_classes), flat_matrix)
43            .unwrap();
44    // Create class labels
45    let class_labels = vec![
46        "Class A".to_string(),
47        "Class B".to_string(),
48        "Class C".to_string(),
49        "Class D".to_string(),
50        "Class E".to_string(),
51    ];
52    // Create confusion matrix
53    let cm = ConfusionMatrix::from_matrix(ndarray_matrix, Some(class_labels)).unwrap();
54    // Example 1: Standard confusion matrix
55    println!("Example 1: Standard Confusion Matrix\n");
56    let regular_output = cm.to_ascii(Some("Classification Results"), false);
57    println!("{regular_output}");
58    // Example 2: Normal heatmap
59    println!("\n\nExample 2: Standard Heatmap Visualization\n");
60    let color_options = ColorOptions {
61        enabled: true,
62        use_bright: true,
63        use_background: false,
64    };
65    let heatmap_output = cm.to_heatmap_with_options(
66        Some("Classification Heatmap"),
67        true, // normalized
68        &color_options,
69    );
70    println!("{heatmap_output}");
71    // Example 3: Error pattern heatmap
72    println!("\n\nExample 3: Error Pattern Heatmap (highlighting misclassifications)\n");
73    let error_heatmap = cm.error_heatmap(Some("Misclassification Analysis"));
74    println!("{error_heatmap}");
75}
More examples
Hide additional examples
examples/confusion_matrix_heatmap.rs (lines 78-82)
8fn main() {
9    // Create a reproducible random number generator
10    let mut rng = SmallRng::from_seed([42; 32]);
11    // Generate synthetic multiclass classification data
12    let num_classes = 5;
13    let n_samples = 500;
14    // Generate true labels (0 to num_classes-1)
15    let mut y_true = Vec::with_capacity(n_samples);
16    for _ in 0..n_samples {
17        y_true.push(rng.random_range(0..num_classes));
18    }
19    // Generate predicted labels with controlled accuracy
20    let mut y_pred = Vec::with_capacity(n_samples);
21    for &true_label in &y_true {
22        // 80% chance to predict correctly..20% chance of error
23        if rng.random::<f64>() < 0.8 {
24            y_pred.push(true_label);
25        } else {
26            // When wrong, tend to predict adjacent classes more often
27            let mut pred = true_label;
28            while pred == true_label {
29                // Generate error that's more likely to be close to true label
30                let error_margin = (rng.random::<f64>() * 2.0).round() as usize; // 0, 1, or 2
31                if rng.random::<bool>() {
32                    pred = (true_label + error_margin) % num_classes;
33                } else {
34                    pred = (true_label + num_classes - error_margin) % num_classes;
35                }
36            }
37            y_pred.push(pred);
38        }
39    }
40    // Convert to ndarray arrays
41    let y_true_array = Array1::from(y_true);
42    let y_pred_array = Array1::from(y_pred);
43    // Create class labels
44    let class_labels = vec![
45        "Cat".to_string(),
46        "Dog".to_string(),
47        "Bird".to_string(),
48        "Fish".to_string(),
49        "Rabbit".to_string(),
50    ];
51    // Create confusion matrix
52    let cm = ConfusionMatrix::<f64>::new(
53        &y_true_array.view(),
54        &y_pred_array.view(),
55        Some(num_classes),
56        Some(class_labels),
57    )
58    .unwrap();
59    // Example 1: Standard confusion matrix
60    println!("Example 1: Standard Confusion Matrix\n");
61    let regular_output = cm.to_ascii(Some("Animal Classification Results"), false);
62    println!("{regular_output}");
63    // Example 2: Confusion matrix with color
64    println!("\n\nExample 2: Colored Confusion Matrix\n");
65    let color_options = ColorOptions {
66        enabled: true,
67        use_bright: true,
68        use_background: false,
69    };
70    let colored_output = cm.to_ascii_with_options(
71        Some("Animal Classification Results (with color)"),
72        false,
73        &color_options,
74    );
75    println!("{colored_output}");
76    // Example 3: Normalized confusion matrix heatmap
77    println!("\n\nExample 3: Normalized Confusion Matrix Heatmap\n");
78    let heatmap_output = cm.to_heatmap_with_options(
79        Some("Animal Classification Heatmap (normalized)"),
80        true, // normalized
81        &color_options,
82    );
83    println!("{heatmap_output}");
84
85    // Example 4: Raw counts heatmap
86    println!("\n\nExample 4: Raw Counts Confusion Matrix Heatmap\n");
87    let raw_heatmap = cm.to_heatmap_with_options(
88        Some("Animal Classification Heatmap (raw counts)"),
89        false, // not normalized
90        &color_options,
91    );
92    println!("{raw_heatmap}");
93}

Trait Implementations§

Source§

impl<F: Clone + Float + Debug + Display> Clone for ConfusionMatrix<F>

Source§

fn clone(&self) -> ConfusionMatrix<F>

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl<F: Debug + Float + Debug + Display> Debug for ConfusionMatrix<F>

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more

Auto Trait Implementations§

§

impl<F> Freeze for ConfusionMatrix<F>

§

impl<F> RefUnwindSafe for ConfusionMatrix<F>
where F: RefUnwindSafe,

§

impl<F> Send for ConfusionMatrix<F>
where F: Send,

§

impl<F> Sync for ConfusionMatrix<F>
where F: Sync,

§

impl<F> Unpin for ConfusionMatrix<F>

§

impl<F> UnwindSafe for ConfusionMatrix<F>
where F: RefUnwindSafe,

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V