1use nalgebra::{DMatrix, DVector};
2use num_traits::{Float, FromPrimitive, Num, ToPrimitive};
3use rand::seq::SliceRandom;
4use rand::Rng;
5use rand::{rngs::StdRng, SeedableRng};
6use std::cmp::PartialOrd;
7use std::error::Error;
8use std::fmt::{self, Display};
9use std::fmt::{Debug, Formatter};
10use std::hash::Hash;
11use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign};
12
13pub trait DataValue:
14 Debug
15 + Clone
16 + Copy
17 + Num
18 + FromPrimitive
19 + ToPrimitive
20 + AddAssign
21 + SubAssign
22 + MulAssign
23 + DivAssign
24 + Send
25 + Sync
26 + Display
27 + 'static
28{
29}
30
31impl<T> DataValue for T where
32 T: Debug
33 + Clone
34 + Copy
35 + Num
36 + FromPrimitive
37 + ToPrimitive
38 + AddAssign
39 + SubAssign
40 + MulAssign
41 + DivAssign
42 + Send
43 + Sync
44 + Display
45 + 'static
46{
47}
48
49pub trait Number: DataValue + PartialOrd {}
50impl<T> Number for T where T: DataValue + PartialOrd {}
51
52pub trait WholeNumber: Number + Eq + Hash {}
53impl<T> WholeNumber for T where T: Number + Eq + Hash {}
54
55pub trait RealNumber: Number + Float {}
56impl<T> RealNumber for T where T: Number + Float {}
57
58pub trait TargetValue: DataValue {}
59impl<T> TargetValue for T where T: DataValue {}
60
61pub struct Dataset<XT: Number, YT: TargetValue> {
62 pub x: DMatrix<XT>,
63 pub y: DVector<YT>,
64}
65
66impl<XT: Number, YT: TargetValue> Debug for Dataset<XT, YT> {
67 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
68 write!(f, "Dataset {{\n x: [\n")?;
69
70 for i in 0..self.x.nrows() {
71 write!(f, " [")?;
72 for j in 0..self.x.ncols() {
73 write!(f, "{:?}, ", self.x[(i, j)])?;
74 }
75 writeln!(f, "],")?;
76 }
77
78 write!(f, " ],\n y: [")?;
79 for i in 0..self.y.len() {
80 write!(f, "{:?}, ", self.y[i])?;
81 }
82 write!(f, "]\n}}")
83 }
84}
85
86impl<XT: Number, YT: TargetValue> Dataset<XT, YT> {
122 pub fn new(x: DMatrix<XT>, y: DVector<YT>) -> Self {
133 Self { x, y }
134 }
135
136 pub fn into_parts(&self) -> (&DMatrix<XT>, &DVector<YT>) {
142 (&self.x, &self.y)
143 }
144
145 pub fn is_not_empty(&self) -> bool {
151 !(self.x.is_empty() || self.y.is_empty())
152 }
153
154 pub fn nrows(&self) -> usize {
160 self.x.nrows()
161 }
162
163 pub fn standardize(&mut self)
172 where
173 XT: RealNumber,
174 {
175 let (nrows, _) = self.x.shape();
176
177 let means = self
178 .x
179 .column_iter()
180 .map(|col| col.sum() / XT::from_usize(col.len()).unwrap())
181 .collect::<Vec<_>>();
182 let std_devs = self
183 .x
184 .column_iter()
185 .zip(means.iter())
186 .map(|(col, mean)| {
187 let mut sum = XT::from_f64(0.0).unwrap();
188 for val in col.iter() {
189 sum += (*val - *mean) * (*val - *mean);
190 }
191 (sum / XT::from_usize(nrows).unwrap()).sqrt()
192 })
193 .collect::<Vec<_>>();
194 let standardized_cols = self
195 .x
196 .column_iter()
197 .zip(means.iter())
198 .zip(std_devs.iter())
199 .map(|((col, &mean), &std_dev)| col.map(|val| (val - mean) / std_dev))
200 .collect::<Vec<_>>();
201 self.x = DMatrix::from_columns(&standardized_cols);
202 }
203
204 pub fn train_test_split(
215 &self,
216 train_size: f64,
217 seed: Option<u64>,
218 ) -> Result<(Self, Self), Box<dyn Error>> {
219 if !(0.0..=1.0).contains(&train_size) {
220 return Err("Train size should be between 0.0 and 1.0".into());
221 }
222 let mut rng = match seed {
223 Some(seed) => StdRng::seed_from_u64(seed),
224 None => StdRng::from_entropy(),
225 };
226
227 let mut indices = (0..self.x.nrows()).collect::<Vec<_>>();
228 indices.shuffle(&mut rng);
229 let train_size = (self.x.nrows() as f64 * train_size).floor() as usize;
230 let train_indices = &indices[..train_size];
231 let test_indices = &indices[train_size..];
232
233 let train_x = train_indices
234 .iter()
235 .map(|&index| self.x.row(index))
236 .collect::<Vec<_>>();
237 let train_y = train_indices
238 .iter()
239 .map(|&index| self.y[index])
240 .collect::<Vec<_>>();
241
242 let test_x = test_indices
243 .iter()
244 .map(|&index| self.x.row(index))
245 .collect::<Vec<_>>();
246 let test_y = test_indices
247 .iter()
248 .map(|&index| self.y[index])
249 .collect::<Vec<_>>();
250
251 let train_dataset = Self::new(DMatrix::from_rows(&train_x), DVector::from_vec(train_y));
252 let test_dataset = Self::new(DMatrix::from_rows(&test_x), DVector::from_vec(test_y));
253
254 Ok((train_dataset, test_dataset))
255 }
256
257 pub fn split_on_threshold(&self, feature_index: usize, threshold: XT) -> (Self, Self) {
272 let (left_indices, right_indices): (Vec<_>, Vec<_>) = self
273 .x
274 .row_iter()
275 .enumerate()
276 .partition(|(_, row)| row[feature_index] <= threshold);
277
278 let left_x: Vec<_> = left_indices
279 .iter()
280 .map(|&(index, _)| self.x.row(index))
281 .collect();
282 let left_y: Vec<_> = left_indices
283 .iter()
284 .map(|&(index, _)| self.y.row(index))
285 .collect();
286
287 let right_x: Vec<_> = right_indices
288 .iter()
289 .map(|&(index, _)| self.x.row(index))
290 .collect();
291 let right_y: Vec<_> = right_indices
292 .iter()
293 .map(|&(index, _)| self.y.row(index))
294 .collect();
295
296 let left_dataset = if left_x.is_empty() {
297 Self::new(DMatrix::zeros(0, self.x.ncols()), DVector::zeros(0))
298 } else {
299 Self::new(DMatrix::from_rows(&left_x), DVector::from_rows(&left_y))
300 };
301
302 let right_dataset = if right_x.is_empty() {
303 Self::new(DMatrix::zeros(0, self.x.ncols()), DVector::zeros(0))
304 } else {
305 Self::new(DMatrix::from_rows(&right_x), DVector::from_rows(&right_y))
306 };
307
308 (left_dataset, right_dataset)
309 }
310
311 pub fn samples(&self, sample_size: usize, seed: Option<u64>) -> Self {
324 let mut rng = match seed {
325 Some(seed) => StdRng::seed_from_u64(seed),
326 None => StdRng::from_entropy(),
327 };
328
329 let nrows = self.x.nrows();
330 let sample_indices = (0..sample_size)
331 .map(|_| rng.gen_range(0..nrows))
332 .collect::<Vec<_>>();
333
334 let sample_x = sample_indices
335 .iter()
336 .map(|&index| self.x.row(index))
337 .collect::<Vec<_>>();
338 let sample_y = sample_indices
339 .iter()
340 .map(|&index| self.y[index])
341 .collect::<Vec<_>>();
342
343 Self::new(DMatrix::from_rows(&sample_x), DVector::from_vec(sample_y))
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use approx::assert_relative_eq;
350
351 use super::*;
352
353 #[test]
354 fn test_dataset_new() {
355 let x = DMatrix::from_row_slice(2, 2, &[1, 2, 3, 4]);
356 let y = DVector::from_vec(vec![5, 6]);
357 let dataset = Dataset::new(x.clone(), y.clone());
358 assert_eq!(dataset.x, x);
359 assert_eq!(dataset.y, y);
360 }
361
362 #[test]
363 fn test_dataset_into_parts() {
364 let x = DMatrix::from_row_slice(2, 2, &[1, 2, 3, 4]);
365 let y = DVector::from_vec(vec![5, 6]);
366 let dataset = Dataset::new(x.clone(), y.clone());
367 let (x_parts, y_parts) = dataset.into_parts();
368 assert_eq!(x_parts, &x);
369 assert_eq!(y_parts, &y);
370 }
371
372 #[test]
373 fn test_dataset_formatting() {
374 let x = DMatrix::from_row_slice(2, 2, &[1, 2, 3, 4]);
376 let y = DVector::from_vec(vec![5, 6]);
377 let dataset = Dataset::new(x, y);
378
379 let dataset_str = format!("{:?}", dataset);
381
382 let expected_str = "\
384Dataset {
385 x: [
386 [1, 2, ],
387 [3, 4, ],
388 ],
389 y: [5, 6, ]
390}";
391
392 assert_eq!(dataset_str, expected_str);
394 }
395
396 #[test]
397 fn test_dataset_is_not_empty() {
398 let x = DMatrix::from_row_slice(2, 2, &[1, 2, 3, 4]);
399 let y = DVector::from_vec(vec![5, 6]);
400 let dataset = Dataset::new(x, y);
401 assert!(dataset.is_not_empty());
402
403 let empty_x = DMatrix::<f64>::from_row_slice(0, 2, &[]);
404 let empty_y = DVector::<f64>::from_vec(vec![]);
405 let empty_dataset = Dataset::new(empty_x, empty_y);
406 assert!(!empty_dataset.is_not_empty());
407 }
408
409 #[test]
410 fn test_dataset_standardize() {
411 let x = DMatrix::from_row_slice(3, 2, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
412 let y = DVector::from_vec(vec![7.0, 8.0, 9.0]);
413 let mut dataset = Dataset::new(x, y);
414 println!("{}", dataset.x);
415 dataset.standardize();
416 println!("{}", dataset.x);
417
418 let expected_x = DMatrix::from_row_slice(
419 3,
420 2,
421 &[
422 -1.224744871391589,
423 -1.224744871391589,
424 0.0,
425 0.0,
426 1.224744871391589,
427 1.224744871391589,
428 ],
429 );
430 assert_relative_eq!(dataset.x, expected_x, epsilon = 1e-6);
431 }
432
433 #[test]
434 fn test_dataset_train_test_split() {
435 let x = DMatrix::from_row_slice(4, 2, &[1, 2, 3, 4, 5, 6, 7, 8]);
436 let y = DVector::from_vec(vec![9, 10, 11, 12]);
437 let dataset = Dataset::new(x, y);
438
439 let (train_dataset, test_dataset) = dataset.train_test_split(0.75, None).unwrap();
440 assert_eq!(train_dataset.x.nrows(), 3);
441 assert_eq!(test_dataset.x.nrows(), 1);
442 }
443
444 #[test]
445 fn test_dataset_split_on_threshold() {
446 let x = DMatrix::from_row_slice(4, 2, &[1, 2, 3, 4, 5, 6, 7, 8]);
447 let y = DVector::from_vec(vec![9, 10, 11, 12]);
448 let dataset = Dataset::new(x, y);
449
450 let (left_dataset, right_dataset) = dataset.split_on_threshold(0, 4);
451 assert_eq!(left_dataset.x.nrows(), 2);
452 assert_eq!(right_dataset.x.nrows(), 2);
453 }
454
455 #[test]
456 fn test_dataset_split_on_threshold_left_empty() {
457 let x = DMatrix::from_row_slice(4, 2, &[1, 2, 3, 4, 5, 6, 7, 8]);
458 let y = DVector::from_vec(vec![9, 10, 11, 12]);
459 let dataset = Dataset::new(x, y);
460
461 let (left_dataset, right_dataset) = dataset.split_on_threshold(0, -1);
462 assert_eq!(left_dataset.x.nrows(), 0);
463 assert_eq!(right_dataset.x.nrows(), 4);
464 }
465
466 #[test]
467 fn test_dataset_split_on_threshold_right_empty() {
468 let x = DMatrix::from_row_slice(4, 2, &[1, 2, 3, 4, 5, 6, 7, 8]);
469 let y = DVector::from_vec(vec![9, 10, 11, 12]);
470 let dataset = Dataset::new(x, y);
471
472 let (left_dataset, right_dataset) = dataset.split_on_threshold(0, 9);
473 assert_eq!(left_dataset.x.nrows(), 4);
474 assert_eq!(right_dataset.x.nrows(), 0);
475 }
476
477 #[test]
478 fn test_dataset_samples() {
479 let x = DMatrix::from_row_slice(4, 2, &[1, 2, 3, 4, 5, 6, 7, 8]);
480 let y = DVector::from_vec(vec![9, 10, 11, 12]);
481 let dataset = Dataset::new(x, y);
482
483 let sampled_dataset = dataset.samples(2, None);
484 assert_eq!(sampled_dataset.x.nrows(), 2);
485 }
486
487 #[test]
488 fn test_dataset_samples_with_seed() {
489 let x = DMatrix::from_row_slice(4, 2, &[1, 2, 3, 4, 5, 6, 7, 8]);
490 let y = DVector::from_vec(vec![9, 10, 11, 12]);
491 let dataset = Dataset::new(x, y);
492
493 let sampled_dataset = dataset.samples(2, Some(1000));
494 assert_eq!(sampled_dataset.x.nrows(), 2);
495 }
496}