rust_lstm/layers/
dropout.rs

1use ndarray::Array2;
2use ndarray_rand::RandomExt;
3use ndarray_rand::rand_distr::Uniform;
4
5/// Dropout layer for regularization
6/// 
7/// Implements different types of dropout:
8/// - Standard dropout: randomly sets elements to zero
9/// - Variational dropout: uses same mask across time steps (for RNNs)
10/// - Zoneout: keeps some hidden/cell state values from previous timestep
11#[derive(Clone)]
12pub struct Dropout {
13    pub dropout_rate: f64,
14    pub is_training: bool,
15    pub variational: bool,
16    mask: Option<Array2<f64>>,
17}
18
19impl Dropout {
20    pub fn new(dropout_rate: f64) -> Self {
21        assert!(dropout_rate >= 0.0 && dropout_rate <= 1.0, 
22                "Dropout rate must be between 0.0 and 1.0");
23        
24        Dropout {
25            dropout_rate,
26            is_training: true,
27            variational: false,
28            mask: None,
29        }
30    }
31
32    pub fn variational(dropout_rate: f64) -> Self {
33        let mut dropout = Self::new(dropout_rate);
34        dropout.variational = true;
35        dropout
36    }
37
38    pub fn train(&mut self) {
39        self.is_training = true;
40        if self.variational {
41            self.mask = None;
42        }
43    }
44
45    pub fn eval(&mut self) {
46        self.is_training = false;
47        self.mask = None;
48    }
49
50    pub fn forward(&mut self, input: &Array2<f64>) -> Array2<f64> {
51        if !self.is_training || self.dropout_rate == 0.0 {
52            return input.clone();
53        }
54
55        let keep_prob = 1.0 - self.dropout_rate;
56
57        let mask = if self.variational {
58            if let Some(ref mask) = self.mask {
59                mask.clone()
60            } else {
61                let new_mask = self.generate_mask(input.raw_dim(), keep_prob);
62                self.mask = Some(new_mask.clone());
63                new_mask
64            }
65        } else {
66            let new_mask = self.generate_mask(input.raw_dim(), keep_prob);
67            self.mask = Some(new_mask.clone());
68            new_mask
69        };
70
71        input * mask / keep_prob
72    }
73
74    pub fn get_last_mask(&self) -> Option<&Array2<f64>> {
75        self.mask.as_ref()
76    }
77
78    pub fn backward(&self, grad_output: &Array2<f64>) -> Array2<f64> {
79        if !self.is_training || self.dropout_rate == 0.0 {
80            return grad_output.clone();
81        }
82
83        let keep_prob = 1.0 - self.dropout_rate;
84        
85        if let Some(ref mask) = self.mask {
86            grad_output * mask / keep_prob
87        } else {
88            grad_output.clone()
89        }
90    }
91
92    fn generate_mask(&self, shape: ndarray::Dim<[usize; 2]>, keep_prob: f64) -> Array2<f64> {
93        let dist = Uniform::new(0.0, 1.0);
94        Array2::random(shape, dist).mapv(|x| if x < keep_prob { 1.0 } else { 0.0 })
95    }
96}
97
98/// Zoneout implementation specifically for LSTM hidden and cell states
99#[derive(Clone)]
100pub struct Zoneout {
101    pub cell_zoneout_rate: f64,
102    pub hidden_zoneout_rate: f64,
103    pub is_training: bool,
104}
105
106impl Zoneout {
107    pub fn new(cell_zoneout_rate: f64, hidden_zoneout_rate: f64) -> Self {
108        assert!(cell_zoneout_rate >= 0.0 && cell_zoneout_rate <= 1.0);
109        assert!(hidden_zoneout_rate >= 0.0 && hidden_zoneout_rate <= 1.0);
110        
111        Zoneout {
112            cell_zoneout_rate,
113            hidden_zoneout_rate,
114            is_training: true,
115        }
116    }
117
118    pub fn train(&mut self) {
119        self.is_training = true;
120    }
121
122    pub fn eval(&mut self) {
123        self.is_training = false;
124    }
125
126    pub fn apply_cell_zoneout(&self, new_cell: &Array2<f64>, prev_cell: &Array2<f64>) -> Array2<f64> {
127        if !self.is_training || self.cell_zoneout_rate == 0.0 {
128            return new_cell.clone();
129        }
130
131        let keep_prob = 1.0 - self.cell_zoneout_rate;
132        let dist = Uniform::new(0.0, 1.0);
133        let mask = Array2::random(new_cell.raw_dim(), dist);
134        
135        let keep_new = mask.mapv(|x| if x < keep_prob { 1.0 } else { 0.0 });
136        let keep_old = mask.mapv(|x| if x >= keep_prob { 1.0 } else { 0.0 });
137        
138        &keep_new * new_cell + &keep_old * prev_cell
139    }
140
141    pub fn apply_hidden_zoneout(&self, new_hidden: &Array2<f64>, prev_hidden: &Array2<f64>) -> Array2<f64> {
142        if !self.is_training || self.hidden_zoneout_rate == 0.0 {
143            return new_hidden.clone();
144        }
145
146        let keep_prob = 1.0 - self.hidden_zoneout_rate;
147        let dist = Uniform::new(0.0, 1.0);
148        let mask = Array2::random(new_hidden.raw_dim(), dist);
149        
150        let keep_new = mask.mapv(|x| if x < keep_prob { 1.0 } else { 0.0 });
151        let keep_old = mask.mapv(|x| if x >= keep_prob { 1.0 } else { 0.0 });
152        
153        &keep_new * new_hidden + &keep_old * prev_hidden
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use ndarray::arr2;
161
162    #[test]
163    fn test_dropout_forward() {
164        let mut dropout = Dropout::new(0.5);
165        let input = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
166
167        dropout.train();
168        let _output_train = dropout.forward(&input);
169
170        dropout.eval();
171        let output_eval = dropout.forward(&input);
172        assert_eq!(output_eval, input);
173    }
174
175    #[test]
176    fn test_variational_dropout() {
177        let mut dropout = Dropout::variational(0.3);
178        let input1 = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
179        let input2 = arr2(&[[2.0, 3.0], [4.0, 5.0]]);
180        
181        dropout.train();
182        let _output1 = dropout.forward(&input1);
183        let _output2 = dropout.forward(&input2);
184    }
185
186    #[test]
187    fn test_zoneout() {
188        let zoneout = Zoneout::new(0.2, 0.3);
189        let new_state = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
190        let prev_state = arr2(&[[0.5, 1.0], [1.5, 2.0]]);
191        
192        let result = zoneout.apply_cell_zoneout(&new_state, &prev_state);
193        assert_eq!(result.shape(), new_state.shape());
194    }
195}