1use crate::error::{NeuralError, Result};
4use crate::utils::colors::{
5 colored_metric_cell, colorize, colorize_and_style, gradient_color, heatmap_cell,
6 heatmap_color_legend, stylize, Color, ColorOptions, Style,
7};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
9use scirs2_core::numeric::Float;
10use std::collections::HashMap;
11use std::fmt::{Debug, Display};
12
13#[derive(Debug, Clone)]
15pub struct ConfusionMatrix<F: Float + Debug + Display> {
16 pub matrix: Array2<F>,
18 pub labels: Option<Vec<String>>,
20 pub num_classes: usize,
22}
23
24impl<F: Float + Debug + Display> ConfusionMatrix<F> {
25 pub fn new(
45 y_true: &ArrayView1<usize>,
46 y_pred: &ArrayView1<usize>,
47 num_classes: Option<usize>,
48 labels: Option<Vec<String>>,
49 ) -> Result<Self> {
50 if y_true.len() != y_pred.len() {
51 return Err(NeuralError::ValidationError(
52 "Predictions and _true labels must have the same length".to_string(),
53 ));
54 }
55
56 let n_classes = num_classes.unwrap_or_else(|| {
58 let max_true = y_true.iter().max().copied().unwrap_or(0);
59 let max_pred = y_pred.iter().max().copied().unwrap_or(0);
60 std::cmp::max(max_true, max_pred) + 1
61 });
62
63 let mut matrix = Array2::zeros((n_classes, n_classes));
65
66 for (true_label, pred_label) in y_true.iter().zip(y_pred.iter()) {
68 if *true_label < n_classes && *pred_label < n_classes {
69 matrix[[*true_label, *pred_label]] = matrix[[*true_label, *pred_label]] + F::one();
70 } else {
71 return Err(NeuralError::ValidationError(format!(
72 "Class index out of bounds: _true={true_label}, _pred={pred_label}, n_classes={n_classes}"
73 )));
74 }
75 }
76
77 let validated_labels = if let Some(label_vec) = labels {
79 if label_vec.len() != n_classes {
80 return Err(NeuralError::ValidationError(format!(
81 "Number of labels ({}) does not match number of _classes ({})",
82 label_vec.len(),
83 n_classes
84 )));
85 }
86 Some(label_vec)
87 } else {
88 None
89 };
90
91 Ok(ConfusionMatrix {
92 matrix,
93 labels: validated_labels,
94 num_classes: n_classes,
95 })
96 }
97
98 pub fn from_matrix(matrix: Array2<F>, labels: Option<Vec<String>>) -> Result<Self> {
104 let shape = matrix.shape();
105 if shape[0] != shape[1] {
106 return Err(NeuralError::ValidationError(
107 "Confusion _matrix must be square".to_string(),
108 ));
109 }
110
111 let n_classes = shape[0];
112
113 if let Some(ref label_vec) = labels {
115 if label_vec.len() != n_classes {
116 return Err(NeuralError::ValidationError(format!(
117 "Number of labels ({}) does not match _matrix size ({})",
118 label_vec.len(),
119 n_classes
120 )));
121 }
122 }
123
124 Ok(ConfusionMatrix {
125 matrix,
126 labels,
127 num_classes: n_classes,
128 })
129 }
130
131 pub fn normalized(&self) -> Array2<F> {
133 let mut norm_matrix = self.matrix.clone();
134 for row in 0..self.num_classes {
136 let row_sum = self.matrix.row(row).sum();
137 if row_sum > F::zero() {
138 for col in 0..self.num_classes {
139 norm_matrix[[row, col]] = self.matrix[[row, col]] / row_sum;
140 }
141 }
142 }
143 norm_matrix
144 }
145
146 pub fn accuracy(&self) -> F {
148 let total: F = self.matrix.sum();
149 if total > F::zero() {
150 let diagonal_sum: F = (0..self.num_classes)
151 .map(|i| self.matrix[[i, i]])
152 .fold(F::zero(), |acc, x| acc + x);
153 diagonal_sum / total
154 } else {
155 F::zero()
156 }
157 }
158
159 pub fn precision(&self) -> Array1<F> {
161 let mut precision = Array1::zeros(self.num_classes);
162 for i in 0..self.num_classes {
163 let col_sum = self.matrix.column(i).sum();
164 if col_sum > F::zero() {
165 precision[i] = self.matrix[[i, i]] / col_sum;
166 }
167 }
168 precision
169 }
170
171 pub fn recall(&self) -> Array1<F> {
173 let mut recall = Array1::zeros(self.num_classes);
174 for i in 0..self.num_classes {
175 let row_sum = self.matrix.row(i).sum();
176 if row_sum > F::zero() {
177 recall[i] = self.matrix[[i, i]] / row_sum;
178 }
179 }
180 recall
181 }
182
183 pub fn f1_score(&self) -> Array1<F> {
185 let precision = self.precision();
186 let recall = self.recall();
187 let mut f1 = Array1::zeros(self.num_classes);
188 for i in 0..self.num_classes {
189 let denom = precision[i] + recall[i];
190 if denom > F::zero() {
191 f1[i] = F::from(2.0).unwrap() * precision[i] * recall[i] / denom;
192 }
193 }
194 f1
195 }
196
197 pub fn macro_f1(&self) -> F {
199 let f1 = self.f1_score();
200 let sum = f1.sum();
201 sum / F::from(self.num_classes).unwrap()
202 }
203
204 pub fn class_metrics(&self) -> HashMap<String, Vec<F>> {
206 let mut metrics = HashMap::new();
207 let precision = self.precision();
208 let recall = self.recall();
209 let f1 = self.f1_score();
210
211 metrics.insert("precision".to_string(), precision.to_vec());
212 metrics.insert("recall".to_string(), recall.to_vec());
213 metrics.insert("f1".to_string(), f1.to_vec());
214 metrics
215 }
216
217 pub fn to_ascii(&self, title: Option<&str>, normalized: bool) -> String {
219 self.to_ascii_with_options(title, normalized, &ColorOptions::default())
220 }
221
222 pub fn to_ascii_with_options(
224 &self,
225 title: Option<&str>,
226 normalized: bool,
227 color_options: &ColorOptions,
228 ) -> String {
229 let matrix = if normalized {
230 self.normalized()
231 } else {
232 self.matrix.clone()
233 };
234
235 let mut result = String::new();
236
237 if let Some(titletext) = title {
239 if color_options.enabled {
240 result.push_str(&stylize(titletext, Style::Bold));
241 } else {
242 result.push_str(titletext);
243 }
244 result.push_str("\n\n");
245 }
246
247 let labels: Vec<String> = match &self.labels {
249 Some(label_vec) => label_vec.clone(),
250 None => (0..self.num_classes).map(|i| i.to_string()).collect(),
251 };
252
253 let label_width = labels.iter().map(|l| l.len()).max().unwrap_or(2).max(5);
255 let cell_width = if normalized {
256 6 } else {
258 matrix
259 .iter()
260 .map(|&v| format!("{v:.0}").len())
261 .max()
262 .unwrap_or(2)
263 .max(5)
264 };
265
266 if color_options.enabled {
268 result.push_str(&format!(
269 "{:<width$} |",
270 stylize("Pred→", Style::Bold),
271 width = label_width + 8
272 ));
273 } else {
274 result.push_str(&format!("{:<width$} |", "Pred→", width = label_width));
275 }
276
277 for label in &labels {
278 if color_options.enabled {
279 let styled_label = stylize(label, Style::Bold);
280 result.push_str(&format!(
281 " {:<width$} |",
282 styled_label,
283 width = cell_width + 8
284 ));
285 } else {
286 result.push_str(&format!(" {label:<cell_width$} |"));
287 }
288 }
289
290 if color_options.enabled {
291 result.push_str(&format!(" {}\n", stylize("Recall", Style::Bold)));
292 } else {
293 result.push_str(" Recall\n");
294 }
295
296 result.push_str(&"-".repeat(label_width + 2));
298 for _ in 0..self.num_classes {
299 result.push_str(&format!("{}-", "-".repeat(cell_width + 2)));
300 }
301 result.push_str(&"-".repeat(8));
302 result.push('\n');
303
304 let precision = self.precision();
306 let recall = self.recall();
307 let f1 = self.f1_score();
308
309 for i in 0..self.num_classes {
310 if color_options.enabled {
312 result.push_str(&format!(
313 "{:<width$} |",
314 stylize(&labels[i], Style::Bold),
315 width = label_width + 8
316 ));
317 } else {
318 result.push_str(&format!("{:<width$} |", labels[i], width = label_width));
319 }
320
321 for j in 0..self.num_classes {
322 let value = matrix[[i, j]];
323 let formatted = if normalized {
324 format!("{value:.3}")
325 } else {
326 format!("{value:.0}")
327 };
328
329 if i == j {
331 if color_options.enabled {
333 let norm_value = if normalized {
335 value.to_f64().unwrap_or(0.0)
336 } else {
337 let row_sum = matrix.row(i).sum().to_f64().unwrap_or(1.0);
339 if row_sum > 0.0 {
340 value.to_f64().unwrap_or(0.0) / row_sum
341 } else {
342 0.0
343 }
344 };
345
346 if let Some(color) = gradient_color(norm_value, color_options) {
348 let colored_value = colorize(stylize(&formatted, Style::Bold), color);
350 result.push_str(&format!(
351 " {:<width$} |",
352 colored_value,
353 width = cell_width + 9
354 ));
355 } else {
356 result.push_str(&format!(
358 " {:<width$} |",
359 stylize(&formatted, Style::Bold),
360 width = cell_width + 8
361 ));
362 }
363 } else {
364 result.push_str(&format!(" \x1b[1m{formatted:<cell_width$}\x1b[0m |"));
366 }
367 } else if color_options.enabled && normalized {
368 let norm_value = value.to_f64().unwrap_or(0.0);
370 if norm_value > 0.1 {
371 result.push_str(&format!(
372 " {:<width$} |",
373 colorize(&formatted, Color::BrightRed),
374 width = cell_width + 9
375 ));
376 } else {
377 result.push_str(&format!(" {formatted:<cell_width$} |"));
378 }
379 } else {
380 result.push_str(&format!(" {formatted:<cell_width$} |"));
382 }
383 }
384
385 if color_options.enabled {
387 let recall_val = recall[i].to_f64().unwrap_or(0.0);
388 let colored_recall =
389 colored_metric_cell(format!("{:.3}", recall[i]), recall_val, color_options);
390 result.push_str(&format!(" {colored_recall}\n"));
391 } else {
392 let recall_val = recall[i];
393 result.push_str(&format!(" {recall_val:.3}\n"));
394 }
395 }
396
397 if color_options.enabled {
399 result.push_str(&format!(
400 "{:<width$} |",
401 stylize("Precision", Style::Bold),
402 width = label_width + 8
403 ));
404 } else {
405 result.push_str(&format!("{:<width$} |", "Precision", width = label_width));
406 }
407
408 for j in 0..self.num_classes {
409 if color_options.enabled {
410 let precision_val = precision[j].to_f64().unwrap_or(0.0);
411 let colored_precision = colored_metric_cell(
412 format!("{:.3}", precision[j]),
413 precision_val,
414 color_options,
415 );
416 result.push_str(&format!(" {colored_precision} |"));
417 } else {
418 let prec_val = precision[j];
419 result.push_str(&format!(" {prec_val:.3} |"));
420 }
421 }
422
423 let accuracy = self.accuracy();
425 if color_options.enabled {
426 let accuracy_val = accuracy.to_f64().unwrap_or(0.0);
427 let colored_accuracy =
428 colored_metric_cell(format!("{accuracy:.3}"), accuracy_val, color_options);
429 result.push_str(&format!(" {colored_accuracy}\n"));
430 } else {
431 result.push_str(&format!(" {accuracy:.3}\n"));
432 }
433
434 if color_options.enabled {
436 result.push_str(&format!(
437 "{:<width$} |",
438 stylize("F1-score", Style::Bold),
439 width = label_width + 8
440 ));
441 } else {
442 result.push_str(&format!("{:<width$} |", "F1-score", width = label_width));
443 }
444
445 for j in 0..self.num_classes {
446 if color_options.enabled {
447 let f1_val = f1[j].to_f64().unwrap_or(0.0);
448 let colored_f1 =
449 colored_metric_cell(format!("{:.3}", f1[j]), f1_val, color_options);
450 result.push_str(&format!(" {colored_f1} |"));
451 } else {
452 result.push_str(&format!(" {:.3} |", f1[j]));
453 }
454 }
455
456 let macro_f1 = self.macro_f1();
458 if color_options.enabled {
459 let macro_f1_val = macro_f1.to_f64().unwrap_or(0.0);
460 let colored_macro_f1 =
461 colored_metric_cell(format!("{macro_f1:.3}"), macro_f1_val, color_options);
462 result.push_str(&format!(" {colored_macro_f1}\n"));
463 } else {
464 result.push_str(&format!(" {macro_f1:.3}\n"));
465 }
466
467 result
468 }
469
470 pub fn to_heatmap(&self, title: Option<&str>, normalized: bool) -> String {
481 self.to_heatmap_with_options(title, normalized, &ColorOptions::default())
482 }
483
484 pub fn error_heatmap(&self, title: Option<&str>) -> String {
514 let _normalized = true;
516 let matrix = self.normalized();
517
518 let color_options = ColorOptions {
520 enabled: true,
521 use_background: false,
522 use_bright: true,
523 };
524
525 let mut result = String::new();
526
527 if let Some(titletext) = title {
529 result.push_str(&stylize(titletext, Style::Bold));
530 result.push_str("\n\n");
531 }
532
533 let labels: Vec<String> = match &self.labels {
535 Some(label_vec) => label_vec.clone(),
536 None => (0..self.num_classes).map(|i| i.to_string()).collect(),
537 };
538
539 let label_width = labels.iter().map(|l| l.len()).max().unwrap_or(2).max(5);
540 let cell_width = 6; let mut max_off_diag = F::zero();
544 for i in 0..self.num_classes {
545 for j in 0..self.num_classes {
546 if i != j && matrix[[i, j]] > max_off_diag {
547 max_off_diag = matrix[[i, j]];
548 }
549 }
550 }
551
552 if max_off_diag == F::zero() {
554 max_off_diag = matrix
555 .iter()
556 .fold(F::zero(), |acc, &v| if v > acc { v } else { acc });
557 }
558
559 if color_options.enabled {
561 result.push_str(&format!(
562 "{:<width$} |",
563 stylize("True↓ / Pred→", Style::Bold),
564 width = label_width + 8
565 ));
566 } else {
567 result.push_str(&format!(
568 "{:<width$} |",
569 "True↓ / Pred→",
570 width = label_width
571 ));
572 }
573
574 for label in &labels {
575 if color_options.enabled {
576 let styled_label = stylize(label, Style::Bold);
577 result.push_str(&format!(
578 " {:<width$} |",
579 styled_label,
580 width = cell_width + 8
581 ));
582 } else {
583 result.push_str(&format!(" {label:<cell_width$} |"));
584 }
585 }
586 result.push('\n');
587
588 result.push_str(&"-".repeat(label_width + 2));
590 for _ in 0..self.num_classes {
591 result.push_str(&format!("{}-", "-".repeat(cell_width + 2)));
592 }
593 result.push('\n');
594
595 for i in 0..self.num_classes {
597 if color_options.enabled {
599 result.push_str(&format!(
600 "{:<width$} |",
601 stylize(&labels[i], Style::Bold),
602 width = label_width + 8
603 ));
604 } else {
605 result.push_str(&format!("{:<width$} |", labels[i], width = label_width));
606 }
607
608 for j in 0..self.num_classes {
609 let value = matrix[[i, j]];
610 let formatted = format!("{value:.3}");
611
612 if i == j {
614 if color_options.enabled {
616 result.push_str(&format!(
618 " {:<width$} |",
619 colorize_and_style(&formatted, None, None, Some(Style::Dim)),
620 width = cell_width + 8
621 ));
622 } else {
623 result.push_str(&format!(" {formatted:<cell_width$} |"));
624 }
625 } else {
626 let norm_value = if max_off_diag > F::zero() {
628 (value / max_off_diag).to_f64().unwrap_or(0.0)
629 } else {
630 0.0
631 };
632
633 if color_options.enabled && norm_value > 0.0 {
634 let error_color = if norm_value < 0.25 {
636 Color::BrightBlue
637 } else if norm_value < 0.5 {
638 Color::BrightCyan
639 } else if norm_value < 0.75 {
640 Color::BrightRed
641 } else {
642 Color::BrightMagenta
643 };
644
645 if norm_value > 0.5 {
647 result.push_str(&format!(
648 " {:<width$} |",
649 colorize_and_style(
650 &formatted,
651 Some(error_color),
652 None,
653 Some(Style::Bold)
654 ),
655 width = cell_width + 9
656 ));
657 } else {
658 result.push_str(&format!(
659 " {:<width$} |",
660 colorize(&formatted, error_color),
661 width = cell_width + 9
662 ));
663 }
664 } else {
665 result.push_str(&format!(" {formatted:<cell_width$} |"));
666 }
667 }
668 }
669 result.push('\n');
670 }
671
672 if color_options.enabled {
674 result.push('\n');
675 let mut legend = String::from("Error Pattern Legend: ");
676
677 let error_levels = [
679 (Color::BrightBlue, "Low Error (0-25%)"),
680 (Color::BrightCyan, "Moderate Error (25-50%)"),
681 (Color::BrightRed, "High Error (50-75%)"),
682 (Color::BrightMagenta, "Critical Error (75-100%)"),
683 ];
684
685 for (i, (color, label)) in error_levels.iter().enumerate() {
686 if i > 0 {
687 legend.push(' ');
688 }
689 legend.push_str(&format!("{} {label}", colorize("■", *color)));
690 }
691
692 legend.push_str(&format!(
694 " {} Correct Classification",
695 colorize_and_style("■", None, None, Some(Style::Dim))
696 ));
697
698 result.push_str(&legend);
699 }
700
701 result
702 }
703
704 pub fn to_heatmap_with_options(
714 &self,
715 title: Option<&str>,
716 normalized: bool,
717 color_options: &ColorOptions,
718 ) -> String {
719 let matrix = if normalized {
720 self.normalized()
721 } else {
722 self.matrix.clone()
723 };
724
725 let mut result = String::new();
726
727 if let Some(titletext) = title {
729 if color_options.enabled {
730 result.push_str(&stylize(titletext, Style::Bold));
731 } else {
732 result.push_str(titletext);
733 }
734 result.push_str("\n\n");
735 }
736
737 let labels: Vec<String> = match &self.labels {
739 Some(label_vec) => label_vec.clone(),
740 None => (0..self.num_classes).map(|i| i.to_string()).collect(),
741 };
742
743 let label_width = labels.iter().map(|l| l.len()).max().unwrap_or(2).max(5);
744 let cell_width = if normalized { 6 } else { 5 };
745
746 let max_value = if normalized {
748 F::one() } else {
750 matrix
751 .iter()
752 .fold(F::zero(), |acc, &v| if v > acc { v } else { acc })
753 };
754
755 if color_options.enabled {
757 result.push_str(&format!(
758 "{:<width$} |",
759 stylize("True↓", Style::Bold),
760 width = label_width + 8
761 ));
762 } else {
763 result.push_str(&format!("{:<width$} |", "True↓", width = label_width));
764 }
765
766 for label in &labels {
767 if color_options.enabled {
768 let styled_label = stylize(label, Style::Bold);
769 result.push_str(&format!(
770 " {:<width$} |",
771 styled_label,
772 width = cell_width + 8
773 ));
774 } else {
775 result.push_str(&format!(" {label:<cell_width$} |"));
776 }
777 }
778 result.push('\n');
779
780 result.push_str(&"-".repeat(label_width + 2));
782 for _ in 0..self.num_classes {
783 result.push_str(&format!("{}-", "-".repeat(cell_width + 2)));
784 }
785 result.push('\n');
786
787 for i in 0..self.num_classes {
789 if color_options.enabled {
791 result.push_str(&format!(
792 "{:<width$} |",
793 stylize(&labels[i], Style::Bold),
794 width = label_width + 8
795 ));
796 } else {
797 result.push_str(&format!("{:<width$} |", labels[i], width = label_width));
798 }
799
800 for j in 0..self.num_classes {
801 let value = matrix[[i, j]];
802 let formatted = if normalized {
803 format!("{value:.3}")
804 } else {
805 format!("{value:.0}")
806 };
807
808 let norm_value = if normalized {
811 value.to_f64().unwrap_or(0.0)
812 } else if max_value > F::zero() {
813 (value / max_value).to_f64().unwrap_or(0.0)
814 } else {
815 0.0
816 };
817
818 if color_options.enabled {
820 let heatmap_value = heatmap_cell(&formatted, norm_value, color_options);
821 result.push_str(&format!(
823 " {:<width$} |",
824 heatmap_value,
825 width = cell_width + 9
826 ));
827 } else {
828 result.push_str(&format!(" {formatted:<cell_width$} |"));
829 }
830 }
831 result.push('\n');
832 }
833
834 if color_options.enabled {
836 if let Some(legend) = heatmap_color_legend(color_options) {
837 result.push('\n');
838 result.push_str(&legend);
839 }
840 }
841
842 result
843 }
844}