1use std::path::PathBuf;
2
3use ndarray::{Axis, Ix1};
4use ndarray_rand::rand::{SeedableRng, seq::SliceRandom};
5use polars::prelude::*;
6
7use crate::core::types::{Matrix, Vector};
8
9pub fn load_dataset(path: PathBuf) -> PolarsResult<DataFrame> {
10 CsvReadOptions::default()
11 .try_into_reader_with_file_path(Some(path))?
12 .finish()
13}
14
15#[cfg(test)]
16mod load_dataset_tests {
17 use std::path::PathBuf;
18
19 use crate::utils::data::load_dataset;
20
21 #[test]
22 fn test_load_dataset_existing_path() {
23 let path = PathBuf::from("./datasets/advertising.csv");
24 println!("path exists: {:?}", path.exists());
25 let df = load_dataset(path);
26 assert!(df.is_ok());
27 }
28
29 #[test]
30 fn test_load_dataset_with_non_existing_path() {
31 let path = PathBuf::from("./data/non_existing.csv");
32 let df = load_dataset(path);
33 assert!(df.is_err());
34 }
35}
36
37pub fn shuffle_split(
38 x: &Matrix,
39 y: &Vector,
40 train_perc: f64,
41 seed: i32,
42) -> (Matrix, Vector, Matrix, Vector) {
43 let mut rng = ndarray_rand::rand::rngs::StdRng::seed_from_u64(seed as u64);
45
46 let n_samples = x.nrows();
48 let indices: Vec<usize> = (0..n_samples).collect();
49 let shuffled_indices: Vec<usize> = indices
50 .choose_multiple(&mut rng, n_samples)
51 .cloned()
52 .collect();
53
54 let split_index = (n_samples as f64 * train_perc).round() as usize;
56
57 let x_train = x.select(Axis(0), &shuffled_indices[..split_index]);
59 let y_train = y.select(Axis(0), &shuffled_indices[..split_index]);
60 let x_test = x.select(Axis(0), &shuffled_indices[split_index..]);
61 let y_test = y.select(Axis(0), &shuffled_indices[split_index..]);
62
63 (x_train, y_train, x_test, y_test)
64}
65
66#[cfg(test)]
67mod shuffle_split_tests {
68 use crate::utils::data::shuffle_split;
69 use ndarray::{arr1, arr2};
70
71 #[test]
72 fn test_shuffle_split_train_test_ratio() {
73 let x = arr2(&[
75 [1.0, 2.0],
76 [3.0, 4.0],
77 [5.0, 6.0],
78 [7.0, 8.0],
79 [9.0, 10.0],
80 [11.0, 12.0],
81 [13.0, 14.0],
82 [15.0, 16.0],
83 [17.0, 18.0],
84 [19.0, 20.0],
85 ]);
86 let y = arr1(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
88 let (x_train, y_train, x_test, y_test) = shuffle_split(&x, &y, 0.7, 42);
90
91 assert_eq!(x_train.nrows(), 7);
92 assert_eq!(y_train.len(), 7);
93 assert_eq!(x_test.nrows(), 3);
94 assert_eq!(y_test.len(), 3);
95 }
96
97 #[test]
98 fn test_shuffle_split_returns_sets_in_random_order() {
99 let x = arr2(&[
101 [1.0, 2.0],
102 [3.0, 4.0],
103 [5.0, 6.0],
104 [7.0, 8.0],
105 [9.0, 10.0],
106 [11.0, 12.0],
107 [13.0, 14.0],
108 [15.0, 16.0],
109 [17.0, 18.0],
110 [19.0, 20.0],
111 ]);
112 let y = arr1(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
113
114 let (x_train_1, y_train_1, _, _) = shuffle_split(&x, &y, 0.7, 42);
116 let (x_train_2, y_train_2, _, _) = shuffle_split(&x, &y, 0.7, 100);
117
118 let mut sets_are_different = false;
120
121 for i in 0..x_train_1.nrows() {
123 if x_train_1.row(i) != x_train_2.row(i) {
124 sets_are_different = true;
125 break;
126 }
127 }
128
129 assert_eq!(x_train_1.nrows(), 7);
131 assert_eq!(x_train_2.nrows(), 7);
132 assert_eq!(y_train_1.len(), 7);
133 assert_eq!(y_train_2.len(), 7);
134 assert!(
135 sets_are_different,
136 "Training sets should be different when using different seeds"
137 );
138 }
139}
140
141pub fn get_features_and_target(
142 df: &DataFrame,
143 features: Vec<&str>,
144 target: &str,
145) -> PolarsResult<(Matrix, Vector)> {
146 let x = df
147 .select(features)
148 .unwrap()
149 .to_ndarray::<Float64Type>(IndexOrder::Fortran)
150 .unwrap();
151 let y = df
152 .select([target])
153 .unwrap()
154 .to_ndarray::<Float64Type>(IndexOrder::Fortran)
155 .unwrap()
156 .column(0)
157 .to_owned()
158 .into_dimensionality::<Ix1>()
159 .unwrap();
160
161 Ok((x, y))
162}
163
164#[cfg(test)]
165mod get_features_and_target_tests {
166 use crate::utils::data::{get_features_and_target, load_dataset};
167 use std::path::PathBuf;
168
169 #[test]
170 fn test_get_features_and_target() {
171 let path = PathBuf::from("./datasets/advertising.csv");
172 let df = load_dataset(path).unwrap();
173 let features = vec!["TV", "Radio", "Newspaper"];
174 let target = "Sales";
175
176 let (x, y) = get_features_and_target(&df, features, target).unwrap();
177
178 assert_eq!(x.nrows(), 200);
179 assert_eq!(x.ncols(), 3);
180 assert_eq!(y.len(), 200);
181 }
182}