scirs2_neural/utils/evaluation/
learning_curve.rs1use crate::error::{NeuralError, Result};
8use crate::utils::colors::{colorize, stylize, Color, ColorOptions, Style};
9use crate::utils::evaluation::helpers::draw_line_with_coords;
10use scirs2_core::ndarray::{Array1, Array2, Axis};
11use scirs2_core::numeric::{Float, FromPrimitive};
12use std::fmt::{Debug, Display};
13pub struct LearningCurve<F: Float + Debug + Display> {
20 pub train_sizes: Array1<usize>,
22 pub train_scores: Array2<F>,
24 pub val_scores: Array2<F>,
26 pub train_mean: Array1<F>,
28 pub train_std: Array1<F>,
30 pub val_mean: Array1<F>,
32 pub val_std: Array1<F>,
34}
35impl<F: Float + Debug + Display + FromPrimitive> LearningCurve<F> {
36 pub fn new(
68 train_sizes: Array1<usize>,
69 train_scores: Array2<F>,
70 val_scores: Array2<F>,
71 ) -> Result<Self> {
72 let n_sizes = train_sizes.len();
73 if train_scores.shape()[0] != n_sizes || val_scores.shape()[0] != n_sizes {
74 return Err(NeuralError::ValidationError(
75 "Number of _scores must match number of training _sizes".to_string(),
76 ));
77 }
78 if train_scores.shape()[1] != val_scores.shape()[1] {
79 return Err(NeuralError::ValidationError(
80 "Training and validation _scores must have the same number of CV folds".to_string(),
81 ));
82 }
83 let train_mean = train_scores.mean_axis(Axis(1)).unwrap();
85 let val_mean = val_scores.mean_axis(Axis(1)).unwrap();
86 let train_std = compute_std(&train_scores, &train_mean, n_sizes);
88 let val_std = compute_std(&val_scores, &val_mean, n_sizes);
89 Ok(LearningCurve {
90 train_sizes,
91 train_scores,
92 val_scores,
93 train_mean,
94 train_std,
95 val_mean,
96 val_std,
97 })
98 }
99 pub fn to_ascii(
106 &self,
107 title: Option<&str>,
108 width: usize,
109 height: usize,
110 metric_name: &str,
111 ) -> String {
112 self.to_ascii_with_options(title, width, height, metric_name, &ColorOptions::default())
113 }
114
115 pub fn to_ascii_with_options(
121 &self,
122 title: Option<&str>,
123 width: usize,
124 height: usize,
125 metric_name: &str,
126 color_options: &ColorOptions,
127 ) -> String {
128 let mut result = String::with_capacity(width * height * 2);
130 if let Some(titletext) = title {
132 if color_options.enabled {
133 let styled_title = stylize(titletext, Style::Bold);
134 result.push_str(&format!("{styled_title}\n\n"));
135 } else {
136 result.push_str(&format!("{titletext}\n\n"));
137 }
138 } else if color_options.enabled {
139 let styled_metric = stylize(metric_name, Style::Bold);
140 let title = format!("Learning Curve ({styled_metric})");
141 let styled_title = stylize(title, Style::Bold);
142 result.push_str(&format!("{styled_title}\n\n"));
143 } else {
144 result.push_str(&format!("Learning Curve ({metric_name})\n\n"));
145 }
146
147 let min_score = self
149 .val_mean
150 .iter()
151 .fold(F::infinity(), |acc, &v| if v < acc { v } else { acc });
152 let max_score =
153 self.train_mean
154 .iter()
155 .fold(F::neg_infinity(), |acc, &v| if v > acc { v } else { acc });
156 let y_margin = F::from(0.1).unwrap() * (max_score - min_score);
158 let y_min = min_score - y_margin;
159 let y_max = max_score + y_margin;
160 let mut grid = vec![vec![' '; width]; height];
162 let mut grid_markers = vec![vec![(false, false); width]; height]; let y_coord = |value: F| -> usize {
165 let norm = (value - y_min) / (y_max - y_min);
166 let y = height - 1 - (norm.to_f64().unwrap() * (height - 1) as f64).round() as usize;
167 std::cmp::min(y, height - 1)
168 };
169 let x_coord = |size_idx: usize| -> usize {
171 ((size_idx as f64) / ((self.train_sizes.len() - 1) as f64) * (width - 1) as f64).round()
172 as usize
173 };
174
175 for i in 0..self.train_sizes.len() - 1 {
177 let x1 = x_coord(i);
178 let y1 = y_coord(self.train_mean[i]);
179 let x2 = x_coord(i + 1);
180 let y2 = y_coord(self.train_mean[i + 1]);
181 for (x, y) in draw_line_with_coords(x1, y1, x2, y2, Some(width), Some(height)) {
183 grid[y][x] = '●';
184 grid_markers[y][x].0 = true; }
186 }
187
188 for i in 0..self.train_sizes.len() - 1 {
190 let x1 = x_coord(i);
191 let y1 = y_coord(self.val_mean[i]);
192 let x2 = x_coord(i + 1);
193 let y2 = y_coord(self.val_mean[i + 1]);
194 for (x, y) in draw_line_with_coords(x1, y1, x2, y2, Some(width), Some(height)) {
196 grid[y][x] = '○';
197 grid_markers[y][x].1 = true; }
199 }
200 for y in 0..height {
202 if y == 0 {
204 if color_options.enabled {
205 let value = format!("{y_max:.2}");
206 result.push_str(&format!("{} |", colorize(value, Color::BrightCyan)));
207 } else {
208 result.push_str(&format!("{y_max:.2} |"));
209 }
210 } else if y == height - 1 {
211 if color_options.enabled {
212 let value = format!("{y_min:.2}");
213 result.push_str(&format!("{} |", colorize(value, Color::BrightCyan)));
214 } else {
215 result.push_str(&format!("{y_min:.2} |"));
216 }
217 } else if y == height / 2 {
218 let mid = y_min + (y_max - y_min) * F::from(0.5).unwrap();
219 if color_options.enabled {
220 let value = format!("{mid:.2}");
221 result.push_str(&format!("{} |", colorize(value, Color::BrightCyan)));
222 } else {
223 result.push_str(&format!("{mid:.2} |"));
224 }
225 } else {
226 result.push_str(" |");
227 }
228 for x in 0..width {
230 let char = grid[y][x];
231 let (is_train, is_val) = grid_markers[y][x];
232 if color_options.enabled {
233 if is_train {
234 result.push_str(&colorize("●", Color::BrightBlue));
236 } else if is_val {
237 result.push_str(&colorize("○", Color::BrightGreen));
239 } else {
240 result.push(char);
241 }
242 } else {
243 result.push(char);
244 }
245 }
246 result.push('\n');
247 }
248 result.push_str(" +");
250 result.push_str(&"-".repeat(width));
251 result.push('\n');
252 result.push_str(" ");
254 let n_labels = std::cmp::min(5, self.train_sizes.len());
256 for i in 0..n_labels {
257 let idx = i * (self.train_sizes.len() - 1) / (n_labels - 1);
258 let size = self.train_sizes[idx];
259 let label = format!("{size}");
260 let x = x_coord(idx);
261 if i == 0 {
263 if color_options.enabled {
264 result.push_str(&colorize(label, Color::BrightCyan));
265 } else {
266 result.push_str(&label);
267 }
268 } else {
269 let prev_end = result.len();
270 let spaces = x.saturating_sub(prev_end - 7);
271 result.push_str(&" ".repeat(spaces));
272 if color_options.enabled {
273 result.push_str(&colorize(label, Color::BrightCyan));
274 } else {
275 result.push_str(&label);
276 }
277 }
278 }
279 result.push('\n');
280
281 if color_options.enabled {
283 result.push_str(&format!(
284 " {}\n\n",
285 stylize("Training Set Size", Style::Bold)
286 ));
287 } else {
288 result.push_str(" Training Set Size\n\n");
289 }
290 if color_options.enabled {
292 result.push_str(&format!(
293 " {} Training score {} Validation score\n",
294 colorize("●", Color::BrightBlue),
295 colorize("○", Color::BrightGreen)
296 ));
297 } else {
298 result.push_str(" ● Training score ○ Validation score\n");
299 }
300
301 result
302 }
303}
304#[allow(dead_code)]
306fn compute_std<F: Float + Debug + Display + FromPrimitive>(
307 scores: &Array2<F>,
308 mean: &Array1<F>,
309 n_sizes: usize,
310) -> Array1<F> {
311 let mut std_arr = Array1::zeros(n_sizes);
312 let n_folds = scores.shape()[1];
313 for i in 0..n_sizes {
314 let mut sum_sq_diff = F::zero();
315 for j in 0..n_folds {
316 let diff = scores[[i, j]] - mean[i];
317 sum_sq_diff = sum_sq_diff + diff * diff;
318 }
319 std_arr[i] = (sum_sq_diff / F::from(n_folds).unwrap()).sqrt();
320 }
321 std_arr
322}