rust_lstm/layers/
lstm_cell.rs

1use ndarray::{Array2, s};
2use ndarray_rand::RandomExt;
3use ndarray_rand::rand_distr::Uniform;
4use crate::utils::sigmoid;
5use crate::layers::dropout::{Dropout, Zoneout};
6
7/// Holds gradients for all LSTM cell parameters during backpropagation
8#[derive(Clone)]
9pub struct LSTMCellGradients {
10    pub w_ih: Array2<f64>,
11    pub w_hh: Array2<f64>,
12    pub b_ih: Array2<f64>,
13    pub b_hh: Array2<f64>,
14}
15
16/// Caches intermediate values during forward pass for efficient backward computation
17#[derive(Clone)]
18pub struct LSTMCellCache {
19    pub input: Array2<f64>,
20    pub hx: Array2<f64>,
21    pub cx: Array2<f64>,
22    pub gates: Array2<f64>,
23    pub input_gate: Array2<f64>,
24    pub forget_gate: Array2<f64>,
25    pub cell_gate: Array2<f64>,
26    pub output_gate: Array2<f64>,
27    pub cy: Array2<f64>,
28    pub hy: Array2<f64>,
29    pub input_dropout_mask: Option<Array2<f64>>,
30    pub recurrent_dropout_mask: Option<Array2<f64>>,
31    pub output_dropout_mask: Option<Array2<f64>>,
32}
33
34/// LSTM cell with trainable parameters and dropout support
35#[derive(Clone)]
36pub struct LSTMCell {
37    pub w_ih: Array2<f64>,
38    pub w_hh: Array2<f64>,
39    pub b_ih: Array2<f64>,
40    pub b_hh: Array2<f64>,
41    pub hidden_size: usize,
42    pub input_dropout: Option<Dropout>,
43    pub recurrent_dropout: Option<Dropout>,
44    pub output_dropout: Option<Dropout>,
45    pub zoneout: Option<Zoneout>,
46    pub is_training: bool,
47}
48
49impl LSTMCell {
50    /// Creates new LSTM cell with Xavier-uniform weight initialization
51    pub fn new(input_size: usize, hidden_size: usize) -> Self {
52        let dist = Uniform::new(-0.1, 0.1);
53
54        let w_ih = Array2::random((4 * hidden_size, input_size), dist);
55        let w_hh = Array2::random((4 * hidden_size, hidden_size), dist);
56        let b_ih = Array2::zeros((4 * hidden_size, 1));
57        let b_hh = Array2::zeros((4 * hidden_size, 1));
58
59        LSTMCell { 
60            w_ih, 
61            w_hh, 
62            b_ih, 
63            b_hh, 
64            hidden_size,
65            input_dropout: None,
66            recurrent_dropout: None,
67            output_dropout: None,
68            zoneout: None,
69            is_training: true,
70        }
71    }
72
73    pub fn with_input_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
74        if variational {
75            self.input_dropout = Some(Dropout::variational(dropout_rate));
76        } else {
77            self.input_dropout = Some(Dropout::new(dropout_rate));
78        }
79        self
80    }
81
82    pub fn with_recurrent_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
83        if variational {
84            self.recurrent_dropout = Some(Dropout::variational(dropout_rate));
85        } else {
86            self.recurrent_dropout = Some(Dropout::new(dropout_rate));
87        }
88        self
89    }
90
91    pub fn with_output_dropout(mut self, dropout_rate: f64) -> Self {
92        self.output_dropout = Some(Dropout::new(dropout_rate));
93        self
94    }
95
96    pub fn with_zoneout(mut self, cell_zoneout_rate: f64, hidden_zoneout_rate: f64) -> Self {
97        self.zoneout = Some(Zoneout::new(cell_zoneout_rate, hidden_zoneout_rate));
98        self
99    }
100
101    pub fn train(&mut self) {
102        self.is_training = true;
103        if let Some(ref mut dropout) = self.input_dropout {
104            dropout.train();
105        }
106        if let Some(ref mut dropout) = self.recurrent_dropout {
107            dropout.train();
108        }
109        if let Some(ref mut dropout) = self.output_dropout {
110            dropout.train();
111        }
112        if let Some(ref mut zoneout) = self.zoneout {
113            zoneout.train();
114        }
115    }
116
117    pub fn eval(&mut self) {
118        self.is_training = false;
119        if let Some(ref mut dropout) = self.input_dropout {
120            dropout.eval();
121        }
122        if let Some(ref mut dropout) = self.recurrent_dropout {
123            dropout.eval();
124        }
125        if let Some(ref mut dropout) = self.output_dropout {
126            dropout.eval();
127        }
128        if let Some(ref mut zoneout) = self.zoneout {
129            zoneout.eval();
130        }
131    }
132
133    pub fn forward(&mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
134        let (hy, cy, _) = self.forward_with_cache(input, hx, cx);
135        (hy, cy)
136    }
137
138    pub fn forward_with_cache(&mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>, LSTMCellCache) {
139        let (input_dropped, input_mask) = if let Some(ref mut dropout) = self.input_dropout {
140            let dropped = dropout.forward(input);
141            let mask = dropout.get_last_mask().map(|m| m.clone());
142            (dropped, mask)
143        } else {
144            (input.clone(), None)
145        };
146
147        let (hx_dropped, recurrent_mask) = if let Some(ref mut dropout) = self.recurrent_dropout {
148            let dropped = dropout.forward(hx);
149            let mask = dropout.get_last_mask().map(|m| m.clone());
150            (dropped, mask)
151        } else {
152            (hx.clone(), None)
153        };
154
155        // Compute all gates in parallel: [input_gate, forget_gate, cell_gate, output_gate]
156        let gates = &self.w_ih.dot(&input_dropped) + &self.b_ih + &self.w_hh.dot(&hx_dropped) + &self.b_hh;
157
158        let input_gate = gates.slice(s![0..self.hidden_size, ..]).map(|&x| sigmoid(x));
159        let forget_gate = gates.slice(s![self.hidden_size..2*self.hidden_size, ..]).map(|&x| sigmoid(x));
160        let cell_gate = gates.slice(s![2*self.hidden_size..3*self.hidden_size, ..]).map(|&x| x.tanh());
161        let output_gate = gates.slice(s![3*self.hidden_size..4*self.hidden_size, ..]).map(|&x| sigmoid(x));
162
163        let mut cy = &forget_gate * cx + &input_gate * &cell_gate;
164
165        if let Some(ref zoneout) = self.zoneout {
166            cy = zoneout.apply_cell_zoneout(&cy, cx);
167        }
168
169        let mut hy = &output_gate * cy.map(|&x| x.tanh());
170
171        if let Some(ref zoneout) = self.zoneout {
172            hy = zoneout.apply_hidden_zoneout(&hy, hx);
173        }
174
175        let (hy_final, output_mask) = if let Some(ref mut dropout) = self.output_dropout {
176            let dropped = dropout.forward(&hy);
177            let mask = dropout.get_last_mask().map(|m| m.clone());
178            (dropped, mask)
179        } else {
180            (hy, None)
181        };
182
183        let cache = LSTMCellCache {
184            input: input.clone(),
185            hx: hx.clone(),
186            cx: cx.clone(),
187            gates: gates,
188            input_gate: input_gate.to_owned(),
189            forget_gate: forget_gate.to_owned(),
190            cell_gate: cell_gate.to_owned(),
191            output_gate: output_gate.to_owned(),
192            cy: cy.clone(),
193            hy: hy_final.clone(),
194            input_dropout_mask: input_mask,
195            recurrent_dropout_mask: recurrent_mask,
196            output_dropout_mask: output_mask,
197        };
198
199        (hy_final, cy, cache)
200    }
201
202    /// Backward pass implementing LSTM gradient computation with dropout
203    /// 
204    /// Returns (parameter_gradients, input_gradient, hidden_gradient, cell_gradient)
205    pub fn backward(&self, dhy: &Array2<f64>, dcy: &Array2<f64>, cache: &LSTMCellCache) -> (LSTMCellGradients, Array2<f64>, Array2<f64>, Array2<f64>) {
206        let hidden_size = self.hidden_size;
207
208        // Apply output dropout backward pass using saved mask
209        let dhy_dropped = if let Some(ref mask) = cache.output_dropout_mask {
210            let keep_prob = if let Some(ref dropout) = self.output_dropout {
211                1.0 - dropout.dropout_rate
212            } else {
213                1.0
214            };
215            dhy * mask / keep_prob
216        } else {
217            dhy.clone()
218        };
219
220        // Output gate gradients: ∂L/∂o_t = ∂L/∂h_t ⊙ tanh(c_t)
221        let tanh_cy = cache.cy.map(|&x| x.tanh());
222        let do_t = &dhy_dropped * &tanh_cy;
223        let do_raw = &do_t * &cache.output_gate * (&cache.output_gate.map(|&x| 1.0 - x));
224
225        // Cell state gradients from both tanh and direct paths
226        let dcy_from_tanh = &dhy_dropped * &cache.output_gate * cache.cy.map(|&x| 1.0 - x.tanh().powi(2));
227        let dcy_total = dcy + dcy_from_tanh;
228
229        // Forget gate gradients: ∂L/∂f_t = ∂L/∂c_t ⊙ c_t-1
230        let df_t = &dcy_total * &cache.cx;
231        let df_raw = &df_t * &cache.forget_gate * cache.forget_gate.map(|&x| 1.0 - x);
232
233        // Input gate gradients: ∂L/∂i_t = ∂L/∂c_t ⊙ g_t
234        let di_t = &dcy_total * &cache.cell_gate;
235        let di_raw = &di_t * &cache.input_gate * cache.input_gate.map(|&x| 1.0 - x);
236
237        // Cell gate gradients: ∂L/∂g_t = ∂L/∂c_t ⊙ i_t
238        let dc_t = &dcy_total * &cache.input_gate;
239        let dc_raw = &dc_t * cache.cell_gate.map(|&x| 1.0 - x.powi(2));
240
241        // Concatenate gate gradients in the same order as forward pass
242        let mut dgates = Array2::zeros((4 * hidden_size, 1));
243        dgates.slice_mut(s![0..hidden_size, ..]).assign(&di_raw);
244        dgates.slice_mut(s![hidden_size..2*hidden_size, ..]).assign(&df_raw);
245        dgates.slice_mut(s![2*hidden_size..3*hidden_size, ..]).assign(&dc_raw);
246        dgates.slice_mut(s![3*hidden_size..4*hidden_size, ..]).assign(&do_raw);
247
248        // Parameter gradients using chain rule
249        let dw_ih = dgates.dot(&cache.input.t());
250        let dw_hh = dgates.dot(&cache.hx.t());
251        let db_ih = dgates.clone();
252        let db_hh = dgates.clone();
253
254        let gradients = LSTMCellGradients {
255            w_ih: dw_ih,
256            w_hh: dw_hh,
257            b_ih: db_ih,
258            b_hh: db_hh,
259        };
260
261        let mut dx = self.w_ih.t().dot(&dgates);
262        let mut dhx = self.w_hh.t().dot(&dgates);
263        let dcx = &dcy_total * &cache.forget_gate;
264
265        if let Some(ref mask) = cache.input_dropout_mask {
266            let keep_prob = if let Some(ref dropout) = self.input_dropout {
267                1.0 - dropout.dropout_rate
268            } else {
269                1.0
270            };
271            dx = dx * mask / keep_prob;
272        }
273
274        if let Some(ref mask) = cache.recurrent_dropout_mask {
275            let keep_prob = if let Some(ref dropout) = self.recurrent_dropout {
276                1.0 - dropout.dropout_rate
277            } else {
278                1.0
279            };
280            dhx = dhx * mask / keep_prob;
281        }
282
283        (gradients, dx, dhx, dcx)
284    }
285
286    /// Initialize zero gradients for accumulation
287    pub fn zero_gradients(&self) -> LSTMCellGradients {
288        LSTMCellGradients {
289            w_ih: Array2::zeros(self.w_ih.raw_dim()),
290            w_hh: Array2::zeros(self.w_hh.raw_dim()),
291            b_ih: Array2::zeros(self.b_ih.raw_dim()),
292            b_hh: Array2::zeros(self.b_hh.raw_dim()),
293        }
294    }
295
296    /// Apply gradients using the provided optimizer
297    pub fn update_parameters<O: crate::optimizers::Optimizer>(&mut self, gradients: &LSTMCellGradients, optimizer: &mut O, prefix: &str) {
298        optimizer.update(&format!("{}_w_ih", prefix), &mut self.w_ih, &gradients.w_ih);
299        optimizer.update(&format!("{}_w_hh", prefix), &mut self.w_hh, &gradients.w_hh);
300        optimizer.update(&format!("{}_b_ih", prefix), &mut self.b_ih, &gradients.b_ih);
301        optimizer.update(&format!("{}_b_hh", prefix), &mut self.b_hh, &gradients.b_hh);
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use ndarray::arr2;
309
310    #[test]
311    fn test_lstm_cell_forward() {
312        let input_size = 3;
313        let hidden_size = 2;
314        let mut cell = LSTMCell::new(input_size, hidden_size);
315
316        let input = arr2(&[[0.5], [0.1], [-0.3]]);
317        let hx = arr2(&[[0.0], [0.0]]);
318        let cx = arr2(&[[0.0], [0.0]]);
319
320        let (hy, cy) = cell.forward(&input, &hx, &cx);
321
322        assert_eq!(hy.shape(), &[hidden_size, 1]);
323        assert_eq!(cy.shape(), &[hidden_size, 1]);
324    }
325
326    #[test]
327    fn test_lstm_cell_with_dropout() {
328        let input_size = 3;
329        let hidden_size = 2;
330        let mut cell = LSTMCell::new(input_size, hidden_size)
331            .with_input_dropout(0.2, false)
332            .with_recurrent_dropout(0.3, true)
333            .with_output_dropout(0.1)
334            .with_zoneout(0.1, 0.1);
335
336        let input = arr2(&[[0.5], [0.1], [-0.3]]);
337        let hx = arr2(&[[0.0], [0.0]]);
338        let cx = arr2(&[[0.0], [0.0]]);
339
340        // Test training mode
341        cell.train();
342        let (hy_train, cy_train) = cell.forward(&input, &hx, &cx);
343
344        // Test evaluation mode
345        cell.eval();
346        let (hy_eval, cy_eval) = cell.forward(&input, &hx, &cx);
347
348        assert_eq!(hy_train.shape(), &[hidden_size, 1]);
349        assert_eq!(cy_train.shape(), &[hidden_size, 1]);
350        assert_eq!(hy_eval.shape(), &[hidden_size, 1]);
351        assert_eq!(cy_eval.shape(), &[hidden_size, 1]);
352    }
353
354    #[test]
355    fn test_dropout_mask_backward_pass() {
356        let input_size = 2;
357        let hidden_size = 3;
358        let mut cell = LSTMCell::new(input_size, hidden_size)
359            .with_input_dropout(0.5, false)
360            .with_output_dropout(0.5);
361
362        let input = arr2(&[[1.0], [0.5]]);
363        let hx = arr2(&[[0.1], [0.2], [0.3]]);
364        let cx = arr2(&[[0.0], [0.0], [0.0]]);
365
366        cell.train();
367        let (_hy, _cy, cache) = cell.forward_with_cache(&input, &hx, &cx);
368
369        assert!(cache.input_dropout_mask.is_some());
370        assert!(cache.output_dropout_mask.is_some());
371
372        let dhy = arr2(&[[1.0], [1.0], [1.0]]);
373        let dcy = arr2(&[[0.0], [0.0], [0.0]]);
374        
375        let (gradients, dx, dhx, dcx) = cell.backward(&dhy, &dcy, &cache);
376
377        assert_eq!(gradients.w_ih.shape(), &[4 * hidden_size, input_size]);
378        assert_eq!(gradients.w_hh.shape(), &[4 * hidden_size, hidden_size]);
379        assert_eq!(dx.shape(), &[input_size, 1]);
380        assert_eq!(dhx.shape(), &[hidden_size, 1]);
381        assert_eq!(dcx.shape(), &[hidden_size, 1]);
382
383        cell.eval();
384        let (_, _, cache_eval) = cell.forward_with_cache(&input, &hx, &cx);
385        assert!(cache_eval.input_dropout_mask.is_none());
386        assert!(cache_eval.output_dropout_mask.is_none());
387    }
388}