rusty_machine/data/transforms/
shuffle.rs

1//! The Shuffler
2//!
3//! This module contains the `Shuffler` transformer. `Shuffler` implements the
4//! `Transformer` trait and is used to shuffle the rows of an input matrix.
5//! You can control the random number generator used by the `Shuffler`.
6//!
7//! # Examples
8//!
9//! ```
10//! use rusty_machine::linalg::Matrix;
11//! use rusty_machine::data::transforms::Transformer;
12//! use rusty_machine::data::transforms::shuffle::Shuffler;
13//!
14//! // Create an input matrix that we want to shuffle
15//! let mat = Matrix::new(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
16//!
17//! // Create a new shuffler
18//! let mut shuffler = Shuffler::default();
19//! let shuffled_mat = shuffler.transform(mat).unwrap();
20//!
21//! println!("{}", shuffled_mat);
22//! ```
23
24use learning::LearningResult;
25use learning::error::Error;
26use linalg::{Matrix, BaseMatrix, BaseMatrixMut};
27use super::Transformer;
28
29use rand::{Rng, thread_rng, ThreadRng};
30
31/// The `Shuffler`
32///
33/// Provides an implementation of `Transformer` which shuffles
34/// the input rows in place.
35#[derive(Debug)]
36pub struct Shuffler<R: Rng> {
37    rng: R,
38}
39
40impl<R: Rng> Shuffler<R> {
41    /// Construct a new `Shuffler` with given random number generator.
42    ///
43    /// # Examples
44    ///
45    /// ```
46    /// # extern crate rand;
47    /// # extern crate rusty_machine;
48    ///
49    /// use rusty_machine::data::transforms::Transformer;
50    /// use rusty_machine::data::transforms::shuffle::Shuffler;
51    /// use rand::{StdRng, SeedableRng};
52    ///
53    /// # fn main() {
54    /// // We can create a seeded rng
55    /// let rng = StdRng::from_seed(&[1, 2, 3]);
56    ///
57    /// let shuffler = Shuffler::new(rng);
58    /// # }
59    /// ```
60    pub fn new(rng: R) -> Self {
61        Shuffler { rng: rng }
62    }
63}
64
65/// Create a new shuffler using the `rand::thread_rng` function
66/// to provide a randomly seeded random number generator.
67impl Default for Shuffler<ThreadRng> {
68    fn default() -> Self {
69        Shuffler { rng: thread_rng() }
70    }
71}
72
73/// The `Shuffler` will transform the input `Matrix` by shuffling
74/// its rows in place.
75///
76/// Under the hood this uses a Fisher-Yates shuffle.
77impl<R: Rng, T> Transformer<Matrix<T>> for Shuffler<R> {
78
79    #[allow(unused_variables)]
80    fn fit(&mut self, inputs: &Matrix<T>) -> Result<(), Error> {
81        Ok(())
82    }
83
84    fn transform(&mut self, mut inputs: Matrix<T>) -> LearningResult<Matrix<T>> {
85        let n = inputs.rows();
86
87        for i in 0..n {
88            // Swap i with a random point after it
89            let j = self.rng.gen_range(0, n - i);
90            inputs.swap_rows(i, i + j);
91        }
92
93        Ok(inputs)
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use linalg::Matrix;
100    use super::super::Transformer;
101    use super::Shuffler;
102
103    use rand::{StdRng, SeedableRng};
104
105    #[test]
106    fn seeded_shuffle() {
107        let rng = StdRng::from_seed(&[1, 2, 3]);
108        let mut shuffler = Shuffler::new(rng);
109
110        let mat = Matrix::new(4, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
111        let shuffled = shuffler.transform(mat).unwrap();
112
113        assert_eq!(shuffled.into_vec(),
114                   vec![3.0, 4.0, 1.0, 2.0, 7.0, 8.0, 5.0, 6.0]);
115    }
116
117    #[test]
118    fn shuffle_single_row() {
119        let mut shuffler = Shuffler::default();
120
121        let mat = Matrix::new(1, 8, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
122        let shuffled = shuffler.transform(mat).unwrap();
123
124        assert_eq!(shuffled.into_vec(),
125                   vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
126    }
127
128    #[test]
129    fn shuffle_fit() {
130        let rng = StdRng::from_seed(&[1, 2, 3]);
131        let mut shuffler = Shuffler::new(rng);
132
133        // no op
134        let mat = Matrix::new(4, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
135        let res = shuffler.fit(&mat).unwrap();
136
137        assert_eq!(res, ());
138    }
139}