scirs2_metrics/visualization/
learning_curve.rs

1//! Learning curve visualization
2//!
3//! This module provides tools for visualizing learning curves, which show model performance
4//! as a function of training set size.
5
6use std::error::Error;
7
8use super::{MetricVisualizer, PlotType, VisualizationData, VisualizationMetadata};
9use crate::error::{MetricsError, Result};
10use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix1, Ix2};
11use scirs2_core::random::prelude::*;
12
13/// Learning curve data
14///
15/// This struct holds the data for a learning curve.
16#[derive(Debug, Clone)]
17pub struct LearningCurveData {
18    /// Training set sizes
19    pub train_sizes: Vec<usize>,
20    /// Training scores for each training set size
21    pub train_scores: Vec<Vec<f64>>,
22    /// Validation scores for each training set size
23    pub validation_scores: Vec<Vec<f64>>,
24}
25
26/// Learning curve visualizer
27///
28/// This struct provides methods for visualizing learning curves.
29#[derive(Debug, Clone)]
30pub struct LearningCurveVisualizer {
31    /// Learning curve data
32    data: LearningCurveData,
33    /// Title for the plot
34    title: String,
35    /// Whether to show standard deviation
36    show_std: bool,
37    /// Scoring metric name
38    scoring: String,
39}
40
41impl LearningCurveVisualizer {
42    /// Create a new LearningCurveVisualizer
43    ///
44    /// # Arguments
45    ///
46    /// * `data` - Learning curve data
47    ///
48    /// # Returns
49    ///
50    /// * A new LearningCurveVisualizer
51    pub fn new(data: LearningCurveData) -> Self {
52        LearningCurveVisualizer {
53            data,
54            title: "Learning Curve".to_string(),
55            show_std: true,
56            scoring: "Score".to_string(),
57        }
58    }
59
60    /// Set the title for the plot
61    ///
62    /// # Arguments
63    ///
64    /// * `title` - Title for the plot
65    ///
66    /// # Returns
67    ///
68    /// * Self for method chaining
69    pub fn with_title(mut self, title: String) -> Self {
70        self.title = title;
71        self
72    }
73
74    /// Set whether to show standard deviation
75    ///
76    /// # Arguments
77    ///
78    /// * `show_std` - Whether to show standard deviation
79    ///
80    /// # Returns
81    ///
82    /// * Self for method chaining
83    pub fn with_show_std(mut self, showstd: bool) -> Self {
84        self.show_std = showstd;
85        self
86    }
87
88    /// Set the scoring metric name
89    ///
90    /// # Arguments
91    ///
92    /// * `scoring` - Scoring metric name
93    ///
94    /// # Returns
95    ///
96    /// * Self for method chaining
97    pub fn with_scoring(mut self, scoring: String) -> Self {
98        self.scoring = scoring;
99        self
100    }
101
102    /// Compute mean and standard deviation of scores
103    ///
104    /// # Arguments
105    ///
106    /// * `scores` - Scores for each training set size
107    ///
108    /// # Returns
109    ///
110    /// * (mean_scores, std_scores)
111    fn compute_statistics(&self, scores: &[Vec<f64>]) -> (Vec<f64>, Vec<f64>) {
112        let n = scores.len();
113        let mut mean_scores = Vec::with_capacity(n);
114        let mut std_scores = Vec::with_capacity(n);
115
116        for fold_scores in scores {
117            // Compute mean
118            let mean = fold_scores.iter().sum::<f64>() / fold_scores.len() as f64;
119            mean_scores.push(mean);
120
121            // Compute standard deviation
122            let variance = fold_scores.iter().map(|&s| (s - mean).powi(2)).sum::<f64>()
123                / fold_scores.len() as f64;
124            std_scores.push(variance.sqrt());
125        }
126
127        (mean_scores, std_scores)
128    }
129}
130
131impl MetricVisualizer for LearningCurveVisualizer {
132    fn prepare_data(&self) -> std::result::Result<VisualizationData, Box<dyn Error>> {
133        // Compute statistics for train and validation scores
134        let (train_mean, train_std) = self.compute_statistics(&self.data.train_scores);
135        let (val_mean, val_std) = self.compute_statistics(&self.data.validation_scores);
136
137        // Convert train_sizes to f64 for plotting
138        let train_sizes: Vec<f64> = self.data.train_sizes.iter().map(|&s| s as f64).collect();
139
140        // Prepare data for visualization
141        let mut x = Vec::new();
142        let mut y = Vec::new();
143
144        // Add training scores
145        x.extend_from_slice(&train_sizes);
146        y.extend_from_slice(&train_mean);
147
148        // Add validation scores
149        x.extend_from_slice(&train_sizes);
150        y.extend_from_slice(&val_mean);
151
152        // Prepare series names
153        let mut series_names = vec!["Training score".to_string(), "Validation score".to_string()];
154
155        // Add standard deviation series if requested
156        if self.show_std {
157            // Add upper and lower bounds for training scores
158            x.extend_from_slice(&train_sizes);
159            x.extend_from_slice(&train_sizes);
160
161            let train_upper: Vec<f64> = train_mean
162                .iter()
163                .zip(train_std.iter())
164                .map(|(&m, &s)| m + s)
165                .collect();
166
167            let train_lower: Vec<f64> = train_mean
168                .iter()
169                .zip(train_std.iter())
170                .map(|(&m, &s)| m - s)
171                .collect();
172
173            y.extend_from_slice(&train_upper);
174            y.extend_from_slice(&train_lower);
175
176            // Add upper and lower bounds for validation scores
177            x.extend_from_slice(&train_sizes);
178            x.extend_from_slice(&train_sizes);
179
180            let val_upper: Vec<f64> = val_mean
181                .iter()
182                .zip(val_std.iter())
183                .map(|(&m, &s)| m + s)
184                .collect();
185
186            let val_lower: Vec<f64> = val_mean
187                .iter()
188                .zip(val_std.iter())
189                .map(|(&m, &s)| m - s)
190                .collect();
191
192            y.extend_from_slice(&val_upper);
193            y.extend_from_slice(&val_lower);
194
195            // Add series names for standard deviation bounds
196            series_names.push("Training score +/- std".to_string());
197            series_names.push("Training score +/- std".to_string());
198            series_names.push("Validation score +/- std".to_string());
199            series_names.push("Validation score +/- std".to_string());
200        }
201
202        Ok(VisualizationData {
203            x,
204            y,
205            z: None,
206            series_names: Some(series_names),
207            x_labels: None,
208            y_labels: None,
209            auxiliary_data: std::collections::HashMap::new(),
210            auxiliary_metadata: std::collections::HashMap::new(),
211            series: std::collections::HashMap::new(),
212        })
213    }
214
215    fn get_metadata(&self) -> VisualizationMetadata {
216        VisualizationMetadata {
217            title: self.title.clone(),
218            x_label: "Training examples".to_string(),
219            y_label: self.scoring.clone(),
220            plot_type: PlotType::Line,
221            description: Some(
222                "Learning curve showing model performance as a function of training set size"
223                    .to_string(),
224            ),
225        }
226    }
227}
228
229/// Create a learning curve visualization
230///
231/// # Arguments
232///
233/// * `train_sizes` - Training set sizes
234/// * `train_scores` - Training scores for each training set size
235/// * `validation_scores` - Validation scores for each training set size
236/// * `scoring` - Scoring metric name
237///
238/// # Returns
239///
240/// * A LearningCurveVisualizer
241#[allow(dead_code)]
242pub fn learning_curve_visualization(
243    train_sizes: Vec<usize>,
244    train_scores: Vec<Vec<f64>>,
245    validation_scores: Vec<Vec<f64>>,
246    scoring: impl Into<String>,
247) -> Result<LearningCurveVisualizer> {
248    // Validate inputs
249    if train_sizes.is_empty() || train_scores.is_empty() || validation_scores.is_empty() {
250        return Err(MetricsError::InvalidInput(
251            "Learning curve data cannot be empty".to_string(),
252        ));
253    }
254
255    if train_scores.len() != train_sizes.len() || validation_scores.len() != train_sizes.len() {
256        return Err(MetricsError::InvalidInput(
257            "Number of train/validation _scores must match number of training _sizes".to_string(),
258        ));
259    }
260
261    let data = LearningCurveData {
262        train_sizes,
263        train_scores,
264        validation_scores,
265    };
266
267    let scoring_string = scoring.into();
268    Ok(LearningCurveVisualizer::new(data).with_scoring(scoring_string))
269}
270
271/// Learning curve scenario types for realistic simulation
272#[derive(Debug, Clone, Copy)]
273pub enum LearningCurveScenario {
274    /// Well-fitted model with good generalization
275    WellFitted,
276    /// High bias scenario (underfitting)
277    HighBias,
278    /// High variance scenario (overfitting)
279    HighVariance,
280    /// Noisy data scenario with irregular patterns
281    NoisyData,
282    /// Learning plateau scenario where more data doesn't help much
283    PlateauEffect,
284}
285
286/// Configuration for learning curve generation
287#[derive(Debug, Clone)]
288pub struct LearningCurveConfig {
289    /// The learning scenario to simulate
290    pub scenario: LearningCurveScenario,
291    /// Number of cross-validation folds
292    pub cv_folds: usize,
293    /// Base performance level (0.0 to 1.0)
294    pub base_performance: f64,
295    /// Noise level in the scores (0.0 to 1.0)
296    pub noise_level: f64,
297    /// Whether to add realistic variance across folds
298    pub add_cv_variance: bool,
299}
300
301impl Default for LearningCurveConfig {
302    fn default() -> Self {
303        Self {
304            scenario: LearningCurveScenario::WellFitted,
305            cv_folds: 5,
306            base_performance: 0.75,
307            noise_level: 0.05,
308            add_cv_variance: true,
309        }
310    }
311}
312
313/// Generate a realistic learning curve based on learning theory principles
314///
315/// This function simulates realistic learning curves that follow common patterns
316/// observed in machine learning, including bias-variance decomposition effects.
317/// Since this is a metrics library without model training capabilities, it
318/// generates theoretically sound learning curves for visualization and analysis.
319///
320/// # Arguments
321///
322/// * `_x` - Feature matrix (used for determining data characteristics)
323/// * `_y` - Target values (used for determining problem characteristics)
324/// * `train_sizes` - Training set sizes to evaluate
325/// * `config` - Configuration for learning curve generation
326/// * `scoring` - Scoring metric to use
327///
328/// # Returns
329///
330/// * A LearningCurveVisualizer with realistic learning curves
331#[allow(dead_code)]
332pub fn learning_curve_realistic<T, S1, S2>(
333    _x: &ArrayBase<S1, Ix2>,
334    _y: &ArrayBase<S2, Ix1>,
335    train_sizes: &[usize],
336    config: LearningCurveConfig,
337    scoring: impl Into<String>,
338) -> Result<LearningCurveVisualizer>
339where
340    T: Clone + 'static,
341    S1: Data<Elem = T>,
342    S2: Data<Elem = T>,
343{
344    use scirs2_core::random::Rng;
345    let mut rng = scirs2_core::random::rng();
346
347    let n_sizes = train_sizes.len();
348    let mut train_scores = Vec::with_capacity(n_sizes);
349    let mut validation_scores = Vec::with_capacity(n_sizes);
350
351    for (i, &_size) in train_sizes.iter().enumerate() {
352        let progress = i as f64 / n_sizes.max(1) as f64;
353
354        let (base_train_score, base_val_score) = match config.scenario {
355            LearningCurveScenario::WellFitted => {
356                // Training score starts high and plateaus
357                let train_score = config.base_performance + 0.15 * progress.powf(0.3);
358                // Validation score starts lower but converges towards training score
359                let val_score = config.base_performance - 0.1 + 0.2 * progress.powf(0.5);
360                (train_score.min(0.95), val_score.min(train_score - 0.02))
361            }
362            LearningCurveScenario::HighBias => {
363                // Both training and validation scores are low and plateau early
364                let train_score = config.base_performance - 0.15 + 0.1 * progress.powf(0.8);
365                let val_score = train_score - 0.05 + 0.03 * progress;
366                (train_score.min(0.7), val_score.min(train_score))
367            }
368            LearningCurveScenario::HighVariance => {
369                // Large gap between training and validation scores
370                let train_score = config.base_performance + 0.2 * progress.powf(0.2);
371                let val_score = config.base_performance - 0.2 + 0.15 * progress.powf(0.7);
372                (train_score.min(0.98), val_score.min(train_score - 0.15))
373            }
374            LearningCurveScenario::NoisyData => {
375                // Irregular patterns with higher variance
376                let noise_factor = 0.1 * (progress * 10.0).sin();
377                let train_score = config.base_performance + 0.1 * progress + noise_factor;
378                let val_score =
379                    config.base_performance - 0.05 + 0.12 * progress + noise_factor * 0.5;
380                (train_score.min(0.9), val_score.min(train_score))
381            }
382            LearningCurveScenario::PlateauEffect => {
383                // Rapid initial improvement then plateau
384                let plateau_factor = 1.0 - (-5.0 * progress).exp();
385                let train_score = config.base_performance + 0.15 * plateau_factor;
386                let val_score = config.base_performance - 0.08 + 0.18 * plateau_factor;
387                (train_score, val_score.min(train_score - 0.01))
388            }
389        };
390
391        // Generate scores for each CV fold
392        let fold_variance = if config.add_cv_variance {
393            config.noise_level
394        } else {
395            0.0
396        };
397
398        let train_fold_scores: Vec<f64> = (0..config.cv_folds)
399            .map(|_| {
400                let noise = rng.random_range(-fold_variance..fold_variance);
401                (base_train_score + noise).clamp(0.0, 1.0)
402            })
403            .collect();
404
405        let val_fold_scores: Vec<f64> = (0..config.cv_folds)
406            .map(|_| {
407                let noise = rng.random_range(-fold_variance * 1.5..fold_variance * 1.5);
408                (base_val_score + noise).clamp(0.0, 1.0)
409            })
410            .collect();
411
412        train_scores.push(train_fold_scores);
413        validation_scores.push(val_fold_scores);
414    }
415
416    learning_curve_visualization(
417        train_sizes.to_vec(),
418        train_scores,
419        validation_scores,
420        scoring,
421    )
422}
423
424/// Generate a learning curve with real model evaluation (Advanced mode)
425///
426/// This function provides a sophisticated interface for learning curve generation
427/// that actually trains and evaluates models on different training set sizes.
428///
429/// # Arguments
430///
431/// * `X` - Feature matrix
432/// * `y` - Target values
433/// * `model` - Model to evaluate (now properly utilized)
434/// * `train_sizes` - Training set sizes to evaluate
435/// * `cv` - Number of cross-validation folds
436/// * `scoring` - Scoring metric to use
437///
438/// # Returns
439///
440/// * A LearningCurveVisualizer with real performance curves
441#[allow(dead_code)]
442pub fn learning_curve<T, S1, S2>(
443    x: &ArrayBase<S1, Ix2>,
444    y: &ArrayBase<S2, Ix1>,
445    model: &impl ModelEvaluator<T>,
446    train_sizes: &[usize],
447    cv: usize,
448    scoring: impl Into<String>,
449) -> Result<LearningCurveVisualizer>
450where
451    T: Clone
452        + 'static
453        + scirs2_core::numeric::Float
454        + Send
455        + Sync
456        + std::fmt::Debug
457        + std::ops::Sub<Output = T>,
458    S1: Data<Elem = T>,
459    S2: Data<Elem = T>,
460    for<'a> &'a T: std::ops::Sub<&'a T, Output = T>,
461{
462    let scoring_str = scoring.into();
463
464    // Validate inputs
465    if x.nrows() != y.len() {
466        return Err(MetricsError::InvalidInput(
467            "Feature matrix and target vector must have same number of samples".to_string(),
468        ));
469    }
470
471    if train_sizes.is_empty() {
472        return Err(MetricsError::InvalidInput(
473            "Training _sizes cannot be empty".to_string(),
474        ));
475    }
476
477    let max_size = train_sizes.iter().max().unwrap();
478    if *max_size > x.nrows() {
479        return Err(MetricsError::InvalidInput(format!(
480            "Maximum training size ({}) exceeds available samples ({})",
481            max_size,
482            x.nrows()
483        )));
484    }
485
486    // Generate actual learning curves using cross-validation
487    let mut train_scores = Vec::new();
488    let mut validation_scores = Vec::new();
489
490    use scirs2_core::simd_ops::SimdUnifiedOps;
491    let mut rng = scirs2_core::random::rng();
492
493    // Create cross-validation folds
494    let fold_size = x.nrows() / cv;
495    let mut indices: Vec<usize> = (0..x.nrows()).collect();
496
497    for &size in train_sizes {
498        let mut train_fold_scores = Vec::new();
499        let mut val_fold_scores = Vec::new();
500
501        // Perform cross-validation for this training size
502        for fold in 0..cv {
503            // Shuffle indices for this fold
504            for i in 0..indices.len() {
505                let j = rng.random_range(0..indices.len());
506                indices.swap(i, j);
507            }
508
509            // Split data for this fold
510            let val_start = fold * fold_size;
511            let val_end = std::cmp::min((fold + 1) * fold_size, x.nrows());
512
513            let mut train_indices = Vec::new();
514            let mut val_indices = Vec::new();
515
516            for (i, &idx) in indices.iter().enumerate() {
517                if i >= val_start && i < val_end {
518                    val_indices.push(idx);
519                } else if train_indices.len() < size {
520                    train_indices.push(idx);
521                }
522            }
523
524            // Create training and validation sets
525            let train_x = extract_rows(x, &train_indices);
526            let train_y = extract_elements(y, &train_indices);
527            let val_x = extract_rows(x, &val_indices);
528            let val_y = extract_elements(y, &val_indices);
529
530            // Train model and evaluate
531            let trained_model = model.fit(&train_x, &train_y)?;
532
533            // Evaluate on training set
534            let train_pred = trained_model.predict(&train_x)?;
535            let train_score = evaluate_predictions(&train_y, &train_pred, &scoring_str)?;
536            train_fold_scores.push(train_score);
537
538            // Evaluate on validation set
539            let val_pred = trained_model.predict(&val_x)?;
540            let val_score = evaluate_predictions(&val_y, &val_pred, &scoring_str)?;
541            val_fold_scores.push(val_score);
542        }
543
544        train_scores.push(train_fold_scores);
545        validation_scores.push(val_fold_scores);
546    }
547
548    learning_curve_visualization(
549        train_sizes.to_vec(),
550        train_scores,
551        validation_scores,
552        scoring_str,
553    )
554}
555
556/// Trait for models that can be evaluated in learning curves
557pub trait ModelEvaluator<T> {
558    type TrainedModel: ModelPredictor<T>;
559
560    fn fit(&self, x: &Array2<T>, y: &Array1<T>) -> Result<Self::TrainedModel>;
561}
562
563/// Trait for trained models that can make predictions
564pub trait ModelPredictor<T> {
565    fn predict(&self, x: &Array2<T>) -> Result<Array1<T>>;
566}
567
568/// Extract specific rows from a 2D array
569#[allow(dead_code)]
570fn extract_rows<T, S>(arr: &ArrayBase<S, Ix2>, indices: &[usize]) -> Array2<T>
571where
572    T: Clone + scirs2_core::numeric::Zero,
573    S: Data<Elem = T>,
574{
575    let mut result = Array2::zeros((indices.len(), arr.ncols()));
576    for (i, &idx) in indices.iter().enumerate() {
577        result.row_mut(i).assign(&arr.row(idx));
578    }
579    result
580}
581
582/// Extract specific elements from a 1D array
583#[allow(dead_code)]
584fn extract_elements<T, S>(arr: &ArrayBase<S, Ix1>, indices: &[usize]) -> Array1<T>
585where
586    T: Clone + scirs2_core::numeric::Zero,
587    S: Data<Elem = T>,
588{
589    let mut result = Array1::zeros(indices.len());
590    for (i, &idx) in indices.iter().enumerate() {
591        result[i] = arr[idx].clone();
592    }
593    result
594}
595
596/// Evaluate predictions using the specified scoring metric
597#[allow(dead_code)]
598fn evaluate_predictions<T>(y_true: &Array1<T>, ypred: &Array1<T>, scoring: &str) -> Result<f64>
599where
600    T: Clone
601        + scirs2_core::numeric::Float
602        + Send
603        + Sync
604        + std::fmt::Debug
605        + std::ops::Sub<Output = T>,
606    for<'a> &'a T: std::ops::Sub<&'a T, Output = T>,
607{
608    match scoring.to_lowercase().as_str() {
609        "accuracy" => {
610            // For classification: count exact matches
611            let correct = y_true
612                .iter()
613                .zip(ypred.iter())
614                .filter(|(t, p)| (*t - *p).abs() < T::from(0.5).unwrap())
615                .count();
616            Ok(correct as f64 / y_true.len() as f64)
617        }
618        "mse" | "mean_squared_error" => {
619            // Mean squared error
620            let mse = y_true
621                .iter()
622                .zip(ypred.iter())
623                .map(|(t, p)| (*t - *p) * (*t - *p))
624                .fold(T::zero(), |acc, x| acc + x)
625                / T::from(y_true.len()).unwrap();
626            Ok(mse.to_f64().unwrap_or(0.0))
627        }
628        "mae" | "mean_absolute_error" => {
629            // Mean absolute error
630            let mae = y_true
631                .iter()
632                .zip(ypred.iter())
633                .map(|(t, p)| (*t - *p).abs())
634                .fold(T::zero(), |acc, x| acc + x)
635                / T::from(y_true.len()).unwrap();
636            Ok(mae.to_f64().unwrap_or(0.0))
637        }
638        "r2" | "r2_score" => {
639            // R² score
640            let mean_true = y_true.iter().cloned().fold(T::zero(), |acc, x| acc + x)
641                / T::from(y_true.len()).unwrap();
642
643            let ss_tot = y_true
644                .iter()
645                .map(|&t| (t - mean_true) * (t - mean_true))
646                .fold(T::zero(), |acc, x| acc + x);
647
648            let ss_res = y_true
649                .iter()
650                .zip(ypred.iter())
651                .map(|(&t, &p)| (t - p) * (t - p))
652                .fold(T::zero(), |acc, x| acc + x);
653
654            if ss_tot == T::zero() {
655                Ok(0.0)
656            } else {
657                let r2 = T::one() - ss_res / ss_tot;
658                Ok(r2.to_f64().unwrap_or(0.0))
659            }
660        }
661        _ => {
662            // Default to MSE for unknown metrics
663            let mse = y_true
664                .iter()
665                .zip(ypred.iter())
666                .map(|(t, p)| (*t - *p) * (*t - *p))
667                .fold(T::zero(), |acc, x| acc + x)
668                / T::from(y_true.len()).unwrap();
669            Ok(mse.to_f64().unwrap_or(0.0))
670        }
671    }
672}
673
674/// Generate learning curves for different scenarios for comparison
675///
676/// This function generates multiple learning curves showing different learning
677/// scenarios, useful for educational purposes and understanding model behavior.
678///
679/// # Arguments
680///
681/// * `train_sizes` - Training set sizes to evaluate
682/// * `scoring` - Scoring metric to use
683///
684/// # Returns
685///
686/// * A vector of LearningCurveVisualizer instances for each scenario
687#[allow(dead_code)]
688pub fn learning_curve_scenarios(
689    train_sizes: &[usize],
690    scoring: impl Into<String>,
691) -> Result<Vec<(String, LearningCurveVisualizer)>> {
692    let scoring_str = scoring.into();
693    let scenarios = [
694        ("Well Fitted", LearningCurveScenario::WellFitted),
695        ("High Bias (Underfitting)", LearningCurveScenario::HighBias),
696        (
697            "High Variance (Overfitting)",
698            LearningCurveScenario::HighVariance,
699        ),
700        ("Noisy Data", LearningCurveScenario::NoisyData),
701        ("Plateau Effect", LearningCurveScenario::PlateauEffect),
702    ];
703
704    let mut results = Vec::new();
705
706    // Create dummy data for the function signature
707    let dummy_x = Array2::<f64>::zeros((100, 5));
708    let dummy_y = Array1::<f64>::zeros(100);
709
710    for (name, scenario) in scenarios.iter() {
711        let config = LearningCurveConfig {
712            scenario: *scenario,
713            cv_folds: 5,
714            base_performance: 0.75,
715            noise_level: 0.03,
716            add_cv_variance: true,
717        };
718
719        let visualizer =
720            learning_curve_realistic(&dummy_x, &dummy_y, train_sizes, config, scoring_str.clone())?;
721
722        results.push((name.to_string(), visualizer));
723    }
724
725    Ok(results)
726}