rust_lstm/layers/
dropout.rs1use ndarray::Array2;
2use ndarray_rand::RandomExt;
3use ndarray_rand::rand_distr::Uniform;
4
5#[derive(Clone)]
12pub struct Dropout {
13 pub dropout_rate: f64,
14 pub is_training: bool,
15 pub variational: bool,
16 mask: Option<Array2<f64>>,
17}
18
19impl Dropout {
20 pub fn new(dropout_rate: f64) -> Self {
21 assert!(dropout_rate >= 0.0 && dropout_rate <= 1.0,
22 "Dropout rate must be between 0.0 and 1.0");
23
24 Dropout {
25 dropout_rate,
26 is_training: true,
27 variational: false,
28 mask: None,
29 }
30 }
31
32 pub fn variational(dropout_rate: f64) -> Self {
33 let mut dropout = Self::new(dropout_rate);
34 dropout.variational = true;
35 dropout
36 }
37
38 pub fn train(&mut self) {
39 self.is_training = true;
40 if self.variational {
41 self.mask = None;
42 }
43 }
44
45 pub fn eval(&mut self) {
46 self.is_training = false;
47 self.mask = None;
48 }
49
50 pub fn forward(&mut self, input: &Array2<f64>) -> Array2<f64> {
51 if !self.is_training || self.dropout_rate == 0.0 {
52 return input.clone();
53 }
54
55 let keep_prob = 1.0 - self.dropout_rate;
56
57 let mask = if self.variational {
58 if let Some(ref mask) = self.mask {
59 mask.clone()
60 } else {
61 let new_mask = self.generate_mask(input.raw_dim(), keep_prob);
62 self.mask = Some(new_mask.clone());
63 new_mask
64 }
65 } else {
66 let new_mask = self.generate_mask(input.raw_dim(), keep_prob);
67 self.mask = Some(new_mask.clone());
68 new_mask
69 };
70
71 input * mask / keep_prob
72 }
73
74 pub fn get_last_mask(&self) -> Option<&Array2<f64>> {
75 self.mask.as_ref()
76 }
77
78 pub fn backward(&self, grad_output: &Array2<f64>) -> Array2<f64> {
79 if !self.is_training || self.dropout_rate == 0.0 {
80 return grad_output.clone();
81 }
82
83 let keep_prob = 1.0 - self.dropout_rate;
84
85 if let Some(ref mask) = self.mask {
86 grad_output * mask / keep_prob
87 } else {
88 grad_output.clone()
89 }
90 }
91
92 fn generate_mask(&self, shape: ndarray::Dim<[usize; 2]>, keep_prob: f64) -> Array2<f64> {
93 let dist = Uniform::new(0.0, 1.0);
94 Array2::random(shape, dist).mapv(|x| if x < keep_prob { 1.0 } else { 0.0 })
95 }
96}
97
98#[derive(Clone)]
100pub struct Zoneout {
101 pub cell_zoneout_rate: f64,
102 pub hidden_zoneout_rate: f64,
103 pub is_training: bool,
104}
105
106impl Zoneout {
107 pub fn new(cell_zoneout_rate: f64, hidden_zoneout_rate: f64) -> Self {
108 assert!(cell_zoneout_rate >= 0.0 && cell_zoneout_rate <= 1.0);
109 assert!(hidden_zoneout_rate >= 0.0 && hidden_zoneout_rate <= 1.0);
110
111 Zoneout {
112 cell_zoneout_rate,
113 hidden_zoneout_rate,
114 is_training: true,
115 }
116 }
117
118 pub fn train(&mut self) {
119 self.is_training = true;
120 }
121
122 pub fn eval(&mut self) {
123 self.is_training = false;
124 }
125
126 pub fn apply_cell_zoneout(&self, new_cell: &Array2<f64>, prev_cell: &Array2<f64>) -> Array2<f64> {
127 if !self.is_training || self.cell_zoneout_rate == 0.0 {
128 return new_cell.clone();
129 }
130
131 let keep_prob = 1.0 - self.cell_zoneout_rate;
132 let dist = Uniform::new(0.0, 1.0);
133 let mask = Array2::random(new_cell.raw_dim(), dist);
134
135 let keep_new = mask.mapv(|x| if x < keep_prob { 1.0 } else { 0.0 });
136 let keep_old = mask.mapv(|x| if x >= keep_prob { 1.0 } else { 0.0 });
137
138 &keep_new * new_cell + &keep_old * prev_cell
139 }
140
141 pub fn apply_hidden_zoneout(&self, new_hidden: &Array2<f64>, prev_hidden: &Array2<f64>) -> Array2<f64> {
142 if !self.is_training || self.hidden_zoneout_rate == 0.0 {
143 return new_hidden.clone();
144 }
145
146 let keep_prob = 1.0 - self.hidden_zoneout_rate;
147 let dist = Uniform::new(0.0, 1.0);
148 let mask = Array2::random(new_hidden.raw_dim(), dist);
149
150 let keep_new = mask.mapv(|x| if x < keep_prob { 1.0 } else { 0.0 });
151 let keep_old = mask.mapv(|x| if x >= keep_prob { 1.0 } else { 0.0 });
152
153 &keep_new * new_hidden + &keep_old * prev_hidden
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use ndarray::arr2;
161
162 #[test]
163 fn test_dropout_forward() {
164 let mut dropout = Dropout::new(0.5);
165 let input = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
166
167 dropout.train();
168 let _output_train = dropout.forward(&input);
169
170 dropout.eval();
171 let output_eval = dropout.forward(&input);
172 assert_eq!(output_eval, input);
173 }
174
175 #[test]
176 fn test_variational_dropout() {
177 let mut dropout = Dropout::variational(0.3);
178 let input1 = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
179 let input2 = arr2(&[[2.0, 3.0], [4.0, 5.0]]);
180
181 dropout.train();
182 let _output1 = dropout.forward(&input1);
183 let _output2 = dropout.forward(&input2);
184 }
185
186 #[test]
187 fn test_zoneout() {
188 let zoneout = Zoneout::new(0.2, 0.3);
189 let new_state = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
190 let prev_state = arr2(&[[0.5, 1.0], [1.5, 2.0]]);
191
192 let result = zoneout.apply_cell_zoneout(&new_state, &prev_state);
193 assert_eq!(result.shape(), new_state.shape());
194 }
195}