1use std::error::Error;
7
8use super::{MetricVisualizer, PlotType, VisualizationData, VisualizationMetadata};
9use crate::error::{MetricsError, Result};
10use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix1, Ix2};
11use scirs2_core::random::prelude::*;
12
13#[derive(Debug, Clone)]
17pub struct LearningCurveData {
18 pub train_sizes: Vec<usize>,
20 pub train_scores: Vec<Vec<f64>>,
22 pub validation_scores: Vec<Vec<f64>>,
24}
25
26#[derive(Debug, Clone)]
30pub struct LearningCurveVisualizer {
31 data: LearningCurveData,
33 title: String,
35 show_std: bool,
37 scoring: String,
39}
40
41impl LearningCurveVisualizer {
42 pub fn new(data: LearningCurveData) -> Self {
52 LearningCurveVisualizer {
53 data,
54 title: "Learning Curve".to_string(),
55 show_std: true,
56 scoring: "Score".to_string(),
57 }
58 }
59
60 pub fn with_title(mut self, title: String) -> Self {
70 self.title = title;
71 self
72 }
73
74 pub fn with_show_std(mut self, showstd: bool) -> Self {
84 self.show_std = showstd;
85 self
86 }
87
88 pub fn with_scoring(mut self, scoring: String) -> Self {
98 self.scoring = scoring;
99 self
100 }
101
102 fn compute_statistics(&self, scores: &[Vec<f64>]) -> (Vec<f64>, Vec<f64>) {
112 let n = scores.len();
113 let mut mean_scores = Vec::with_capacity(n);
114 let mut std_scores = Vec::with_capacity(n);
115
116 for fold_scores in scores {
117 let mean = fold_scores.iter().sum::<f64>() / fold_scores.len() as f64;
119 mean_scores.push(mean);
120
121 let variance = fold_scores.iter().map(|&s| (s - mean).powi(2)).sum::<f64>()
123 / fold_scores.len() as f64;
124 std_scores.push(variance.sqrt());
125 }
126
127 (mean_scores, std_scores)
128 }
129}
130
131impl MetricVisualizer for LearningCurveVisualizer {
132 fn prepare_data(&self) -> std::result::Result<VisualizationData, Box<dyn Error>> {
133 let (train_mean, train_std) = self.compute_statistics(&self.data.train_scores);
135 let (val_mean, val_std) = self.compute_statistics(&self.data.validation_scores);
136
137 let train_sizes: Vec<f64> = self.data.train_sizes.iter().map(|&s| s as f64).collect();
139
140 let mut x = Vec::new();
142 let mut y = Vec::new();
143
144 x.extend_from_slice(&train_sizes);
146 y.extend_from_slice(&train_mean);
147
148 x.extend_from_slice(&train_sizes);
150 y.extend_from_slice(&val_mean);
151
152 let mut series_names = vec!["Training score".to_string(), "Validation score".to_string()];
154
155 if self.show_std {
157 x.extend_from_slice(&train_sizes);
159 x.extend_from_slice(&train_sizes);
160
161 let train_upper: Vec<f64> = train_mean
162 .iter()
163 .zip(train_std.iter())
164 .map(|(&m, &s)| m + s)
165 .collect();
166
167 let train_lower: Vec<f64> = train_mean
168 .iter()
169 .zip(train_std.iter())
170 .map(|(&m, &s)| m - s)
171 .collect();
172
173 y.extend_from_slice(&train_upper);
174 y.extend_from_slice(&train_lower);
175
176 x.extend_from_slice(&train_sizes);
178 x.extend_from_slice(&train_sizes);
179
180 let val_upper: Vec<f64> = val_mean
181 .iter()
182 .zip(val_std.iter())
183 .map(|(&m, &s)| m + s)
184 .collect();
185
186 let val_lower: Vec<f64> = val_mean
187 .iter()
188 .zip(val_std.iter())
189 .map(|(&m, &s)| m - s)
190 .collect();
191
192 y.extend_from_slice(&val_upper);
193 y.extend_from_slice(&val_lower);
194
195 series_names.push("Training score +/- std".to_string());
197 series_names.push("Training score +/- std".to_string());
198 series_names.push("Validation score +/- std".to_string());
199 series_names.push("Validation score +/- std".to_string());
200 }
201
202 Ok(VisualizationData {
203 x,
204 y,
205 z: None,
206 series_names: Some(series_names),
207 x_labels: None,
208 y_labels: None,
209 auxiliary_data: std::collections::HashMap::new(),
210 auxiliary_metadata: std::collections::HashMap::new(),
211 series: std::collections::HashMap::new(),
212 })
213 }
214
215 fn get_metadata(&self) -> VisualizationMetadata {
216 VisualizationMetadata {
217 title: self.title.clone(),
218 x_label: "Training examples".to_string(),
219 y_label: self.scoring.clone(),
220 plot_type: PlotType::Line,
221 description: Some(
222 "Learning curve showing model performance as a function of training set size"
223 .to_string(),
224 ),
225 }
226 }
227}
228
229#[allow(dead_code)]
242pub fn learning_curve_visualization(
243 train_sizes: Vec<usize>,
244 train_scores: Vec<Vec<f64>>,
245 validation_scores: Vec<Vec<f64>>,
246 scoring: impl Into<String>,
247) -> Result<LearningCurveVisualizer> {
248 if train_sizes.is_empty() || train_scores.is_empty() || validation_scores.is_empty() {
250 return Err(MetricsError::InvalidInput(
251 "Learning curve data cannot be empty".to_string(),
252 ));
253 }
254
255 if train_scores.len() != train_sizes.len() || validation_scores.len() != train_sizes.len() {
256 return Err(MetricsError::InvalidInput(
257 "Number of train/validation _scores must match number of training _sizes".to_string(),
258 ));
259 }
260
261 let data = LearningCurveData {
262 train_sizes,
263 train_scores,
264 validation_scores,
265 };
266
267 let scoring_string = scoring.into();
268 Ok(LearningCurveVisualizer::new(data).with_scoring(scoring_string))
269}
270
271#[derive(Debug, Clone, Copy)]
273pub enum LearningCurveScenario {
274 WellFitted,
276 HighBias,
278 HighVariance,
280 NoisyData,
282 PlateauEffect,
284}
285
286#[derive(Debug, Clone)]
288pub struct LearningCurveConfig {
289 pub scenario: LearningCurveScenario,
291 pub cv_folds: usize,
293 pub base_performance: f64,
295 pub noise_level: f64,
297 pub add_cv_variance: bool,
299}
300
301impl Default for LearningCurveConfig {
302 fn default() -> Self {
303 Self {
304 scenario: LearningCurveScenario::WellFitted,
305 cv_folds: 5,
306 base_performance: 0.75,
307 noise_level: 0.05,
308 add_cv_variance: true,
309 }
310 }
311}
312
313#[allow(dead_code)]
332pub fn learning_curve_realistic<T, S1, S2>(
333 _x: &ArrayBase<S1, Ix2>,
334 _y: &ArrayBase<S2, Ix1>,
335 train_sizes: &[usize],
336 config: LearningCurveConfig,
337 scoring: impl Into<String>,
338) -> Result<LearningCurveVisualizer>
339where
340 T: Clone + 'static,
341 S1: Data<Elem = T>,
342 S2: Data<Elem = T>,
343{
344 use scirs2_core::random::Rng;
345 let mut rng = scirs2_core::random::rng();
346
347 let n_sizes = train_sizes.len();
348 let mut train_scores = Vec::with_capacity(n_sizes);
349 let mut validation_scores = Vec::with_capacity(n_sizes);
350
351 for (i, &_size) in train_sizes.iter().enumerate() {
352 let progress = i as f64 / n_sizes.max(1) as f64;
353
354 let (base_train_score, base_val_score) = match config.scenario {
355 LearningCurveScenario::WellFitted => {
356 let train_score = config.base_performance + 0.15 * progress.powf(0.3);
358 let val_score = config.base_performance - 0.1 + 0.2 * progress.powf(0.5);
360 (train_score.min(0.95), val_score.min(train_score - 0.02))
361 }
362 LearningCurveScenario::HighBias => {
363 let train_score = config.base_performance - 0.15 + 0.1 * progress.powf(0.8);
365 let val_score = train_score - 0.05 + 0.03 * progress;
366 (train_score.min(0.7), val_score.min(train_score))
367 }
368 LearningCurveScenario::HighVariance => {
369 let train_score = config.base_performance + 0.2 * progress.powf(0.2);
371 let val_score = config.base_performance - 0.2 + 0.15 * progress.powf(0.7);
372 (train_score.min(0.98), val_score.min(train_score - 0.15))
373 }
374 LearningCurveScenario::NoisyData => {
375 let noise_factor = 0.1 * (progress * 10.0).sin();
377 let train_score = config.base_performance + 0.1 * progress + noise_factor;
378 let val_score =
379 config.base_performance - 0.05 + 0.12 * progress + noise_factor * 0.5;
380 (train_score.min(0.9), val_score.min(train_score))
381 }
382 LearningCurveScenario::PlateauEffect => {
383 let plateau_factor = 1.0 - (-5.0 * progress).exp();
385 let train_score = config.base_performance + 0.15 * plateau_factor;
386 let val_score = config.base_performance - 0.08 + 0.18 * plateau_factor;
387 (train_score, val_score.min(train_score - 0.01))
388 }
389 };
390
391 let fold_variance = if config.add_cv_variance {
393 config.noise_level
394 } else {
395 0.0
396 };
397
398 let train_fold_scores: Vec<f64> = (0..config.cv_folds)
399 .map(|_| {
400 let noise = rng.random_range(-fold_variance..fold_variance);
401 (base_train_score + noise).clamp(0.0, 1.0)
402 })
403 .collect();
404
405 let val_fold_scores: Vec<f64> = (0..config.cv_folds)
406 .map(|_| {
407 let noise = rng.random_range(-fold_variance * 1.5..fold_variance * 1.5);
408 (base_val_score + noise).clamp(0.0, 1.0)
409 })
410 .collect();
411
412 train_scores.push(train_fold_scores);
413 validation_scores.push(val_fold_scores);
414 }
415
416 learning_curve_visualization(
417 train_sizes.to_vec(),
418 train_scores,
419 validation_scores,
420 scoring,
421 )
422}
423
424#[allow(dead_code)]
442pub fn learning_curve<T, S1, S2>(
443 x: &ArrayBase<S1, Ix2>,
444 y: &ArrayBase<S2, Ix1>,
445 model: &impl ModelEvaluator<T>,
446 train_sizes: &[usize],
447 cv: usize,
448 scoring: impl Into<String>,
449) -> Result<LearningCurveVisualizer>
450where
451 T: Clone
452 + 'static
453 + scirs2_core::numeric::Float
454 + Send
455 + Sync
456 + std::fmt::Debug
457 + std::ops::Sub<Output = T>,
458 S1: Data<Elem = T>,
459 S2: Data<Elem = T>,
460 for<'a> &'a T: std::ops::Sub<&'a T, Output = T>,
461{
462 let scoring_str = scoring.into();
463
464 if x.nrows() != y.len() {
466 return Err(MetricsError::InvalidInput(
467 "Feature matrix and target vector must have same number of samples".to_string(),
468 ));
469 }
470
471 if train_sizes.is_empty() {
472 return Err(MetricsError::InvalidInput(
473 "Training _sizes cannot be empty".to_string(),
474 ));
475 }
476
477 let max_size = train_sizes.iter().max().unwrap();
478 if *max_size > x.nrows() {
479 return Err(MetricsError::InvalidInput(format!(
480 "Maximum training size ({}) exceeds available samples ({})",
481 max_size,
482 x.nrows()
483 )));
484 }
485
486 let mut train_scores = Vec::new();
488 let mut validation_scores = Vec::new();
489
490 use scirs2_core::simd_ops::SimdUnifiedOps;
491 let mut rng = scirs2_core::random::rng();
492
493 let fold_size = x.nrows() / cv;
495 let mut indices: Vec<usize> = (0..x.nrows()).collect();
496
497 for &size in train_sizes {
498 let mut train_fold_scores = Vec::new();
499 let mut val_fold_scores = Vec::new();
500
501 for fold in 0..cv {
503 for i in 0..indices.len() {
505 let j = rng.random_range(0..indices.len());
506 indices.swap(i, j);
507 }
508
509 let val_start = fold * fold_size;
511 let val_end = std::cmp::min((fold + 1) * fold_size, x.nrows());
512
513 let mut train_indices = Vec::new();
514 let mut val_indices = Vec::new();
515
516 for (i, &idx) in indices.iter().enumerate() {
517 if i >= val_start && i < val_end {
518 val_indices.push(idx);
519 } else if train_indices.len() < size {
520 train_indices.push(idx);
521 }
522 }
523
524 let train_x = extract_rows(x, &train_indices);
526 let train_y = extract_elements(y, &train_indices);
527 let val_x = extract_rows(x, &val_indices);
528 let val_y = extract_elements(y, &val_indices);
529
530 let trained_model = model.fit(&train_x, &train_y)?;
532
533 let train_pred = trained_model.predict(&train_x)?;
535 let train_score = evaluate_predictions(&train_y, &train_pred, &scoring_str)?;
536 train_fold_scores.push(train_score);
537
538 let val_pred = trained_model.predict(&val_x)?;
540 let val_score = evaluate_predictions(&val_y, &val_pred, &scoring_str)?;
541 val_fold_scores.push(val_score);
542 }
543
544 train_scores.push(train_fold_scores);
545 validation_scores.push(val_fold_scores);
546 }
547
548 learning_curve_visualization(
549 train_sizes.to_vec(),
550 train_scores,
551 validation_scores,
552 scoring_str,
553 )
554}
555
556pub trait ModelEvaluator<T> {
558 type TrainedModel: ModelPredictor<T>;
559
560 fn fit(&self, x: &Array2<T>, y: &Array1<T>) -> Result<Self::TrainedModel>;
561}
562
563pub trait ModelPredictor<T> {
565 fn predict(&self, x: &Array2<T>) -> Result<Array1<T>>;
566}
567
568#[allow(dead_code)]
570fn extract_rows<T, S>(arr: &ArrayBase<S, Ix2>, indices: &[usize]) -> Array2<T>
571where
572 T: Clone + scirs2_core::numeric::Zero,
573 S: Data<Elem = T>,
574{
575 let mut result = Array2::zeros((indices.len(), arr.ncols()));
576 for (i, &idx) in indices.iter().enumerate() {
577 result.row_mut(i).assign(&arr.row(idx));
578 }
579 result
580}
581
582#[allow(dead_code)]
584fn extract_elements<T, S>(arr: &ArrayBase<S, Ix1>, indices: &[usize]) -> Array1<T>
585where
586 T: Clone + scirs2_core::numeric::Zero,
587 S: Data<Elem = T>,
588{
589 let mut result = Array1::zeros(indices.len());
590 for (i, &idx) in indices.iter().enumerate() {
591 result[i] = arr[idx].clone();
592 }
593 result
594}
595
596#[allow(dead_code)]
598fn evaluate_predictions<T>(y_true: &Array1<T>, ypred: &Array1<T>, scoring: &str) -> Result<f64>
599where
600 T: Clone
601 + scirs2_core::numeric::Float
602 + Send
603 + Sync
604 + std::fmt::Debug
605 + std::ops::Sub<Output = T>,
606 for<'a> &'a T: std::ops::Sub<&'a T, Output = T>,
607{
608 match scoring.to_lowercase().as_str() {
609 "accuracy" => {
610 let correct = y_true
612 .iter()
613 .zip(ypred.iter())
614 .filter(|(t, p)| (*t - *p).abs() < T::from(0.5).unwrap())
615 .count();
616 Ok(correct as f64 / y_true.len() as f64)
617 }
618 "mse" | "mean_squared_error" => {
619 let mse = y_true
621 .iter()
622 .zip(ypred.iter())
623 .map(|(t, p)| (*t - *p) * (*t - *p))
624 .fold(T::zero(), |acc, x| acc + x)
625 / T::from(y_true.len()).unwrap();
626 Ok(mse.to_f64().unwrap_or(0.0))
627 }
628 "mae" | "mean_absolute_error" => {
629 let mae = y_true
631 .iter()
632 .zip(ypred.iter())
633 .map(|(t, p)| (*t - *p).abs())
634 .fold(T::zero(), |acc, x| acc + x)
635 / T::from(y_true.len()).unwrap();
636 Ok(mae.to_f64().unwrap_or(0.0))
637 }
638 "r2" | "r2_score" => {
639 let mean_true = y_true.iter().cloned().fold(T::zero(), |acc, x| acc + x)
641 / T::from(y_true.len()).unwrap();
642
643 let ss_tot = y_true
644 .iter()
645 .map(|&t| (t - mean_true) * (t - mean_true))
646 .fold(T::zero(), |acc, x| acc + x);
647
648 let ss_res = y_true
649 .iter()
650 .zip(ypred.iter())
651 .map(|(&t, &p)| (t - p) * (t - p))
652 .fold(T::zero(), |acc, x| acc + x);
653
654 if ss_tot == T::zero() {
655 Ok(0.0)
656 } else {
657 let r2 = T::one() - ss_res / ss_tot;
658 Ok(r2.to_f64().unwrap_or(0.0))
659 }
660 }
661 _ => {
662 let mse = y_true
664 .iter()
665 .zip(ypred.iter())
666 .map(|(t, p)| (*t - *p) * (*t - *p))
667 .fold(T::zero(), |acc, x| acc + x)
668 / T::from(y_true.len()).unwrap();
669 Ok(mse.to_f64().unwrap_or(0.0))
670 }
671 }
672}
673
674#[allow(dead_code)]
688pub fn learning_curve_scenarios(
689 train_sizes: &[usize],
690 scoring: impl Into<String>,
691) -> Result<Vec<(String, LearningCurveVisualizer)>> {
692 let scoring_str = scoring.into();
693 let scenarios = [
694 ("Well Fitted", LearningCurveScenario::WellFitted),
695 ("High Bias (Underfitting)", LearningCurveScenario::HighBias),
696 (
697 "High Variance (Overfitting)",
698 LearningCurveScenario::HighVariance,
699 ),
700 ("Noisy Data", LearningCurveScenario::NoisyData),
701 ("Plateau Effect", LearningCurveScenario::PlateauEffect),
702 ];
703
704 let mut results = Vec::new();
705
706 let dummy_x = Array2::<f64>::zeros((100, 5));
708 let dummy_y = Array1::<f64>::zeros(100);
709
710 for (name, scenario) in scenarios.iter() {
711 let config = LearningCurveConfig {
712 scenario: *scenario,
713 cv_folds: 5,
714 base_performance: 0.75,
715 noise_level: 0.03,
716 add_cv_variance: true,
717 };
718
719 let visualizer =
720 learning_curve_realistic(&dummy_x, &dummy_y, train_sizes, config, scoring_str.clone())?;
721
722 results.push((name.to_string(), visualizer));
723 }
724
725 Ok(results)
726}