scirs2_neural/utils/evaluation/
learning_curve.rs

1//! Learning curve for model performance analysis
2//!
3//! This module provides the LearningCurve data structure for visualizing model
4//! performance across different training set sizes, comparing training and
5//! validation metrics.
6
7use crate::error::{NeuralError, Result};
8use crate::utils::colors::{colorize, stylize, Color, ColorOptions, Style};
9use crate::utils::evaluation::helpers::draw_line_with_coords;
10use ndarray::{Array1, Array2, Axis};
11use num_traits::{Float, FromPrimitive};
12use std::fmt::{Debug, Display};
13
14// Removed problematic type alias - use trait bounds directly in implementations
15
16/// Learning curve data structure for visualizing model performance
17///
18/// This structure represents learning curves that show how model performance
19/// changes as the training set size increases, comparing training and validation
20/// metrics to help diagnose overfitting, underfitting, and other training issues.
21pub struct LearningCurve<F: Float + Debug + Display> {
22    /// Training set sizes used for evaluation
23    pub train_sizes: Array1<usize>,
24    /// Training scores for each size and fold (rows=sizes, cols=folds)
25    pub train_scores: Array2<F>,
26    /// Validation scores for each size and fold (rows=sizes, cols=folds)
27    pub val_scores: Array2<F>,
28    /// Mean training scores across folds
29    pub train_mean: Array1<F>,
30    /// Standard deviation of training scores
31    pub train_std: Array1<F>,
32    /// Mean validation scores across folds
33    pub val_mean: Array1<F>,
34    /// Standard deviation of validation scores
35    pub val_std: Array1<F>,
36}
37
38impl<F: Float + Debug + Display + FromPrimitive> LearningCurve<F> {
39    /// Create a new learning curve from training and validation scores
40    ///
41    /// # Arguments
42    ///
43    /// * `train_sizes` - Array of training set sizes
44    /// * `train_scores` - 2D array of training scores (rows=sizes, cols=cv folds)
45    /// * `val_scores` - 2D array of validation scores (rows=sizes, cols=cv folds)
46    ///
47    /// # Returns
48    ///
49    /// * `Result<LearningCurve<F>>` - Learning curve data
50    ///
51    /// # Example
52    ///
53    /// ```
54    /// use ndarray::{Array1, Array2};
55    /// use scirs2_neural::utils::evaluation::LearningCurve;
56    ///
57    /// // Create sample data
58    /// let train_sizes = Array1::from_vec(vec![100, 200, 300, 400, 500]);
59    /// let train_scores = Array2::from_shape_vec((5, 3), vec![
60    ///     0.6, 0.62, 0.58,    // 100 samples, 3 folds
61    ///     0.7, 0.72, 0.68,    // 200 samples, 3 folds
62    ///     0.8, 0.78, 0.79,    // 300 samples, 3 folds
63    ///     0.85, 0.83, 0.84,   // 400 samples, 3 folds
64    ///     0.87, 0.88, 0.86,   // 500 samples, 3 folds
65    /// ]).unwrap();
66    /// let val_scores = Array2::from_shape_vec((5, 3), vec![
67    ///     0.55, 0.53, 0.54,   // 100 samples, 3 folds
68    ///     0.65, 0.63, 0.64,   // 200 samples, 3 folds
69    ///     0.75, 0.73, 0.74,   // 300 samples, 3 folds
70    ///     0.76, 0.74, 0.75,   // 400 samples, 3 folds
71    ///     0.77, 0.76, 0.76,   // 500 samples, 3 folds
72    /// ]).unwrap();
73    ///
74    /// // Create learning curve
75    /// let curve = LearningCurve::<f64>::new(train_sizes, train_scores, val_scores).unwrap();
76    /// ```
77    pub fn new(
78        train_sizes: Array1<usize>,
79        train_scores: Array2<F>,
80        val_scores: Array2<F>,
81    ) -> Result<Self> {
82        let n_sizes = train_sizes.len();
83
84        if train_scores.shape()[0] != n_sizes || val_scores.shape()[0] != n_sizes {
85            return Err(NeuralError::ValidationError(
86                "Number of scores must match number of training sizes".to_string(),
87            ));
88        }
89
90        if train_scores.shape()[1] != val_scores.shape()[1] {
91            return Err(NeuralError::ValidationError(
92                "Training and validation scores must have the same number of CV folds".to_string(),
93            ));
94        }
95
96        // Compute means and standard deviations
97        let train_mean = train_scores.mean_axis(Axis(1)).unwrap();
98        let val_mean = val_scores.mean_axis(Axis(1)).unwrap();
99
100        // Compute standard deviations using helper function
101        let train_std = compute_std(&train_scores, &train_mean, n_sizes);
102        let val_std = compute_std(&val_scores, &val_mean, n_sizes);
103
104        Ok(LearningCurve {
105            train_sizes,
106            train_scores,
107            val_scores,
108            train_mean,
109            train_std,
110            val_mean,
111            val_std,
112        })
113    }
114
115    /// Create an ASCII line plot of the learning curve
116    ///
117    /// # Arguments
118    ///
119    /// * `title` - Optional title for the plot
120    /// * `width` - Width of the plot
121    /// * `height` - Height of the plot
122    /// * `metric_name` - Name of the metric (e.g., "Accuracy")
123    ///
124    /// # Returns
125    ///
126    /// * `String` - ASCII line plot
127    pub fn to_ascii(
128        &self,
129        title: Option<&str>,
130        width: usize,
131        height: usize,
132        metric_name: &str,
133    ) -> String {
134        self.to_ascii_with_options(title, width, height, metric_name, &ColorOptions::default())
135    }
136
137    /// Create an ASCII line plot of the learning curve with customizable colors
138    ///
139    /// This method allows fine-grained control over the color scheme using the
140    /// provided ColorOptions parameter.
141    ///
142    /// # Arguments
143    ///
144    /// * `title` - Optional title for the plot
145    /// * `width` - Width of the plot
146    /// * `height` - Height of the plot
147    /// * `metric_name` - Name of the metric (e.g., "Accuracy")
148    /// * `color_options` - Color options for visualization
149    ///
150    /// # Returns
151    ///
152    /// * `String` - ASCII line plot with colors
153    pub fn to_ascii_with_options(
154        &self,
155        title: Option<&str>,
156        width: usize,
157        height: usize,
158        metric_name: &str,
159        color_options: &ColorOptions,
160    ) -> String {
161        // Pre-allocate result string with estimated capacity
162        let mut result = String::with_capacity(width * height * 2);
163
164        // Add title with styling if provided
165        if let Some(title_text) = title {
166            if color_options.enabled {
167                result.push_str(&format!("{}\n\n", stylize(title_text, Style::Bold)));
168            } else {
169                result.push_str(&format!("{}\n\n", title_text));
170            }
171        } else if color_options.enabled {
172            let title = format!("Learning Curve ({})", stylize(metric_name, Style::Bold));
173            result.push_str(&format!("{}\n\n", stylize(title, Style::Bold)));
174        } else {
175            result.push_str(&format!("Learning Curve ({})\n\n", metric_name));
176        }
177
178        // Find min and max values for y-axis scaling
179        let min_score = self
180            .val_mean
181            .iter()
182            .fold(F::infinity(), |acc, &v| if v < acc { v } else { acc });
183
184        let max_score =
185            self.train_mean
186                .iter()
187                .fold(F::neg_infinity(), |acc, &v| if v > acc { v } else { acc });
188
189        // Add a small margin to the y-range
190        let y_margin = F::from(0.1).unwrap() * (max_score - min_score);
191        let y_min = min_score - y_margin;
192        let y_max = max_score + y_margin;
193
194        // Create a 2D grid for the plot
195        let mut grid = vec![vec![' '; width]; height];
196        let mut grid_markers = vec![vec![(false, false); width]; height]; // Track which points are training vs. validation
197
198        // Function to map a value to a y-coordinate
199        let y_coord = |value: F| -> usize {
200            let norm = (value - y_min) / (y_max - y_min);
201            let y = height - 1 - (norm.to_f64().unwrap() * (height - 1) as f64).round() as usize;
202            std::cmp::min(y, height - 1)
203        };
204
205        // Function to map a training size to an x-coordinate
206        let x_coord = |size_idx: usize| -> usize {
207            ((size_idx as f64) / ((self.train_sizes.len() - 1) as f64) * (width - 1) as f64).round()
208                as usize
209        };
210
211        // Draw training curve and mark as training points
212        for i in 0..self.train_sizes.len() - 1 {
213            let x1 = x_coord(i);
214            let y1 = y_coord(self.train_mean[i]);
215            let x2 = x_coord(i + 1);
216            let y2 = y_coord(self.train_mean[i + 1]);
217
218            // Draw a line between points and mark as training points
219            for (x, y) in draw_line_with_coords(x1, y1, x2, y2, Some(width), Some(height)) {
220                grid[y][x] = '●';
221                grid_markers[y][x].0 = true; // Mark as training point
222            }
223        }
224
225        // Draw validation curve and mark as validation points
226        for i in 0..self.train_sizes.len() - 1 {
227            let x1 = x_coord(i);
228            let y1 = y_coord(self.val_mean[i]);
229            let x2 = x_coord(i + 1);
230            let y2 = y_coord(self.val_mean[i + 1]);
231
232            // Draw a line between points and mark as validation points
233            for (x, y) in draw_line_with_coords(x1, y1, x2, y2, Some(width), Some(height)) {
234                grid[y][x] = '○';
235                grid_markers[y][x].1 = true; // Mark as validation point
236            }
237        }
238
239        // Draw the grid
240        for y in 0..height {
241            // Y-axis labels with styling
242            if y == 0 {
243                if color_options.enabled {
244                    let value = format!("{:.2}", y_max);
245                    result.push_str(&format!("{} |", colorize(value, Color::BrightCyan)));
246                } else {
247                    result.push_str(&format!("{:.2} |", y_max));
248                }
249            } else if y == height - 1 {
250                if color_options.enabled {
251                    let value = format!("{:.2}", y_min);
252                    result.push_str(&format!("{} |", colorize(value, Color::BrightCyan)));
253                } else {
254                    result.push_str(&format!("{:.2} |", y_min));
255                }
256            } else if y == height / 2 {
257                let mid = y_min + (y_max - y_min) * F::from(0.5).unwrap();
258                if color_options.enabled {
259                    let value = format!("{:.2}", mid);
260                    result.push_str(&format!("{} |", colorize(value, Color::BrightCyan)));
261                } else {
262                    result.push_str(&format!("{:.2} |", mid));
263                }
264            } else {
265                result.push_str("      |");
266            }
267
268            // Grid content with coloring
269            for x in 0..width {
270                let char = grid[y][x];
271                let (is_train, is_val) = grid_markers[y][x];
272
273                if color_options.enabled {
274                    if is_train {
275                        // Training point
276                        result.push_str(&colorize("●", Color::BrightBlue));
277                    } else if is_val {
278                        // Validation point
279                        result.push_str(&colorize("○", Color::BrightGreen));
280                    } else {
281                        result.push(char);
282                    }
283                } else {
284                    result.push(char);
285                }
286            }
287
288            result.push('\n');
289        }
290
291        // X-axis
292        result.push_str("      +");
293        result.push_str(&"-".repeat(width));
294        result.push('\n');
295
296        // X-axis labels with styling
297        result.push_str("       ");
298
299        // Put a few size labels along the x-axis
300        let n_labels = std::cmp::min(5, self.train_sizes.len());
301        for i in 0..n_labels {
302            let idx = i * (self.train_sizes.len() - 1) / (n_labels - 1);
303            let size = self.train_sizes[idx];
304            let label = format!("{}", size);
305            let x = x_coord(idx);
306
307            // Position the label with styling
308            if i == 0 {
309                if color_options.enabled {
310                    result.push_str(&colorize(label, Color::BrightCyan));
311                } else {
312                    result.push_str(&label);
313                }
314            } else {
315                let prev_end = result.len();
316                let spaces = x.saturating_sub(prev_end - 7);
317                result.push_str(&" ".repeat(spaces));
318                if color_options.enabled {
319                    result.push_str(&colorize(label, Color::BrightCyan));
320                } else {
321                    result.push_str(&label);
322                }
323            }
324        }
325
326        result.push('\n');
327
328        // X-axis title with styling
329        if color_options.enabled {
330            result.push_str(&format!(
331                "       {}\n\n",
332                stylize("Training Set Size", Style::Bold)
333            ));
334        } else {
335            result.push_str("       Training Set Size\n\n");
336        }
337
338        // Add legend with colors
339        if color_options.enabled {
340            result.push_str(&format!(
341                "       {} Training score   {} Validation score\n",
342                colorize("●", Color::BrightBlue),
343                colorize("○", Color::BrightGreen)
344            ));
345        } else {
346            result.push_str("       ● Training score   ○ Validation score\n");
347        }
348
349        result
350    }
351}
352
353/// Helper function to compute standard deviation for scores
354fn compute_std<F: Float + Debug + Display + FromPrimitive>(
355    scores: &Array2<F>,
356    mean: &Array1<F>,
357    n_sizes: usize,
358) -> Array1<F> {
359    let mut std_arr = Array1::zeros(n_sizes);
360    let n_folds = scores.shape()[1];
361
362    for i in 0..n_sizes {
363        let mut sum_sq_diff = F::zero();
364        for j in 0..n_folds {
365            let diff = scores[[i, j]] - mean[i];
366            sum_sq_diff = sum_sq_diff + diff * diff;
367        }
368        std_arr[i] = (sum_sq_diff / F::from(n_folds).unwrap()).sqrt();
369    }
370
371    std_arr
372}