rust_ml/utils/
data.rs

1use std::path::PathBuf;
2
3use ndarray::{Axis, Ix1};
4use ndarray_rand::rand::{SeedableRng, seq::SliceRandom};
5use polars::prelude::*;
6
7use crate::core::types::{Matrix, Vector};
8
9pub fn load_dataset(path: PathBuf) -> PolarsResult<DataFrame> {
10    CsvReadOptions::default()
11        .try_into_reader_with_file_path(Some(path))?
12        .finish()
13}
14
15#[cfg(test)]
16mod load_dataset_tests {
17    use std::path::PathBuf;
18
19    use crate::utils::data::load_dataset;
20
21    #[test]
22    fn test_load_dataset_existing_path() {
23        let path = PathBuf::from("./datasets/advertising.csv");
24        println!("path exists: {:?}", path.exists());
25        let df = load_dataset(path);
26        assert!(df.is_ok());
27    }
28
29    #[test]
30    fn test_load_dataset_with_non_existing_path() {
31        let path = PathBuf::from("./data/non_existing.csv");
32        let df = load_dataset(path);
33        assert!(df.is_err());
34    }
35}
36
37pub fn shuffle_split(
38    x: &Matrix,
39    y: &Vector,
40    train_perc: f64,
41    seed: i32,
42) -> (Matrix, Vector, Matrix, Vector) {
43    // Create a seedable range and use the provided seed
44    let mut rng = ndarray_rand::rand::rngs::StdRng::seed_from_u64(seed as u64);
45
46    // Shuffle the indices of the dataset
47    let n_samples = x.nrows();
48    let indices: Vec<usize> = (0..n_samples).collect();
49    let shuffled_indices: Vec<usize> = indices
50        .choose_multiple(&mut rng, n_samples)
51        .cloned()
52        .collect();
53
54    // Calculate the split index
55    let split_index = (n_samples as f64 * train_perc).round() as usize;
56
57    // Split the dataset into training and testing sets
58    let x_train = x.select(Axis(0), &shuffled_indices[..split_index]);
59    let y_train = y.select(Axis(0), &shuffled_indices[..split_index]);
60    let x_test = x.select(Axis(0), &shuffled_indices[split_index..]);
61    let y_test = y.select(Axis(0), &shuffled_indices[split_index..]);
62
63    (x_train, y_train, x_test, y_test)
64}
65
66#[cfg(test)]
67mod shuffle_split_tests {
68    use crate::utils::data::shuffle_split;
69    use ndarray::{arr1, arr2};
70
71    #[test]
72    fn test_shuffle_split_train_test_ratio() {
73        // x is a (10, 2) matrix
74        let x = arr2(&[
75            [1.0, 2.0],
76            [3.0, 4.0],
77            [5.0, 6.0],
78            [7.0, 8.0],
79            [9.0, 10.0],
80            [11.0, 12.0],
81            [13.0, 14.0],
82            [15.0, 16.0],
83            [17.0, 18.0],
84            [19.0, 20.0],
85        ]);
86        // y is a (10, ) vector
87        let y = arr1(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
88        // A shuffle split with 75% training data should return 7 values for the train sets and 3 for the test sets.
89        let (x_train, y_train, x_test, y_test) = shuffle_split(&x, &y, 0.7, 42);
90
91        assert_eq!(x_train.nrows(), 7);
92        assert_eq!(y_train.len(), 7);
93        assert_eq!(x_test.nrows(), 3);
94        assert_eq!(y_test.len(), 3);
95    }
96
97    #[test]
98    fn test_shuffle_split_returns_sets_in_random_order() {
99        // Create a sample dataset
100        let x = arr2(&[
101            [1.0, 2.0],
102            [3.0, 4.0],
103            [5.0, 6.0],
104            [7.0, 8.0],
105            [9.0, 10.0],
106            [11.0, 12.0],
107            [13.0, 14.0],
108            [15.0, 16.0],
109            [17.0, 18.0],
110            [19.0, 20.0],
111        ]);
112        let y = arr1(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
113
114        // Split the dataset using two different seeds
115        let (x_train_1, y_train_1, _, _) = shuffle_split(&x, &y, 0.7, 42);
116        let (x_train_2, y_train_2, _, _) = shuffle_split(&x, &y, 0.7, 100);
117
118        // Check that the training sets are different, which indicates shuffling occurred
119        let mut sets_are_different = false;
120
121        // Compare each row in the training sets to see if they're different
122        for i in 0..x_train_1.nrows() {
123            if x_train_1.row(i) != x_train_2.row(i) {
124                sets_are_different = true;
125                break;
126            }
127        }
128
129        // Same size training sets should have been created with different content
130        assert_eq!(x_train_1.nrows(), 7);
131        assert_eq!(x_train_2.nrows(), 7);
132        assert_eq!(y_train_1.len(), 7);
133        assert_eq!(y_train_2.len(), 7);
134        assert!(
135            sets_are_different,
136            "Training sets should be different when using different seeds"
137        );
138    }
139}
140
141pub fn get_features_and_target(
142    df: &DataFrame,
143    features: Vec<&str>,
144    target: &str,
145) -> PolarsResult<(Matrix, Vector)> {
146    let x = df
147        .select(features)
148        .unwrap()
149        .to_ndarray::<Float64Type>(IndexOrder::Fortran)
150        .unwrap();
151    let y = df
152        .select([target])
153        .unwrap()
154        .to_ndarray::<Float64Type>(IndexOrder::Fortran)
155        .unwrap()
156        .column(0)
157        .to_owned()
158        .into_dimensionality::<Ix1>()
159        .unwrap();
160
161    Ok((x, y))
162}
163
164#[cfg(test)]
165mod get_features_and_target_tests {
166    use crate::utils::data::{get_features_and_target, load_dataset};
167    use std::path::PathBuf;
168
169    #[test]
170    fn test_get_features_and_target() {
171        let path = PathBuf::from("./datasets/advertising.csv");
172        let df = load_dataset(path).unwrap();
173        let features = vec!["TV", "Radio", "Newspaper"];
174        let target = "Sales";
175
176        let (x, y) = get_features_and_target(&df, features, target).unwrap();
177
178        assert_eq!(x.nrows(), 200);
179        assert_eq!(x.ncols(), 3);
180        assert_eq!(y.len(), 200);
181    }
182}