scirs2_neural/utils/evaluation/learning_curve.rs
1//! Learning curve for model performance analysis
2//!
3//! This module provides the LearningCurve data structure for visualizing model
4//! performance across different training set sizes, comparing training and
5//! validation metrics.
6
7use crate::error::{NeuralError, Result};
8use crate::utils::colors::{colorize, stylize, Color, ColorOptions, Style};
9use crate::utils::evaluation::helpers::draw_line_with_coords;
10use ndarray::{Array1, Array2, Axis};
11use num_traits::{Float, FromPrimitive};
12use std::fmt::{Debug, Display};
13
14// Removed problematic type alias - use trait bounds directly in implementations
15
16/// Learning curve data structure for visualizing model performance
17///
18/// This structure represents learning curves that show how model performance
19/// changes as the training set size increases, comparing training and validation
20/// metrics to help diagnose overfitting, underfitting, and other training issues.
21pub struct LearningCurve<F: Float + Debug + Display> {
22 /// Training set sizes used for evaluation
23 pub train_sizes: Array1<usize>,
24 /// Training scores for each size and fold (rows=sizes, cols=folds)
25 pub train_scores: Array2<F>,
26 /// Validation scores for each size and fold (rows=sizes, cols=folds)
27 pub val_scores: Array2<F>,
28 /// Mean training scores across folds
29 pub train_mean: Array1<F>,
30 /// Standard deviation of training scores
31 pub train_std: Array1<F>,
32 /// Mean validation scores across folds
33 pub val_mean: Array1<F>,
34 /// Standard deviation of validation scores
35 pub val_std: Array1<F>,
36}
37
38impl<F: Float + Debug + Display + FromPrimitive> LearningCurve<F> {
39 /// Create a new learning curve from training and validation scores
40 ///
41 /// # Arguments
42 ///
43 /// * `train_sizes` - Array of training set sizes
44 /// * `train_scores` - 2D array of training scores (rows=sizes, cols=cv folds)
45 /// * `val_scores` - 2D array of validation scores (rows=sizes, cols=cv folds)
46 ///
47 /// # Returns
48 ///
49 /// * `Result<LearningCurve<F>>` - Learning curve data
50 ///
51 /// # Example
52 ///
53 /// ```
54 /// use ndarray::{Array1, Array2};
55 /// use scirs2_neural::utils::evaluation::LearningCurve;
56 ///
57 /// // Create sample data
58 /// let train_sizes = Array1::from_vec(vec![100, 200, 300, 400, 500]);
59 /// let train_scores = Array2::from_shape_vec((5, 3), vec![
60 /// 0.6, 0.62, 0.58, // 100 samples, 3 folds
61 /// 0.7, 0.72, 0.68, // 200 samples, 3 folds
62 /// 0.8, 0.78, 0.79, // 300 samples, 3 folds
63 /// 0.85, 0.83, 0.84, // 400 samples, 3 folds
64 /// 0.87, 0.88, 0.86, // 500 samples, 3 folds
65 /// ]).unwrap();
66 /// let val_scores = Array2::from_shape_vec((5, 3), vec![
67 /// 0.55, 0.53, 0.54, // 100 samples, 3 folds
68 /// 0.65, 0.63, 0.64, // 200 samples, 3 folds
69 /// 0.75, 0.73, 0.74, // 300 samples, 3 folds
70 /// 0.76, 0.74, 0.75, // 400 samples, 3 folds
71 /// 0.77, 0.76, 0.76, // 500 samples, 3 folds
72 /// ]).unwrap();
73 ///
74 /// // Create learning curve
75 /// let curve = LearningCurve::<f64>::new(train_sizes, train_scores, val_scores).unwrap();
76 /// ```
77 pub fn new(
78 train_sizes: Array1<usize>,
79 train_scores: Array2<F>,
80 val_scores: Array2<F>,
81 ) -> Result<Self> {
82 let n_sizes = train_sizes.len();
83
84 if train_scores.shape()[0] != n_sizes || val_scores.shape()[0] != n_sizes {
85 return Err(NeuralError::ValidationError(
86 "Number of scores must match number of training sizes".to_string(),
87 ));
88 }
89
90 if train_scores.shape()[1] != val_scores.shape()[1] {
91 return Err(NeuralError::ValidationError(
92 "Training and validation scores must have the same number of CV folds".to_string(),
93 ));
94 }
95
96 // Compute means and standard deviations
97 let train_mean = train_scores.mean_axis(Axis(1)).unwrap();
98 let val_mean = val_scores.mean_axis(Axis(1)).unwrap();
99
100 // Compute standard deviations using helper function
101 let train_std = compute_std(&train_scores, &train_mean, n_sizes);
102 let val_std = compute_std(&val_scores, &val_mean, n_sizes);
103
104 Ok(LearningCurve {
105 train_sizes,
106 train_scores,
107 val_scores,
108 train_mean,
109 train_std,
110 val_mean,
111 val_std,
112 })
113 }
114
115 /// Create an ASCII line plot of the learning curve
116 ///
117 /// # Arguments
118 ///
119 /// * `title` - Optional title for the plot
120 /// * `width` - Width of the plot
121 /// * `height` - Height of the plot
122 /// * `metric_name` - Name of the metric (e.g., "Accuracy")
123 ///
124 /// # Returns
125 ///
126 /// * `String` - ASCII line plot
127 pub fn to_ascii(
128 &self,
129 title: Option<&str>,
130 width: usize,
131 height: usize,
132 metric_name: &str,
133 ) -> String {
134 self.to_ascii_with_options(title, width, height, metric_name, &ColorOptions::default())
135 }
136
137 /// Create an ASCII line plot of the learning curve with customizable colors
138 ///
139 /// This method allows fine-grained control over the color scheme using the
140 /// provided ColorOptions parameter.
141 ///
142 /// # Arguments
143 ///
144 /// * `title` - Optional title for the plot
145 /// * `width` - Width of the plot
146 /// * `height` - Height of the plot
147 /// * `metric_name` - Name of the metric (e.g., "Accuracy")
148 /// * `color_options` - Color options for visualization
149 ///
150 /// # Returns
151 ///
152 /// * `String` - ASCII line plot with colors
153 pub fn to_ascii_with_options(
154 &self,
155 title: Option<&str>,
156 width: usize,
157 height: usize,
158 metric_name: &str,
159 color_options: &ColorOptions,
160 ) -> String {
161 // Pre-allocate result string with estimated capacity
162 let mut result = String::with_capacity(width * height * 2);
163
164 // Add title with styling if provided
165 if let Some(title_text) = title {
166 if color_options.enabled {
167 result.push_str(&format!("{}\n\n", stylize(title_text, Style::Bold)));
168 } else {
169 result.push_str(&format!("{}\n\n", title_text));
170 }
171 } else if color_options.enabled {
172 let title = format!("Learning Curve ({})", stylize(metric_name, Style::Bold));
173 result.push_str(&format!("{}\n\n", stylize(title, Style::Bold)));
174 } else {
175 result.push_str(&format!("Learning Curve ({})\n\n", metric_name));
176 }
177
178 // Find min and max values for y-axis scaling
179 let min_score = self
180 .val_mean
181 .iter()
182 .fold(F::infinity(), |acc, &v| if v < acc { v } else { acc });
183
184 let max_score =
185 self.train_mean
186 .iter()
187 .fold(F::neg_infinity(), |acc, &v| if v > acc { v } else { acc });
188
189 // Add a small margin to the y-range
190 let y_margin = F::from(0.1).unwrap() * (max_score - min_score);
191 let y_min = min_score - y_margin;
192 let y_max = max_score + y_margin;
193
194 // Create a 2D grid for the plot
195 let mut grid = vec![vec![' '; width]; height];
196 let mut grid_markers = vec![vec![(false, false); width]; height]; // Track which points are training vs. validation
197
198 // Function to map a value to a y-coordinate
199 let y_coord = |value: F| -> usize {
200 let norm = (value - y_min) / (y_max - y_min);
201 let y = height - 1 - (norm.to_f64().unwrap() * (height - 1) as f64).round() as usize;
202 std::cmp::min(y, height - 1)
203 };
204
205 // Function to map a training size to an x-coordinate
206 let x_coord = |size_idx: usize| -> usize {
207 ((size_idx as f64) / ((self.train_sizes.len() - 1) as f64) * (width - 1) as f64).round()
208 as usize
209 };
210
211 // Draw training curve and mark as training points
212 for i in 0..self.train_sizes.len() - 1 {
213 let x1 = x_coord(i);
214 let y1 = y_coord(self.train_mean[i]);
215 let x2 = x_coord(i + 1);
216 let y2 = y_coord(self.train_mean[i + 1]);
217
218 // Draw a line between points and mark as training points
219 for (x, y) in draw_line_with_coords(x1, y1, x2, y2, Some(width), Some(height)) {
220 grid[y][x] = '●';
221 grid_markers[y][x].0 = true; // Mark as training point
222 }
223 }
224
225 // Draw validation curve and mark as validation points
226 for i in 0..self.train_sizes.len() - 1 {
227 let x1 = x_coord(i);
228 let y1 = y_coord(self.val_mean[i]);
229 let x2 = x_coord(i + 1);
230 let y2 = y_coord(self.val_mean[i + 1]);
231
232 // Draw a line between points and mark as validation points
233 for (x, y) in draw_line_with_coords(x1, y1, x2, y2, Some(width), Some(height)) {
234 grid[y][x] = '○';
235 grid_markers[y][x].1 = true; // Mark as validation point
236 }
237 }
238
239 // Draw the grid
240 for y in 0..height {
241 // Y-axis labels with styling
242 if y == 0 {
243 if color_options.enabled {
244 let value = format!("{:.2}", y_max);
245 result.push_str(&format!("{} |", colorize(value, Color::BrightCyan)));
246 } else {
247 result.push_str(&format!("{:.2} |", y_max));
248 }
249 } else if y == height - 1 {
250 if color_options.enabled {
251 let value = format!("{:.2}", y_min);
252 result.push_str(&format!("{} |", colorize(value, Color::BrightCyan)));
253 } else {
254 result.push_str(&format!("{:.2} |", y_min));
255 }
256 } else if y == height / 2 {
257 let mid = y_min + (y_max - y_min) * F::from(0.5).unwrap();
258 if color_options.enabled {
259 let value = format!("{:.2}", mid);
260 result.push_str(&format!("{} |", colorize(value, Color::BrightCyan)));
261 } else {
262 result.push_str(&format!("{:.2} |", mid));
263 }
264 } else {
265 result.push_str(" |");
266 }
267
268 // Grid content with coloring
269 for x in 0..width {
270 let char = grid[y][x];
271 let (is_train, is_val) = grid_markers[y][x];
272
273 if color_options.enabled {
274 if is_train {
275 // Training point
276 result.push_str(&colorize("●", Color::BrightBlue));
277 } else if is_val {
278 // Validation point
279 result.push_str(&colorize("○", Color::BrightGreen));
280 } else {
281 result.push(char);
282 }
283 } else {
284 result.push(char);
285 }
286 }
287
288 result.push('\n');
289 }
290
291 // X-axis
292 result.push_str(" +");
293 result.push_str(&"-".repeat(width));
294 result.push('\n');
295
296 // X-axis labels with styling
297 result.push_str(" ");
298
299 // Put a few size labels along the x-axis
300 let n_labels = std::cmp::min(5, self.train_sizes.len());
301 for i in 0..n_labels {
302 let idx = i * (self.train_sizes.len() - 1) / (n_labels - 1);
303 let size = self.train_sizes[idx];
304 let label = format!("{}", size);
305 let x = x_coord(idx);
306
307 // Position the label with styling
308 if i == 0 {
309 if color_options.enabled {
310 result.push_str(&colorize(label, Color::BrightCyan));
311 } else {
312 result.push_str(&label);
313 }
314 } else {
315 let prev_end = result.len();
316 let spaces = x.saturating_sub(prev_end - 7);
317 result.push_str(&" ".repeat(spaces));
318 if color_options.enabled {
319 result.push_str(&colorize(label, Color::BrightCyan));
320 } else {
321 result.push_str(&label);
322 }
323 }
324 }
325
326 result.push('\n');
327
328 // X-axis title with styling
329 if color_options.enabled {
330 result.push_str(&format!(
331 " {}\n\n",
332 stylize("Training Set Size", Style::Bold)
333 ));
334 } else {
335 result.push_str(" Training Set Size\n\n");
336 }
337
338 // Add legend with colors
339 if color_options.enabled {
340 result.push_str(&format!(
341 " {} Training score {} Validation score\n",
342 colorize("●", Color::BrightBlue),
343 colorize("○", Color::BrightGreen)
344 ));
345 } else {
346 result.push_str(" ● Training score ○ Validation score\n");
347 }
348
349 result
350 }
351}
352
353/// Helper function to compute standard deviation for scores
354fn compute_std<F: Float + Debug + Display + FromPrimitive>(
355 scores: &Array2<F>,
356 mean: &Array1<F>,
357 n_sizes: usize,
358) -> Array1<F> {
359 let mut std_arr = Array1::zeros(n_sizes);
360 let n_folds = scores.shape()[1];
361
362 for i in 0..n_sizes {
363 let mut sum_sq_diff = F::zero();
364 for j in 0..n_folds {
365 let diff = scores[[i, j]] - mean[i];
366 sum_sq_diff = sum_sq_diff + diff * diff;
367 }
368 std_arr[i] = (sum_sq_diff / F::from(n_folds).unwrap()).sqrt();
369 }
370
371 std_arr
372}