pub struct ConfusionMatrix<F: Float + Debug + Display> {
pub matrix: Array2<F>,
pub labels: Option<Vec<String>>,
pub num_classes: usize,
}
Expand description
Confusion matrix for classification problems
Fields§
§matrix: Array2<F>
The raw confusion matrix data
labels: Option<Vec<String>>
Class labels (optional)
num_classes: usize
Number of classes
Implementations§
Source§impl<F: Float + Debug + Display> ConfusionMatrix<F>
impl<F: Float + Debug + Display> ConfusionMatrix<F>
Sourcepub fn new(
y_true: &ArrayView1<'_, usize>,
y_pred: &ArrayView1<'_, usize>,
num_classes: Option<usize>,
labels: Option<Vec<String>>,
) -> Result<Self>
pub fn new( y_true: &ArrayView1<'_, usize>, y_pred: &ArrayView1<'_, usize>, num_classes: Option<usize>, labels: Option<Vec<String>>, ) -> Result<Self>
Create a new confusion matrix from predictions and true labels
§Arguments
y_true
- True class labels as integersy_pred
- Predicted class labels as integersnum_classes
- Number of classes (if None, determined from data)labels
- Optional class labels as strings
§Returns
Result<ConfusionMatrix<F>>
- The confusion matrix
§Example
use scirs2_neural::utils::evaluation::ConfusionMatrix;
use scirs2_core::ndarray::Array1;
let y_true = Array1::from_vec(vec![0, 1, 2, 0, 1, 2, 0]);
let y_pred = Array1::from_vec(vec![0, 1, 1, 0, 1, 2, 0]);
let cm = ConfusionMatrix::<f32>::new(&y_true.view(), &y_pred.view(), None, None).unwrap();
Examples found in repository?
8fn main() {
9 // Create a reproducible random number generator
10 let mut rng = SmallRng::from_seed([42; 32]);
11 // Generate synthetic multiclass classification data
12 let num_classes = 5;
13 let n_samples = 500;
14 // Generate true labels (0 to num_classes-1)
15 let mut y_true = Vec::with_capacity(n_samples);
16 for _ in 0..n_samples {
17 y_true.push(rng.random_range(0..num_classes));
18 }
19 // Generate predicted labels with controlled accuracy
20 let mut y_pred = Vec::with_capacity(n_samples);
21 for &true_label in &y_true {
22 // 80% chance to predict correctly..20% chance of error
23 if rng.random::<f64>() < 0.8 {
24 y_pred.push(true_label);
25 } else {
26 // When wrong, tend to predict adjacent classes more often
27 let mut pred = true_label;
28 while pred == true_label {
29 // Generate error that's more likely to be close to true label
30 let error_margin = (rng.random::<f64>() * 2.0).round() as usize; // 0, 1, or 2
31 if rng.random::<bool>() {
32 pred = (true_label + error_margin) % num_classes;
33 } else {
34 pred = (true_label + num_classes - error_margin) % num_classes;
35 }
36 }
37 y_pred.push(pred);
38 }
39 }
40 // Convert to ndarray arrays
41 let y_true_array = Array1::from(y_true);
42 let y_pred_array = Array1::from(y_pred);
43 // Create class labels
44 let class_labels = vec![
45 "Cat".to_string(),
46 "Dog".to_string(),
47 "Bird".to_string(),
48 "Fish".to_string(),
49 "Rabbit".to_string(),
50 ];
51 // Create confusion matrix
52 let cm = ConfusionMatrix::<f64>::new(
53 &y_true_array.view(),
54 &y_pred_array.view(),
55 Some(num_classes),
56 Some(class_labels),
57 )
58 .unwrap();
59 // Example 1: Standard confusion matrix
60 println!("Example 1: Standard Confusion Matrix\n");
61 let regular_output = cm.to_ascii(Some("Animal Classification Results"), false);
62 println!("{regular_output}");
63 // Example 2: Confusion matrix with color
64 println!("\n\nExample 2: Colored Confusion Matrix\n");
65 let color_options = ColorOptions {
66 enabled: true,
67 use_bright: true,
68 use_background: false,
69 };
70 let colored_output = cm.to_ascii_with_options(
71 Some("Animal Classification Results (with color)"),
72 false,
73 &color_options,
74 );
75 println!("{colored_output}");
76 // Example 3: Normalized confusion matrix heatmap
77 println!("\n\nExample 3: Normalized Confusion Matrix Heatmap\n");
78 let heatmap_output = cm.to_heatmap_with_options(
79 Some("Animal Classification Heatmap (normalized)"),
80 true, // normalized
81 &color_options,
82 );
83 println!("{heatmap_output}");
84
85 // Example 4: Raw counts heatmap
86 println!("\n\nExample 4: Raw Counts Confusion Matrix Heatmap\n");
87 let raw_heatmap = cm.to_heatmap_with_options(
88 Some("Animal Classification Heatmap (raw counts)"),
89 false, // not normalized
90 &color_options,
91 );
92 println!("{raw_heatmap}");
93}
More examples
11fn main() -> Result<()> {
12 println!(
13 "{}",
14 stylize("Neural Network Model Evaluation with Color", Style::Bold)
15 );
16 println!("{}", "-".repeat(50));
17 // Set up color options
18 let color_options = ColorOptions {
19 enabled: true,
20 use_background: false,
21 use_bright: true,
22 };
23 // Generate some example data
24 let n_samples = 500;
25 let n_features = 10;
26 let n_classes = 4;
27 println!(
28 "\n{} {} {} {} {} {}",
29 colorize("Generating", Color::BrightGreen),
30 colorize(n_samples.to_string(), Color::BrightYellow),
31 colorize("samples with", Color::BrightGreen),
32 colorize(n_features.to_string(), Color::BrightYellow),
33 colorize("features for", Color::BrightGreen),
34 colorize(n_classes.to_string(), Color::BrightYellow),
35 );
36
37 // Create a deterministic RNG for reproducibility
38 let mut rng = SmallRng::from_seed([42; 32]);
39
40 // 1. Confusion Matrix Example
41 println!(
42 "\n{}",
43 stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
44 );
45 // Generate random predictions and true labels
46 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
47 // Create slightly correlated predictions (not completely random)
48 let y_pred = Array::from_shape_fn(n_samples, |i| {
49 if rng.random::<f32>() < 0.7 {
50 // 70% chance of correct prediction
51 y_true[i]
52 } else {
53 // 30% chance of random class
54 rng.random_range(0..n_classes)
55 }
56 });
57 // Create confusion matrix
58 let class_labels = vec![
59 "Class A".to_string(),
60 "Class B".to_string(),
61 "Class C".to_string(),
62 "Class D".to_string(),
63 ];
64 let cm = ConfusionMatrix::<f32>::new(
65 &y_true.view(),
66 &y_pred.view(),
67 Some(n_classes),
68 Some(class_labels),
69 )?;
70 // Print raw and normalized confusion matrices with color
71 println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
72 println!(
73 "{}",
74 cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
75 );
76 println!(
77 "\n{}",
78 colorize("Normalized Confusion Matrix:", Color::BrightCyan)
79 );
80 println!(
81 "{}",
82 cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
83 );
84 // Print metrics
85 println!(
86 "\n{} {:.3}",
87 colorize("Overall Accuracy:", Color::BrightMagenta),
88 cm.accuracy()
89 );
90 let precision = cm.precision();
91 let recall = cm.recall();
92 let f1 = cm.f1_score();
93 println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
94 for i in 0..n_classes {
95 println!(
96 " {}: {}={:.3}, {}={:.3}, {}={:.3}",
97 colorize(format!("Class {i}"), Color::BrightYellow),
98 colorize("Precision", Color::BrightCyan),
99 precision[i],
100 colorize("Recall", Color::BrightGreen),
101 recall[i],
102 colorize("F1", Color::BrightBlue),
103 f1[i]
104 );
105 }
106 println!(
107 "{} {:.3}",
108 colorize("Macro F1 Score:", Color::BrightMagenta),
109 cm.macro_f1()
110 );
111 // 2. Feature Importance Visualization
112 println!(
113 "{}",
114 stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
115 );
116 // Generate random feature importance scores
117 let feature_names = (0..n_features)
118 .map(|i| format!("Feature_{i}"))
119 .collect::<Vec<String>>();
120 let importance = Array1::from_shape_fn(n_features, |i| {
121 // Make some features more important than others
122 let base = (n_features - i) as f32 / n_features as f32;
123 base + 0.2 * rng.random::<f32>()
124 });
125
126 let fi = FeatureImportance::new(feature_names, importance)?;
127
128 // Print full feature importance with color
129 println!(
130 "{}",
131 fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
132 );
133
134 // Print top-5 features with color
135 println!(
136 "{}",
137 colorize("Top 5 Most Important Features:", Color::BrightCyan)
138 );
139 println!(
140 "{}",
141 fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
142 );
143 // 3. ROC Curve for Binary Classification
144 println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
145 // Generate binary classification data
146 let n_binary = 200;
147 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
148 // Generate scores with some predictive power
149 let y_scores = Array1::from_shape_fn(n_binary, |i| {
150 if y_true_binary[i] == 1 {
151 // Higher scores for positive class
152 0.6 + 0.4 * rng.random::<f32>()
153 } else {
154 // Lower scores for negative class
155 0.4 * rng.random::<f32>()
156 }
157 });
158
159 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
160 println!(
161 "{}: {:.3}",
162 colorize("ROC AUC:", Color::BrightMagenta),
163 roc.auc
164 );
165 println!("\n{}", roc.to_ascii(None, 50, 20));
166
167 // 4. Learning Curve Visualization
168 println!(
169 "\n{}",
170 stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
171 );
172 // Generate learning curve data
173 let n_points = 10;
174 let n_cv = 5;
175 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
176 // Generate training scores (decreasing with size due to overfitting)
177 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
178 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
179 });
180
181 // Generate validation scores (increasing with size)
182 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
183 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
184 });
185
186 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
187 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
188
189 // Print final message with color
190 println!(
191 "{}",
192 colorize(
193 "Model evaluation visualizations completed successfully!",
194 Color::BrightGreen
195 )
196 );
197 Ok(())
198}
Sourcepub fn from_matrix(
matrix: Array2<F>,
labels: Option<Vec<String>>,
) -> Result<Self>
pub fn from_matrix( matrix: Array2<F>, labels: Option<Vec<String>>, ) -> Result<Self>
Create a confusion matrix from raw matrix data
§Arguments
matrix
- Raw confusion matrix datalabels
- Optional class labels
Examples found in repository?
7fn main() {
8 // Create a reproducible random number generator
9 let mut rng = SmallRng::from_seed([42; 32]);
10 // Generate synthetic multiclass classification data with specific error patterns
11 let num_classes = 5;
12 // Create confusion matrix with controlled error patterns
13 let mut matrix = vec![vec![0; num_classes]; num_classes];
14 // Set diagonal elements (correct classifications) with high values
15 #[allow(clippy::needless_range_loop)]
16 for i in 0..num_classes {
17 matrix[i][i] = 70 + rng.random_range(0..15); // 70-85 correct per class
18 }
19 // Create specific error patterns:
20 // - Classes 0 and 1 often confused
21 matrix[0][1] = 25;
22 matrix[1][0] = 15;
23 // - Class 2 sometimes confused with Class 3
24 matrix[2][3] = 18;
25 // - Class 4 has some misclassifications to all other classes
26 matrix[4][0] = 8;
27 matrix[4][1] = 5;
28 matrix[4][2] = 10;
29 matrix[4][3] = 12;
30 // - Some minor errors scattered about
31 #[allow(clippy::needless_range_loop)]
32 for i in 0..num_classes {
33 for j in 0..num_classes {
34 if i != j && matrix[i][j] == 0 {
35 matrix[i][j] = rng.random_range(0..5);
36 }
37 }
38 }
39 // Convert to ndarray
40 let flat_matrix: Vec<f64> = matrix.iter().flatten().map(|&x| x as f64).collect();
41 let ndarray_matrix =
42 scirs2_core::ndarray::Array::from_shape_vec((num_classes, num_classes), flat_matrix)
43 .unwrap();
44 // Create class labels
45 let class_labels = vec![
46 "Class A".to_string(),
47 "Class B".to_string(),
48 "Class C".to_string(),
49 "Class D".to_string(),
50 "Class E".to_string(),
51 ];
52 // Create confusion matrix
53 let cm = ConfusionMatrix::from_matrix(ndarray_matrix, Some(class_labels)).unwrap();
54 // Example 1: Standard confusion matrix
55 println!("Example 1: Standard Confusion Matrix\n");
56 let regular_output = cm.to_ascii(Some("Classification Results"), false);
57 println!("{regular_output}");
58 // Example 2: Normal heatmap
59 println!("\n\nExample 2: Standard Heatmap Visualization\n");
60 let color_options = ColorOptions {
61 enabled: true,
62 use_bright: true,
63 use_background: false,
64 };
65 let heatmap_output = cm.to_heatmap_with_options(
66 Some("Classification Heatmap"),
67 true, // normalized
68 &color_options,
69 );
70 println!("{heatmap_output}");
71 // Example 3: Error pattern heatmap
72 println!("\n\nExample 3: Error Pattern Heatmap (highlighting misclassifications)\n");
73 let error_heatmap = cm.error_heatmap(Some("Misclassification Analysis"));
74 println!("{error_heatmap}");
75}
Sourcepub fn normalized(&self) -> Array2<F>
pub fn normalized(&self) -> Array2<F>
Get the normalized confusion matrix (rows sum to 1)
Sourcepub fn accuracy(&self) -> F
pub fn accuracy(&self) -> F
Calculate accuracy from the confusion matrix
Examples found in repository?
11fn main() -> Result<()> {
12 println!(
13 "{}",
14 stylize("Neural Network Model Evaluation with Color", Style::Bold)
15 );
16 println!("{}", "-".repeat(50));
17 // Set up color options
18 let color_options = ColorOptions {
19 enabled: true,
20 use_background: false,
21 use_bright: true,
22 };
23 // Generate some example data
24 let n_samples = 500;
25 let n_features = 10;
26 let n_classes = 4;
27 println!(
28 "\n{} {} {} {} {} {}",
29 colorize("Generating", Color::BrightGreen),
30 colorize(n_samples.to_string(), Color::BrightYellow),
31 colorize("samples with", Color::BrightGreen),
32 colorize(n_features.to_string(), Color::BrightYellow),
33 colorize("features for", Color::BrightGreen),
34 colorize(n_classes.to_string(), Color::BrightYellow),
35 );
36
37 // Create a deterministic RNG for reproducibility
38 let mut rng = SmallRng::from_seed([42; 32]);
39
40 // 1. Confusion Matrix Example
41 println!(
42 "\n{}",
43 stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
44 );
45 // Generate random predictions and true labels
46 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
47 // Create slightly correlated predictions (not completely random)
48 let y_pred = Array::from_shape_fn(n_samples, |i| {
49 if rng.random::<f32>() < 0.7 {
50 // 70% chance of correct prediction
51 y_true[i]
52 } else {
53 // 30% chance of random class
54 rng.random_range(0..n_classes)
55 }
56 });
57 // Create confusion matrix
58 let class_labels = vec![
59 "Class A".to_string(),
60 "Class B".to_string(),
61 "Class C".to_string(),
62 "Class D".to_string(),
63 ];
64 let cm = ConfusionMatrix::<f32>::new(
65 &y_true.view(),
66 &y_pred.view(),
67 Some(n_classes),
68 Some(class_labels),
69 )?;
70 // Print raw and normalized confusion matrices with color
71 println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
72 println!(
73 "{}",
74 cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
75 );
76 println!(
77 "\n{}",
78 colorize("Normalized Confusion Matrix:", Color::BrightCyan)
79 );
80 println!(
81 "{}",
82 cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
83 );
84 // Print metrics
85 println!(
86 "\n{} {:.3}",
87 colorize("Overall Accuracy:", Color::BrightMagenta),
88 cm.accuracy()
89 );
90 let precision = cm.precision();
91 let recall = cm.recall();
92 let f1 = cm.f1_score();
93 println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
94 for i in 0..n_classes {
95 println!(
96 " {}: {}={:.3}, {}={:.3}, {}={:.3}",
97 colorize(format!("Class {i}"), Color::BrightYellow),
98 colorize("Precision", Color::BrightCyan),
99 precision[i],
100 colorize("Recall", Color::BrightGreen),
101 recall[i],
102 colorize("F1", Color::BrightBlue),
103 f1[i]
104 );
105 }
106 println!(
107 "{} {:.3}",
108 colorize("Macro F1 Score:", Color::BrightMagenta),
109 cm.macro_f1()
110 );
111 // 2. Feature Importance Visualization
112 println!(
113 "{}",
114 stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
115 );
116 // Generate random feature importance scores
117 let feature_names = (0..n_features)
118 .map(|i| format!("Feature_{i}"))
119 .collect::<Vec<String>>();
120 let importance = Array1::from_shape_fn(n_features, |i| {
121 // Make some features more important than others
122 let base = (n_features - i) as f32 / n_features as f32;
123 base + 0.2 * rng.random::<f32>()
124 });
125
126 let fi = FeatureImportance::new(feature_names, importance)?;
127
128 // Print full feature importance with color
129 println!(
130 "{}",
131 fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
132 );
133
134 // Print top-5 features with color
135 println!(
136 "{}",
137 colorize("Top 5 Most Important Features:", Color::BrightCyan)
138 );
139 println!(
140 "{}",
141 fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
142 );
143 // 3. ROC Curve for Binary Classification
144 println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
145 // Generate binary classification data
146 let n_binary = 200;
147 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
148 // Generate scores with some predictive power
149 let y_scores = Array1::from_shape_fn(n_binary, |i| {
150 if y_true_binary[i] == 1 {
151 // Higher scores for positive class
152 0.6 + 0.4 * rng.random::<f32>()
153 } else {
154 // Lower scores for negative class
155 0.4 * rng.random::<f32>()
156 }
157 });
158
159 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
160 println!(
161 "{}: {:.3}",
162 colorize("ROC AUC:", Color::BrightMagenta),
163 roc.auc
164 );
165 println!("\n{}", roc.to_ascii(None, 50, 20));
166
167 // 4. Learning Curve Visualization
168 println!(
169 "\n{}",
170 stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
171 );
172 // Generate learning curve data
173 let n_points = 10;
174 let n_cv = 5;
175 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
176 // Generate training scores (decreasing with size due to overfitting)
177 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
178 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
179 });
180
181 // Generate validation scores (increasing with size)
182 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
183 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
184 });
185
186 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
187 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
188
189 // Print final message with color
190 println!(
191 "{}",
192 colorize(
193 "Model evaluation visualizations completed successfully!",
194 Color::BrightGreen
195 )
196 );
197 Ok(())
198}
Sourcepub fn precision(&self) -> Array1<F>
pub fn precision(&self) -> Array1<F>
Calculate precision for each class
Examples found in repository?
11fn main() -> Result<()> {
12 println!(
13 "{}",
14 stylize("Neural Network Model Evaluation with Color", Style::Bold)
15 );
16 println!("{}", "-".repeat(50));
17 // Set up color options
18 let color_options = ColorOptions {
19 enabled: true,
20 use_background: false,
21 use_bright: true,
22 };
23 // Generate some example data
24 let n_samples = 500;
25 let n_features = 10;
26 let n_classes = 4;
27 println!(
28 "\n{} {} {} {} {} {}",
29 colorize("Generating", Color::BrightGreen),
30 colorize(n_samples.to_string(), Color::BrightYellow),
31 colorize("samples with", Color::BrightGreen),
32 colorize(n_features.to_string(), Color::BrightYellow),
33 colorize("features for", Color::BrightGreen),
34 colorize(n_classes.to_string(), Color::BrightYellow),
35 );
36
37 // Create a deterministic RNG for reproducibility
38 let mut rng = SmallRng::from_seed([42; 32]);
39
40 // 1. Confusion Matrix Example
41 println!(
42 "\n{}",
43 stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
44 );
45 // Generate random predictions and true labels
46 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
47 // Create slightly correlated predictions (not completely random)
48 let y_pred = Array::from_shape_fn(n_samples, |i| {
49 if rng.random::<f32>() < 0.7 {
50 // 70% chance of correct prediction
51 y_true[i]
52 } else {
53 // 30% chance of random class
54 rng.random_range(0..n_classes)
55 }
56 });
57 // Create confusion matrix
58 let class_labels = vec![
59 "Class A".to_string(),
60 "Class B".to_string(),
61 "Class C".to_string(),
62 "Class D".to_string(),
63 ];
64 let cm = ConfusionMatrix::<f32>::new(
65 &y_true.view(),
66 &y_pred.view(),
67 Some(n_classes),
68 Some(class_labels),
69 )?;
70 // Print raw and normalized confusion matrices with color
71 println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
72 println!(
73 "{}",
74 cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
75 );
76 println!(
77 "\n{}",
78 colorize("Normalized Confusion Matrix:", Color::BrightCyan)
79 );
80 println!(
81 "{}",
82 cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
83 );
84 // Print metrics
85 println!(
86 "\n{} {:.3}",
87 colorize("Overall Accuracy:", Color::BrightMagenta),
88 cm.accuracy()
89 );
90 let precision = cm.precision();
91 let recall = cm.recall();
92 let f1 = cm.f1_score();
93 println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
94 for i in 0..n_classes {
95 println!(
96 " {}: {}={:.3}, {}={:.3}, {}={:.3}",
97 colorize(format!("Class {i}"), Color::BrightYellow),
98 colorize("Precision", Color::BrightCyan),
99 precision[i],
100 colorize("Recall", Color::BrightGreen),
101 recall[i],
102 colorize("F1", Color::BrightBlue),
103 f1[i]
104 );
105 }
106 println!(
107 "{} {:.3}",
108 colorize("Macro F1 Score:", Color::BrightMagenta),
109 cm.macro_f1()
110 );
111 // 2. Feature Importance Visualization
112 println!(
113 "{}",
114 stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
115 );
116 // Generate random feature importance scores
117 let feature_names = (0..n_features)
118 .map(|i| format!("Feature_{i}"))
119 .collect::<Vec<String>>();
120 let importance = Array1::from_shape_fn(n_features, |i| {
121 // Make some features more important than others
122 let base = (n_features - i) as f32 / n_features as f32;
123 base + 0.2 * rng.random::<f32>()
124 });
125
126 let fi = FeatureImportance::new(feature_names, importance)?;
127
128 // Print full feature importance with color
129 println!(
130 "{}",
131 fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
132 );
133
134 // Print top-5 features with color
135 println!(
136 "{}",
137 colorize("Top 5 Most Important Features:", Color::BrightCyan)
138 );
139 println!(
140 "{}",
141 fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
142 );
143 // 3. ROC Curve for Binary Classification
144 println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
145 // Generate binary classification data
146 let n_binary = 200;
147 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
148 // Generate scores with some predictive power
149 let y_scores = Array1::from_shape_fn(n_binary, |i| {
150 if y_true_binary[i] == 1 {
151 // Higher scores for positive class
152 0.6 + 0.4 * rng.random::<f32>()
153 } else {
154 // Lower scores for negative class
155 0.4 * rng.random::<f32>()
156 }
157 });
158
159 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
160 println!(
161 "{}: {:.3}",
162 colorize("ROC AUC:", Color::BrightMagenta),
163 roc.auc
164 );
165 println!("\n{}", roc.to_ascii(None, 50, 20));
166
167 // 4. Learning Curve Visualization
168 println!(
169 "\n{}",
170 stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
171 );
172 // Generate learning curve data
173 let n_points = 10;
174 let n_cv = 5;
175 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
176 // Generate training scores (decreasing with size due to overfitting)
177 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
178 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
179 });
180
181 // Generate validation scores (increasing with size)
182 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
183 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
184 });
185
186 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
187 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
188
189 // Print final message with color
190 println!(
191 "{}",
192 colorize(
193 "Model evaluation visualizations completed successfully!",
194 Color::BrightGreen
195 )
196 );
197 Ok(())
198}
Sourcepub fn recall(&self) -> Array1<F>
pub fn recall(&self) -> Array1<F>
Calculate recall for each class
Examples found in repository?
11fn main() -> Result<()> {
12 println!(
13 "{}",
14 stylize("Neural Network Model Evaluation with Color", Style::Bold)
15 );
16 println!("{}", "-".repeat(50));
17 // Set up color options
18 let color_options = ColorOptions {
19 enabled: true,
20 use_background: false,
21 use_bright: true,
22 };
23 // Generate some example data
24 let n_samples = 500;
25 let n_features = 10;
26 let n_classes = 4;
27 println!(
28 "\n{} {} {} {} {} {}",
29 colorize("Generating", Color::BrightGreen),
30 colorize(n_samples.to_string(), Color::BrightYellow),
31 colorize("samples with", Color::BrightGreen),
32 colorize(n_features.to_string(), Color::BrightYellow),
33 colorize("features for", Color::BrightGreen),
34 colorize(n_classes.to_string(), Color::BrightYellow),
35 );
36
37 // Create a deterministic RNG for reproducibility
38 let mut rng = SmallRng::from_seed([42; 32]);
39
40 // 1. Confusion Matrix Example
41 println!(
42 "\n{}",
43 stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
44 );
45 // Generate random predictions and true labels
46 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
47 // Create slightly correlated predictions (not completely random)
48 let y_pred = Array::from_shape_fn(n_samples, |i| {
49 if rng.random::<f32>() < 0.7 {
50 // 70% chance of correct prediction
51 y_true[i]
52 } else {
53 // 30% chance of random class
54 rng.random_range(0..n_classes)
55 }
56 });
57 // Create confusion matrix
58 let class_labels = vec![
59 "Class A".to_string(),
60 "Class B".to_string(),
61 "Class C".to_string(),
62 "Class D".to_string(),
63 ];
64 let cm = ConfusionMatrix::<f32>::new(
65 &y_true.view(),
66 &y_pred.view(),
67 Some(n_classes),
68 Some(class_labels),
69 )?;
70 // Print raw and normalized confusion matrices with color
71 println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
72 println!(
73 "{}",
74 cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
75 );
76 println!(
77 "\n{}",
78 colorize("Normalized Confusion Matrix:", Color::BrightCyan)
79 );
80 println!(
81 "{}",
82 cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
83 );
84 // Print metrics
85 println!(
86 "\n{} {:.3}",
87 colorize("Overall Accuracy:", Color::BrightMagenta),
88 cm.accuracy()
89 );
90 let precision = cm.precision();
91 let recall = cm.recall();
92 let f1 = cm.f1_score();
93 println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
94 for i in 0..n_classes {
95 println!(
96 " {}: {}={:.3}, {}={:.3}, {}={:.3}",
97 colorize(format!("Class {i}"), Color::BrightYellow),
98 colorize("Precision", Color::BrightCyan),
99 precision[i],
100 colorize("Recall", Color::BrightGreen),
101 recall[i],
102 colorize("F1", Color::BrightBlue),
103 f1[i]
104 );
105 }
106 println!(
107 "{} {:.3}",
108 colorize("Macro F1 Score:", Color::BrightMagenta),
109 cm.macro_f1()
110 );
111 // 2. Feature Importance Visualization
112 println!(
113 "{}",
114 stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
115 );
116 // Generate random feature importance scores
117 let feature_names = (0..n_features)
118 .map(|i| format!("Feature_{i}"))
119 .collect::<Vec<String>>();
120 let importance = Array1::from_shape_fn(n_features, |i| {
121 // Make some features more important than others
122 let base = (n_features - i) as f32 / n_features as f32;
123 base + 0.2 * rng.random::<f32>()
124 });
125
126 let fi = FeatureImportance::new(feature_names, importance)?;
127
128 // Print full feature importance with color
129 println!(
130 "{}",
131 fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
132 );
133
134 // Print top-5 features with color
135 println!(
136 "{}",
137 colorize("Top 5 Most Important Features:", Color::BrightCyan)
138 );
139 println!(
140 "{}",
141 fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
142 );
143 // 3. ROC Curve for Binary Classification
144 println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
145 // Generate binary classification data
146 let n_binary = 200;
147 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
148 // Generate scores with some predictive power
149 let y_scores = Array1::from_shape_fn(n_binary, |i| {
150 if y_true_binary[i] == 1 {
151 // Higher scores for positive class
152 0.6 + 0.4 * rng.random::<f32>()
153 } else {
154 // Lower scores for negative class
155 0.4 * rng.random::<f32>()
156 }
157 });
158
159 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
160 println!(
161 "{}: {:.3}",
162 colorize("ROC AUC:", Color::BrightMagenta),
163 roc.auc
164 );
165 println!("\n{}", roc.to_ascii(None, 50, 20));
166
167 // 4. Learning Curve Visualization
168 println!(
169 "\n{}",
170 stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
171 );
172 // Generate learning curve data
173 let n_points = 10;
174 let n_cv = 5;
175 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
176 // Generate training scores (decreasing with size due to overfitting)
177 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
178 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
179 });
180
181 // Generate validation scores (increasing with size)
182 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
183 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
184 });
185
186 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
187 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
188
189 // Print final message with color
190 println!(
191 "{}",
192 colorize(
193 "Model evaluation visualizations completed successfully!",
194 Color::BrightGreen
195 )
196 );
197 Ok(())
198}
Sourcepub fn f1_score(&self) -> Array1<F>
pub fn f1_score(&self) -> Array1<F>
Calculate F1 score for each class
Examples found in repository?
11fn main() -> Result<()> {
12 println!(
13 "{}",
14 stylize("Neural Network Model Evaluation with Color", Style::Bold)
15 );
16 println!("{}", "-".repeat(50));
17 // Set up color options
18 let color_options = ColorOptions {
19 enabled: true,
20 use_background: false,
21 use_bright: true,
22 };
23 // Generate some example data
24 let n_samples = 500;
25 let n_features = 10;
26 let n_classes = 4;
27 println!(
28 "\n{} {} {} {} {} {}",
29 colorize("Generating", Color::BrightGreen),
30 colorize(n_samples.to_string(), Color::BrightYellow),
31 colorize("samples with", Color::BrightGreen),
32 colorize(n_features.to_string(), Color::BrightYellow),
33 colorize("features for", Color::BrightGreen),
34 colorize(n_classes.to_string(), Color::BrightYellow),
35 );
36
37 // Create a deterministic RNG for reproducibility
38 let mut rng = SmallRng::from_seed([42; 32]);
39
40 // 1. Confusion Matrix Example
41 println!(
42 "\n{}",
43 stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
44 );
45 // Generate random predictions and true labels
46 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
47 // Create slightly correlated predictions (not completely random)
48 let y_pred = Array::from_shape_fn(n_samples, |i| {
49 if rng.random::<f32>() < 0.7 {
50 // 70% chance of correct prediction
51 y_true[i]
52 } else {
53 // 30% chance of random class
54 rng.random_range(0..n_classes)
55 }
56 });
57 // Create confusion matrix
58 let class_labels = vec![
59 "Class A".to_string(),
60 "Class B".to_string(),
61 "Class C".to_string(),
62 "Class D".to_string(),
63 ];
64 let cm = ConfusionMatrix::<f32>::new(
65 &y_true.view(),
66 &y_pred.view(),
67 Some(n_classes),
68 Some(class_labels),
69 )?;
70 // Print raw and normalized confusion matrices with color
71 println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
72 println!(
73 "{}",
74 cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
75 );
76 println!(
77 "\n{}",
78 colorize("Normalized Confusion Matrix:", Color::BrightCyan)
79 );
80 println!(
81 "{}",
82 cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
83 );
84 // Print metrics
85 println!(
86 "\n{} {:.3}",
87 colorize("Overall Accuracy:", Color::BrightMagenta),
88 cm.accuracy()
89 );
90 let precision = cm.precision();
91 let recall = cm.recall();
92 let f1 = cm.f1_score();
93 println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
94 for i in 0..n_classes {
95 println!(
96 " {}: {}={:.3}, {}={:.3}, {}={:.3}",
97 colorize(format!("Class {i}"), Color::BrightYellow),
98 colorize("Precision", Color::BrightCyan),
99 precision[i],
100 colorize("Recall", Color::BrightGreen),
101 recall[i],
102 colorize("F1", Color::BrightBlue),
103 f1[i]
104 );
105 }
106 println!(
107 "{} {:.3}",
108 colorize("Macro F1 Score:", Color::BrightMagenta),
109 cm.macro_f1()
110 );
111 // 2. Feature Importance Visualization
112 println!(
113 "{}",
114 stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
115 );
116 // Generate random feature importance scores
117 let feature_names = (0..n_features)
118 .map(|i| format!("Feature_{i}"))
119 .collect::<Vec<String>>();
120 let importance = Array1::from_shape_fn(n_features, |i| {
121 // Make some features more important than others
122 let base = (n_features - i) as f32 / n_features as f32;
123 base + 0.2 * rng.random::<f32>()
124 });
125
126 let fi = FeatureImportance::new(feature_names, importance)?;
127
128 // Print full feature importance with color
129 println!(
130 "{}",
131 fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
132 );
133
134 // Print top-5 features with color
135 println!(
136 "{}",
137 colorize("Top 5 Most Important Features:", Color::BrightCyan)
138 );
139 println!(
140 "{}",
141 fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
142 );
143 // 3. ROC Curve for Binary Classification
144 println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
145 // Generate binary classification data
146 let n_binary = 200;
147 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
148 // Generate scores with some predictive power
149 let y_scores = Array1::from_shape_fn(n_binary, |i| {
150 if y_true_binary[i] == 1 {
151 // Higher scores for positive class
152 0.6 + 0.4 * rng.random::<f32>()
153 } else {
154 // Lower scores for negative class
155 0.4 * rng.random::<f32>()
156 }
157 });
158
159 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
160 println!(
161 "{}: {:.3}",
162 colorize("ROC AUC:", Color::BrightMagenta),
163 roc.auc
164 );
165 println!("\n{}", roc.to_ascii(None, 50, 20));
166
167 // 4. Learning Curve Visualization
168 println!(
169 "\n{}",
170 stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
171 );
172 // Generate learning curve data
173 let n_points = 10;
174 let n_cv = 5;
175 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
176 // Generate training scores (decreasing with size due to overfitting)
177 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
178 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
179 });
180
181 // Generate validation scores (increasing with size)
182 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
183 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
184 });
185
186 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
187 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
188
189 // Print final message with color
190 println!(
191 "{}",
192 colorize(
193 "Model evaluation visualizations completed successfully!",
194 Color::BrightGreen
195 )
196 );
197 Ok(())
198}
Sourcepub fn macro_f1(&self) -> F
pub fn macro_f1(&self) -> F
Calculate macro-averaged F1 score
Examples found in repository?
11fn main() -> Result<()> {
12 println!(
13 "{}",
14 stylize("Neural Network Model Evaluation with Color", Style::Bold)
15 );
16 println!("{}", "-".repeat(50));
17 // Set up color options
18 let color_options = ColorOptions {
19 enabled: true,
20 use_background: false,
21 use_bright: true,
22 };
23 // Generate some example data
24 let n_samples = 500;
25 let n_features = 10;
26 let n_classes = 4;
27 println!(
28 "\n{} {} {} {} {} {}",
29 colorize("Generating", Color::BrightGreen),
30 colorize(n_samples.to_string(), Color::BrightYellow),
31 colorize("samples with", Color::BrightGreen),
32 colorize(n_features.to_string(), Color::BrightYellow),
33 colorize("features for", Color::BrightGreen),
34 colorize(n_classes.to_string(), Color::BrightYellow),
35 );
36
37 // Create a deterministic RNG for reproducibility
38 let mut rng = SmallRng::from_seed([42; 32]);
39
40 // 1. Confusion Matrix Example
41 println!(
42 "\n{}",
43 stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
44 );
45 // Generate random predictions and true labels
46 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
47 // Create slightly correlated predictions (not completely random)
48 let y_pred = Array::from_shape_fn(n_samples, |i| {
49 if rng.random::<f32>() < 0.7 {
50 // 70% chance of correct prediction
51 y_true[i]
52 } else {
53 // 30% chance of random class
54 rng.random_range(0..n_classes)
55 }
56 });
57 // Create confusion matrix
58 let class_labels = vec![
59 "Class A".to_string(),
60 "Class B".to_string(),
61 "Class C".to_string(),
62 "Class D".to_string(),
63 ];
64 let cm = ConfusionMatrix::<f32>::new(
65 &y_true.view(),
66 &y_pred.view(),
67 Some(n_classes),
68 Some(class_labels),
69 )?;
70 // Print raw and normalized confusion matrices with color
71 println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
72 println!(
73 "{}",
74 cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
75 );
76 println!(
77 "\n{}",
78 colorize("Normalized Confusion Matrix:", Color::BrightCyan)
79 );
80 println!(
81 "{}",
82 cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
83 );
84 // Print metrics
85 println!(
86 "\n{} {:.3}",
87 colorize("Overall Accuracy:", Color::BrightMagenta),
88 cm.accuracy()
89 );
90 let precision = cm.precision();
91 let recall = cm.recall();
92 let f1 = cm.f1_score();
93 println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
94 for i in 0..n_classes {
95 println!(
96 " {}: {}={:.3}, {}={:.3}, {}={:.3}",
97 colorize(format!("Class {i}"), Color::BrightYellow),
98 colorize("Precision", Color::BrightCyan),
99 precision[i],
100 colorize("Recall", Color::BrightGreen),
101 recall[i],
102 colorize("F1", Color::BrightBlue),
103 f1[i]
104 );
105 }
106 println!(
107 "{} {:.3}",
108 colorize("Macro F1 Score:", Color::BrightMagenta),
109 cm.macro_f1()
110 );
111 // 2. Feature Importance Visualization
112 println!(
113 "{}",
114 stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
115 );
116 // Generate random feature importance scores
117 let feature_names = (0..n_features)
118 .map(|i| format!("Feature_{i}"))
119 .collect::<Vec<String>>();
120 let importance = Array1::from_shape_fn(n_features, |i| {
121 // Make some features more important than others
122 let base = (n_features - i) as f32 / n_features as f32;
123 base + 0.2 * rng.random::<f32>()
124 });
125
126 let fi = FeatureImportance::new(feature_names, importance)?;
127
128 // Print full feature importance with color
129 println!(
130 "{}",
131 fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
132 );
133
134 // Print top-5 features with color
135 println!(
136 "{}",
137 colorize("Top 5 Most Important Features:", Color::BrightCyan)
138 );
139 println!(
140 "{}",
141 fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
142 );
143 // 3. ROC Curve for Binary Classification
144 println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
145 // Generate binary classification data
146 let n_binary = 200;
147 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
148 // Generate scores with some predictive power
149 let y_scores = Array1::from_shape_fn(n_binary, |i| {
150 if y_true_binary[i] == 1 {
151 // Higher scores for positive class
152 0.6 + 0.4 * rng.random::<f32>()
153 } else {
154 // Lower scores for negative class
155 0.4 * rng.random::<f32>()
156 }
157 });
158
159 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
160 println!(
161 "{}: {:.3}",
162 colorize("ROC AUC:", Color::BrightMagenta),
163 roc.auc
164 );
165 println!("\n{}", roc.to_ascii(None, 50, 20));
166
167 // 4. Learning Curve Visualization
168 println!(
169 "\n{}",
170 stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
171 );
172 // Generate learning curve data
173 let n_points = 10;
174 let n_cv = 5;
175 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
176 // Generate training scores (decreasing with size due to overfitting)
177 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
178 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
179 });
180
181 // Generate validation scores (increasing with size)
182 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
183 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
184 });
185
186 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
187 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
188
189 // Print final message with color
190 println!(
191 "{}",
192 colorize(
193 "Model evaluation visualizations completed successfully!",
194 Color::BrightGreen
195 )
196 );
197 Ok(())
198}
Sourcepub fn class_metrics(&self) -> HashMap<String, Vec<F>>
pub fn class_metrics(&self) -> HashMap<String, Vec<F>>
Get class-wise metrics as a HashMap
Sourcepub fn to_ascii(&self, title: Option<&str>, normalized: bool) -> String
pub fn to_ascii(&self, title: Option<&str>, normalized: bool) -> String
Convert the confusion matrix to an ASCII representation
Examples found in repository?
7fn main() {
8 // Create a reproducible random number generator
9 let mut rng = SmallRng::from_seed([42; 32]);
10 // Generate synthetic multiclass classification data with specific error patterns
11 let num_classes = 5;
12 // Create confusion matrix with controlled error patterns
13 let mut matrix = vec![vec![0; num_classes]; num_classes];
14 // Set diagonal elements (correct classifications) with high values
15 #[allow(clippy::needless_range_loop)]
16 for i in 0..num_classes {
17 matrix[i][i] = 70 + rng.random_range(0..15); // 70-85 correct per class
18 }
19 // Create specific error patterns:
20 // - Classes 0 and 1 often confused
21 matrix[0][1] = 25;
22 matrix[1][0] = 15;
23 // - Class 2 sometimes confused with Class 3
24 matrix[2][3] = 18;
25 // - Class 4 has some misclassifications to all other classes
26 matrix[4][0] = 8;
27 matrix[4][1] = 5;
28 matrix[4][2] = 10;
29 matrix[4][3] = 12;
30 // - Some minor errors scattered about
31 #[allow(clippy::needless_range_loop)]
32 for i in 0..num_classes {
33 for j in 0..num_classes {
34 if i != j && matrix[i][j] == 0 {
35 matrix[i][j] = rng.random_range(0..5);
36 }
37 }
38 }
39 // Convert to ndarray
40 let flat_matrix: Vec<f64> = matrix.iter().flatten().map(|&x| x as f64).collect();
41 let ndarray_matrix =
42 scirs2_core::ndarray::Array::from_shape_vec((num_classes, num_classes), flat_matrix)
43 .unwrap();
44 // Create class labels
45 let class_labels = vec![
46 "Class A".to_string(),
47 "Class B".to_string(),
48 "Class C".to_string(),
49 "Class D".to_string(),
50 "Class E".to_string(),
51 ];
52 // Create confusion matrix
53 let cm = ConfusionMatrix::from_matrix(ndarray_matrix, Some(class_labels)).unwrap();
54 // Example 1: Standard confusion matrix
55 println!("Example 1: Standard Confusion Matrix\n");
56 let regular_output = cm.to_ascii(Some("Classification Results"), false);
57 println!("{regular_output}");
58 // Example 2: Normal heatmap
59 println!("\n\nExample 2: Standard Heatmap Visualization\n");
60 let color_options = ColorOptions {
61 enabled: true,
62 use_bright: true,
63 use_background: false,
64 };
65 let heatmap_output = cm.to_heatmap_with_options(
66 Some("Classification Heatmap"),
67 true, // normalized
68 &color_options,
69 );
70 println!("{heatmap_output}");
71 // Example 3: Error pattern heatmap
72 println!("\n\nExample 3: Error Pattern Heatmap (highlighting misclassifications)\n");
73 let error_heatmap = cm.error_heatmap(Some("Misclassification Analysis"));
74 println!("{error_heatmap}");
75}
More examples
8fn main() {
9 // Create a reproducible random number generator
10 let mut rng = SmallRng::from_seed([42; 32]);
11 // Generate synthetic multiclass classification data
12 let num_classes = 5;
13 let n_samples = 500;
14 // Generate true labels (0 to num_classes-1)
15 let mut y_true = Vec::with_capacity(n_samples);
16 for _ in 0..n_samples {
17 y_true.push(rng.random_range(0..num_classes));
18 }
19 // Generate predicted labels with controlled accuracy
20 let mut y_pred = Vec::with_capacity(n_samples);
21 for &true_label in &y_true {
22 // 80% chance to predict correctly..20% chance of error
23 if rng.random::<f64>() < 0.8 {
24 y_pred.push(true_label);
25 } else {
26 // When wrong, tend to predict adjacent classes more often
27 let mut pred = true_label;
28 while pred == true_label {
29 // Generate error that's more likely to be close to true label
30 let error_margin = (rng.random::<f64>() * 2.0).round() as usize; // 0, 1, or 2
31 if rng.random::<bool>() {
32 pred = (true_label + error_margin) % num_classes;
33 } else {
34 pred = (true_label + num_classes - error_margin) % num_classes;
35 }
36 }
37 y_pred.push(pred);
38 }
39 }
40 // Convert to ndarray arrays
41 let y_true_array = Array1::from(y_true);
42 let y_pred_array = Array1::from(y_pred);
43 // Create class labels
44 let class_labels = vec![
45 "Cat".to_string(),
46 "Dog".to_string(),
47 "Bird".to_string(),
48 "Fish".to_string(),
49 "Rabbit".to_string(),
50 ];
51 // Create confusion matrix
52 let cm = ConfusionMatrix::<f64>::new(
53 &y_true_array.view(),
54 &y_pred_array.view(),
55 Some(num_classes),
56 Some(class_labels),
57 )
58 .unwrap();
59 // Example 1: Standard confusion matrix
60 println!("Example 1: Standard Confusion Matrix\n");
61 let regular_output = cm.to_ascii(Some("Animal Classification Results"), false);
62 println!("{regular_output}");
63 // Example 2: Confusion matrix with color
64 println!("\n\nExample 2: Colored Confusion Matrix\n");
65 let color_options = ColorOptions {
66 enabled: true,
67 use_bright: true,
68 use_background: false,
69 };
70 let colored_output = cm.to_ascii_with_options(
71 Some("Animal Classification Results (with color)"),
72 false,
73 &color_options,
74 );
75 println!("{colored_output}");
76 // Example 3: Normalized confusion matrix heatmap
77 println!("\n\nExample 3: Normalized Confusion Matrix Heatmap\n");
78 let heatmap_output = cm.to_heatmap_with_options(
79 Some("Animal Classification Heatmap (normalized)"),
80 true, // normalized
81 &color_options,
82 );
83 println!("{heatmap_output}");
84
85 // Example 4: Raw counts heatmap
86 println!("\n\nExample 4: Raw Counts Confusion Matrix Heatmap\n");
87 let raw_heatmap = cm.to_heatmap_with_options(
88 Some("Animal Classification Heatmap (raw counts)"),
89 false, // not normalized
90 &color_options,
91 );
92 println!("{raw_heatmap}");
93}
Sourcepub fn to_ascii_with_options(
&self,
title: Option<&str>,
normalized: bool,
color_options: &ColorOptions,
) -> String
pub fn to_ascii_with_options( &self, title: Option<&str>, normalized: bool, color_options: &ColorOptions, ) -> String
Convert the confusion matrix to an ASCII representation with color options
Examples found in repository?
8fn main() {
9 // Create a reproducible random number generator
10 let mut rng = SmallRng::from_seed([42; 32]);
11 // Generate synthetic multiclass classification data
12 let num_classes = 5;
13 let n_samples = 500;
14 // Generate true labels (0 to num_classes-1)
15 let mut y_true = Vec::with_capacity(n_samples);
16 for _ in 0..n_samples {
17 y_true.push(rng.random_range(0..num_classes));
18 }
19 // Generate predicted labels with controlled accuracy
20 let mut y_pred = Vec::with_capacity(n_samples);
21 for &true_label in &y_true {
22 // 80% chance to predict correctly..20% chance of error
23 if rng.random::<f64>() < 0.8 {
24 y_pred.push(true_label);
25 } else {
26 // When wrong, tend to predict adjacent classes more often
27 let mut pred = true_label;
28 while pred == true_label {
29 // Generate error that's more likely to be close to true label
30 let error_margin = (rng.random::<f64>() * 2.0).round() as usize; // 0, 1, or 2
31 if rng.random::<bool>() {
32 pred = (true_label + error_margin) % num_classes;
33 } else {
34 pred = (true_label + num_classes - error_margin) % num_classes;
35 }
36 }
37 y_pred.push(pred);
38 }
39 }
40 // Convert to ndarray arrays
41 let y_true_array = Array1::from(y_true);
42 let y_pred_array = Array1::from(y_pred);
43 // Create class labels
44 let class_labels = vec![
45 "Cat".to_string(),
46 "Dog".to_string(),
47 "Bird".to_string(),
48 "Fish".to_string(),
49 "Rabbit".to_string(),
50 ];
51 // Create confusion matrix
52 let cm = ConfusionMatrix::<f64>::new(
53 &y_true_array.view(),
54 &y_pred_array.view(),
55 Some(num_classes),
56 Some(class_labels),
57 )
58 .unwrap();
59 // Example 1: Standard confusion matrix
60 println!("Example 1: Standard Confusion Matrix\n");
61 let regular_output = cm.to_ascii(Some("Animal Classification Results"), false);
62 println!("{regular_output}");
63 // Example 2: Confusion matrix with color
64 println!("\n\nExample 2: Colored Confusion Matrix\n");
65 let color_options = ColorOptions {
66 enabled: true,
67 use_bright: true,
68 use_background: false,
69 };
70 let colored_output = cm.to_ascii_with_options(
71 Some("Animal Classification Results (with color)"),
72 false,
73 &color_options,
74 );
75 println!("{colored_output}");
76 // Example 3: Normalized confusion matrix heatmap
77 println!("\n\nExample 3: Normalized Confusion Matrix Heatmap\n");
78 let heatmap_output = cm.to_heatmap_with_options(
79 Some("Animal Classification Heatmap (normalized)"),
80 true, // normalized
81 &color_options,
82 );
83 println!("{heatmap_output}");
84
85 // Example 4: Raw counts heatmap
86 println!("\n\nExample 4: Raw Counts Confusion Matrix Heatmap\n");
87 let raw_heatmap = cm.to_heatmap_with_options(
88 Some("Animal Classification Heatmap (raw counts)"),
89 false, // not normalized
90 &color_options,
91 );
92 println!("{raw_heatmap}");
93}
More examples
11fn main() -> Result<()> {
12 println!(
13 "{}",
14 stylize("Neural Network Model Evaluation with Color", Style::Bold)
15 );
16 println!("{}", "-".repeat(50));
17 // Set up color options
18 let color_options = ColorOptions {
19 enabled: true,
20 use_background: false,
21 use_bright: true,
22 };
23 // Generate some example data
24 let n_samples = 500;
25 let n_features = 10;
26 let n_classes = 4;
27 println!(
28 "\n{} {} {} {} {} {}",
29 colorize("Generating", Color::BrightGreen),
30 colorize(n_samples.to_string(), Color::BrightYellow),
31 colorize("samples with", Color::BrightGreen),
32 colorize(n_features.to_string(), Color::BrightYellow),
33 colorize("features for", Color::BrightGreen),
34 colorize(n_classes.to_string(), Color::BrightYellow),
35 );
36
37 // Create a deterministic RNG for reproducibility
38 let mut rng = SmallRng::from_seed([42; 32]);
39
40 // 1. Confusion Matrix Example
41 println!(
42 "\n{}",
43 stylize("1. CONFUSION MATRIX VISUALIZATION", Style::Bold)
44 );
45 // Generate random predictions and true labels
46 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
47 // Create slightly correlated predictions (not completely random)
48 let y_pred = Array::from_shape_fn(n_samples, |i| {
49 if rng.random::<f32>() < 0.7 {
50 // 70% chance of correct prediction
51 y_true[i]
52 } else {
53 // 30% chance of random class
54 rng.random_range(0..n_classes)
55 }
56 });
57 // Create confusion matrix
58 let class_labels = vec![
59 "Class A".to_string(),
60 "Class B".to_string(),
61 "Class C".to_string(),
62 "Class D".to_string(),
63 ];
64 let cm = ConfusionMatrix::<f32>::new(
65 &y_true.view(),
66 &y_pred.view(),
67 Some(n_classes),
68 Some(class_labels),
69 )?;
70 // Print raw and normalized confusion matrices with color
71 println!("\n{}", colorize("Raw Confusion Matrix:", Color::BrightCyan));
72 println!(
73 "{}",
74 cm.to_ascii_with_options(Some("Confusion Matrix"), false, &color_options)
75 );
76 println!(
77 "\n{}",
78 colorize("Normalized Confusion Matrix:", Color::BrightCyan)
79 );
80 println!(
81 "{}",
82 cm.to_ascii_with_options(Some("Normalized Confusion Matrix"), true, &color_options)
83 );
84 // Print metrics
85 println!(
86 "\n{} {:.3}",
87 colorize("Overall Accuracy:", Color::BrightMagenta),
88 cm.accuracy()
89 );
90 let precision = cm.precision();
91 let recall = cm.recall();
92 let f1 = cm.f1_score();
93 println!("{}", colorize("Per-class metrics:", Color::BrightMagenta));
94 for i in 0..n_classes {
95 println!(
96 " {}: {}={:.3}, {}={:.3}, {}={:.3}",
97 colorize(format!("Class {i}"), Color::BrightYellow),
98 colorize("Precision", Color::BrightCyan),
99 precision[i],
100 colorize("Recall", Color::BrightGreen),
101 recall[i],
102 colorize("F1", Color::BrightBlue),
103 f1[i]
104 );
105 }
106 println!(
107 "{} {:.3}",
108 colorize("Macro F1 Score:", Color::BrightMagenta),
109 cm.macro_f1()
110 );
111 // 2. Feature Importance Visualization
112 println!(
113 "{}",
114 stylize("2. FEATURE IMPORTANCE VISUALIZATION", Style::Bold)
115 );
116 // Generate random feature importance scores
117 let feature_names = (0..n_features)
118 .map(|i| format!("Feature_{i}"))
119 .collect::<Vec<String>>();
120 let importance = Array1::from_shape_fn(n_features, |i| {
121 // Make some features more important than others
122 let base = (n_features - i) as f32 / n_features as f32;
123 base + 0.2 * rng.random::<f32>()
124 });
125
126 let fi = FeatureImportance::new(feature_names, importance)?;
127
128 // Print full feature importance with color
129 println!(
130 "{}",
131 fi.to_ascii_with_options(Some("Feature Importance"), 60, None, &color_options)
132 );
133
134 // Print top-5 features with color
135 println!(
136 "{}",
137 colorize("Top 5 Most Important Features:", Color::BrightCyan)
138 );
139 println!(
140 "{}",
141 fi.to_ascii_with_options(Some("Top 5 Features"), 60, Some(5), &color_options)
142 );
143 // 3. ROC Curve for Binary Classification
144 println!("\n{}", stylize("3. ROC CURVE VISUALIZATION", Style::Bold));
145 // Generate binary classification data
146 let n_binary = 200;
147 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
148 // Generate scores with some predictive power
149 let y_scores = Array1::from_shape_fn(n_binary, |i| {
150 if y_true_binary[i] == 1 {
151 // Higher scores for positive class
152 0.6 + 0.4 * rng.random::<f32>()
153 } else {
154 // Lower scores for negative class
155 0.4 * rng.random::<f32>()
156 }
157 });
158
159 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
160 println!(
161 "{}: {:.3}",
162 colorize("ROC AUC:", Color::BrightMagenta),
163 roc.auc
164 );
165 println!("\n{}", roc.to_ascii(None, 50, 20));
166
167 // 4. Learning Curve Visualization
168 println!(
169 "\n{}",
170 stylize("4. LEARNING CURVE VISUALIZATION", Style::Bold)
171 );
172 // Generate learning curve data
173 let n_points = 10;
174 let n_cv = 5;
175 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
176 // Generate training scores (decreasing with size due to overfitting)
177 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
178 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
179 });
180
181 // Generate validation scores (increasing with size)
182 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
183 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
184 });
185
186 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
187 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
188
189 // Print final message with color
190 println!(
191 "{}",
192 colorize(
193 "Model evaluation visualizations completed successfully!",
194 Color::BrightGreen
195 )
196 );
197 Ok(())
198}
Sourcepub fn to_heatmap(&self, title: Option<&str>, normalized: bool) -> String
pub fn to_heatmap(&self, title: Option<&str>, normalized: bool) -> String
Convert the confusion matrix to a heatmap visualization This creates a colorful heatmap visualization of the confusion matrix where cell colors represent the intensity of values using a detailed color gradient.
§Arguments
title
- Optional title for the heatmapnormalized
- Whether to normalize the matrix (row values sum to 1)
§Returns
String
- ASCII heatmap representation
Sourcepub fn error_heatmap(&self, title: Option<&str>) -> String
pub fn error_heatmap(&self, title: Option<&str>) -> String
Create a confusion matrix heatmap that focuses on misclassification patterns This visualization is specialized to highlight where the model makes mistakes, with emphasis on the off-diagonal elements to help identify error patterns.
The key features of this visualization are:
- Diagonal elements (correct classifications) are de-emphasized with dim styling
- Off-diagonal elements (errors) are highlighted with a color gradient
- Colors are normalized relative to the maximum off-diagonal value
- A specialized legend explains error intensity levels
§Arguments
title
- Optional title for the error heatmap
§Returns
String
- ASCII error pattern heatmap
§Example
use scirs2_core::ndarray::Array1;
use scirs2_neural::utils::ConfusionMatrix;
// Create some example data
let y_true = Array1::from_vec(vec![0, 1, 2, 0, 1, 2, 0, 1, 2, 0]);
let y_pred = Array1::from_vec(vec![0, 1, 1, 0, 1, 2, 1, 1, 0, 0]);
let class_labels = vec!["Class A".to_string(), "Class B".to_string(), "Class C".to_string()];
let cm = ConfusionMatrix::<f32>::new(&y_true.view(), &y_pred.view(), None, Some(class_labels)).unwrap();
// Generate the error pattern heatmap
let error_viz = cm.error_heatmap(Some("Misclassification Analysis"));
println!("{}", error_viz);
Examples found in repository?
7fn main() {
8 // Create a reproducible random number generator
9 let mut rng = SmallRng::from_seed([42; 32]);
10 // Generate synthetic multiclass classification data with specific error patterns
11 let num_classes = 5;
12 // Create confusion matrix with controlled error patterns
13 let mut matrix = vec![vec![0; num_classes]; num_classes];
14 // Set diagonal elements (correct classifications) with high values
15 #[allow(clippy::needless_range_loop)]
16 for i in 0..num_classes {
17 matrix[i][i] = 70 + rng.random_range(0..15); // 70-85 correct per class
18 }
19 // Create specific error patterns:
20 // - Classes 0 and 1 often confused
21 matrix[0][1] = 25;
22 matrix[1][0] = 15;
23 // - Class 2 sometimes confused with Class 3
24 matrix[2][3] = 18;
25 // - Class 4 has some misclassifications to all other classes
26 matrix[4][0] = 8;
27 matrix[4][1] = 5;
28 matrix[4][2] = 10;
29 matrix[4][3] = 12;
30 // - Some minor errors scattered about
31 #[allow(clippy::needless_range_loop)]
32 for i in 0..num_classes {
33 for j in 0..num_classes {
34 if i != j && matrix[i][j] == 0 {
35 matrix[i][j] = rng.random_range(0..5);
36 }
37 }
38 }
39 // Convert to ndarray
40 let flat_matrix: Vec<f64> = matrix.iter().flatten().map(|&x| x as f64).collect();
41 let ndarray_matrix =
42 scirs2_core::ndarray::Array::from_shape_vec((num_classes, num_classes), flat_matrix)
43 .unwrap();
44 // Create class labels
45 let class_labels = vec![
46 "Class A".to_string(),
47 "Class B".to_string(),
48 "Class C".to_string(),
49 "Class D".to_string(),
50 "Class E".to_string(),
51 ];
52 // Create confusion matrix
53 let cm = ConfusionMatrix::from_matrix(ndarray_matrix, Some(class_labels)).unwrap();
54 // Example 1: Standard confusion matrix
55 println!("Example 1: Standard Confusion Matrix\n");
56 let regular_output = cm.to_ascii(Some("Classification Results"), false);
57 println!("{regular_output}");
58 // Example 2: Normal heatmap
59 println!("\n\nExample 2: Standard Heatmap Visualization\n");
60 let color_options = ColorOptions {
61 enabled: true,
62 use_bright: true,
63 use_background: false,
64 };
65 let heatmap_output = cm.to_heatmap_with_options(
66 Some("Classification Heatmap"),
67 true, // normalized
68 &color_options,
69 );
70 println!("{heatmap_output}");
71 // Example 3: Error pattern heatmap
72 println!("\n\nExample 3: Error Pattern Heatmap (highlighting misclassifications)\n");
73 let error_heatmap = cm.error_heatmap(Some("Misclassification Analysis"));
74 println!("{error_heatmap}");
75}
Sourcepub fn to_heatmap_with_options(
&self,
title: Option<&str>,
normalized: bool,
color_options: &ColorOptions,
) -> String
pub fn to_heatmap_with_options( &self, title: Option<&str>, normalized: bool, color_options: &ColorOptions, ) -> String
Convert the confusion matrix to a heatmap visualization with customizable options
§Arguments
title
- Optional title for the heatmapnormalized
- Whether to normalize the matrixcolor_options
- Color options for visualization
§Returns
String
- ASCII heatmap representation with colors
Examples found in repository?
7fn main() {
8 // Create a reproducible random number generator
9 let mut rng = SmallRng::from_seed([42; 32]);
10 // Generate synthetic multiclass classification data with specific error patterns
11 let num_classes = 5;
12 // Create confusion matrix with controlled error patterns
13 let mut matrix = vec![vec![0; num_classes]; num_classes];
14 // Set diagonal elements (correct classifications) with high values
15 #[allow(clippy::needless_range_loop)]
16 for i in 0..num_classes {
17 matrix[i][i] = 70 + rng.random_range(0..15); // 70-85 correct per class
18 }
19 // Create specific error patterns:
20 // - Classes 0 and 1 often confused
21 matrix[0][1] = 25;
22 matrix[1][0] = 15;
23 // - Class 2 sometimes confused with Class 3
24 matrix[2][3] = 18;
25 // - Class 4 has some misclassifications to all other classes
26 matrix[4][0] = 8;
27 matrix[4][1] = 5;
28 matrix[4][2] = 10;
29 matrix[4][3] = 12;
30 // - Some minor errors scattered about
31 #[allow(clippy::needless_range_loop)]
32 for i in 0..num_classes {
33 for j in 0..num_classes {
34 if i != j && matrix[i][j] == 0 {
35 matrix[i][j] = rng.random_range(0..5);
36 }
37 }
38 }
39 // Convert to ndarray
40 let flat_matrix: Vec<f64> = matrix.iter().flatten().map(|&x| x as f64).collect();
41 let ndarray_matrix =
42 scirs2_core::ndarray::Array::from_shape_vec((num_classes, num_classes), flat_matrix)
43 .unwrap();
44 // Create class labels
45 let class_labels = vec![
46 "Class A".to_string(),
47 "Class B".to_string(),
48 "Class C".to_string(),
49 "Class D".to_string(),
50 "Class E".to_string(),
51 ];
52 // Create confusion matrix
53 let cm = ConfusionMatrix::from_matrix(ndarray_matrix, Some(class_labels)).unwrap();
54 // Example 1: Standard confusion matrix
55 println!("Example 1: Standard Confusion Matrix\n");
56 let regular_output = cm.to_ascii(Some("Classification Results"), false);
57 println!("{regular_output}");
58 // Example 2: Normal heatmap
59 println!("\n\nExample 2: Standard Heatmap Visualization\n");
60 let color_options = ColorOptions {
61 enabled: true,
62 use_bright: true,
63 use_background: false,
64 };
65 let heatmap_output = cm.to_heatmap_with_options(
66 Some("Classification Heatmap"),
67 true, // normalized
68 &color_options,
69 );
70 println!("{heatmap_output}");
71 // Example 3: Error pattern heatmap
72 println!("\n\nExample 3: Error Pattern Heatmap (highlighting misclassifications)\n");
73 let error_heatmap = cm.error_heatmap(Some("Misclassification Analysis"));
74 println!("{error_heatmap}");
75}
More examples
8fn main() {
9 // Create a reproducible random number generator
10 let mut rng = SmallRng::from_seed([42; 32]);
11 // Generate synthetic multiclass classification data
12 let num_classes = 5;
13 let n_samples = 500;
14 // Generate true labels (0 to num_classes-1)
15 let mut y_true = Vec::with_capacity(n_samples);
16 for _ in 0..n_samples {
17 y_true.push(rng.random_range(0..num_classes));
18 }
19 // Generate predicted labels with controlled accuracy
20 let mut y_pred = Vec::with_capacity(n_samples);
21 for &true_label in &y_true {
22 // 80% chance to predict correctly..20% chance of error
23 if rng.random::<f64>() < 0.8 {
24 y_pred.push(true_label);
25 } else {
26 // When wrong, tend to predict adjacent classes more often
27 let mut pred = true_label;
28 while pred == true_label {
29 // Generate error that's more likely to be close to true label
30 let error_margin = (rng.random::<f64>() * 2.0).round() as usize; // 0, 1, or 2
31 if rng.random::<bool>() {
32 pred = (true_label + error_margin) % num_classes;
33 } else {
34 pred = (true_label + num_classes - error_margin) % num_classes;
35 }
36 }
37 y_pred.push(pred);
38 }
39 }
40 // Convert to ndarray arrays
41 let y_true_array = Array1::from(y_true);
42 let y_pred_array = Array1::from(y_pred);
43 // Create class labels
44 let class_labels = vec![
45 "Cat".to_string(),
46 "Dog".to_string(),
47 "Bird".to_string(),
48 "Fish".to_string(),
49 "Rabbit".to_string(),
50 ];
51 // Create confusion matrix
52 let cm = ConfusionMatrix::<f64>::new(
53 &y_true_array.view(),
54 &y_pred_array.view(),
55 Some(num_classes),
56 Some(class_labels),
57 )
58 .unwrap();
59 // Example 1: Standard confusion matrix
60 println!("Example 1: Standard Confusion Matrix\n");
61 let regular_output = cm.to_ascii(Some("Animal Classification Results"), false);
62 println!("{regular_output}");
63 // Example 2: Confusion matrix with color
64 println!("\n\nExample 2: Colored Confusion Matrix\n");
65 let color_options = ColorOptions {
66 enabled: true,
67 use_bright: true,
68 use_background: false,
69 };
70 let colored_output = cm.to_ascii_with_options(
71 Some("Animal Classification Results (with color)"),
72 false,
73 &color_options,
74 );
75 println!("{colored_output}");
76 // Example 3: Normalized confusion matrix heatmap
77 println!("\n\nExample 3: Normalized Confusion Matrix Heatmap\n");
78 let heatmap_output = cm.to_heatmap_with_options(
79 Some("Animal Classification Heatmap (normalized)"),
80 true, // normalized
81 &color_options,
82 );
83 println!("{heatmap_output}");
84
85 // Example 4: Raw counts heatmap
86 println!("\n\nExample 4: Raw Counts Confusion Matrix Heatmap\n");
87 let raw_heatmap = cm.to_heatmap_with_options(
88 Some("Animal Classification Heatmap (raw counts)"),
89 false, // not normalized
90 &color_options,
91 );
92 println!("{raw_heatmap}");
93}
Trait Implementations§
Auto Trait Implementations§
impl<F> Freeze for ConfusionMatrix<F>
impl<F> RefUnwindSafe for ConfusionMatrix<F>where
F: RefUnwindSafe,
impl<F> Send for ConfusionMatrix<F>where
F: Send,
impl<F> Sync for ConfusionMatrix<F>where
F: Sync,
impl<F> Unpin for ConfusionMatrix<F>
impl<F> UnwindSafe for ConfusionMatrix<F>where
F: RefUnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more