sklears_model_selection/cv/
mod.rs

1//! Cross-validation iterators and utilities
2
3pub mod basic_cv;
4pub mod custom_cv;
5pub mod group_cv;
6pub mod regression_cv;
7pub mod repeated_cv;
8pub mod shuffle_cv;
9pub mod time_series_cv;
10
11use scirs2_core::ndarray::Array1;
12use sklears_core::types::Float;
13
14/// Trait for cross-validation iterators
15pub trait CrossValidator: Send + Sync {
16    /// Returns the number of splits
17    fn n_splits(&self) -> usize;
18
19    /// Generate train/test indices for cross-validation
20    ///
21    /// For cross-validators that don't need y (like KFold), pass None.
22    /// For stratified cross-validators, y should contain integer class labels.
23    fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)>;
24}
25
26/// Extended trait for regression cross-validation that works with continuous targets
27pub trait RegressionCrossValidator: Send + Sync {
28    /// Returns the number of splits
29    fn n_splits(&self) -> usize;
30
31    /// Generate train/test indices for cross-validation with continuous targets
32    fn split_regression(
33        &self,
34        n_samples: usize,
35        y: &Array1<Float>,
36    ) -> Vec<(Vec<usize>, Vec<usize>)>;
37}
38
39// Re-export all cross-validators
40pub use basic_cv::{KFold, LeaveOneOut, LeavePOut, StratifiedKFold};
41pub use custom_cv::{BlockCrossValidator, CustomCrossValidator, PredefinedSplit};
42pub use group_cv::{
43    GroupKFold, GroupShuffleSplit, GroupStrategy, LeaveOneGroupOut, LeavePGroupsOut,
44    StratifiedGroupKFold,
45};
46pub use regression_cv::StratifiedRegressionKFold;
47pub use repeated_cv::{RepeatedKFold, RepeatedStratifiedKFold};
48pub use shuffle_cv::{BootstrapCV, MonteCarloCV, ShuffleSplit, StratifiedShuffleSplit};
49pub use time_series_cv::{BlockedTimeSeriesCV, PurgedGroupTimeSeriesSplit, TimeSeriesSplit};