Skip to main content

stratified_train_test_split

Function stratified_train_test_split 

Source
pub fn stratified_train_test_split<L: Eq + Hash + Clone>(
    labels: &[L],
    test_size: f64,
    seed: Option<u64>,
) -> CoreResult<SplitIndices>
Expand description

Stratified train/test split that preserves the proportion of each class.

§Arguments

  • labels - Class labels for each sample
  • test_size - Fraction of data for testing (0.0 .. 1.0)
  • seed - Optional random seed

§Example

use scirs2_core::data_split::stratified_train_test_split;

let labels = vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2];
let (train, test) = stratified_train_test_split(&labels, 0.3, Some(42)).expect("split");
assert_eq!(train.len() + test.len(), 10);