1use ndarray::{Array2, s};
2use ndarray_rand::RandomExt;
3use ndarray_rand::rand_distr::Uniform;
4use crate::utils::sigmoid;
5use crate::layers::dropout::{Dropout, Zoneout};
6
7#[derive(Clone)]
9pub struct LSTMCellGradients {
10 pub w_ih: Array2<f64>,
11 pub w_hh: Array2<f64>,
12 pub b_ih: Array2<f64>,
13 pub b_hh: Array2<f64>,
14}
15
16#[derive(Clone)]
18pub struct LSTMCellCache {
19 pub input: Array2<f64>,
20 pub hx: Array2<f64>,
21 pub cx: Array2<f64>,
22 pub gates: Array2<f64>,
23 pub input_gate: Array2<f64>,
24 pub forget_gate: Array2<f64>,
25 pub cell_gate: Array2<f64>,
26 pub output_gate: Array2<f64>,
27 pub cy: Array2<f64>,
28 pub hy: Array2<f64>,
29 pub input_dropout_mask: Option<Array2<f64>>,
30 pub recurrent_dropout_mask: Option<Array2<f64>>,
31 pub output_dropout_mask: Option<Array2<f64>>,
32}
33
34#[derive(Clone)]
36pub struct LSTMCellBatchCache {
37 pub input: Array2<f64>,
38 pub hx: Array2<f64>,
39 pub cx: Array2<f64>,
40 pub gates: Array2<f64>,
41 pub input_gate: Array2<f64>,
42 pub forget_gate: Array2<f64>,
43 pub cell_gate: Array2<f64>,
44 pub output_gate: Array2<f64>,
45 pub cy: Array2<f64>,
46 pub hy: Array2<f64>,
47 pub input_dropout_mask: Option<Array2<f64>>,
48 pub recurrent_dropout_mask: Option<Array2<f64>>,
49 pub output_dropout_mask: Option<Array2<f64>>,
50 pub batch_size: usize,
51}
52
53#[derive(Clone)]
55pub struct LSTMCell {
56 pub w_ih: Array2<f64>,
57 pub w_hh: Array2<f64>,
58 pub b_ih: Array2<f64>,
59 pub b_hh: Array2<f64>,
60 pub hidden_size: usize,
61 pub input_dropout: Option<Dropout>,
62 pub recurrent_dropout: Option<Dropout>,
63 pub output_dropout: Option<Dropout>,
64 pub zoneout: Option<Zoneout>,
65 pub is_training: bool,
66}
67
68impl LSTMCell {
69 pub fn new(input_size: usize, hidden_size: usize) -> Self {
71 let dist = Uniform::new(-0.1, 0.1);
72
73 let w_ih = Array2::random((4 * hidden_size, input_size), dist);
74 let w_hh = Array2::random((4 * hidden_size, hidden_size), dist);
75 let b_ih = Array2::zeros((4 * hidden_size, 1));
76 let b_hh = Array2::zeros((4 * hidden_size, 1));
77
78 LSTMCell {
79 w_ih,
80 w_hh,
81 b_ih,
82 b_hh,
83 hidden_size,
84 input_dropout: None,
85 recurrent_dropout: None,
86 output_dropout: None,
87 zoneout: None,
88 is_training: true,
89 }
90 }
91
92 pub fn with_input_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
93 if variational {
94 self.input_dropout = Some(Dropout::variational(dropout_rate));
95 } else {
96 self.input_dropout = Some(Dropout::new(dropout_rate));
97 }
98 self
99 }
100
101 pub fn with_recurrent_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
102 if variational {
103 self.recurrent_dropout = Some(Dropout::variational(dropout_rate));
104 } else {
105 self.recurrent_dropout = Some(Dropout::new(dropout_rate));
106 }
107 self
108 }
109
110 pub fn with_output_dropout(mut self, dropout_rate: f64) -> Self {
111 self.output_dropout = Some(Dropout::new(dropout_rate));
112 self
113 }
114
115 pub fn with_zoneout(mut self, cell_zoneout_rate: f64, hidden_zoneout_rate: f64) -> Self {
116 self.zoneout = Some(Zoneout::new(cell_zoneout_rate, hidden_zoneout_rate));
117 self
118 }
119
120 pub fn train(&mut self) {
121 self.is_training = true;
122 if let Some(ref mut dropout) = self.input_dropout {
123 dropout.train();
124 }
125 if let Some(ref mut dropout) = self.recurrent_dropout {
126 dropout.train();
127 }
128 if let Some(ref mut dropout) = self.output_dropout {
129 dropout.train();
130 }
131 if let Some(ref mut zoneout) = self.zoneout {
132 zoneout.train();
133 }
134 }
135
136 pub fn eval(&mut self) {
137 self.is_training = false;
138 if let Some(ref mut dropout) = self.input_dropout {
139 dropout.eval();
140 }
141 if let Some(ref mut dropout) = self.recurrent_dropout {
142 dropout.eval();
143 }
144 if let Some(ref mut dropout) = self.output_dropout {
145 dropout.eval();
146 }
147 if let Some(ref mut zoneout) = self.zoneout {
148 zoneout.eval();
149 }
150 }
151
152 pub fn forward(&mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
153 let (hy, cy, _) = self.forward_with_cache(input, hx, cx);
154 (hy, cy)
155 }
156
157 pub fn forward_with_cache(&mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>, LSTMCellCache) {
158 let (input_dropped, input_mask) = if let Some(ref mut dropout) = self.input_dropout {
159 let dropped = dropout.forward(input);
160 let mask = dropout.get_last_mask().map(|m| m.clone());
161 (dropped, mask)
162 } else {
163 (input.clone(), None)
164 };
165
166 let (hx_dropped, recurrent_mask) = if let Some(ref mut dropout) = self.recurrent_dropout {
167 let dropped = dropout.forward(hx);
168 let mask = dropout.get_last_mask().map(|m| m.clone());
169 (dropped, mask)
170 } else {
171 (hx.clone(), None)
172 };
173
174 let gates = &self.w_ih.dot(&input_dropped) + &self.b_ih + &self.w_hh.dot(&hx_dropped) + &self.b_hh;
176
177 let input_gate = gates.slice(s![0..self.hidden_size, ..]).map(|&x| sigmoid(x));
178 let forget_gate = gates.slice(s![self.hidden_size..2*self.hidden_size, ..]).map(|&x| sigmoid(x));
179 let cell_gate = gates.slice(s![2*self.hidden_size..3*self.hidden_size, ..]).map(|&x| x.tanh());
180 let output_gate = gates.slice(s![3*self.hidden_size..4*self.hidden_size, ..]).map(|&x| sigmoid(x));
181
182 let mut cy = &forget_gate * cx + &input_gate * &cell_gate;
183
184 if let Some(ref zoneout) = self.zoneout {
185 cy = zoneout.apply_cell_zoneout(&cy, cx);
186 }
187
188 let mut hy = &output_gate * cy.map(|&x| x.tanh());
189
190 if let Some(ref zoneout) = self.zoneout {
191 hy = zoneout.apply_hidden_zoneout(&hy, hx);
192 }
193
194 let (hy_final, output_mask) = if let Some(ref mut dropout) = self.output_dropout {
195 let dropped = dropout.forward(&hy);
196 let mask = dropout.get_last_mask().map(|m| m.clone());
197 (dropped, mask)
198 } else {
199 (hy, None)
200 };
201
202 let cache = LSTMCellCache {
203 input: input.clone(),
204 hx: hx.clone(),
205 cx: cx.clone(),
206 gates: gates,
207 input_gate: input_gate.to_owned(),
208 forget_gate: forget_gate.to_owned(),
209 cell_gate: cell_gate.to_owned(),
210 output_gate: output_gate.to_owned(),
211 cy: cy.clone(),
212 hy: hy_final.clone(),
213 input_dropout_mask: input_mask,
214 recurrent_dropout_mask: recurrent_mask,
215 output_dropout_mask: output_mask,
216 };
217
218 (hy_final, cy, cache)
219 }
220
221 pub fn forward_batch(&mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
231 let batch_size = input.ncols();
232 assert_eq!(hx.ncols(), batch_size, "Hidden state batch size must match input batch size");
233 assert_eq!(cx.ncols(), batch_size, "Cell state batch size must match input batch size");
234 assert_eq!(input.nrows(), self.w_ih.ncols(), "Input feature size must match weight matrix");
235 assert_eq!(hx.nrows(), self.hidden_size, "Hidden state size must match network hidden size");
236 assert_eq!(cx.nrows(), self.hidden_size, "Cell state size must match network hidden size");
237
238 let (input_dropped, _input_mask) = if let Some(ref mut dropout) = self.input_dropout {
240 let dropped = dropout.forward(input);
241 let mask = dropout.get_last_mask().map(|m| m.clone());
242 (dropped, mask)
243 } else {
244 (input.clone(), None)
245 };
246
247 let (hx_dropped, _recurrent_mask) = if let Some(ref mut dropout) = self.recurrent_dropout {
249 let dropped = dropout.forward(hx);
250 let mask = dropout.get_last_mask().map(|m| m.clone());
251 (dropped, mask)
252 } else {
253 (hx.clone(), None)
254 };
255
256 let gates = &self.w_ih.dot(&input_dropped) + &self.b_ih.broadcast((4 * self.hidden_size, batch_size)).unwrap()
259 + &self.w_hh.dot(&hx_dropped) + &self.b_hh.broadcast((4 * self.hidden_size, batch_size)).unwrap();
260
261 let input_gate = gates.slice(s![0..self.hidden_size, ..]).map(|&x| sigmoid(x));
263 let forget_gate = gates.slice(s![self.hidden_size..2*self.hidden_size, ..]).map(|&x| sigmoid(x));
264 let cell_gate = gates.slice(s![2*self.hidden_size..3*self.hidden_size, ..]).map(|&x| x.tanh());
265 let output_gate = gates.slice(s![3*self.hidden_size..4*self.hidden_size, ..]).map(|&x| sigmoid(x));
266
267 let mut cy = &forget_gate * cx + &input_gate * &cell_gate;
269
270 if let Some(ref zoneout) = self.zoneout {
272 for col_idx in 0..batch_size {
273 let cy_col = cy.column(col_idx).to_owned().insert_axis(ndarray::Axis(1));
274 let cx_col = cx.column(col_idx).to_owned().insert_axis(ndarray::Axis(1));
275 let cy_zoneout = zoneout.apply_cell_zoneout(&cy_col, &cx_col);
276 cy.column_mut(col_idx).assign(&cy_zoneout.column(0));
277 }
278 }
279
280 let mut hy = &output_gate * cy.map(|&x| x.tanh());
282
283 if let Some(ref zoneout) = self.zoneout {
285 for col_idx in 0..batch_size {
286 let hy_col = hy.column(col_idx).to_owned().insert_axis(ndarray::Axis(1));
287 let hx_col = hx.column(col_idx).to_owned().insert_axis(ndarray::Axis(1));
288 let hy_zoneout = zoneout.apply_hidden_zoneout(&hy_col, &hx_col);
289 hy.column_mut(col_idx).assign(&hy_zoneout.column(0));
290 }
291 }
292
293 let hy_final = if let Some(ref mut dropout) = self.output_dropout {
295 dropout.forward(&hy)
296 } else {
297 hy
298 };
299
300 (hy_final, cy)
301 }
302
303 pub fn forward_batch_with_cache(&mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>, LSTMCellBatchCache) {
307 let batch_size = input.ncols();
308
309 let (input_dropped, input_mask) = if let Some(ref mut dropout) = self.input_dropout {
311 let dropped = dropout.forward(input);
312 let mask = dropout.get_last_mask().map(|m| m.clone());
313 (dropped, mask)
314 } else {
315 (input.clone(), None)
316 };
317
318 let (hx_dropped, recurrent_mask) = if let Some(ref mut dropout) = self.recurrent_dropout {
319 let dropped = dropout.forward(hx);
320 let mask = dropout.get_last_mask().map(|m| m.clone());
321 (dropped, mask)
322 } else {
323 (hx.clone(), None)
324 };
325
326 let gates = &self.w_ih.dot(&input_dropped) + &self.b_ih.broadcast((4 * self.hidden_size, batch_size)).unwrap()
328 + &self.w_hh.dot(&hx_dropped) + &self.b_hh.broadcast((4 * self.hidden_size, batch_size)).unwrap();
329
330 let input_gate = gates.slice(s![0..self.hidden_size, ..]).map(|&x| sigmoid(x));
331 let forget_gate = gates.slice(s![self.hidden_size..2*self.hidden_size, ..]).map(|&x| sigmoid(x));
332 let cell_gate = gates.slice(s![2*self.hidden_size..3*self.hidden_size, ..]).map(|&x| x.tanh());
333 let output_gate = gates.slice(s![3*self.hidden_size..4*self.hidden_size, ..]).map(|&x| sigmoid(x));
334
335 let mut cy = &forget_gate * cx + &input_gate * &cell_gate;
336
337 if let Some(ref zoneout) = self.zoneout {
339 for col_idx in 0..batch_size {
340 let cy_col = cy.column(col_idx).to_owned().insert_axis(ndarray::Axis(1));
341 let cx_col = cx.column(col_idx).to_owned().insert_axis(ndarray::Axis(1));
342 let cy_zoneout = zoneout.apply_cell_zoneout(&cy_col, &cx_col);
343 cy.column_mut(col_idx).assign(&cy_zoneout.column(0));
344 }
345 }
346
347 let mut hy = &output_gate * cy.map(|&x| x.tanh());
348
349 if let Some(ref zoneout) = self.zoneout {
350 for col_idx in 0..batch_size {
351 let hy_col = hy.column(col_idx).to_owned().insert_axis(ndarray::Axis(1));
352 let hx_col = hx.column(col_idx).to_owned().insert_axis(ndarray::Axis(1));
353 let hy_zoneout = zoneout.apply_hidden_zoneout(&hy_col, &hx_col);
354 hy.column_mut(col_idx).assign(&hy_zoneout.column(0));
355 }
356 }
357
358 let (hy_final, output_mask) = if let Some(ref mut dropout) = self.output_dropout {
359 let dropped = dropout.forward(&hy);
360 let mask = dropout.get_last_mask().map(|m| m.clone());
361 (dropped, mask)
362 } else {
363 (hy, None)
364 };
365
366 let cache = LSTMCellBatchCache {
368 input: input.clone(),
369 hx: hx.clone(),
370 cx: cx.clone(),
371 gates: gates.to_owned(),
372 input_gate: input_gate.to_owned(),
373 forget_gate: forget_gate.to_owned(),
374 cell_gate: cell_gate.to_owned(),
375 output_gate: output_gate.to_owned(),
376 cy: cy.clone(),
377 hy: hy_final.clone(),
378 input_dropout_mask: input_mask,
379 recurrent_dropout_mask: recurrent_mask,
380 output_dropout_mask: output_mask,
381 batch_size,
382 };
383
384 (hy_final, cy, cache)
385 }
386
387 pub fn backward(&self, dhy: &Array2<f64>, dcy: &Array2<f64>, cache: &LSTMCellCache) -> (LSTMCellGradients, Array2<f64>, Array2<f64>, Array2<f64>) {
391 let hidden_size = self.hidden_size;
392
393 let dhy_dropped = if let Some(ref mask) = cache.output_dropout_mask {
395 let keep_prob = if let Some(ref dropout) = self.output_dropout {
396 1.0 - dropout.dropout_rate
397 } else {
398 1.0
399 };
400 dhy * mask / keep_prob
401 } else {
402 dhy.clone()
403 };
404
405 let tanh_cy = cache.cy.map(|&x| x.tanh());
407 let do_t = &dhy_dropped * &tanh_cy;
408 let do_raw = &do_t * &cache.output_gate * (&cache.output_gate.map(|&x| 1.0 - x));
409
410 let dcy_from_tanh = &dhy_dropped * &cache.output_gate * cache.cy.map(|&x| 1.0 - x.tanh().powi(2));
412 let dcy_total = dcy + dcy_from_tanh;
413
414 let df_t = &dcy_total * &cache.cx;
416 let df_raw = &df_t * &cache.forget_gate * cache.forget_gate.map(|&x| 1.0 - x);
417
418 let di_t = &dcy_total * &cache.cell_gate;
420 let di_raw = &di_t * &cache.input_gate * cache.input_gate.map(|&x| 1.0 - x);
421
422 let dc_t = &dcy_total * &cache.input_gate;
424 let dc_raw = &dc_t * cache.cell_gate.map(|&x| 1.0 - x.powi(2));
425
426 let mut dgates = Array2::zeros((4 * hidden_size, 1));
428 dgates.slice_mut(s![0..hidden_size, ..]).assign(&di_raw);
429 dgates.slice_mut(s![hidden_size..2*hidden_size, ..]).assign(&df_raw);
430 dgates.slice_mut(s![2*hidden_size..3*hidden_size, ..]).assign(&dc_raw);
431 dgates.slice_mut(s![3*hidden_size..4*hidden_size, ..]).assign(&do_raw);
432
433 let dw_ih = dgates.dot(&cache.input.t());
435 let dw_hh = dgates.dot(&cache.hx.t());
436 let db_ih = dgates.clone();
437 let db_hh = dgates.clone();
438
439 let gradients = LSTMCellGradients {
440 w_ih: dw_ih,
441 w_hh: dw_hh,
442 b_ih: db_ih,
443 b_hh: db_hh,
444 };
445
446 let mut dx = self.w_ih.t().dot(&dgates);
447 let mut dhx = self.w_hh.t().dot(&dgates);
448 let dcx = &dcy_total * &cache.forget_gate;
449
450 if let Some(ref mask) = cache.input_dropout_mask {
451 let keep_prob = if let Some(ref dropout) = self.input_dropout {
452 1.0 - dropout.dropout_rate
453 } else {
454 1.0
455 };
456 dx = dx * mask / keep_prob;
457 }
458
459 if let Some(ref mask) = cache.recurrent_dropout_mask {
460 let keep_prob = if let Some(ref dropout) = self.recurrent_dropout {
461 1.0 - dropout.dropout_rate
462 } else {
463 1.0
464 };
465 dhx = dhx * mask / keep_prob;
466 }
467
468 (gradients, dx, dhx, dcx)
469 }
470
471 pub fn backward_batch(&self, dhy: &Array2<f64>, dcy: &Array2<f64>, cache: &LSTMCellBatchCache) -> (LSTMCellGradients, Array2<f64>, Array2<f64>, Array2<f64>) {
475 let batch_size = cache.batch_size;
476 let hidden_size = self.hidden_size;
477
478 let dhy_dropped = if let Some(ref mask) = cache.output_dropout_mask {
480 let keep_prob = if let Some(ref dropout) = self.output_dropout {
481 1.0 - dropout.dropout_rate
482 } else {
483 1.0
484 };
485 dhy * mask / keep_prob
486 } else {
487 dhy.clone()
488 };
489
490 let tanh_cy = cache.cy.map(|&x| x.tanh());
492 let do_t = &dhy_dropped * &tanh_cy;
493 let do_raw = &do_t * &cache.output_gate * &cache.output_gate.map(|&x| 1.0 - x);
494
495 let dcy_from_tanh = &dhy_dropped * &cache.output_gate * cache.cy.map(|&x| 1.0 - x.tanh().powi(2));
497 let dcy_total = dcy + dcy_from_tanh;
498
499 let df_t = &dcy_total * &cache.cx;
501 let df_raw = &df_t * &cache.forget_gate * cache.forget_gate.map(|&x| 1.0 - x);
502
503 let di_t = &dcy_total * &cache.cell_gate;
504 let di_raw = &di_t * &cache.input_gate * cache.input_gate.map(|&x| 1.0 - x);
505
506 let dc_t = &dcy_total * &cache.input_gate;
507 let dc_raw = &dc_t * cache.cell_gate.map(|&x| 1.0 - x.powi(2));
508
509 let mut dgates = Array2::zeros((4 * hidden_size, batch_size));
511 dgates.slice_mut(s![0..hidden_size, ..]).assign(&di_raw);
512 dgates.slice_mut(s![hidden_size..2*hidden_size, ..]).assign(&df_raw);
513 dgates.slice_mut(s![2*hidden_size..3*hidden_size, ..]).assign(&dc_raw);
514 dgates.slice_mut(s![3*hidden_size..4*hidden_size, ..]).assign(&do_raw);
515
516 let dw_ih = dgates.dot(&cache.input.t());
518 let dw_hh = dgates.dot(&cache.hx.t());
519 let db_ih = dgates.sum_axis(ndarray::Axis(1)).insert_axis(ndarray::Axis(1));
520 let db_hh = db_ih.clone();
521
522 let gradients = LSTMCellGradients {
523 w_ih: dw_ih,
524 w_hh: dw_hh,
525 b_ih: db_ih,
526 b_hh: db_hh,
527 };
528
529 let mut dx = self.w_ih.t().dot(&dgates);
531 let mut dhx = self.w_hh.t().dot(&dgates);
532 let dcx = &dcy_total * &cache.forget_gate;
533
534 if let Some(ref mask) = cache.input_dropout_mask {
536 let keep_prob = if let Some(ref dropout) = self.input_dropout {
537 1.0 - dropout.dropout_rate
538 } else {
539 1.0
540 };
541 dx = dx * mask / keep_prob;
542 }
543
544 if let Some(ref mask) = cache.recurrent_dropout_mask {
545 let keep_prob = if let Some(ref dropout) = self.recurrent_dropout {
546 1.0 - dropout.dropout_rate
547 } else {
548 1.0
549 };
550 dhx = dhx * mask / keep_prob;
551 }
552
553 (gradients, dx, dhx, dcx)
554 }
555
556 pub fn zero_gradients(&self) -> LSTMCellGradients {
558 LSTMCellGradients {
559 w_ih: Array2::zeros(self.w_ih.raw_dim()),
560 w_hh: Array2::zeros(self.w_hh.raw_dim()),
561 b_ih: Array2::zeros(self.b_ih.raw_dim()),
562 b_hh: Array2::zeros(self.b_hh.raw_dim()),
563 }
564 }
565
566 pub fn update_parameters<O: crate::optimizers::Optimizer>(&mut self, gradients: &LSTMCellGradients, optimizer: &mut O, prefix: &str) {
568 optimizer.update(&format!("{}_w_ih", prefix), &mut self.w_ih, &gradients.w_ih);
569 optimizer.update(&format!("{}_w_hh", prefix), &mut self.w_hh, &gradients.w_hh);
570 optimizer.update(&format!("{}_b_ih", prefix), &mut self.b_ih, &gradients.b_ih);
571 optimizer.update(&format!("{}_b_hh", prefix), &mut self.b_hh, &gradients.b_hh);
572 }
573}
574
575#[cfg(test)]
576mod tests {
577 use super::*;
578 use ndarray::arr2;
579
580 #[test]
581 fn test_lstm_cell_forward() {
582 let input_size = 3;
583 let hidden_size = 2;
584 let mut cell = LSTMCell::new(input_size, hidden_size);
585
586 let input = arr2(&[[0.5], [0.1], [-0.3]]);
587 let hx = arr2(&[[0.0], [0.0]]);
588 let cx = arr2(&[[0.0], [0.0]]);
589
590 let (hy, cy) = cell.forward(&input, &hx, &cx);
591
592 assert_eq!(hy.shape(), &[hidden_size, 1]);
593 assert_eq!(cy.shape(), &[hidden_size, 1]);
594 }
595
596 #[test]
597 fn test_lstm_cell_with_dropout() {
598 let input_size = 3;
599 let hidden_size = 2;
600 let mut cell = LSTMCell::new(input_size, hidden_size)
601 .with_input_dropout(0.2, false)
602 .with_recurrent_dropout(0.3, true)
603 .with_output_dropout(0.1)
604 .with_zoneout(0.1, 0.1);
605
606 let input = arr2(&[[0.5], [0.1], [-0.3]]);
607 let hx = arr2(&[[0.0], [0.0]]);
608 let cx = arr2(&[[0.0], [0.0]]);
609
610 cell.train();
612 let (hy_train, cy_train) = cell.forward(&input, &hx, &cx);
613
614 cell.eval();
616 let (hy_eval, cy_eval) = cell.forward(&input, &hx, &cx);
617
618 assert_eq!(hy_train.shape(), &[hidden_size, 1]);
619 assert_eq!(cy_train.shape(), &[hidden_size, 1]);
620 assert_eq!(hy_eval.shape(), &[hidden_size, 1]);
621 assert_eq!(cy_eval.shape(), &[hidden_size, 1]);
622 }
623
624 #[test]
625 fn test_dropout_mask_backward_pass() {
626 let input_size = 2;
627 let hidden_size = 3;
628 let mut cell = LSTMCell::new(input_size, hidden_size)
629 .with_input_dropout(0.5, false)
630 .with_output_dropout(0.5);
631
632 let input = arr2(&[[1.0], [0.5]]);
633 let hx = arr2(&[[0.1], [0.2], [0.3]]);
634 let cx = arr2(&[[0.0], [0.0], [0.0]]);
635
636 cell.train();
637 let (_hy, _cy, cache) = cell.forward_with_cache(&input, &hx, &cx);
638
639 assert!(cache.input_dropout_mask.is_some());
640 assert!(cache.output_dropout_mask.is_some());
641
642 let dhy = arr2(&[[1.0], [1.0], [1.0]]);
643 let dcy = arr2(&[[0.0], [0.0], [0.0]]);
644
645 let (gradients, dx, dhx, dcx) = cell.backward(&dhy, &dcy, &cache);
646
647 assert_eq!(gradients.w_ih.shape(), &[4 * hidden_size, input_size]);
648 assert_eq!(gradients.w_hh.shape(), &[4 * hidden_size, hidden_size]);
649 assert_eq!(dx.shape(), &[input_size, 1]);
650 assert_eq!(dhx.shape(), &[hidden_size, 1]);
651 assert_eq!(dcx.shape(), &[hidden_size, 1]);
652
653 cell.eval();
654 let (_, _, cache_eval) = cell.forward_with_cache(&input, &hx, &cx);
655 assert!(cache_eval.input_dropout_mask.is_none());
656 assert!(cache_eval.output_dropout_mask.is_none());
657 }
658}