sklears_model_selection/
train_test_split.rs

1//! Train-test split functionality
2
3use scirs2_core::ndarray::{Array1, Array2, Axis};
4use scirs2_core::random::prelude::*;
5use scirs2_core::random::rngs::StdRng;
6use scirs2_core::SliceRandomExt;
7use sklears_core::error::Result;
8
9/// Split arrays or matrices into random train and test subsets
10#[allow(clippy::type_complexity)]
11pub fn train_test_split<X, Y>(
12    x: &Array2<X>,
13    y: &Array1<Y>,
14    test_size: f64,
15    random_state: Option<u64>,
16) -> Result<(Array2<X>, Array2<X>, Array1<Y>, Array1<Y>)>
17where
18    X: Clone,
19    Y: Clone,
20{
21    let n_samples = x.nrows();
22    let n_test = (n_samples as f64 * test_size).round() as usize;
23    let n_train = n_samples - n_test;
24
25    // Generate random indices
26    let mut rng = match random_state {
27        Some(seed) => StdRng::seed_from_u64(seed),
28        None => StdRng::from_rng(&mut scirs2_core::random::thread_rng()),
29    };
30
31    let mut indices: Vec<usize> = (0..n_samples).collect();
32    indices.shuffle(&mut rng);
33
34    let train_indices = &indices[..n_train];
35    let test_indices = &indices[n_train..];
36
37    // Create train and test arrays
38    let x_train = x.select(Axis(0), train_indices);
39    let x_test = x.select(Axis(0), test_indices);
40    let y_train = y.select(Axis(0), train_indices);
41    let y_test = y.select(Axis(0), test_indices);
42
43    Ok((x_train, x_test, y_train, y_test))
44}