sklears_model_selection/
train_test_split.rs1use scirs2_core::ndarray::{Array1, Array2, Axis};
4use scirs2_core::random::prelude::*;
5use scirs2_core::random::rngs::StdRng;
6use scirs2_core::SliceRandomExt;
7use sklears_core::error::Result;
8
9#[allow(clippy::type_complexity)]
11pub fn train_test_split<X, Y>(
12 x: &Array2<X>,
13 y: &Array1<Y>,
14 test_size: f64,
15 random_state: Option<u64>,
16) -> Result<(Array2<X>, Array2<X>, Array1<Y>, Array1<Y>)>
17where
18 X: Clone,
19 Y: Clone,
20{
21 let n_samples = x.nrows();
22 let n_test = (n_samples as f64 * test_size).round() as usize;
23 let n_train = n_samples - n_test;
24
25 let mut rng = match random_state {
27 Some(seed) => StdRng::seed_from_u64(seed),
28 None => StdRng::from_rng(&mut scirs2_core::random::thread_rng()),
29 };
30
31 let mut indices: Vec<usize> = (0..n_samples).collect();
32 indices.shuffle(&mut rng);
33
34 let train_indices = &indices[..n_train];
35 let test_indices = &indices[n_train..];
36
37 let x_train = x.select(Axis(0), train_indices);
39 let x_test = x.select(Axis(0), test_indices);
40 let y_train = y.select(Axis(0), train_indices);
41 let y_test = y.select(Axis(0), test_indices);
42
43 Ok((x_train, x_test, y_train, y_test))
44}