sklears_model_selection/
validation.rs

1//! Model validation utilities
2
3use crate::{CrossValidator, ParameterValue};
4use scirs2_core::ndarray::{Array1, Array2};
5// use scirs2_core::SliceRandomExt;
6use sklears_core::{
7    error::Result,
8    prelude::{Predict, SklearsError},
9    traits::Fit,
10    traits::Score,
11    types::Float,
12};
13use sklears_metrics::{
14    classification::accuracy_score, get_scorer, regression::mean_squared_error, Scorer,
15};
16use std::collections::HashMap;
17
18/// Helper function for scoring that handles both regression and classification
19fn compute_score_for_regression_val(
20    metric_name: &str,
21    y_true: &Array1<f64>,
22    y_pred: &Array1<f64>,
23) -> Result<f64> {
24    match metric_name {
25        "neg_mean_squared_error" => Ok(-mean_squared_error(y_true, y_pred)?),
26        "mean_squared_error" => Ok(mean_squared_error(y_true, y_pred)?),
27        _ => {
28            // For unsupported metrics, return a default score
29            Err(sklears_core::error::SklearsError::InvalidInput(format!(
30                "Metric '{}' not supported for regression",
31                metric_name
32            )))
33        }
34    }
35}
36
37/// Helper function for scoring classification data
38fn compute_score_for_classification_val(
39    metric_name: &str,
40    y_true: &Array1<i32>,
41    y_pred: &Array1<i32>,
42) -> Result<f64> {
43    match metric_name {
44        "accuracy" => Ok(accuracy_score(y_true, y_pred)?),
45        _ => {
46            let scorer = get_scorer(metric_name)?;
47            scorer.score(y_true.as_slice().unwrap(), y_pred.as_slice().unwrap())
48        }
49    }
50}
51
52/// Scoring method for cross-validation
53#[derive(Debug, Clone)]
54pub enum Scoring {
55    /// Use the estimator's built-in score method
56    EstimatorScore,
57    /// Use a predefined scorer by name
58    Metric(String),
59    /// Use a specific scorer configuration
60    Scorer(Scorer),
61    /// Use multiple scoring metrics
62    MultiMetric(Vec<String>),
63    /// Use a custom scoring function
64    Custom(fn(&Array1<Float>, &Array1<Float>) -> Result<f64>),
65}
66
67/// Enhanced scoring result that can handle multiple metrics
68#[derive(Debug, Clone)]
69pub enum ScoreResult {
70    /// Single score value
71    Single(f64),
72    /// Multiple score values with metric names
73    Multiple(HashMap<String, f64>),
74}
75
76impl ScoreResult {
77    /// Get a single score (first score if multiple)
78    pub fn as_single(&self) -> f64 {
79        match self {
80            ScoreResult::Single(score) => *score,
81            ScoreResult::Multiple(scores) => scores.values().next().copied().unwrap_or(0.0),
82        }
83    }
84
85    /// Get scores as a map
86    pub fn as_multiple(&self) -> HashMap<String, f64> {
87        match self {
88            ScoreResult::Single(score) => {
89                let mut map = HashMap::new();
90                map.insert("score".to_string(), *score);
91                map
92            }
93            ScoreResult::Multiple(scores) => scores.clone(),
94        }
95    }
96}
97
98/// Evaluate metric(s) by cross-validation and also record fit/score times
99#[allow(clippy::too_many_arguments)]
100pub fn cross_validate<E, F, C>(
101    estimator: E,
102    x: &Array2<Float>,
103    y: &Array1<Float>,
104    cv: &C,
105    scoring: Scoring,
106    return_train_score: bool,
107    return_estimator: bool,
108    _n_jobs: Option<usize>,
109) -> Result<CrossValidateResult<F>>
110where
111    E: Clone,
112    F: Clone,
113    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
114    F: Predict<Array2<Float>, Array1<Float>>,
115    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
116    C: CrossValidator,
117{
118    // Note: This assumes KFold or other CV that doesn't need y
119    // For StratifiedKFold, you would need to pass integer labels
120    let splits = cv.split(x.nrows(), None);
121    let n_splits = splits.len();
122
123    let mut test_scores = Vec::with_capacity(n_splits);
124    let mut train_scores = if return_train_score {
125        Some(Vec::with_capacity(n_splits))
126    } else {
127        None
128    };
129    let mut fit_times = Vec::with_capacity(n_splits);
130    let mut score_times = Vec::with_capacity(n_splits);
131    let mut estimators = if return_estimator {
132        Some(Vec::with_capacity(n_splits))
133    } else {
134        None
135    };
136
137    // Process each fold
138    for (train_idx, test_idx) in splits {
139        // Extract train and test data
140        let x_train = x.select(scirs2_core::ndarray::Axis(0), &train_idx);
141        let y_train = y.select(scirs2_core::ndarray::Axis(0), &train_idx);
142        let x_test = x.select(scirs2_core::ndarray::Axis(0), &test_idx);
143        let y_test = y.select(scirs2_core::ndarray::Axis(0), &test_idx);
144
145        // Fit the estimator
146        let start = std::time::Instant::now();
147        let fitted = estimator.clone().fit(&x_train, &y_train)?;
148        let fit_time = start.elapsed().as_secs_f64();
149        fit_times.push(fit_time);
150
151        // Score on test set
152        let start = std::time::Instant::now();
153        let test_score = match &scoring {
154            Scoring::EstimatorScore => fitted.score(&x_test, &y_test)?,
155            Scoring::Custom(func) => {
156                let y_pred = fitted.predict(&x_test)?;
157                func(&y_test.to_owned(), &y_pred)?
158            }
159            Scoring::Metric(metric_name) => {
160                let y_pred = fitted.predict(&x_test)?;
161                // Determine if this is classification or regression based on the data type
162                if y_test.iter().all(|&x| x.fract() == 0.0) {
163                    // Integer-like values, likely classification
164                    let y_true_int: Array1<i32> = y_test.mapv(|x| x as i32);
165                    let y_pred_int: Array1<i32> = y_pred.mapv(|x| x as i32);
166                    compute_score_for_classification_val(metric_name, &y_true_int, &y_pred_int)?
167                } else {
168                    // Float values, likely regression
169                    compute_score_for_regression_val(metric_name, &y_test, &y_pred)?
170                }
171            }
172            Scoring::Scorer(scorer) => {
173                let y_pred = fitted.predict(&x_test)?;
174                scorer.score_float(y_test.as_slice().unwrap(), y_pred.as_slice().unwrap())?
175            }
176            Scoring::MultiMetric(_) => {
177                return Err(SklearsError::InvalidInput(
178                    "MultiMetric scoring not supported in single metric context".to_string(),
179                ));
180            }
181        };
182        let score_time = start.elapsed().as_secs_f64();
183        score_times.push(score_time);
184        test_scores.push(test_score);
185
186        // Score on train set if requested
187        if let Some(ref mut train_scores) = train_scores {
188            let train_score = match &scoring {
189                Scoring::EstimatorScore => fitted.score(&x_train, &y_train)?,
190                Scoring::Custom(func) => {
191                    let y_pred = fitted.predict(&x_train)?;
192                    func(&y_train.to_owned(), &y_pred)?
193                }
194                Scoring::Metric(metric_name) => {
195                    let y_pred = fitted.predict(&x_train)?;
196                    // Determine if this is classification or regression based on the data type
197                    if y_train.iter().all(|&x| x.fract() == 0.0) {
198                        // Integer-like values, likely classification
199                        let y_true_int: Array1<i32> = y_train.mapv(|x| x as i32);
200                        let y_pred_int: Array1<i32> = y_pred.mapv(|x| x as i32);
201                        compute_score_for_classification_val(metric_name, &y_true_int, &y_pred_int)?
202                    } else {
203                        // Float values, likely regression
204                        compute_score_for_regression_val(metric_name, &y_train, &y_pred)?
205                    }
206                }
207                Scoring::Scorer(scorer) => {
208                    let y_pred = fitted.predict(&x_train)?;
209                    scorer.score_float(y_train.as_slice().unwrap(), y_pred.as_slice().unwrap())?
210                }
211                Scoring::MultiMetric(_metrics) => {
212                    // For multi-metric, just use the first metric for now
213                    fitted.score(&x_train, &y_train)?
214                }
215            };
216            train_scores.push(train_score);
217        }
218
219        // Store estimator if requested
220        if let Some(ref mut estimators) = estimators {
221            estimators.push(fitted);
222        }
223    }
224
225    Ok(CrossValidateResult {
226        test_scores: Array1::from_vec(test_scores),
227        train_scores: train_scores.map(Array1::from_vec),
228        fit_times: Array1::from_vec(fit_times),
229        score_times: Array1::from_vec(score_times),
230        estimators,
231    })
232}
233
234/// Result of cross_validate
235#[derive(Debug, Clone)]
236pub struct CrossValidateResult<F> {
237    pub test_scores: Array1<f64>,
238    pub train_scores: Option<Array1<f64>>,
239    pub fit_times: Array1<f64>,
240    pub score_times: Array1<f64>,
241    pub estimators: Option<Vec<F>>,
242}
243
244/// Evaluate a score by cross-validation
245pub fn cross_val_score<E, F, C>(
246    estimator: E,
247    x: &Array2<Float>,
248    y: &Array1<Float>,
249    cv: &C,
250    scoring: Option<Scoring>,
251    n_jobs: Option<usize>,
252) -> Result<Array1<f64>>
253where
254    E: Clone,
255    F: Clone,
256    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
257    F: Predict<Array2<Float>, Array1<Float>>,
258    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
259    C: CrossValidator,
260{
261    let scoring = scoring.unwrap_or(Scoring::EstimatorScore);
262    let result = cross_validate(
263        estimator, x, y, cv, scoring, false, // return_train_score
264        false, // return_estimator
265        n_jobs,
266    )?;
267
268    Ok(result.test_scores)
269}
270
271/// Generate cross-validated estimates for each input data point
272pub fn cross_val_predict<E, F, C>(
273    estimator: E,
274    x: &Array2<Float>,
275    y: &Array1<Float>,
276    cv: &C,
277    _n_jobs: Option<usize>,
278) -> Result<Array1<Float>>
279where
280    E: Clone,
281    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
282    F: Predict<Array2<Float>, Array1<Float>>,
283    C: CrossValidator,
284{
285    // Note: This assumes KFold or other CV that doesn't need y
286    // For StratifiedKFold, you would need to pass integer labels
287    let splits = cv.split(x.nrows(), None);
288    let n_samples = x.nrows();
289
290    // Initialize predictions array
291    let mut predictions = Array1::<Float>::zeros(n_samples);
292
293    // Process each fold
294    for (train_idx, test_idx) in splits {
295        // Extract train and test data
296        let x_train = x.select(scirs2_core::ndarray::Axis(0), &train_idx);
297        let y_train = y.select(scirs2_core::ndarray::Axis(0), &train_idx);
298        let x_test = x.select(scirs2_core::ndarray::Axis(0), &test_idx);
299
300        // Fit and predict
301        let fitted = estimator.clone().fit(&x_train, &y_train)?;
302        let y_pred = fitted.predict(&x_test)?;
303
304        // Store predictions at the correct indices
305        for (i, &idx) in test_idx.iter().enumerate() {
306            predictions[idx] = y_pred[i];
307        }
308    }
309
310    Ok(predictions)
311}
312
313/// Learning curve results
314#[derive(Debug, Clone)]
315pub struct LearningCurveResult {
316    /// Training set sizes used
317    pub train_sizes: Array1<usize>,
318    /// Training scores for each size
319    pub train_scores: Array2<f64>,
320    /// Validation scores for each size
321    pub test_scores: Array2<f64>,
322    /// Mean training scores for each size
323    pub train_scores_mean: Array1<f64>,
324    /// Mean validation scores for each size
325    pub test_scores_mean: Array1<f64>,
326    /// Standard deviation of training scores for each size
327    pub train_scores_std: Array1<f64>,
328    /// Standard deviation of validation scores for each size
329    pub test_scores_std: Array1<f64>,
330    /// Lower confidence bound for training scores (mean - confidence_interval)
331    pub train_scores_lower: Array1<f64>,
332    /// Upper confidence bound for training scores (mean + confidence_interval)
333    pub train_scores_upper: Array1<f64>,
334    /// Lower confidence bound for validation scores (mean - confidence_interval)
335    pub test_scores_lower: Array1<f64>,
336    /// Upper confidence bound for validation scores (mean + confidence_interval)
337    pub test_scores_upper: Array1<f64>,
338}
339
340/// Compute learning curves for an estimator
341///
342/// Determines cross-validated training and test scores for different training
343/// set sizes. This is useful to find out if we suffer from bias vs variance
344/// when we add more data to the training set.
345///
346/// # Arguments
347/// * `estimator` - The estimator to evaluate
348/// * `x` - Training data features
349/// * `y` - Training data targets
350/// * `cv` - Cross-validation splitter
351/// * `train_sizes` - Relative or absolute numbers of training examples that will be used to generate the learning curve
352/// * `scoring` - Scoring method to use
353/// * `confidence_level` - Confidence level for confidence bands (default: 0.95 for 95% confidence interval)
354#[allow(clippy::too_many_arguments)]
355pub fn learning_curve<E, F, C>(
356    estimator: E,
357    x: &Array2<Float>,
358    y: &Array1<Float>,
359    cv: &C,
360    train_sizes: Option<Vec<f64>>,
361    scoring: Option<Scoring>,
362    confidence_level: Option<f64>,
363) -> Result<LearningCurveResult>
364where
365    E: Clone,
366    F: Clone,
367    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
368    F: Predict<Array2<Float>, Array1<Float>>,
369    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
370    C: CrossValidator,
371{
372    let n_samples = x.nrows();
373    let scoring = scoring.unwrap_or(Scoring::EstimatorScore);
374
375    // Default train sizes: 10%, 30%, 50%, 70%, 90%, 100%
376    let train_size_fractions = train_sizes.unwrap_or_else(|| vec![0.1, 0.3, 0.5, 0.7, 0.9, 1.0]);
377
378    // Convert fractions to actual sizes
379    let train_sizes_actual: Vec<usize> = train_size_fractions
380        .iter()
381        .map(|&frac| {
382            let size = (frac * n_samples as f64).round() as usize;
383            size.max(1).min(n_samples) // Ensure between 1 and n_samples
384        })
385        .collect();
386
387    let n_splits = cv.n_splits();
388    let n_train_sizes = train_sizes_actual.len();
389
390    let mut train_scores = Array2::<f64>::zeros((n_train_sizes, n_splits));
391    let mut test_scores = Array2::<f64>::zeros((n_train_sizes, n_splits));
392
393    // Get CV splits
394    let splits = cv.split(x.nrows(), None);
395
396    for (size_idx, &train_size) in train_sizes_actual.iter().enumerate() {
397        for (split_idx, (train_idx, test_idx)) in splits.iter().enumerate() {
398            // Limit training set to the desired size
399            let mut limited_train_idx = train_idx.clone();
400            if limited_train_idx.len() > train_size {
401                limited_train_idx.truncate(train_size);
402            }
403
404            // Extract data
405            let x_train = x.select(scirs2_core::ndarray::Axis(0), &limited_train_idx);
406            let y_train = y.select(scirs2_core::ndarray::Axis(0), &limited_train_idx);
407            let x_test = x.select(scirs2_core::ndarray::Axis(0), test_idx);
408            let y_test = y.select(scirs2_core::ndarray::Axis(0), test_idx);
409
410            // Fit estimator
411            let fitted = estimator.clone().fit(&x_train, &y_train)?;
412
413            // Score on training set
414            let train_score = match &scoring {
415                Scoring::EstimatorScore => fitted.score(&x_train, &y_train)?,
416                Scoring::Custom(func) => {
417                    let y_pred = fitted.predict(&x_train)?;
418                    func(&y_train.to_owned(), &y_pred)?
419                }
420                Scoring::Metric(metric_name) => {
421                    let y_pred = fitted.predict(&x_train)?;
422                    // Determine if this is classification or regression based on the data type
423                    if y_train.iter().all(|&x| x.fract() == 0.0) {
424                        // Integer-like values, likely classification
425                        let y_true_int: Array1<i32> = y_train.mapv(|x| x as i32);
426                        let y_pred_int: Array1<i32> = y_pred.mapv(|x| x as i32);
427                        compute_score_for_classification_val(metric_name, &y_true_int, &y_pred_int)?
428                    } else {
429                        // Float values, likely regression
430                        compute_score_for_regression_val(metric_name, &y_train, &y_pred)?
431                    }
432                }
433                Scoring::Scorer(scorer) => {
434                    let y_pred = fitted.predict(&x_train)?;
435                    scorer.score_float(y_train.as_slice().unwrap(), y_pred.as_slice().unwrap())?
436                }
437                Scoring::MultiMetric(_metrics) => {
438                    // For multi-metric, just use the first metric for now
439                    fitted.score(&x_train, &y_train)?
440                }
441            };
442            train_scores[[size_idx, split_idx]] = train_score;
443
444            // Score on test set
445            let test_score = match &scoring {
446                Scoring::EstimatorScore => fitted.score(&x_test, &y_test)?,
447                Scoring::Custom(func) => {
448                    let y_pred = fitted.predict(&x_test)?;
449                    func(&y_test.to_owned(), &y_pred)?
450                }
451                Scoring::Metric(metric_name) => {
452                    let y_pred = fitted.predict(&x_test)?;
453                    // Determine if this is classification or regression based on the data type
454                    if y_test.iter().all(|&x| x.fract() == 0.0) {
455                        // Integer-like values, likely classification
456                        let y_true_int: Array1<i32> = y_test.mapv(|x| x as i32);
457                        let y_pred_int: Array1<i32> = y_pred.mapv(|x| x as i32);
458                        compute_score_for_classification_val(metric_name, &y_true_int, &y_pred_int)?
459                    } else {
460                        // Float values, likely regression
461                        compute_score_for_regression_val(metric_name, &y_test, &y_pred)?
462                    }
463                }
464                Scoring::Scorer(scorer) => {
465                    let y_pred = fitted.predict(&x_test)?;
466                    scorer.score_float(y_test.as_slice().unwrap(), y_pred.as_slice().unwrap())?
467                }
468                Scoring::MultiMetric(_metrics) => {
469                    // For multi-metric, just use the first metric for now
470                    fitted.score(&x_test, &y_test)?
471                }
472            };
473            test_scores[[size_idx, split_idx]] = test_score;
474        }
475    }
476
477    // Calculate confidence level (default 95%)
478    let confidence = confidence_level.unwrap_or(0.95);
479    let _alpha = 1.0 - confidence;
480    let z_score = 1.96; // Approximate 95% confidence interval
481
482    // Calculate statistics for each training size
483    let mut train_scores_mean = Array1::<f64>::zeros(n_train_sizes);
484    let mut test_scores_mean = Array1::<f64>::zeros(n_train_sizes);
485    let mut train_scores_std = Array1::<f64>::zeros(n_train_sizes);
486    let mut test_scores_std = Array1::<f64>::zeros(n_train_sizes);
487    let mut train_scores_lower = Array1::<f64>::zeros(n_train_sizes);
488    let mut train_scores_upper = Array1::<f64>::zeros(n_train_sizes);
489    let mut test_scores_lower = Array1::<f64>::zeros(n_train_sizes);
490    let mut test_scores_upper = Array1::<f64>::zeros(n_train_sizes);
491
492    for size_idx in 0..n_train_sizes {
493        // Extract scores for this training size across all CV folds
494        let train_scores_for_size: Vec<f64> = (0..n_splits)
495            .map(|split_idx| train_scores[[size_idx, split_idx]])
496            .collect();
497        let test_scores_for_size: Vec<f64> = (0..n_splits)
498            .map(|split_idx| test_scores[[size_idx, split_idx]])
499            .collect();
500
501        // Calculate mean and std for training scores
502        let train_mean = train_scores_for_size.iter().sum::<f64>() / n_splits as f64;
503        let train_variance = train_scores_for_size
504            .iter()
505            .map(|&x| (x - train_mean).powi(2))
506            .sum::<f64>()
507            / (n_splits - 1).max(1) as f64;
508        let train_std = train_variance.sqrt();
509        let train_sem = train_std / (n_splits as f64).sqrt(); // Standard error of the mean
510
511        // Calculate mean and std for test scores
512        let test_mean = test_scores_for_size.iter().sum::<f64>() / n_splits as f64;
513        let test_variance = test_scores_for_size
514            .iter()
515            .map(|&x| (x - test_mean).powi(2))
516            .sum::<f64>()
517            / (n_splits - 1).max(1) as f64;
518        let test_std = test_variance.sqrt();
519        let test_sem = test_std / (n_splits as f64).sqrt(); // Standard error of the mean
520
521        // Calculate confidence intervals
522        let train_margin = z_score * train_sem;
523        let test_margin = z_score * test_sem;
524
525        train_scores_mean[size_idx] = train_mean;
526        test_scores_mean[size_idx] = test_mean;
527        train_scores_std[size_idx] = train_std;
528        test_scores_std[size_idx] = test_std;
529        train_scores_lower[size_idx] = train_mean - train_margin;
530        train_scores_upper[size_idx] = train_mean + train_margin;
531        test_scores_lower[size_idx] = test_mean - test_margin;
532        test_scores_upper[size_idx] = test_mean + test_margin;
533    }
534
535    Ok(LearningCurveResult {
536        train_sizes: Array1::from_vec(train_sizes_actual),
537        train_scores,
538        test_scores,
539        train_scores_mean,
540        test_scores_mean,
541        train_scores_std,
542        test_scores_std,
543        train_scores_lower,
544        train_scores_upper,
545        test_scores_lower,
546        test_scores_upper,
547    })
548}
549
550/// Validation curve results
551#[derive(Debug, Clone)]
552pub struct ValidationCurveResult {
553    /// Parameter values used
554    pub param_values: Vec<ParameterValue>,
555    /// Training scores for each parameter value
556    pub train_scores: Array2<f64>,
557    /// Validation scores for each parameter value
558    pub test_scores: Array2<f64>,
559    /// Mean training scores for each parameter value
560    pub train_scores_mean: Array1<f64>,
561    /// Mean validation scores for each parameter value
562    pub test_scores_mean: Array1<f64>,
563    /// Standard deviation of training scores for each parameter value
564    pub train_scores_std: Array1<f64>,
565    /// Standard deviation of validation scores for each parameter value
566    pub test_scores_std: Array1<f64>,
567    /// Lower error bar for training scores (mean - std_error)
568    pub train_scores_lower: Array1<f64>,
569    /// Upper error bar for training scores (mean + std_error)
570    pub train_scores_upper: Array1<f64>,
571    /// Lower error bar for validation scores (mean - std_error)
572    pub test_scores_lower: Array1<f64>,
573    /// Upper error bar for validation scores (mean + std_error)
574    pub test_scores_upper: Array1<f64>,
575}
576
577/// Parameter configuration function type
578pub type ParamConfigFn<E> = Box<dyn Fn(E, &ParameterValue) -> Result<E>>;
579
580/// Compute validation curves for an estimator
581///
582/// Determines training and test scores for a varying parameter value.
583/// This is useful to understand the effect of a specific parameter on
584/// model performance and to detect overfitting/underfitting.
585///
586/// # Arguments
587/// * `estimator` - The estimator to evaluate
588/// * `x` - Training data features
589/// * `y` - Training data targets
590/// * `_param_name` - Name of the parameter being varied (for documentation)
591/// * `param_range` - Parameter values to test
592/// * `param_config` - Function to configure estimator with parameter values
593/// * `cv` - Cross-validation splitter
594/// * `scoring` - Scoring method to use
595/// * `confidence_level` - Confidence level for error bars (default: 0.95 for 95% confidence interval)
596#[allow(clippy::too_many_arguments)]
597pub fn validation_curve<E, F, C>(
598    estimator: E,
599    x: &Array2<Float>,
600    y: &Array1<Float>,
601    _param_name: &str,
602    param_range: Vec<ParameterValue>,
603    param_config: ParamConfigFn<E>,
604    cv: &C,
605    scoring: Option<Scoring>,
606    confidence_level: Option<f64>,
607) -> Result<ValidationCurveResult>
608where
609    E: Clone,
610    F: Clone,
611    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
612    F: Predict<Array2<Float>, Array1<Float>>,
613    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
614    C: CrossValidator,
615{
616    let scoring = scoring.unwrap_or(Scoring::EstimatorScore);
617    let n_splits = cv.n_splits();
618    let n_params = param_range.len();
619
620    let mut train_scores = Array2::<f64>::zeros((n_params, n_splits));
621    let mut test_scores = Array2::<f64>::zeros((n_params, n_splits));
622
623    // Get CV splits
624    let splits = cv.split(x.nrows(), None);
625
626    for (param_idx, param_value) in param_range.iter().enumerate() {
627        for (split_idx, (train_idx, test_idx)) in splits.iter().enumerate() {
628            // Extract data
629            let x_train = x.select(scirs2_core::ndarray::Axis(0), train_idx);
630            let y_train = y.select(scirs2_core::ndarray::Axis(0), train_idx);
631            let x_test = x.select(scirs2_core::ndarray::Axis(0), test_idx);
632            let y_test = y.select(scirs2_core::ndarray::Axis(0), test_idx);
633
634            // Configure estimator with current parameter value
635            let configured_estimator = param_config(estimator.clone(), param_value)?;
636
637            // Fit estimator
638            let fitted = configured_estimator.fit(&x_train, &y_train)?;
639
640            // Score on training set
641            let train_score = match &scoring {
642                Scoring::EstimatorScore => fitted.score(&x_train, &y_train)?,
643                Scoring::Custom(func) => {
644                    let y_pred = fitted.predict(&x_train)?;
645                    func(&y_train.to_owned(), &y_pred)?
646                }
647                Scoring::Metric(metric_name) => {
648                    let y_pred = fitted.predict(&x_train)?;
649                    // Determine if this is classification or regression based on the data type
650                    if y_train.iter().all(|&x| x.fract() == 0.0) {
651                        // Integer-like values, likely classification
652                        let y_true_int: Array1<i32> = y_train.mapv(|x| x as i32);
653                        let y_pred_int: Array1<i32> = y_pred.mapv(|x| x as i32);
654                        compute_score_for_classification_val(metric_name, &y_true_int, &y_pred_int)?
655                    } else {
656                        // Float values, likely regression
657                        compute_score_for_regression_val(metric_name, &y_train, &y_pred)?
658                    }
659                }
660                Scoring::Scorer(scorer) => {
661                    let y_pred = fitted.predict(&x_train)?;
662                    scorer.score_float(y_train.as_slice().unwrap(), y_pred.as_slice().unwrap())?
663                }
664                Scoring::MultiMetric(_metrics) => {
665                    // For multi-metric, just use the first metric for now
666                    fitted.score(&x_train, &y_train)?
667                }
668            };
669            train_scores[[param_idx, split_idx]] = train_score;
670
671            // Score on test set
672            let test_score = match &scoring {
673                Scoring::EstimatorScore => fitted.score(&x_test, &y_test)?,
674                Scoring::Custom(func) => {
675                    let y_pred = fitted.predict(&x_test)?;
676                    func(&y_test.to_owned(), &y_pred)?
677                }
678                Scoring::Metric(metric_name) => {
679                    let y_pred = fitted.predict(&x_test)?;
680                    // Determine if this is classification or regression based on the data type
681                    if y_test.iter().all(|&x| x.fract() == 0.0) {
682                        // Integer-like values, likely classification
683                        let y_true_int: Array1<i32> = y_test.mapv(|x| x as i32);
684                        let y_pred_int: Array1<i32> = y_pred.mapv(|x| x as i32);
685                        compute_score_for_classification_val(metric_name, &y_true_int, &y_pred_int)?
686                    } else {
687                        // Float values, likely regression
688                        compute_score_for_regression_val(metric_name, &y_test, &y_pred)?
689                    }
690                }
691                Scoring::Scorer(scorer) => {
692                    let y_pred = fitted.predict(&x_test)?;
693                    scorer.score_float(y_test.as_slice().unwrap(), y_pred.as_slice().unwrap())?
694                }
695                Scoring::MultiMetric(_metrics) => {
696                    // For multi-metric, just use the first metric for now
697                    fitted.score(&x_test, &y_test)?
698                }
699            };
700            test_scores[[param_idx, split_idx]] = test_score;
701        }
702    }
703
704    // Calculate confidence level (default 95%)
705    let _confidence = confidence_level.unwrap_or(0.95);
706    let _z_score = 1.96; // Approximate 95% confidence interval
707
708    // Calculate statistics for each parameter value
709    let mut train_scores_mean = Array1::<f64>::zeros(n_params);
710    let mut test_scores_mean = Array1::<f64>::zeros(n_params);
711    let mut train_scores_std = Array1::<f64>::zeros(n_params);
712    let mut test_scores_std = Array1::<f64>::zeros(n_params);
713    let mut train_scores_lower = Array1::<f64>::zeros(n_params);
714    let mut train_scores_upper = Array1::<f64>::zeros(n_params);
715    let mut test_scores_lower = Array1::<f64>::zeros(n_params);
716    let mut test_scores_upper = Array1::<f64>::zeros(n_params);
717
718    for param_idx in 0..n_params {
719        // Extract scores for this parameter value across all CV folds
720        let train_scores_for_param: Vec<f64> = (0..n_splits)
721            .map(|split_idx| train_scores[[param_idx, split_idx]])
722            .collect();
723        let test_scores_for_param: Vec<f64> = (0..n_splits)
724            .map(|split_idx| test_scores[[param_idx, split_idx]])
725            .collect();
726
727        // Calculate mean and std for training scores
728        let train_mean = train_scores_for_param.iter().sum::<f64>() / n_splits as f64;
729        let train_variance = train_scores_for_param
730            .iter()
731            .map(|&x| (x - train_mean).powi(2))
732            .sum::<f64>()
733            / (n_splits - 1).max(1) as f64;
734        let train_std = train_variance.sqrt();
735        let train_sem = train_std / (n_splits as f64).sqrt(); // Standard error of the mean
736
737        // Calculate mean and std for test scores
738        let test_mean = test_scores_for_param.iter().sum::<f64>() / n_splits as f64;
739        let test_variance = test_scores_for_param
740            .iter()
741            .map(|&x| (x - test_mean).powi(2))
742            .sum::<f64>()
743            / (n_splits - 1).max(1) as f64;
744        let test_std = test_variance.sqrt();
745        let test_sem = test_std / (n_splits as f64).sqrt(); // Standard error of the mean
746
747        // Calculate error bars (using standard error for error bars)
748        let train_margin = train_sem;
749        let test_margin = test_sem;
750
751        train_scores_mean[param_idx] = train_mean;
752        test_scores_mean[param_idx] = test_mean;
753        train_scores_std[param_idx] = train_std;
754        test_scores_std[param_idx] = test_std;
755        train_scores_lower[param_idx] = train_mean - train_margin;
756        train_scores_upper[param_idx] = train_mean + train_margin;
757        test_scores_lower[param_idx] = test_mean - test_margin;
758        test_scores_upper[param_idx] = test_mean + test_margin;
759    }
760
761    Ok(ValidationCurveResult {
762        param_values: param_range,
763        train_scores,
764        test_scores,
765        train_scores_mean,
766        test_scores_mean,
767        train_scores_std,
768        test_scores_std,
769        train_scores_lower,
770        train_scores_upper,
771        test_scores_lower,
772        test_scores_upper,
773    })
774}
775
776/// Evaluate the significance of a cross-validated score with permutations
777///
778/// This function tests whether the estimator performs significantly better than
779/// random by computing cross-validation scores on permuted labels.
780#[allow(clippy::too_many_arguments)]
781pub fn permutation_test_score<E, F, C>(
782    estimator: E,
783    x: &Array2<Float>,
784    y: &Array1<Float>,
785    cv: &C,
786    scoring: Option<Scoring>,
787    n_permutations: usize,
788    random_state: Option<u64>,
789    n_jobs: Option<usize>,
790) -> Result<PermutationTestResult>
791where
792    E: Clone,
793    F: Clone,
794    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
795    F: Predict<Array2<Float>, Array1<Float>>,
796    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
797    C: CrossValidator,
798{
799    use scirs2_core::random::prelude::*;
800    use scirs2_core::random::rngs::StdRng;
801
802    let scoring = scoring.unwrap_or(Scoring::EstimatorScore);
803
804    // Compute original score
805    let original_scores =
806        cross_val_score(estimator.clone(), x, y, cv, Some(scoring.clone()), n_jobs)?;
807    let original_score = original_scores.mean().unwrap_or(0.0);
808
809    // Initialize random number generator
810    let mut rng = if let Some(seed) = random_state {
811        StdRng::seed_from_u64(seed)
812    } else {
813        StdRng::seed_from_u64(42)
814    };
815
816    // Compute permutation scores
817    let mut permutation_scores = Vec::with_capacity(n_permutations);
818
819    for _ in 0..n_permutations {
820        // Create permuted labels
821        let mut y_permuted = y.to_owned();
822        let mut indices: Vec<usize> = (0..y.len()).collect();
823        indices.shuffle(&mut rng);
824
825        for (i, &perm_idx) in indices.iter().enumerate() {
826            y_permuted[i] = y[perm_idx];
827        }
828
829        // Compute score with permuted labels
830        let perm_scores = cross_val_score(
831            estimator.clone(),
832            x,
833            &y_permuted,
834            cv,
835            Some(scoring.clone()),
836            n_jobs,
837        )?;
838        let perm_score = perm_scores.mean().unwrap_or(0.0);
839        permutation_scores.push(perm_score);
840    }
841
842    // Compute p-value
843    let n_better_or_equal = permutation_scores
844        .iter()
845        .filter(|&&score| score >= original_score)
846        .count();
847    let p_value = (n_better_or_equal + 1) as f64 / (n_permutations + 1) as f64;
848
849    Ok(PermutationTestResult {
850        statistic: original_score,
851        pvalue: p_value,
852        permutation_scores: Array1::from_vec(permutation_scores),
853    })
854}
855
856/// Result of permutation test
857#[derive(Debug, Clone)]
858pub struct PermutationTestResult {
859    /// The original cross-validation score
860    pub statistic: f64,
861    /// The p-value of the permutation test
862    pub pvalue: f64,
863    /// Scores obtained for each permutation
864    pub permutation_scores: Array1<f64>,
865}
866
867/// Nested cross-validation for unbiased model evaluation with hyperparameter optimization
868///
869/// This implements nested cross-validation which provides an unbiased estimate of model
870/// performance by using separate CV loops for hyperparameter optimization (inner loop)
871/// and performance estimation (outer loop).
872#[allow(clippy::too_many_arguments)]
873pub fn nested_cross_validate<E, F, C>(
874    estimator: E,
875    x: &Array2<Float>,
876    y: &Array1<Float>,
877    outer_cv: &C,
878    inner_cv: &C,
879    param_grid: &[ParameterValue],
880    param_config: ParamConfigFn<E>,
881    scoring: Option<fn(&Array1<Float>, &Array1<Float>) -> f64>,
882) -> Result<NestedCVResult>
883where
884    E: Clone,
885    F: Clone,
886    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
887    F: Predict<Array2<Float>, Array1<Float>>,
888    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
889    C: CrossValidator,
890{
891    let outer_splits = outer_cv.split(x.nrows(), None);
892    let mut outer_scores = Vec::with_capacity(outer_splits.len());
893    let mut best_params_per_fold = Vec::with_capacity(outer_splits.len());
894    let mut inner_scores_per_fold = Vec::with_capacity(outer_splits.len());
895
896    for (outer_train_idx, outer_test_idx) in outer_splits {
897        // Extract outer train/test data
898        let outer_train_x = extract_rows(x, &outer_train_idx);
899        let outer_train_y = extract_elements(y, &outer_train_idx);
900        let outer_test_x = extract_rows(x, &outer_test_idx);
901        let outer_test_y = extract_elements(y, &outer_test_idx);
902
903        // Inner cross-validation for hyperparameter optimization
904        let mut best_score = f64::NEG_INFINITY;
905        let mut best_param = param_grid[0].clone();
906        let mut inner_scores = Vec::new();
907
908        for param in param_grid {
909            let param_estimator = param_config(estimator.clone(), param)?;
910
911            // Inner CV evaluation
912            let inner_splits = inner_cv.split(outer_train_x.nrows(), None);
913            let mut param_scores = Vec::new();
914
915            for (inner_train_idx, inner_test_idx) in inner_splits {
916                let inner_train_x = extract_rows(&outer_train_x, &inner_train_idx);
917                let inner_train_y = extract_elements(&outer_train_y, &inner_train_idx);
918                let inner_test_x = extract_rows(&outer_train_x, &inner_test_idx);
919                let inner_test_y = extract_elements(&outer_train_y, &inner_test_idx);
920
921                // Fit and score on inner split
922                let fitted = param_estimator
923                    .clone()
924                    .fit(&inner_train_x, &inner_train_y)?;
925                let predictions = fitted.predict(&inner_test_x)?;
926
927                let score = if let Some(scoring_fn) = scoring {
928                    scoring_fn(&inner_test_y, &predictions)
929                } else {
930                    fitted.score(&inner_test_x, &inner_test_y)?
931                };
932
933                param_scores.push(score);
934            }
935
936            let mean_score = param_scores.iter().sum::<f64>() / param_scores.len() as f64;
937            inner_scores.push(mean_score);
938
939            if mean_score > best_score {
940                best_score = mean_score;
941                best_param = param.clone();
942            }
943        }
944
945        // Train best model on full outer training set and evaluate on outer test set
946        let best_estimator = param_config(estimator.clone(), &best_param)?;
947        let final_fitted = best_estimator.fit(&outer_train_x, &outer_train_y)?;
948        let outer_predictions = final_fitted.predict(&outer_test_x)?;
949
950        let outer_score = if let Some(scoring_fn) = scoring {
951            scoring_fn(&outer_test_y, &outer_predictions)
952        } else {
953            final_fitted.score(&outer_test_x, &outer_test_y)?
954        };
955
956        outer_scores.push(outer_score);
957        best_params_per_fold.push(best_param);
958        inner_scores_per_fold.push(inner_scores);
959    }
960
961    let mean_score = outer_scores.iter().sum::<f64>() / outer_scores.len() as f64;
962    let std_score = {
963        let variance = outer_scores
964            .iter()
965            .map(|&x| (x - mean_score).powi(2))
966            .sum::<f64>()
967            / outer_scores.len() as f64;
968        variance.sqrt()
969    };
970
971    Ok(NestedCVResult {
972        outer_scores: Array1::from_vec(outer_scores),
973        best_params_per_fold,
974        inner_scores_per_fold,
975        mean_outer_score: mean_score,
976        std_outer_score: std_score,
977    })
978}
979
980/// Result of nested cross-validation
981#[derive(Debug, Clone)]
982pub struct NestedCVResult {
983    /// Outer cross-validation scores (unbiased performance estimates)
984    pub outer_scores: Array1<f64>,
985    /// Best parameters found for each outer fold
986    pub best_params_per_fold: Vec<ParameterValue>,
987    /// Inner CV scores for each parameter in each outer fold
988    pub inner_scores_per_fold: Vec<Vec<f64>>,
989    /// Mean of outer scores
990    pub mean_outer_score: f64,
991    /// Standard deviation of outer scores
992    pub std_outer_score: f64,
993}
994
995// Helper functions for data extraction
996fn extract_rows(arr: &Array2<Float>, indices: &[usize]) -> Array2<Float> {
997    let mut result = Array2::zeros((indices.len(), arr.ncols()));
998    for (i, &idx) in indices.iter().enumerate() {
999        for j in 0..arr.ncols() {
1000            result[[i, j]] = arr[[idx, j]];
1001        }
1002    }
1003    result
1004}
1005
1006fn extract_elements(arr: &Array1<Float>, indices: &[usize]) -> Array1<Float> {
1007    Array1::from_iter(indices.iter().map(|&i| arr[i]))
1008}
1009
1010#[allow(non_snake_case)]
1011#[cfg(test)]
1012mod tests {
1013    use super::*;
1014    use crate::KFold;
1015    use scirs2_core::ndarray::array;
1016
1017    // Mock estimator for testing
1018    #[derive(Clone)]
1019    struct MockEstimator;
1020
1021    #[derive(Clone)]
1022    struct MockFitted {
1023        train_mean: f64,
1024    }
1025
1026    impl Fit<Array2<Float>, Array1<Float>> for MockEstimator {
1027        type Fitted = MockFitted;
1028
1029        fn fit(self, _x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
1030            Ok(MockFitted {
1031                train_mean: y.mean().unwrap_or(0.0),
1032            })
1033        }
1034    }
1035
1036    impl Predict<Array2<Float>, Array1<Float>> for MockFitted {
1037        fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
1038            Ok(Array1::from_elem(x.nrows(), self.train_mean))
1039        }
1040    }
1041
1042    impl Score<Array2<Float>, Array1<Float>> for MockFitted {
1043        type Float = Float;
1044
1045        fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
1046            let y_pred = self.predict(x)?;
1047            let mse = (y - &y_pred).mapv(|e| e * e).mean().unwrap_or(0.0);
1048            Ok(1.0 - mse) // Simple R² approximation
1049        }
1050    }
1051
1052    #[test]
1053    fn test_cross_val_score() {
1054        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
1055        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1056
1057        let estimator = MockEstimator;
1058        let cv = KFold::new(3);
1059
1060        let scores = cross_val_score(estimator, &x, &y, &cv, None, None).unwrap();
1061
1062        assert_eq!(scores.len(), 3);
1063        // All scores should be negative (since we're predicting mean)
1064        for score in scores.iter() {
1065            assert!(*score <= 1.0);
1066        }
1067    }
1068
1069    #[test]
1070    fn test_cross_val_predict() {
1071        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
1072        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1073
1074        let estimator = MockEstimator;
1075        let cv = KFold::new(3);
1076
1077        let predictions = cross_val_predict(estimator, &x, &y, &cv, None).unwrap();
1078
1079        assert_eq!(predictions.len(), 6);
1080        // Each prediction should be the mean of the training fold
1081        // Since we're using KFold with 3 splits, each test set has 2 samples
1082        // and each train set has 4 samples
1083    }
1084
1085    #[test]
1086    fn test_learning_curve() {
1087        let x = array![
1088            [1.0],
1089            [2.0],
1090            [3.0],
1091            [4.0],
1092            [5.0],
1093            [6.0],
1094            [7.0],
1095            [8.0],
1096            [9.0],
1097            [10.0]
1098        ];
1099        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1100
1101        let estimator = MockEstimator;
1102        let cv = KFold::new(3);
1103
1104        let result = learning_curve(
1105            estimator,
1106            &x,
1107            &y,
1108            &cv,
1109            Some(vec![0.3, 0.6, 1.0]), // 30%, 60%, 100% of training data
1110            None,
1111            None, // Use default confidence level
1112        )
1113        .unwrap();
1114
1115        // Check dimensions
1116        assert_eq!(result.train_sizes.len(), 3);
1117        assert_eq!(result.train_scores.dim(), (3, 3)); // 3 sizes x 3 CV folds
1118        assert_eq!(result.test_scores.dim(), (3, 3));
1119
1120        // Check that train sizes are reasonable
1121        assert_eq!(result.train_sizes[0], 3); // 30% of 10 = 3
1122        assert_eq!(result.train_sizes[1], 6); // 60% of 10 = 6
1123        assert_eq!(result.train_sizes[2], 10); // 100% of 10 = 10
1124
1125        // Training scores should generally be better than test scores for our mock estimator
1126        let mean_train_score = result.train_scores.mean().unwrap();
1127        let mean_test_score = result.test_scores.mean().unwrap();
1128        // Our mock estimator predicts the mean, so training should be perfect
1129        assert!(mean_train_score >= mean_test_score);
1130
1131        // Verify confidence bands are calculated
1132        assert_eq!(result.train_scores_mean.len(), 3);
1133        assert_eq!(result.test_scores_mean.len(), 3);
1134        assert_eq!(result.train_scores_std.len(), 3);
1135        assert_eq!(result.test_scores_std.len(), 3);
1136        assert_eq!(result.train_scores_lower.len(), 3);
1137        assert_eq!(result.train_scores_upper.len(), 3);
1138        assert_eq!(result.test_scores_lower.len(), 3);
1139        assert_eq!(result.test_scores_upper.len(), 3);
1140
1141        // Verify confidence intervals are sensible (lower < mean < upper)
1142        for i in 0..3 {
1143            assert!(result.train_scores_lower[i] <= result.train_scores_mean[i]);
1144            assert!(result.train_scores_mean[i] <= result.train_scores_upper[i]);
1145            assert!(result.test_scores_lower[i] <= result.test_scores_mean[i]);
1146            assert!(result.test_scores_mean[i] <= result.test_scores_upper[i]);
1147        }
1148    }
1149
1150    #[test]
1151    fn test_validation_curve() {
1152        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
1153        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1154
1155        let estimator = MockEstimator;
1156        let cv = KFold::new(3);
1157
1158        // Mock parameter configuration function
1159        let param_config: ParamConfigFn<MockEstimator> = Box::new(|estimator, _param_value| {
1160            // For our mock estimator, parameters don't matter
1161            Ok(estimator)
1162        });
1163
1164        let param_range = vec![
1165            ParameterValue::Float(0.1),
1166            ParameterValue::Float(0.5),
1167            ParameterValue::Float(1.0),
1168        ];
1169
1170        let result = validation_curve(
1171            estimator,
1172            &x,
1173            &y,
1174            "mock_param",
1175            param_range.clone(),
1176            param_config,
1177            &cv,
1178            None,
1179            None, // Use default confidence level
1180        )
1181        .unwrap();
1182
1183        // Check dimensions
1184        assert_eq!(result.param_values.len(), 3);
1185        assert_eq!(result.train_scores.dim(), (3, 3)); // 3 params x 3 CV folds
1186        assert_eq!(result.test_scores.dim(), (3, 3));
1187
1188        // Check that parameter values match
1189        assert_eq!(result.param_values, param_range);
1190
1191        // For our mock estimator, all parameter values should give similar results
1192        let train_score_std = {
1193            let mean = result.train_scores.mean().unwrap();
1194            let variance = result
1195                .train_scores
1196                .mapv(|x| (x - mean).powi(2))
1197                .mean()
1198                .unwrap();
1199            variance.sqrt()
1200        };
1201
1202        // Standard deviation should be low since our mock estimator ignores parameters
1203        // But allow for some variation due to different CV folds
1204        assert!(train_score_std < 2.0);
1205
1206        // Verify error bars are calculated
1207        assert_eq!(result.train_scores_mean.len(), 3);
1208        assert_eq!(result.test_scores_mean.len(), 3);
1209        assert_eq!(result.train_scores_std.len(), 3);
1210        assert_eq!(result.test_scores_std.len(), 3);
1211        assert_eq!(result.train_scores_lower.len(), 3);
1212        assert_eq!(result.train_scores_upper.len(), 3);
1213        assert_eq!(result.test_scores_lower.len(), 3);
1214        assert_eq!(result.test_scores_upper.len(), 3);
1215
1216        // Verify error bars are sensible (lower <= mean <= upper)
1217        for i in 0..3 {
1218            assert!(result.train_scores_lower[i] <= result.train_scores_mean[i]);
1219            assert!(result.train_scores_mean[i] <= result.train_scores_upper[i]);
1220            assert!(result.test_scores_lower[i] <= result.test_scores_mean[i]);
1221            assert!(result.test_scores_mean[i] <= result.test_scores_upper[i]);
1222        }
1223    }
1224
1225    #[test]
1226    fn test_learning_curve_default_sizes() {
1227        let x = array![
1228            [1.0],
1229            [2.0],
1230            [3.0],
1231            [4.0],
1232            [5.0],
1233            [6.0],
1234            [7.0],
1235            [8.0],
1236            [9.0],
1237            [10.0]
1238        ];
1239        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1240
1241        let estimator = MockEstimator;
1242        let cv = KFold::new(2);
1243
1244        let result = learning_curve(
1245            estimator, &x, &y, &cv, None, // Use default train sizes
1246            None, None, // Use default confidence level
1247        )
1248        .unwrap();
1249
1250        // Should use default sizes: 10%, 30%, 50%, 70%, 90%, 100%
1251        assert_eq!(result.train_sizes.len(), 6);
1252        assert_eq!(result.train_scores.dim(), (6, 2)); // 6 sizes x 2 CV folds
1253
1254        // Check that sizes are increasing
1255        for i in 1..result.train_sizes.len() {
1256            assert!(result.train_sizes[i] >= result.train_sizes[i - 1]);
1257        }
1258    }
1259
1260    #[test]
1261    fn test_permutation_test_score() {
1262        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]];
1263        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1264
1265        let estimator = MockEstimator;
1266        let cv = KFold::new(4);
1267
1268        let result = permutation_test_score(
1269            estimator,
1270            &x,
1271            &y,
1272            &cv,
1273            None,
1274            10, // 10 permutations
1275            Some(42),
1276            None,
1277        )
1278        .unwrap();
1279
1280        // Check that we got reasonable results
1281        assert!(result.pvalue >= 0.0 && result.pvalue <= 1.0);
1282        assert_eq!(result.permutation_scores.len(), 10);
1283
1284        // For our mock estimator, the original score should be reasonably good
1285        // compared to permuted scores
1286        assert!(result.statistic.is_finite());
1287
1288        // Permutation scores should all be finite
1289        for &score in result.permutation_scores.iter() {
1290            assert!(score.is_finite());
1291        }
1292
1293        // P-value should be calculated correctly (at least one score >= original)
1294        let n_better = result
1295            .permutation_scores
1296            .iter()
1297            .filter(|&&score| score >= result.statistic)
1298            .count();
1299        let expected_p = (n_better + 1) as f64 / 11.0; // 10 permutations + 1
1300        assert!((result.pvalue - expected_p).abs() < 1e-10);
1301    }
1302
1303    #[test]
1304    fn test_nested_cross_validate() {
1305        let x = array![
1306            [1.0],
1307            [2.0],
1308            [3.0],
1309            [4.0],
1310            [5.0],
1311            [6.0],
1312            [7.0],
1313            [8.0],
1314            [9.0],
1315            [10.0]
1316        ];
1317        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1318
1319        let estimator = MockEstimator;
1320        let outer_cv = KFold::new(3);
1321        let inner_cv = KFold::new(2);
1322
1323        // Mock parameter configuration function
1324        let param_config: ParamConfigFn<MockEstimator> = Box::new(|estimator, _param_value| {
1325            // For our mock estimator, parameters don't matter
1326            Ok(estimator)
1327        });
1328
1329        let param_grid = vec![
1330            ParameterValue::Float(0.1),
1331            ParameterValue::Float(0.5),
1332            ParameterValue::Float(1.0),
1333        ];
1334
1335        let result = nested_cross_validate(
1336            estimator,
1337            &x,
1338            &y,
1339            &outer_cv,
1340            &inner_cv,
1341            &param_grid,
1342            param_config,
1343            None,
1344        )
1345        .unwrap();
1346
1347        // Check dimensions
1348        assert_eq!(result.outer_scores.len(), 3); // 3 outer folds
1349        assert_eq!(result.best_params_per_fold.len(), 3);
1350        assert_eq!(result.inner_scores_per_fold.len(), 3);
1351
1352        // Each inner fold should have scores for all parameters
1353        for inner_scores in &result.inner_scores_per_fold {
1354            assert_eq!(inner_scores.len(), 3); // 3 parameters
1355        }
1356
1357        // Check that outer scores are finite
1358        for &score in result.outer_scores.iter() {
1359            assert!(score.is_finite());
1360        }
1361
1362        // Check that mean and std are calculated correctly
1363        let manual_mean =
1364            result.outer_scores.iter().sum::<f64>() / result.outer_scores.len() as f64;
1365        assert!((result.mean_outer_score - manual_mean).abs() < 1e-10);
1366
1367        assert!(result.std_outer_score >= 0.0);
1368        assert!(result.std_outer_score.is_finite());
1369    }
1370}