1use ndarray::Array2;
2use ndarray_rand::RandomExt;
3use ndarray_rand::rand_distr::Uniform;
4use crate::utils::sigmoid;
5use crate::layers::dropout::Dropout;
6
7#[derive(Clone)]
9pub struct GRUCellGradients {
10 pub w_ir: Array2<f64>,
11 pub w_hr: Array2<f64>,
12 pub b_ir: Array2<f64>,
13 pub b_hr: Array2<f64>,
14 pub w_iz: Array2<f64>,
15 pub w_hz: Array2<f64>,
16 pub b_iz: Array2<f64>,
17 pub b_hz: Array2<f64>,
18 pub w_ih: Array2<f64>,
19 pub w_hh: Array2<f64>,
20 pub b_ih: Array2<f64>,
21 pub b_hh: Array2<f64>,
22}
23
24#[derive(Clone)]
26pub struct GRUCellCache {
27 pub input: Array2<f64>,
28 pub hx: Array2<f64>,
29 pub reset_gate: Array2<f64>,
30 pub update_gate: Array2<f64>,
31 pub new_gate: Array2<f64>,
32 pub reset_hidden: Array2<f64>,
33 pub hy: Array2<f64>,
34 pub input_dropout_mask: Option<Array2<f64>>,
35 pub recurrent_dropout_mask: Option<Array2<f64>>,
36 pub output_dropout_mask: Option<Array2<f64>>,
37}
38
39#[derive(Clone)]
41pub struct GRUCell {
42 pub w_ir: Array2<f64>,
44 pub w_hr: Array2<f64>,
45 pub b_ir: Array2<f64>,
46 pub b_hr: Array2<f64>,
47
48 pub w_iz: Array2<f64>,
50 pub w_hz: Array2<f64>,
51 pub b_iz: Array2<f64>,
52 pub b_hz: Array2<f64>,
53
54 pub w_ih: Array2<f64>,
56 pub w_hh: Array2<f64>,
57 pub b_ih: Array2<f64>,
58 pub b_hh: Array2<f64>,
59
60 pub hidden_size: usize,
61 pub input_dropout: Option<Dropout>,
62 pub recurrent_dropout: Option<Dropout>,
63 pub output_dropout: Option<Dropout>,
64 pub is_training: bool,
65}
66
67impl GRUCell {
68 pub fn new(input_size: usize, hidden_size: usize) -> Self {
70 let dist = Uniform::new(-0.1, 0.1);
71
72 let w_ir = Array2::random((hidden_size, input_size), dist);
74 let w_hr = Array2::random((hidden_size, hidden_size), dist);
75 let b_ir = Array2::zeros((hidden_size, 1));
76 let b_hr = Array2::zeros((hidden_size, 1));
77
78 let w_iz = Array2::random((hidden_size, input_size), dist);
80 let w_hz = Array2::random((hidden_size, hidden_size), dist);
81 let b_iz = Array2::zeros((hidden_size, 1));
82 let b_hz = Array2::zeros((hidden_size, 1));
83
84 let w_ih = Array2::random((hidden_size, input_size), dist);
86 let w_hh = Array2::random((hidden_size, hidden_size), dist);
87 let b_ih = Array2::zeros((hidden_size, 1));
88 let b_hh = Array2::zeros((hidden_size, 1));
89
90 GRUCell {
91 w_ir, w_hr, b_ir, b_hr,
92 w_iz, w_hz, b_iz, b_hz,
93 w_ih, w_hh, b_ih, b_hh,
94 hidden_size,
95 input_dropout: None,
96 recurrent_dropout: None,
97 output_dropout: None,
98 is_training: true,
99 }
100 }
101
102 pub fn with_input_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
103 if variational {
104 self.input_dropout = Some(Dropout::variational(dropout_rate));
105 } else {
106 self.input_dropout = Some(Dropout::new(dropout_rate));
107 }
108 self
109 }
110
111 pub fn with_recurrent_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
112 if variational {
113 self.recurrent_dropout = Some(Dropout::variational(dropout_rate));
114 } else {
115 self.recurrent_dropout = Some(Dropout::new(dropout_rate));
116 }
117 self
118 }
119
120 pub fn with_output_dropout(mut self, dropout_rate: f64) -> Self {
121 self.output_dropout = Some(Dropout::new(dropout_rate));
122 self
123 }
124
125 pub fn train(&mut self) {
126 self.is_training = true;
127 if let Some(ref mut dropout) = self.input_dropout {
128 dropout.train();
129 }
130 if let Some(ref mut dropout) = self.recurrent_dropout {
131 dropout.train();
132 }
133 if let Some(ref mut dropout) = self.output_dropout {
134 dropout.train();
135 }
136 }
137
138 pub fn eval(&mut self) {
139 self.is_training = false;
140 if let Some(ref mut dropout) = self.input_dropout {
141 dropout.eval();
142 }
143 if let Some(ref mut dropout) = self.recurrent_dropout {
144 dropout.eval();
145 }
146 if let Some(ref mut dropout) = self.output_dropout {
147 dropout.eval();
148 }
149 }
150
151 pub fn forward(&mut self, input: &Array2<f64>, hx: &Array2<f64>) -> Array2<f64> {
152 let (hy, _) = self.forward_with_cache(input, hx);
153 hy
154 }
155
156 pub fn forward_with_cache(&mut self, input: &Array2<f64>, hx: &Array2<f64>) -> (Array2<f64>, GRUCellCache) {
157 let (input_dropped, input_mask) = if let Some(ref mut dropout) = self.input_dropout {
159 let dropped = dropout.forward(input);
160 let mask = dropout.get_last_mask().map(|m| m.clone());
161 (dropped, mask)
162 } else {
163 (input.clone(), None)
164 };
165
166 let (hx_dropped, recurrent_mask) = if let Some(ref mut dropout) = self.recurrent_dropout {
168 let dropped = dropout.forward(hx);
169 let mask = dropout.get_last_mask().map(|m| m.clone());
170 (dropped, mask)
171 } else {
172 (hx.clone(), None)
173 };
174
175 let reset_gate = (&self.w_ir.dot(&input_dropped) + &self.b_ir + &self.w_hr.dot(&hx_dropped) + &self.b_hr)
177 .map(|&x| sigmoid(x));
178
179 let update_gate = (&self.w_iz.dot(&input_dropped) + &self.b_iz + &self.w_hz.dot(&hx_dropped) + &self.b_hz)
181 .map(|&x| sigmoid(x));
182
183 let reset_hidden = &reset_gate * &hx_dropped;
185
186 let new_gate = (&self.w_ih.dot(&input_dropped) + &self.b_ih + &self.w_hh.dot(&reset_hidden) + &self.b_hh)
188 .map(|&x| x.tanh());
189
190 let hy = &update_gate.map(|&x| 1.0 - x) * &hx_dropped + &update_gate * &new_gate;
192
193 let (hy_final, output_mask) = if let Some(ref mut dropout) = self.output_dropout {
195 let dropped = dropout.forward(&hy);
196 let mask = dropout.get_last_mask().map(|m| m.clone());
197 (dropped, mask)
198 } else {
199 (hy, None)
200 };
201
202 let cache = GRUCellCache {
203 input: input.clone(),
204 hx: hx.clone(),
205 reset_gate: reset_gate.clone(),
206 update_gate: update_gate.clone(),
207 new_gate: new_gate.clone(),
208 reset_hidden: reset_hidden,
209 hy: hy_final.clone(),
210 input_dropout_mask: input_mask,
211 recurrent_dropout_mask: recurrent_mask,
212 output_dropout_mask: output_mask,
213 };
214
215 (hy_final, cache)
216 }
217
218 pub fn backward(&self, dhy: &Array2<f64>, cache: &GRUCellCache) -> (GRUCellGradients, Array2<f64>, Array2<f64>) {
222 let dhy_dropped = if let Some(ref mask) = cache.output_dropout_mask {
224 let keep_prob = if let Some(ref dropout) = self.output_dropout {
225 1.0 - dropout.dropout_rate
226 } else {
227 1.0
228 };
229 dhy * mask / keep_prob
230 } else {
231 dhy.clone()
232 };
233
234 let d_update_gate = &dhy_dropped * (&cache.new_gate - &cache.hx);
236 let d_new_gate = &dhy_dropped * &cache.update_gate;
237 let dhx_from_output = &dhy_dropped * cache.update_gate.map(|&x| 1.0 - x);
238
239 let d_new_gate_raw = &d_new_gate * cache.new_gate.map(|&x| 1.0 - x.powi(2));
241
242 let d_reset_hidden = self.w_hh.t().dot(&d_new_gate_raw);
244 let d_reset_gate = &d_reset_hidden * &cache.hx;
245 let dhx_from_reset = &d_reset_hidden * &cache.reset_gate;
246
247 let d_reset_gate_raw = &d_reset_gate * &cache.reset_gate * cache.reset_gate.map(|&x| 1.0 - x);
249
250 let d_update_gate_raw = &d_update_gate * &cache.update_gate * cache.update_gate.map(|&x| 1.0 - x);
252
253 let dw_ir = d_reset_gate_raw.dot(&cache.input.t());
255 let dw_hr = d_reset_gate_raw.dot(&cache.hx.t());
256 let db_ir = d_reset_gate_raw.clone();
257 let db_hr = d_reset_gate_raw.clone();
258
259 let dw_iz = d_update_gate_raw.dot(&cache.input.t());
260 let dw_hz = d_update_gate_raw.dot(&cache.hx.t());
261 let db_iz = d_update_gate_raw.clone();
262 let db_hz = d_update_gate_raw.clone();
263
264 let dw_ih = d_new_gate_raw.dot(&cache.input.t());
265 let dw_hh = d_new_gate_raw.dot(&cache.reset_hidden.t());
266 let db_ih = d_new_gate_raw.clone();
267 let db_hh = d_new_gate_raw.clone();
268
269 let gradients = GRUCellGradients {
270 w_ir: dw_ir, w_hr: dw_hr, b_ir: db_ir, b_hr: db_hr,
271 w_iz: dw_iz, w_hz: dw_hz, b_iz: db_iz, b_hz: db_hz,
272 w_ih: dw_ih, w_hh: dw_hh, b_ih: db_ih, b_hh: db_hh,
273 };
274
275 let mut dx = self.w_ir.t().dot(&d_reset_gate_raw) +
277 self.w_iz.t().dot(&d_update_gate_raw) +
278 self.w_ih.t().dot(&d_new_gate_raw);
279
280 let mut dhx = dhx_from_output + dhx_from_reset +
281 self.w_hr.t().dot(&d_reset_gate_raw) +
282 self.w_hz.t().dot(&d_update_gate_raw);
283
284 if let Some(ref mask) = cache.input_dropout_mask {
286 let keep_prob = if let Some(ref dropout) = self.input_dropout {
287 1.0 - dropout.dropout_rate
288 } else {
289 1.0
290 };
291 dx = dx * mask / keep_prob;
292 }
293
294 if let Some(ref mask) = cache.recurrent_dropout_mask {
295 let keep_prob = if let Some(ref dropout) = self.recurrent_dropout {
296 1.0 - dropout.dropout_rate
297 } else {
298 1.0
299 };
300 dhx = dhx * mask / keep_prob;
301 }
302
303 (gradients, dx, dhx)
304 }
305
306 pub fn zero_gradients(&self) -> GRUCellGradients {
308 GRUCellGradients {
309 w_ir: Array2::zeros(self.w_ir.raw_dim()),
310 w_hr: Array2::zeros(self.w_hr.raw_dim()),
311 b_ir: Array2::zeros(self.b_ir.raw_dim()),
312 b_hr: Array2::zeros(self.b_hr.raw_dim()),
313 w_iz: Array2::zeros(self.w_iz.raw_dim()),
314 w_hz: Array2::zeros(self.w_hz.raw_dim()),
315 b_iz: Array2::zeros(self.b_iz.raw_dim()),
316 b_hz: Array2::zeros(self.b_hz.raw_dim()),
317 w_ih: Array2::zeros(self.w_ih.raw_dim()),
318 w_hh: Array2::zeros(self.w_hh.raw_dim()),
319 b_ih: Array2::zeros(self.b_ih.raw_dim()),
320 b_hh: Array2::zeros(self.b_hh.raw_dim()),
321 }
322 }
323
324 pub fn update_parameters<O: crate::optimizers::Optimizer>(&mut self, gradients: &GRUCellGradients, optimizer: &mut O, prefix: &str) {
326 optimizer.update(&format!("{}_w_ir", prefix), &mut self.w_ir, &gradients.w_ir);
327 optimizer.update(&format!("{}_w_hr", prefix), &mut self.w_hr, &gradients.w_hr);
328 optimizer.update(&format!("{}_b_ir", prefix), &mut self.b_ir, &gradients.b_ir);
329 optimizer.update(&format!("{}_b_hr", prefix), &mut self.b_hr, &gradients.b_hr);
330 optimizer.update(&format!("{}_w_iz", prefix), &mut self.w_iz, &gradients.w_iz);
331 optimizer.update(&format!("{}_w_hz", prefix), &mut self.w_hz, &gradients.w_hz);
332 optimizer.update(&format!("{}_b_iz", prefix), &mut self.b_iz, &gradients.b_iz);
333 optimizer.update(&format!("{}_b_hz", prefix), &mut self.b_hz, &gradients.b_hz);
334 optimizer.update(&format!("{}_w_ih", prefix), &mut self.w_ih, &gradients.w_ih);
335 optimizer.update(&format!("{}_w_hh", prefix), &mut self.w_hh, &gradients.w_hh);
336 optimizer.update(&format!("{}_b_ih", prefix), &mut self.b_ih, &gradients.b_ih);
337 optimizer.update(&format!("{}_b_hh", prefix), &mut self.b_hh, &gradients.b_hh);
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344 use ndarray::arr2;
345
346 #[test]
347 fn test_gru_cell_forward() {
348 let input_size = 3;
349 let hidden_size = 2;
350 let mut cell = GRUCell::new(input_size, hidden_size);
351
352 let input = arr2(&[[0.5], [0.1], [-0.3]]);
353 let hx = arr2(&[[0.1], [0.2]]);
354
355 let hy = cell.forward(&input, &hx);
356
357 assert_eq!(hy.shape(), &[hidden_size, 1]);
358 }
359
360 #[test]
361 fn test_gru_cell_with_dropout() {
362 let input_size = 3;
363 let hidden_size = 2;
364 let mut cell = GRUCell::new(input_size, hidden_size)
365 .with_input_dropout(0.2, false)
366 .with_recurrent_dropout(0.3, true)
367 .with_output_dropout(0.1);
368
369 let input = arr2(&[[0.5], [0.1], [-0.3]]);
370 let hx = arr2(&[[0.1], [0.2]]);
371
372 cell.train();
374 let hy_train = cell.forward(&input, &hx);
375
376 cell.eval();
378 let hy_eval = cell.forward(&input, &hx);
379
380 assert_eq!(hy_train.shape(), &[hidden_size, 1]);
381 assert_eq!(hy_eval.shape(), &[hidden_size, 1]);
382 }
383
384 #[test]
385 fn test_gru_backward_pass() {
386 let input_size = 2;
387 let hidden_size = 3;
388 let mut cell = GRUCell::new(input_size, hidden_size);
389
390 let input = arr2(&[[1.0], [0.5]]);
391 let hx = arr2(&[[0.1], [0.2], [0.3]]);
392
393 let (_hy, cache) = cell.forward_with_cache(&input, &hx);
394
395 let dhy = arr2(&[[1.0], [1.0], [1.0]]);
396 let (gradients, dx, dhx) = cell.backward(&dhy, &cache);
397
398 assert_eq!(gradients.w_ir.shape(), &[hidden_size, input_size]);
399 assert_eq!(gradients.w_hr.shape(), &[hidden_size, hidden_size]);
400 assert_eq!(dx.shape(), &[input_size, 1]);
401 assert_eq!(dhx.shape(), &[hidden_size, 1]);
402 }
403}