rusty_machine/data/transforms/shuffle.rs
1//! The Shuffler
2//!
3//! This module contains the `Shuffler` transformer. `Shuffler` implements the
4//! `Transformer` trait and is used to shuffle the rows of an input matrix.
5//! You can control the random number generator used by the `Shuffler`.
6//!
7//! # Examples
8//!
9//! ```
10//! use rusty_machine::linalg::Matrix;
11//! use rusty_machine::data::transforms::Transformer;
12//! use rusty_machine::data::transforms::shuffle::Shuffler;
13//!
14//! // Create an input matrix that we want to shuffle
15//! let mat = Matrix::new(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
16//!
17//! // Create a new shuffler
18//! let mut shuffler = Shuffler::default();
19//! let shuffled_mat = shuffler.transform(mat).unwrap();
20//!
21//! println!("{}", shuffled_mat);
22//! ```
23
24use learning::LearningResult;
25use learning::error::Error;
26use linalg::{Matrix, BaseMatrix, BaseMatrixMut};
27use super::Transformer;
28
29use rand::{Rng, thread_rng, ThreadRng};
30
31/// The `Shuffler`
32///
33/// Provides an implementation of `Transformer` which shuffles
34/// the input rows in place.
35#[derive(Debug)]
36pub struct Shuffler<R: Rng> {
37 rng: R,
38}
39
40impl<R: Rng> Shuffler<R> {
41 /// Construct a new `Shuffler` with given random number generator.
42 ///
43 /// # Examples
44 ///
45 /// ```
46 /// # extern crate rand;
47 /// # extern crate rusty_machine;
48 ///
49 /// use rusty_machine::data::transforms::Transformer;
50 /// use rusty_machine::data::transforms::shuffle::Shuffler;
51 /// use rand::{StdRng, SeedableRng};
52 ///
53 /// # fn main() {
54 /// // We can create a seeded rng
55 /// let rng = StdRng::from_seed(&[1, 2, 3]);
56 ///
57 /// let shuffler = Shuffler::new(rng);
58 /// # }
59 /// ```
60 pub fn new(rng: R) -> Self {
61 Shuffler { rng: rng }
62 }
63}
64
65/// Create a new shuffler using the `rand::thread_rng` function
66/// to provide a randomly seeded random number generator.
67impl Default for Shuffler<ThreadRng> {
68 fn default() -> Self {
69 Shuffler { rng: thread_rng() }
70 }
71}
72
73/// The `Shuffler` will transform the input `Matrix` by shuffling
74/// its rows in place.
75///
76/// Under the hood this uses a Fisher-Yates shuffle.
77impl<R: Rng, T> Transformer<Matrix<T>> for Shuffler<R> {
78
79 #[allow(unused_variables)]
80 fn fit(&mut self, inputs: &Matrix<T>) -> Result<(), Error> {
81 Ok(())
82 }
83
84 fn transform(&mut self, mut inputs: Matrix<T>) -> LearningResult<Matrix<T>> {
85 let n = inputs.rows();
86
87 for i in 0..n {
88 // Swap i with a random point after it
89 let j = self.rng.gen_range(0, n - i);
90 inputs.swap_rows(i, i + j);
91 }
92
93 Ok(inputs)
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use linalg::Matrix;
100 use super::super::Transformer;
101 use super::Shuffler;
102
103 use rand::{StdRng, SeedableRng};
104
105 #[test]
106 fn seeded_shuffle() {
107 let rng = StdRng::from_seed(&[1, 2, 3]);
108 let mut shuffler = Shuffler::new(rng);
109
110 let mat = Matrix::new(4, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
111 let shuffled = shuffler.transform(mat).unwrap();
112
113 assert_eq!(shuffled.into_vec(),
114 vec![3.0, 4.0, 1.0, 2.0, 7.0, 8.0, 5.0, 6.0]);
115 }
116
117 #[test]
118 fn shuffle_single_row() {
119 let mut shuffler = Shuffler::default();
120
121 let mat = Matrix::new(1, 8, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
122 let shuffled = shuffler.transform(mat).unwrap();
123
124 assert_eq!(shuffled.into_vec(),
125 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
126 }
127
128 #[test]
129 fn shuffle_fit() {
130 let rng = StdRng::from_seed(&[1, 2, 3]);
131 let mut shuffler = Shuffler::new(rng);
132
133 // no op
134 let mat = Matrix::new(4, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
135 let res = shuffler.fit(&mat).unwrap();
136
137 assert_eq!(res, ());
138 }
139}