1use ndarray::{Array2, s};
2use ndarray_rand::RandomExt;
3use ndarray_rand::rand_distr::Uniform;
4use crate::utils::sigmoid;
5use crate::layers::dropout::{Dropout, Zoneout};
6
7#[derive(Clone)]
9pub struct LSTMCellGradients {
10 pub w_ih: Array2<f64>,
11 pub w_hh: Array2<f64>,
12 pub b_ih: Array2<f64>,
13 pub b_hh: Array2<f64>,
14}
15
16#[derive(Clone)]
18pub struct LSTMCellCache {
19 pub input: Array2<f64>,
20 pub hx: Array2<f64>,
21 pub cx: Array2<f64>,
22 pub gates: Array2<f64>,
23 pub input_gate: Array2<f64>,
24 pub forget_gate: Array2<f64>,
25 pub cell_gate: Array2<f64>,
26 pub output_gate: Array2<f64>,
27 pub cy: Array2<f64>,
28 pub hy: Array2<f64>,
29 pub input_dropout_mask: Option<Array2<f64>>,
30 pub recurrent_dropout_mask: Option<Array2<f64>>,
31 pub output_dropout_mask: Option<Array2<f64>>,
32}
33
34#[derive(Clone)]
36pub struct LSTMCell {
37 pub w_ih: Array2<f64>,
38 pub w_hh: Array2<f64>,
39 pub b_ih: Array2<f64>,
40 pub b_hh: Array2<f64>,
41 pub hidden_size: usize,
42 pub input_dropout: Option<Dropout>,
43 pub recurrent_dropout: Option<Dropout>,
44 pub output_dropout: Option<Dropout>,
45 pub zoneout: Option<Zoneout>,
46 pub is_training: bool,
47}
48
49impl LSTMCell {
50 pub fn new(input_size: usize, hidden_size: usize) -> Self {
52 let dist = Uniform::new(-0.1, 0.1);
53
54 let w_ih = Array2::random((4 * hidden_size, input_size), dist);
55 let w_hh = Array2::random((4 * hidden_size, hidden_size), dist);
56 let b_ih = Array2::zeros((4 * hidden_size, 1));
57 let b_hh = Array2::zeros((4 * hidden_size, 1));
58
59 LSTMCell {
60 w_ih,
61 w_hh,
62 b_ih,
63 b_hh,
64 hidden_size,
65 input_dropout: None,
66 recurrent_dropout: None,
67 output_dropout: None,
68 zoneout: None,
69 is_training: true,
70 }
71 }
72
73 pub fn with_input_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
74 if variational {
75 self.input_dropout = Some(Dropout::variational(dropout_rate));
76 } else {
77 self.input_dropout = Some(Dropout::new(dropout_rate));
78 }
79 self
80 }
81
82 pub fn with_recurrent_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
83 if variational {
84 self.recurrent_dropout = Some(Dropout::variational(dropout_rate));
85 } else {
86 self.recurrent_dropout = Some(Dropout::new(dropout_rate));
87 }
88 self
89 }
90
91 pub fn with_output_dropout(mut self, dropout_rate: f64) -> Self {
92 self.output_dropout = Some(Dropout::new(dropout_rate));
93 self
94 }
95
96 pub fn with_zoneout(mut self, cell_zoneout_rate: f64, hidden_zoneout_rate: f64) -> Self {
97 self.zoneout = Some(Zoneout::new(cell_zoneout_rate, hidden_zoneout_rate));
98 self
99 }
100
101 pub fn train(&mut self) {
102 self.is_training = true;
103 if let Some(ref mut dropout) = self.input_dropout {
104 dropout.train();
105 }
106 if let Some(ref mut dropout) = self.recurrent_dropout {
107 dropout.train();
108 }
109 if let Some(ref mut dropout) = self.output_dropout {
110 dropout.train();
111 }
112 if let Some(ref mut zoneout) = self.zoneout {
113 zoneout.train();
114 }
115 }
116
117 pub fn eval(&mut self) {
118 self.is_training = false;
119 if let Some(ref mut dropout) = self.input_dropout {
120 dropout.eval();
121 }
122 if let Some(ref mut dropout) = self.recurrent_dropout {
123 dropout.eval();
124 }
125 if let Some(ref mut dropout) = self.output_dropout {
126 dropout.eval();
127 }
128 if let Some(ref mut zoneout) = self.zoneout {
129 zoneout.eval();
130 }
131 }
132
133 pub fn forward(&mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
134 let (hy, cy, _) = self.forward_with_cache(input, hx, cx);
135 (hy, cy)
136 }
137
138 pub fn forward_with_cache(&mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>, LSTMCellCache) {
139 let (input_dropped, input_mask) = if let Some(ref mut dropout) = self.input_dropout {
140 let dropped = dropout.forward(input);
141 let mask = dropout.get_last_mask().map(|m| m.clone());
142 (dropped, mask)
143 } else {
144 (input.clone(), None)
145 };
146
147 let (hx_dropped, recurrent_mask) = if let Some(ref mut dropout) = self.recurrent_dropout {
148 let dropped = dropout.forward(hx);
149 let mask = dropout.get_last_mask().map(|m| m.clone());
150 (dropped, mask)
151 } else {
152 (hx.clone(), None)
153 };
154
155 let gates = &self.w_ih.dot(&input_dropped) + &self.b_ih + &self.w_hh.dot(&hx_dropped) + &self.b_hh;
157
158 let input_gate = gates.slice(s![0..self.hidden_size, ..]).map(|&x| sigmoid(x));
159 let forget_gate = gates.slice(s![self.hidden_size..2*self.hidden_size, ..]).map(|&x| sigmoid(x));
160 let cell_gate = gates.slice(s![2*self.hidden_size..3*self.hidden_size, ..]).map(|&x| x.tanh());
161 let output_gate = gates.slice(s![3*self.hidden_size..4*self.hidden_size, ..]).map(|&x| sigmoid(x));
162
163 let mut cy = &forget_gate * cx + &input_gate * &cell_gate;
164
165 if let Some(ref zoneout) = self.zoneout {
166 cy = zoneout.apply_cell_zoneout(&cy, cx);
167 }
168
169 let mut hy = &output_gate * cy.map(|&x| x.tanh());
170
171 if let Some(ref zoneout) = self.zoneout {
172 hy = zoneout.apply_hidden_zoneout(&hy, hx);
173 }
174
175 let (hy_final, output_mask) = if let Some(ref mut dropout) = self.output_dropout {
176 let dropped = dropout.forward(&hy);
177 let mask = dropout.get_last_mask().map(|m| m.clone());
178 (dropped, mask)
179 } else {
180 (hy, None)
181 };
182
183 let cache = LSTMCellCache {
184 input: input.clone(),
185 hx: hx.clone(),
186 cx: cx.clone(),
187 gates: gates,
188 input_gate: input_gate.to_owned(),
189 forget_gate: forget_gate.to_owned(),
190 cell_gate: cell_gate.to_owned(),
191 output_gate: output_gate.to_owned(),
192 cy: cy.clone(),
193 hy: hy_final.clone(),
194 input_dropout_mask: input_mask,
195 recurrent_dropout_mask: recurrent_mask,
196 output_dropout_mask: output_mask,
197 };
198
199 (hy_final, cy, cache)
200 }
201
202 pub fn backward(&self, dhy: &Array2<f64>, dcy: &Array2<f64>, cache: &LSTMCellCache) -> (LSTMCellGradients, Array2<f64>, Array2<f64>, Array2<f64>) {
206 let hidden_size = self.hidden_size;
207
208 let dhy_dropped = if let Some(ref mask) = cache.output_dropout_mask {
210 let keep_prob = if let Some(ref dropout) = self.output_dropout {
211 1.0 - dropout.dropout_rate
212 } else {
213 1.0
214 };
215 dhy * mask / keep_prob
216 } else {
217 dhy.clone()
218 };
219
220 let tanh_cy = cache.cy.map(|&x| x.tanh());
222 let do_t = &dhy_dropped * &tanh_cy;
223 let do_raw = &do_t * &cache.output_gate * (&cache.output_gate.map(|&x| 1.0 - x));
224
225 let dcy_from_tanh = &dhy_dropped * &cache.output_gate * cache.cy.map(|&x| 1.0 - x.tanh().powi(2));
227 let dcy_total = dcy + dcy_from_tanh;
228
229 let df_t = &dcy_total * &cache.cx;
231 let df_raw = &df_t * &cache.forget_gate * cache.forget_gate.map(|&x| 1.0 - x);
232
233 let di_t = &dcy_total * &cache.cell_gate;
235 let di_raw = &di_t * &cache.input_gate * cache.input_gate.map(|&x| 1.0 - x);
236
237 let dc_t = &dcy_total * &cache.input_gate;
239 let dc_raw = &dc_t * cache.cell_gate.map(|&x| 1.0 - x.powi(2));
240
241 let mut dgates = Array2::zeros((4 * hidden_size, 1));
243 dgates.slice_mut(s![0..hidden_size, ..]).assign(&di_raw);
244 dgates.slice_mut(s![hidden_size..2*hidden_size, ..]).assign(&df_raw);
245 dgates.slice_mut(s![2*hidden_size..3*hidden_size, ..]).assign(&dc_raw);
246 dgates.slice_mut(s![3*hidden_size..4*hidden_size, ..]).assign(&do_raw);
247
248 let dw_ih = dgates.dot(&cache.input.t());
250 let dw_hh = dgates.dot(&cache.hx.t());
251 let db_ih = dgates.clone();
252 let db_hh = dgates.clone();
253
254 let gradients = LSTMCellGradients {
255 w_ih: dw_ih,
256 w_hh: dw_hh,
257 b_ih: db_ih,
258 b_hh: db_hh,
259 };
260
261 let mut dx = self.w_ih.t().dot(&dgates);
262 let mut dhx = self.w_hh.t().dot(&dgates);
263 let dcx = &dcy_total * &cache.forget_gate;
264
265 if let Some(ref mask) = cache.input_dropout_mask {
266 let keep_prob = if let Some(ref dropout) = self.input_dropout {
267 1.0 - dropout.dropout_rate
268 } else {
269 1.0
270 };
271 dx = dx * mask / keep_prob;
272 }
273
274 if let Some(ref mask) = cache.recurrent_dropout_mask {
275 let keep_prob = if let Some(ref dropout) = self.recurrent_dropout {
276 1.0 - dropout.dropout_rate
277 } else {
278 1.0
279 };
280 dhx = dhx * mask / keep_prob;
281 }
282
283 (gradients, dx, dhx, dcx)
284 }
285
286 pub fn zero_gradients(&self) -> LSTMCellGradients {
288 LSTMCellGradients {
289 w_ih: Array2::zeros(self.w_ih.raw_dim()),
290 w_hh: Array2::zeros(self.w_hh.raw_dim()),
291 b_ih: Array2::zeros(self.b_ih.raw_dim()),
292 b_hh: Array2::zeros(self.b_hh.raw_dim()),
293 }
294 }
295
296 pub fn update_parameters<O: crate::optimizers::Optimizer>(&mut self, gradients: &LSTMCellGradients, optimizer: &mut O, prefix: &str) {
298 optimizer.update(&format!("{}_w_ih", prefix), &mut self.w_ih, &gradients.w_ih);
299 optimizer.update(&format!("{}_w_hh", prefix), &mut self.w_hh, &gradients.w_hh);
300 optimizer.update(&format!("{}_b_ih", prefix), &mut self.b_ih, &gradients.b_ih);
301 optimizer.update(&format!("{}_b_hh", prefix), &mut self.b_hh, &gradients.b_hh);
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use ndarray::arr2;
309
310 #[test]
311 fn test_lstm_cell_forward() {
312 let input_size = 3;
313 let hidden_size = 2;
314 let mut cell = LSTMCell::new(input_size, hidden_size);
315
316 let input = arr2(&[[0.5], [0.1], [-0.3]]);
317 let hx = arr2(&[[0.0], [0.0]]);
318 let cx = arr2(&[[0.0], [0.0]]);
319
320 let (hy, cy) = cell.forward(&input, &hx, &cx);
321
322 assert_eq!(hy.shape(), &[hidden_size, 1]);
323 assert_eq!(cy.shape(), &[hidden_size, 1]);
324 }
325
326 #[test]
327 fn test_lstm_cell_with_dropout() {
328 let input_size = 3;
329 let hidden_size = 2;
330 let mut cell = LSTMCell::new(input_size, hidden_size)
331 .with_input_dropout(0.2, false)
332 .with_recurrent_dropout(0.3, true)
333 .with_output_dropout(0.1)
334 .with_zoneout(0.1, 0.1);
335
336 let input = arr2(&[[0.5], [0.1], [-0.3]]);
337 let hx = arr2(&[[0.0], [0.0]]);
338 let cx = arr2(&[[0.0], [0.0]]);
339
340 cell.train();
342 let (hy_train, cy_train) = cell.forward(&input, &hx, &cx);
343
344 cell.eval();
346 let (hy_eval, cy_eval) = cell.forward(&input, &hx, &cx);
347
348 assert_eq!(hy_train.shape(), &[hidden_size, 1]);
349 assert_eq!(cy_train.shape(), &[hidden_size, 1]);
350 assert_eq!(hy_eval.shape(), &[hidden_size, 1]);
351 assert_eq!(cy_eval.shape(), &[hidden_size, 1]);
352 }
353
354 #[test]
355 fn test_dropout_mask_backward_pass() {
356 let input_size = 2;
357 let hidden_size = 3;
358 let mut cell = LSTMCell::new(input_size, hidden_size)
359 .with_input_dropout(0.5, false)
360 .with_output_dropout(0.5);
361
362 let input = arr2(&[[1.0], [0.5]]);
363 let hx = arr2(&[[0.1], [0.2], [0.3]]);
364 let cx = arr2(&[[0.0], [0.0], [0.0]]);
365
366 cell.train();
367 let (_hy, _cy, cache) = cell.forward_with_cache(&input, &hx, &cx);
368
369 assert!(cache.input_dropout_mask.is_some());
370 assert!(cache.output_dropout_mask.is_some());
371
372 let dhy = arr2(&[[1.0], [1.0], [1.0]]);
373 let dcy = arr2(&[[0.0], [0.0], [0.0]]);
374
375 let (gradients, dx, dhx, dcx) = cell.backward(&dhy, &dcy, &cache);
376
377 assert_eq!(gradients.w_ih.shape(), &[4 * hidden_size, input_size]);
378 assert_eq!(gradients.w_hh.shape(), &[4 * hidden_size, hidden_size]);
379 assert_eq!(dx.shape(), &[input_size, 1]);
380 assert_eq!(dhx.shape(), &[hidden_size, 1]);
381 assert_eq!(dcx.shape(), &[hidden_size, 1]);
382
383 cell.eval();
384 let (_, _, cache_eval) = cell.forward_with_cache(&input, &hx, &cx);
385 assert!(cache_eval.input_dropout_mask.is_none());
386 assert!(cache_eval.output_dropout_mask.is_none());
387 }
388}