scirs2_neural/utils/evaluation/
learning_curve.rs

1//! Learning curve for model performance analysis
2//!
3//! This module provides the LearningCurve data structure for visualizing model
4//! performance across different training set sizes, comparing training and
5//! validation metrics.
6
7use crate::error::{NeuralError, Result};
8use crate::utils::colors::{colorize, stylize, Color, ColorOptions, Style};
9use crate::utils::evaluation::helpers::draw_line_with_coords;
10use scirs2_core::ndarray::{Array1, Array2, Axis};
11use scirs2_core::numeric::{Float, FromPrimitive};
12use std::fmt::{Debug, Display};
13// Removed problematic type alias - use trait bounds directly in implementations
14/// Learning curve data structure for visualizing model performance
15///
16/// This structure represents learning curves that show how model performance
17/// changes as the training set size increases, comparing training and validation
18/// metrics to help diagnose overfitting, underfitting, and other training issues.
19pub struct LearningCurve<F: Float + Debug + Display> {
20    /// Training set sizes used for evaluation
21    pub train_sizes: Array1<usize>,
22    /// Training scores for each size and fold (rows=sizes, cols=folds)
23    pub train_scores: Array2<F>,
24    /// Validation scores for each size and fold (rows=sizes, cols=folds)
25    pub val_scores: Array2<F>,
26    /// Mean training scores across folds
27    pub train_mean: Array1<F>,
28    /// Standard deviation of training scores
29    pub train_std: Array1<F>,
30    /// Mean validation scores across folds
31    pub val_mean: Array1<F>,
32    /// Standard deviation of validation scores
33    pub val_std: Array1<F>,
34}
35impl<F: Float + Debug + Display + FromPrimitive> LearningCurve<F> {
36    /// Create a new learning curve from training and validation scores
37    ///
38    /// # Arguments
39    /// * `train_sizes` - Array of training set sizes
40    /// * `train_scores` - 2D array of training scores (rows=sizes, cols=cv folds)
41    /// * `val_scores` - 2D array of validation scores (rows=sizes, cols=cv folds)
42    /// # Returns
43    /// * `Result<LearningCurve<F>>` - Learning curve data
44    /// # Example
45    /// ```
46    /// use scirs2_core::ndarray::{Array1, Array2};
47    /// use scirs2_neural::utils::evaluation::LearningCurve;
48    /// // Create sample data
49    /// let train_sizes = Array1::from_vec(vec![100, 200, 300, 400, 500]);
50    /// let train_scores = Array2::from_shape_vec((5, 3), vec![
51    ///     0.6, 0.62, 0.58,    // 100 samples, 3 folds
52    ///     0.7, 0.72, 0.68,    // 200 samples, 3 folds
53    ///     0.8, 0.78, 0.79,    // 300 samples, 3 folds
54    ///     0.85, 0.83, 0.84,   // 400 samples, 3 folds
55    ///     0.87, 0.88, 0.86,   // 500 samples, 3 folds
56    /// ]).unwrap();
57    /// let val_scores = Array2::from_shape_vec((5, 3), vec![
58    ///     0.55, 0.53, 0.54,   // 100 samples, 3 folds
59    ///     0.65, 0.63, 0.64,   // 200 samples, 3 folds
60    ///     0.75, 0.73, 0.74,   // 300 samples, 3 folds
61    ///     0.76, 0.74, 0.75,   // 400 samples, 3 folds
62    ///     0.77, 0.76, 0.76,   // 500 samples, 3 folds
63    /// ]).unwrap();
64    /// // Create learning curve
65    /// let curve = LearningCurve::<f64>::new(train_sizes, train_scores, val_scores).unwrap();
66    /// ```
67    pub fn new(
68        train_sizes: Array1<usize>,
69        train_scores: Array2<F>,
70        val_scores: Array2<F>,
71    ) -> Result<Self> {
72        let n_sizes = train_sizes.len();
73        if train_scores.shape()[0] != n_sizes || val_scores.shape()[0] != n_sizes {
74            return Err(NeuralError::ValidationError(
75                "Number of _scores must match number of training _sizes".to_string(),
76            ));
77        }
78        if train_scores.shape()[1] != val_scores.shape()[1] {
79            return Err(NeuralError::ValidationError(
80                "Training and validation _scores must have the same number of CV folds".to_string(),
81            ));
82        }
83        // Compute means and standard deviations
84        let train_mean = train_scores.mean_axis(Axis(1)).unwrap();
85        let val_mean = val_scores.mean_axis(Axis(1)).unwrap();
86        // Compute standard deviations using helper function
87        let train_std = compute_std(&train_scores, &train_mean, n_sizes);
88        let val_std = compute_std(&val_scores, &val_mean, n_sizes);
89        Ok(LearningCurve {
90            train_sizes,
91            train_scores,
92            val_scores,
93            train_mean,
94            train_std,
95            val_mean,
96            val_std,
97        })
98    }
99    /// Create an ASCII line plot of the learning curve
100    /// * `title` - Optional title for the plot
101    /// * `width` - Width of the plot
102    /// * `height` - Height of the plot
103    /// * `metric_name` - Name of the metric (e.g., "Accuracy")
104    /// * `String` - ASCII line plot
105    pub fn to_ascii(
106        &self,
107        title: Option<&str>,
108        width: usize,
109        height: usize,
110        metric_name: &str,
111    ) -> String {
112        self.to_ascii_with_options(title, width, height, metric_name, &ColorOptions::default())
113    }
114
115    /// Create an ASCII line plot of the learning curve with customizable colors
116    /// This method allows fine-grained control over the color scheme using the
117    /// provided ColorOptions parameter.
118    /// * `color_options` - Color options for visualization
119    /// * `String` - ASCII line plot with colors
120    pub fn to_ascii_with_options(
121        &self,
122        title: Option<&str>,
123        width: usize,
124        height: usize,
125        metric_name: &str,
126        color_options: &ColorOptions,
127    ) -> String {
128        // Pre-allocate result string with estimated capacity
129        let mut result = String::with_capacity(width * height * 2);
130        // Add title with styling if provided
131        if let Some(titletext) = title {
132            if color_options.enabled {
133                let styled_title = stylize(titletext, Style::Bold);
134                result.push_str(&format!("{styled_title}\n\n"));
135            } else {
136                result.push_str(&format!("{titletext}\n\n"));
137            }
138        } else if color_options.enabled {
139            let styled_metric = stylize(metric_name, Style::Bold);
140            let title = format!("Learning Curve ({styled_metric})");
141            let styled_title = stylize(title, Style::Bold);
142            result.push_str(&format!("{styled_title}\n\n"));
143        } else {
144            result.push_str(&format!("Learning Curve ({metric_name})\n\n"));
145        }
146
147        // Find min and max values for y-axis scaling
148        let min_score = self
149            .val_mean
150            .iter()
151            .fold(F::infinity(), |acc, &v| if v < acc { v } else { acc });
152        let max_score =
153            self.train_mean
154                .iter()
155                .fold(F::neg_infinity(), |acc, &v| if v > acc { v } else { acc });
156        // Add a small margin to the y-range
157        let y_margin = F::from(0.1).unwrap() * (max_score - min_score);
158        let y_min = min_score - y_margin;
159        let y_max = max_score + y_margin;
160        // Create a 2D grid for the plot
161        let mut grid = vec![vec![' '; width]; height];
162        let mut grid_markers = vec![vec![(false, false); width]; height]; // Track which points are training vs. validation
163                                                                          // Function to map a value to a y-coordinate
164        let y_coord = |value: F| -> usize {
165            let norm = (value - y_min) / (y_max - y_min);
166            let y = height - 1 - (norm.to_f64().unwrap() * (height - 1) as f64).round() as usize;
167            std::cmp::min(y, height - 1)
168        };
169        // Function to map a training size to an x-coordinate
170        let x_coord = |size_idx: usize| -> usize {
171            ((size_idx as f64) / ((self.train_sizes.len() - 1) as f64) * (width - 1) as f64).round()
172                as usize
173        };
174
175        // Draw training curve and mark as training points
176        for i in 0..self.train_sizes.len() - 1 {
177            let x1 = x_coord(i);
178            let y1 = y_coord(self.train_mean[i]);
179            let x2 = x_coord(i + 1);
180            let y2 = y_coord(self.train_mean[i + 1]);
181            // Draw a line between points and mark as training points
182            for (x, y) in draw_line_with_coords(x1, y1, x2, y2, Some(width), Some(height)) {
183                grid[y][x] = '●';
184                grid_markers[y][x].0 = true; // Mark as training point
185            }
186        }
187
188        // Draw validation curve and mark as validation points
189        for i in 0..self.train_sizes.len() - 1 {
190            let x1 = x_coord(i);
191            let y1 = y_coord(self.val_mean[i]);
192            let x2 = x_coord(i + 1);
193            let y2 = y_coord(self.val_mean[i + 1]);
194            // Draw a line between points and mark as validation points
195            for (x, y) in draw_line_with_coords(x1, y1, x2, y2, Some(width), Some(height)) {
196                grid[y][x] = '○';
197                grid_markers[y][x].1 = true; // Mark as validation point
198            }
199        }
200        // Draw the grid
201        for y in 0..height {
202            // Y-axis labels with styling
203            if y == 0 {
204                if color_options.enabled {
205                    let value = format!("{y_max:.2}");
206                    result.push_str(&format!("{} |", colorize(value, Color::BrightCyan)));
207                } else {
208                    result.push_str(&format!("{y_max:.2} |"));
209                }
210            } else if y == height - 1 {
211                if color_options.enabled {
212                    let value = format!("{y_min:.2}");
213                    result.push_str(&format!("{} |", colorize(value, Color::BrightCyan)));
214                } else {
215                    result.push_str(&format!("{y_min:.2} |"));
216                }
217            } else if y == height / 2 {
218                let mid = y_min + (y_max - y_min) * F::from(0.5).unwrap();
219                if color_options.enabled {
220                    let value = format!("{mid:.2}");
221                    result.push_str(&format!("{} |", colorize(value, Color::BrightCyan)));
222                } else {
223                    result.push_str(&format!("{mid:.2} |"));
224                }
225            } else {
226                result.push_str("      |");
227            }
228            // Grid content with coloring
229            for x in 0..width {
230                let char = grid[y][x];
231                let (is_train, is_val) = grid_markers[y][x];
232                if color_options.enabled {
233                    if is_train {
234                        // Training point
235                        result.push_str(&colorize("●", Color::BrightBlue));
236                    } else if is_val {
237                        // Validation point
238                        result.push_str(&colorize("○", Color::BrightGreen));
239                    } else {
240                        result.push(char);
241                    }
242                } else {
243                    result.push(char);
244                }
245            }
246            result.push('\n');
247        }
248        // X-axis
249        result.push_str("      +");
250        result.push_str(&"-".repeat(width));
251        result.push('\n');
252        // X-axis labels with styling
253        result.push_str("       ");
254        // Put a few size labels along the x-axis
255        let n_labels = std::cmp::min(5, self.train_sizes.len());
256        for i in 0..n_labels {
257            let idx = i * (self.train_sizes.len() - 1) / (n_labels - 1);
258            let size = self.train_sizes[idx];
259            let label = format!("{size}");
260            let x = x_coord(idx);
261            // Position the label with styling
262            if i == 0 {
263                if color_options.enabled {
264                    result.push_str(&colorize(label, Color::BrightCyan));
265                } else {
266                    result.push_str(&label);
267                }
268            } else {
269                let prev_end = result.len();
270                let spaces = x.saturating_sub(prev_end - 7);
271                result.push_str(&" ".repeat(spaces));
272                if color_options.enabled {
273                    result.push_str(&colorize(label, Color::BrightCyan));
274                } else {
275                    result.push_str(&label);
276                }
277            }
278        }
279        result.push('\n');
280
281        // X-axis title with styling
282        if color_options.enabled {
283            result.push_str(&format!(
284                "       {}\n\n",
285                stylize("Training Set Size", Style::Bold)
286            ));
287        } else {
288            result.push_str("       Training Set Size\n\n");
289        }
290        // Add legend with colors
291        if color_options.enabled {
292            result.push_str(&format!(
293                "       {} Training score   {} Validation score\n",
294                colorize("●", Color::BrightBlue),
295                colorize("○", Color::BrightGreen)
296            ));
297        } else {
298            result.push_str("       ● Training score   ○ Validation score\n");
299        }
300
301        result
302    }
303}
304/// Helper function to compute standard deviation for scores
305#[allow(dead_code)]
306fn compute_std<F: Float + Debug + Display + FromPrimitive>(
307    scores: &Array2<F>,
308    mean: &Array1<F>,
309    n_sizes: usize,
310) -> Array1<F> {
311    let mut std_arr = Array1::zeros(n_sizes);
312    let n_folds = scores.shape()[1];
313    for i in 0..n_sizes {
314        let mut sum_sq_diff = F::zero();
315        for j in 0..n_folds {
316            let diff = scores[[i, j]] - mean[i];
317            sum_sq_diff = sum_sq_diff + diff * diff;
318        }
319        std_arr[i] = (sum_sq_diff / F::from(n_folds).unwrap()).sqrt();
320    }
321    std_arr
322}