scirs2_neural/utils/evaluation/
confusion_matrix.rs

1//! Confusion matrix for classification problems
2
3use crate::error::{NeuralError, Result};
4use crate::utils::colors::{
5    colored_metric_cell, colorize, colorize_and_style, gradient_color, heatmap_cell,
6    heatmap_color_legend, stylize, Color, ColorOptions, Style,
7};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
9use scirs2_core::numeric::Float;
10use std::collections::HashMap;
11use std::fmt::{Debug, Display};
12
13/// Confusion matrix for classification problems
14#[derive(Debug, Clone)]
15pub struct ConfusionMatrix<F: Float + Debug + Display> {
16    /// The raw confusion matrix data
17    pub matrix: Array2<F>,
18    /// Class labels (optional)
19    pub labels: Option<Vec<String>>,
20    /// Number of classes
21    pub num_classes: usize,
22}
23
24impl<F: Float + Debug + Display> ConfusionMatrix<F> {
25    /// Create a new confusion matrix from predictions and true labels
26    ///
27    /// # Arguments
28    /// * `y_true` - True class labels as integers
29    /// * `y_pred` - Predicted class labels as integers
30    /// * `num_classes` - Number of classes (if None, determined from data)
31    /// * `labels` - Optional class labels as strings
32    ///
33    /// # Returns
34    /// * `Result<ConfusionMatrix<F>>` - The confusion matrix
35    ///
36    /// # Example
37    /// ```
38    /// use scirs2_neural::utils::evaluation::ConfusionMatrix;
39    /// use scirs2_core::ndarray::Array1;
40    /// let y_true = Array1::from_vec(vec![0, 1, 2, 0, 1, 2, 0]);
41    /// let y_pred = Array1::from_vec(vec![0, 1, 1, 0, 1, 2, 0]);
42    /// let cm = ConfusionMatrix::<f32>::new(&y_true.view(), &y_pred.view(), None, None).unwrap();
43    /// ```
44    pub fn new(
45        y_true: &ArrayView1<usize>,
46        y_pred: &ArrayView1<usize>,
47        num_classes: Option<usize>,
48        labels: Option<Vec<String>>,
49    ) -> Result<Self> {
50        if y_true.len() != y_pred.len() {
51            return Err(NeuralError::ValidationError(
52                "Predictions and _true labels must have the same length".to_string(),
53            ));
54        }
55
56        // Determine number of _classes
57        let n_classes = num_classes.unwrap_or_else(|| {
58            let max_true = y_true.iter().max().copied().unwrap_or(0);
59            let max_pred = y_pred.iter().max().copied().unwrap_or(0);
60            std::cmp::max(max_true, max_pred) + 1
61        });
62
63        // Initialize confusion matrix with zeros
64        let mut matrix = Array2::zeros((n_classes, n_classes));
65
66        // Fill confusion matrix
67        for (true_label, pred_label) in y_true.iter().zip(y_pred.iter()) {
68            if *true_label < n_classes && *pred_label < n_classes {
69                matrix[[*true_label, *pred_label]] = matrix[[*true_label, *pred_label]] + F::one();
70            } else {
71                return Err(NeuralError::ValidationError(format!(
72                    "Class index out of bounds: _true={true_label}, _pred={pred_label}, n_classes={n_classes}"
73                )));
74            }
75        }
76
77        // Validate labels if provided
78        let validated_labels = if let Some(label_vec) = labels {
79            if label_vec.len() != n_classes {
80                return Err(NeuralError::ValidationError(format!(
81                    "Number of labels ({}) does not match number of _classes ({})",
82                    label_vec.len(),
83                    n_classes
84                )));
85            }
86            Some(label_vec)
87        } else {
88            None
89        };
90
91        Ok(ConfusionMatrix {
92            matrix,
93            labels: validated_labels,
94            num_classes: n_classes,
95        })
96    }
97
98    /// Create a confusion matrix from raw matrix data
99    ///
100    /// # Arguments
101    /// * `matrix` - Raw confusion matrix data
102    /// * `labels` - Optional class labels
103    pub fn from_matrix(matrix: Array2<F>, labels: Option<Vec<String>>) -> Result<Self> {
104        let shape = matrix.shape();
105        if shape[0] != shape[1] {
106            return Err(NeuralError::ValidationError(
107                "Confusion _matrix must be square".to_string(),
108            ));
109        }
110
111        let n_classes = shape[0];
112
113        // Validate labels if provided
114        if let Some(ref label_vec) = labels {
115            if label_vec.len() != n_classes {
116                return Err(NeuralError::ValidationError(format!(
117                    "Number of labels ({}) does not match _matrix size ({})",
118                    label_vec.len(),
119                    n_classes
120                )));
121            }
122        }
123
124        Ok(ConfusionMatrix {
125            matrix,
126            labels,
127            num_classes: n_classes,
128        })
129    }
130
131    /// Get the normalized confusion matrix (rows sum to 1)
132    pub fn normalized(&self) -> Array2<F> {
133        let mut norm_matrix = self.matrix.clone();
134        // Normalize rows to sum to 1
135        for row in 0..self.num_classes {
136            let row_sum = self.matrix.row(row).sum();
137            if row_sum > F::zero() {
138                for col in 0..self.num_classes {
139                    norm_matrix[[row, col]] = self.matrix[[row, col]] / row_sum;
140                }
141            }
142        }
143        norm_matrix
144    }
145
146    /// Calculate accuracy from the confusion matrix
147    pub fn accuracy(&self) -> F {
148        let total: F = self.matrix.sum();
149        if total > F::zero() {
150            let diagonal_sum: F = (0..self.num_classes)
151                .map(|i| self.matrix[[i, i]])
152                .fold(F::zero(), |acc, x| acc + x);
153            diagonal_sum / total
154        } else {
155            F::zero()
156        }
157    }
158
159    /// Calculate precision for each class
160    pub fn precision(&self) -> Array1<F> {
161        let mut precision = Array1::zeros(self.num_classes);
162        for i in 0..self.num_classes {
163            let col_sum = self.matrix.column(i).sum();
164            if col_sum > F::zero() {
165                precision[i] = self.matrix[[i, i]] / col_sum;
166            }
167        }
168        precision
169    }
170
171    /// Calculate recall for each class
172    pub fn recall(&self) -> Array1<F> {
173        let mut recall = Array1::zeros(self.num_classes);
174        for i in 0..self.num_classes {
175            let row_sum = self.matrix.row(i).sum();
176            if row_sum > F::zero() {
177                recall[i] = self.matrix[[i, i]] / row_sum;
178            }
179        }
180        recall
181    }
182
183    /// Calculate F1 score for each class
184    pub fn f1_score(&self) -> Array1<F> {
185        let precision = self.precision();
186        let recall = self.recall();
187        let mut f1 = Array1::zeros(self.num_classes);
188        for i in 0..self.num_classes {
189            let denom = precision[i] + recall[i];
190            if denom > F::zero() {
191                f1[i] = F::from(2.0).unwrap() * precision[i] * recall[i] / denom;
192            }
193        }
194        f1
195    }
196
197    /// Calculate macro-averaged F1 score
198    pub fn macro_f1(&self) -> F {
199        let f1 = self.f1_score();
200        let sum = f1.sum();
201        sum / F::from(self.num_classes).unwrap()
202    }
203
204    /// Get class-wise metrics as a HashMap
205    pub fn class_metrics(&self) -> HashMap<String, Vec<F>> {
206        let mut metrics = HashMap::new();
207        let precision = self.precision();
208        let recall = self.recall();
209        let f1 = self.f1_score();
210
211        metrics.insert("precision".to_string(), precision.to_vec());
212        metrics.insert("recall".to_string(), recall.to_vec());
213        metrics.insert("f1".to_string(), f1.to_vec());
214        metrics
215    }
216
217    /// Convert the confusion matrix to an ASCII representation
218    pub fn to_ascii(&self, title: Option<&str>, normalized: bool) -> String {
219        self.to_ascii_with_options(title, normalized, &ColorOptions::default())
220    }
221
222    /// Convert the confusion matrix to an ASCII representation with color options
223    pub fn to_ascii_with_options(
224        &self,
225        title: Option<&str>,
226        normalized: bool,
227        color_options: &ColorOptions,
228    ) -> String {
229        let matrix = if normalized {
230            self.normalized()
231        } else {
232            self.matrix.clone()
233        };
234
235        let mut result = String::new();
236
237        // Add title if provided
238        if let Some(titletext) = title {
239            if color_options.enabled {
240                result.push_str(&stylize(titletext, Style::Bold));
241            } else {
242                result.push_str(titletext);
243            }
244            result.push_str("\n\n");
245        }
246
247        // Get class labels
248        let labels: Vec<String> = match &self.labels {
249            Some(label_vec) => label_vec.clone(),
250            None => (0..self.num_classes).map(|i| i.to_string()).collect(),
251        };
252
253        // Determine column widths
254        let label_width = labels.iter().map(|l| l.len()).max().unwrap_or(2).max(5);
255        let cell_width = if normalized {
256            6 // Width for normalized values (0.XX)
257        } else {
258            matrix
259                .iter()
260                .map(|&v| format!("{v:.0}").len())
261                .max()
262                .unwrap_or(2)
263                .max(5)
264        };
265
266        // Header row with class labels
267        if color_options.enabled {
268            result.push_str(&format!(
269                "{:<width$} |",
270                stylize("Pred→", Style::Bold),
271                width = label_width + 8
272            ));
273        } else {
274            result.push_str(&format!("{:<width$} |", "Pred→", width = label_width));
275        }
276
277        for label in &labels {
278            if color_options.enabled {
279                let styled_label = stylize(label, Style::Bold);
280                result.push_str(&format!(
281                    " {:<width$} |",
282                    styled_label,
283                    width = cell_width + 8
284                ));
285            } else {
286                result.push_str(&format!(" {label:<cell_width$} |"));
287            }
288        }
289
290        if color_options.enabled {
291            result.push_str(&format!(" {}\n", stylize("Recall", Style::Bold)));
292        } else {
293            result.push_str(" Recall\n");
294        }
295
296        // Separator
297        result.push_str(&"-".repeat(label_width + 2));
298        for _ in 0..self.num_classes {
299            result.push_str(&format!("{}-", "-".repeat(cell_width + 2)));
300        }
301        result.push_str(&"-".repeat(8));
302        result.push('\n');
303
304        // Data rows
305        let precision = self.precision();
306        let recall = self.recall();
307        let f1 = self.f1_score();
308
309        for i in 0..self.num_classes {
310            // Row label
311            if color_options.enabled {
312                result.push_str(&format!(
313                    "{:<width$} |",
314                    stylize(&labels[i], Style::Bold),
315                    width = label_width + 8
316                ));
317            } else {
318                result.push_str(&format!("{:<width$} |", labels[i], width = label_width));
319            }
320
321            for j in 0..self.num_classes {
322                let value = matrix[[i, j]];
323                let formatted = if normalized {
324                    format!("{value:.3}")
325                } else {
326                    format!("{value:.0}")
327                };
328
329                // Color cells based on value (if enabled)
330                if i == j {
331                    // Diagonal cells (true positives)
332                    if color_options.enabled {
333                        // Get normalized value for coloring
334                        let norm_value = if normalized {
335                            value.to_f64().unwrap_or(0.0)
336                        } else {
337                            // For non-normalized matrices, normalize by row sum
338                            let row_sum = matrix.row(i).sum().to_f64().unwrap_or(1.0);
339                            if row_sum > 0.0 {
340                                value.to_f64().unwrap_or(0.0) / row_sum
341                            } else {
342                                0.0
343                            }
344                        };
345
346                        // Use gradient colors based on value
347                        if let Some(color) = gradient_color(norm_value, color_options) {
348                            // Apply both bold style and color
349                            let colored_value = colorize(stylize(&formatted, Style::Bold), color);
350                            result.push_str(&format!(
351                                " {:<width$} |",
352                                colored_value,
353                                width = cell_width + 9
354                            ));
355                        } else {
356                            // Just use bold if no color
357                            result.push_str(&format!(
358                                " {:<width$} |",
359                                stylize(&formatted, Style::Bold),
360                                width = cell_width + 8
361                            ));
362                        }
363                    } else {
364                        // No color, just bold
365                        result.push_str(&format!(" \x1b[1m{formatted:<cell_width$}\x1b[0m |"));
366                    }
367                } else if color_options.enabled && normalized {
368                    // Color non-diagonal cells by intensity
369                    let norm_value = value.to_f64().unwrap_or(0.0);
370                    if norm_value > 0.1 {
371                        result.push_str(&format!(
372                            " {:<width$} |",
373                            colorize(&formatted, Color::BrightRed),
374                            width = cell_width + 9
375                        ));
376                    } else {
377                        result.push_str(&format!(" {formatted:<cell_width$} |"));
378                    }
379                } else {
380                    // No color for non-diagonal cells
381                    result.push_str(&format!(" {formatted:<cell_width$} |"));
382                }
383            }
384
385            // Add recall for this class with coloring
386            if color_options.enabled {
387                let recall_val = recall[i].to_f64().unwrap_or(0.0);
388                let colored_recall =
389                    colored_metric_cell(format!("{:.3}", recall[i]), recall_val, color_options);
390                result.push_str(&format!(" {colored_recall}\n"));
391            } else {
392                let recall_val = recall[i];
393                result.push_str(&format!(" {recall_val:.3}\n"));
394            }
395        }
396
397        // Precision row
398        if color_options.enabled {
399            result.push_str(&format!(
400                "{:<width$} |",
401                stylize("Precision", Style::Bold),
402                width = label_width + 8
403            ));
404        } else {
405            result.push_str(&format!("{:<width$} |", "Precision", width = label_width));
406        }
407
408        for j in 0..self.num_classes {
409            if color_options.enabled {
410                let precision_val = precision[j].to_f64().unwrap_or(0.0);
411                let colored_precision = colored_metric_cell(
412                    format!("{:.3}", precision[j]),
413                    precision_val,
414                    color_options,
415                );
416                result.push_str(&format!(" {colored_precision} |"));
417            } else {
418                let prec_val = precision[j];
419                result.push_str(&format!(" {prec_val:.3} |"));
420            }
421        }
422
423        // Add overall accuracy
424        let accuracy = self.accuracy();
425        if color_options.enabled {
426            let accuracy_val = accuracy.to_f64().unwrap_or(0.0);
427            let colored_accuracy =
428                colored_metric_cell(format!("{accuracy:.3}"), accuracy_val, color_options);
429            result.push_str(&format!(" {colored_accuracy}\n"));
430        } else {
431            result.push_str(&format!(" {accuracy:.3}\n"));
432        }
433
434        // Add F1 scores
435        if color_options.enabled {
436            result.push_str(&format!(
437                "{:<width$} |",
438                stylize("F1-score", Style::Bold),
439                width = label_width + 8
440            ));
441        } else {
442            result.push_str(&format!("{:<width$} |", "F1-score", width = label_width));
443        }
444
445        for j in 0..self.num_classes {
446            if color_options.enabled {
447                let f1_val = f1[j].to_f64().unwrap_or(0.0);
448                let colored_f1 =
449                    colored_metric_cell(format!("{:.3}", f1[j]), f1_val, color_options);
450                result.push_str(&format!(" {colored_f1} |"));
451            } else {
452                result.push_str(&format!(" {:.3} |", f1[j]));
453            }
454        }
455
456        // Add macro F1
457        let macro_f1 = self.macro_f1();
458        if color_options.enabled {
459            let macro_f1_val = macro_f1.to_f64().unwrap_or(0.0);
460            let colored_macro_f1 =
461                colored_metric_cell(format!("{macro_f1:.3}"), macro_f1_val, color_options);
462            result.push_str(&format!(" {colored_macro_f1}\n"));
463        } else {
464            result.push_str(&format!(" {macro_f1:.3}\n"));
465        }
466
467        result
468    }
469
470    /// Convert the confusion matrix to a heatmap visualization
471    /// This creates a colorful heatmap visualization of the confusion matrix
472    /// where cell colors represent the intensity of values using a detailed color gradient.
473    ///
474    /// # Arguments
475    /// * `title` - Optional title for the heatmap
476    /// * `normalized` - Whether to normalize the matrix (row values sum to 1)
477    ///
478    /// # Returns
479    /// * `String` - ASCII heatmap representation
480    pub fn to_heatmap(&self, title: Option<&str>, normalized: bool) -> String {
481        self.to_heatmap_with_options(title, normalized, &ColorOptions::default())
482    }
483
484    /// Create a confusion matrix heatmap that focuses on misclassification patterns
485    /// This visualization is specialized to highlight where the model makes mistakes,
486    /// with emphasis on the off-diagonal elements to help identify error patterns.
487    ///
488    /// The key features of this visualization are:
489    /// - Diagonal elements (correct classifications) are de-emphasized with dim styling
490    /// - Off-diagonal elements (errors) are highlighted with a color gradient
491    /// - Colors are normalized relative to the maximum off-diagonal value
492    /// - A specialized legend explains error intensity levels
493    ///
494    /// # Arguments
495    /// * `title` - Optional title for the error heatmap
496    ///
497    /// # Returns
498    /// * `String` - ASCII error pattern heatmap
499    ///
500    /// # Example
501    /// ```
502    /// use scirs2_core::ndarray::Array1;
503    /// use scirs2_neural::utils::ConfusionMatrix;
504    /// // Create some example data
505    /// let y_true = Array1::from_vec(vec![0, 1, 2, 0, 1, 2, 0, 1, 2, 0]);
506    /// let y_pred = Array1::from_vec(vec![0, 1, 1, 0, 1, 2, 1, 1, 0, 0]);
507    /// let class_labels = vec!["Class A".to_string(), "Class B".to_string(), "Class C".to_string()];
508    /// let cm = ConfusionMatrix::<f32>::new(&y_true.view(), &y_pred.view(), None, Some(class_labels)).unwrap();
509    /// // Generate the error pattern heatmap
510    /// let error_viz = cm.error_heatmap(Some("Misclassification Analysis"));
511    /// println!("{}", error_viz);
512    /// ```
513    pub fn error_heatmap(&self, title: Option<&str>) -> String {
514        // Always use normalized values for error heatmap
515        let _normalized = true;
516        let matrix = self.normalized();
517
518        // Create custom color options for error visualization
519        let color_options = ColorOptions {
520            enabled: true,
521            use_background: false,
522            use_bright: true,
523        };
524
525        let mut result = String::new();
526
527        // Add title if provided
528        if let Some(titletext) = title {
529            result.push_str(&stylize(titletext, Style::Bold));
530            result.push_str("\n\n");
531        }
532
533        // Get class labels
534        let labels: Vec<String> = match &self.labels {
535            Some(label_vec) => label_vec.clone(),
536            None => (0..self.num_classes).map(|i| i.to_string()).collect(),
537        };
538
539        let label_width = labels.iter().map(|l| l.len()).max().unwrap_or(2).max(5);
540        let cell_width = 6; // Width for normalized values
541
542        // Find the maximum off-diagonal value for normalization
543        let mut max_off_diag = F::zero();
544        for i in 0..self.num_classes {
545            for j in 0..self.num_classes {
546                if i != j && matrix[[i, j]] > max_off_diag {
547                    max_off_diag = matrix[[i, j]];
548                }
549            }
550        }
551
552        // If there are no off-diagonal values (perfect classifier), use max value
553        if max_off_diag == F::zero() {
554            max_off_diag = matrix
555                .iter()
556                .fold(F::zero(), |acc, &v| if v > acc { v } else { acc });
557        }
558
559        // Header with error emphasis title
560        if color_options.enabled {
561            result.push_str(&format!(
562                "{:<width$} |",
563                stylize("True↓ / Pred→", Style::Bold),
564                width = label_width + 8
565            ));
566        } else {
567            result.push_str(&format!(
568                "{:<width$} |",
569                "True↓ / Pred→",
570                width = label_width
571            ));
572        }
573
574        for label in &labels {
575            if color_options.enabled {
576                let styled_label = stylize(label, Style::Bold);
577                result.push_str(&format!(
578                    " {:<width$} |",
579                    styled_label,
580                    width = cell_width + 8
581                ));
582            } else {
583                result.push_str(&format!(" {label:<cell_width$} |"));
584            }
585        }
586        result.push('\n');
587
588        // Separator
589        result.push_str(&"-".repeat(label_width + 2));
590        for _ in 0..self.num_classes {
591            result.push_str(&format!("{}-", "-".repeat(cell_width + 2)));
592        }
593        result.push('\n');
594
595        // Data rows - using error-focused coloring
596        for i in 0..self.num_classes {
597            // Row label
598            if color_options.enabled {
599                result.push_str(&format!(
600                    "{:<width$} |",
601                    stylize(&labels[i], Style::Bold),
602                    width = label_width + 8
603                ));
604            } else {
605                result.push_str(&format!("{:<width$} |", labels[i], width = label_width));
606            }
607
608            for j in 0..self.num_classes {
609                let value = matrix[[i, j]];
610                let formatted = format!("{value:.3}");
611
612                // Format and color each cell - but emphasize errors (off-diagonal)
613                if i == j {
614                    // Diagonal elements (correct classifications) - de-emphasize
615                    if color_options.enabled {
616                        // Dim style for diagonal elements
617                        result.push_str(&format!(
618                            " {:<width$} |",
619                            colorize_and_style(&formatted, None, None, Some(Style::Dim)),
620                            width = cell_width + 8
621                        ));
622                    } else {
623                        result.push_str(&format!(" {formatted:<cell_width$} |"));
624                    }
625                } else {
626                    // Off-diagonal elements (errors) - emphasize with color gradient
627                    let norm_value = if max_off_diag > F::zero() {
628                        (value / max_off_diag).to_f64().unwrap_or(0.0)
629                    } else {
630                        0.0
631                    };
632
633                    if color_options.enabled && norm_value > 0.0 {
634                        // Use a specialized color scheme for errors
635                        let error_color = if norm_value < 0.25 {
636                            Color::BrightBlue
637                        } else if norm_value < 0.5 {
638                            Color::BrightCyan
639                        } else if norm_value < 0.75 {
640                            Color::BrightRed
641                        } else {
642                            Color::BrightMagenta
643                        };
644
645                        // Bold style for the most significant errors
646                        if norm_value > 0.5 {
647                            result.push_str(&format!(
648                                " {:<width$} |",
649                                colorize_and_style(
650                                    &formatted,
651                                    Some(error_color),
652                                    None,
653                                    Some(Style::Bold)
654                                ),
655                                width = cell_width + 9
656                            ));
657                        } else {
658                            result.push_str(&format!(
659                                " {:<width$} |",
660                                colorize(&formatted, error_color),
661                                width = cell_width + 9
662                            ));
663                        }
664                    } else {
665                        result.push_str(&format!(" {formatted:<cell_width$} |"));
666                    }
667                }
668            }
669            result.push('\n');
670        }
671
672        // Add specialized error heatmap legend
673        if color_options.enabled {
674            result.push('\n');
675            let mut legend = String::from("Error Pattern Legend: ");
676
677            // Custom legend showing error intensity levels
678            let error_levels = [
679                (Color::BrightBlue, "Low Error (0-25%)"),
680                (Color::BrightCyan, "Moderate Error (25-50%)"),
681                (Color::BrightRed, "High Error (50-75%)"),
682                (Color::BrightMagenta, "Critical Error (75-100%)"),
683            ];
684
685            for (i, (color, label)) in error_levels.iter().enumerate() {
686                if i > 0 {
687                    legend.push(' ');
688                }
689                legend.push_str(&format!("{} {label}", colorize("■", *color)));
690            }
691
692            // Add note about diagonal elements
693            legend.push_str(&format!(
694                " {} Correct Classification",
695                colorize_and_style("■", None, None, Some(Style::Dim))
696            ));
697
698            result.push_str(&legend);
699        }
700
701        result
702    }
703
704    /// Convert the confusion matrix to a heatmap visualization with customizable options
705    ///
706    /// # Arguments
707    /// * `title` - Optional title for the heatmap
708    /// * `normalized` - Whether to normalize the matrix
709    /// * `color_options` - Color options for visualization
710    ///
711    /// # Returns
712    /// * `String` - ASCII heatmap representation with colors
713    pub fn to_heatmap_with_options(
714        &self,
715        title: Option<&str>,
716        normalized: bool,
717        color_options: &ColorOptions,
718    ) -> String {
719        let matrix = if normalized {
720            self.normalized()
721        } else {
722            self.matrix.clone()
723        };
724
725        let mut result = String::new();
726
727        // Add title if provided
728        if let Some(titletext) = title {
729            if color_options.enabled {
730                result.push_str(&stylize(titletext, Style::Bold));
731            } else {
732                result.push_str(titletext);
733            }
734            result.push_str("\n\n");
735        }
736
737        // Get class labels
738        let labels: Vec<String> = match &self.labels {
739            Some(label_vec) => label_vec.clone(),
740            None => (0..self.num_classes).map(|i| i.to_string()).collect(),
741        };
742
743        let label_width = labels.iter().map(|l| l.len()).max().unwrap_or(2).max(5);
744        let cell_width = if normalized { 6 } else { 5 };
745
746        // Find the maximum value for normalization
747        let max_value = if normalized {
748            F::one() // Normalized values are already between 0 and 1
749        } else {
750            matrix
751                .iter()
752                .fold(F::zero(), |acc, &v| if v > acc { v } else { acc })
753        };
754
755        // Header
756        if color_options.enabled {
757            result.push_str(&format!(
758                "{:<width$} |",
759                stylize("True↓", Style::Bold),
760                width = label_width + 8
761            ));
762        } else {
763            result.push_str(&format!("{:<width$} |", "True↓", width = label_width));
764        }
765
766        for label in &labels {
767            if color_options.enabled {
768                let styled_label = stylize(label, Style::Bold);
769                result.push_str(&format!(
770                    " {:<width$} |",
771                    styled_label,
772                    width = cell_width + 8
773                ));
774            } else {
775                result.push_str(&format!(" {label:<cell_width$} |"));
776            }
777        }
778        result.push('\n');
779
780        // Separator
781        result.push_str(&"-".repeat(label_width + 2));
782        for _ in 0..self.num_classes {
783            result.push_str(&format!("{}-", "-".repeat(cell_width + 2)));
784        }
785        result.push('\n');
786
787        // Data rows - using heatmap coloring
788        for i in 0..self.num_classes {
789            // Row label
790            if color_options.enabled {
791                result.push_str(&format!(
792                    "{:<width$} |",
793                    stylize(&labels[i], Style::Bold),
794                    width = label_width + 8
795                ));
796            } else {
797                result.push_str(&format!("{:<width$} |", labels[i], width = label_width));
798            }
799
800            for j in 0..self.num_classes {
801                let value = matrix[[i, j]];
802                let formatted = if normalized {
803                    format!("{value:.3}")
804                } else {
805                    format!("{value:.0}")
806                };
807
808                // Format and color each cell
809                // Get normalized value for coloring
810                let norm_value = if normalized {
811                    value.to_f64().unwrap_or(0.0)
812                } else if max_value > F::zero() {
813                    (value / max_value).to_f64().unwrap_or(0.0)
814                } else {
815                    0.0
816                };
817
818                // Apply heatmap coloring
819                if color_options.enabled {
820                    let heatmap_value = heatmap_cell(&formatted, norm_value, color_options);
821                    // Add extra space for ANSI color codes
822                    result.push_str(&format!(
823                        " {:<width$} |",
824                        heatmap_value,
825                        width = cell_width + 9
826                    ));
827                } else {
828                    result.push_str(&format!(" {formatted:<cell_width$} |"));
829                }
830            }
831            result.push('\n');
832        }
833
834        // Add heatmap legend
835        if color_options.enabled {
836            if let Some(legend) = heatmap_color_legend(color_options) {
837                result.push('\n');
838                result.push_str(&legend);
839            }
840        }
841
842        result
843    }
844}