rust_lstm/layers/
lstm_cell.rs

1use ndarray::{Array2, s};
2use ndarray_rand::RandomExt;
3use ndarray_rand::rand_distr::Uniform;
4use crate::utils::sigmoid;
5use crate::layers::dropout::{Dropout, Zoneout};
6
7/// Holds gradients for all LSTM cell parameters during backpropagation
8#[derive(Clone)]
9pub struct LSTMCellGradients {
10    pub w_ih: Array2<f64>,
11    pub w_hh: Array2<f64>,
12    pub b_ih: Array2<f64>,
13    pub b_hh: Array2<f64>,
14}
15
16/// Caches intermediate values during forward pass for efficient backward computation
17#[derive(Clone)]
18pub struct LSTMCellCache {
19    pub input: Array2<f64>,
20    pub hx: Array2<f64>,
21    pub cx: Array2<f64>,
22    pub gates: Array2<f64>,
23    pub input_gate: Array2<f64>,
24    pub forget_gate: Array2<f64>,
25    pub cell_gate: Array2<f64>,
26    pub output_gate: Array2<f64>,
27    pub cy: Array2<f64>,
28    pub hy: Array2<f64>,
29    pub input_dropout_mask: Option<Array2<f64>>,
30    pub recurrent_dropout_mask: Option<Array2<f64>>,
31    pub output_dropout_mask: Option<Array2<f64>>,
32}
33
34/// Batch cache for multiple sequences processed simultaneously
35#[derive(Clone)]
36pub struct LSTMCellBatchCache {
37    pub input: Array2<f64>,
38    pub hx: Array2<f64>,
39    pub cx: Array2<f64>,
40    pub gates: Array2<f64>,
41    pub input_gate: Array2<f64>,
42    pub forget_gate: Array2<f64>,
43    pub cell_gate: Array2<f64>,
44    pub output_gate: Array2<f64>,
45    pub cy: Array2<f64>,
46    pub hy: Array2<f64>,
47    pub input_dropout_mask: Option<Array2<f64>>,
48    pub recurrent_dropout_mask: Option<Array2<f64>>,
49    pub output_dropout_mask: Option<Array2<f64>>,
50    pub batch_size: usize,
51}
52
53/// LSTM cell with trainable parameters and dropout support
54#[derive(Clone)]
55pub struct LSTMCell {
56    pub w_ih: Array2<f64>,
57    pub w_hh: Array2<f64>,
58    pub b_ih: Array2<f64>,
59    pub b_hh: Array2<f64>,
60    pub hidden_size: usize,
61    pub input_dropout: Option<Dropout>,
62    pub recurrent_dropout: Option<Dropout>,
63    pub output_dropout: Option<Dropout>,
64    pub zoneout: Option<Zoneout>,
65    pub is_training: bool,
66}
67
68impl LSTMCell {
69    /// Creates new LSTM cell with Xavier-uniform weight initialization
70    pub fn new(input_size: usize, hidden_size: usize) -> Self {
71        let dist = Uniform::new(-0.1, 0.1);
72
73        let w_ih = Array2::random((4 * hidden_size, input_size), dist);
74        let w_hh = Array2::random((4 * hidden_size, hidden_size), dist);
75        let b_ih = Array2::zeros((4 * hidden_size, 1));
76        let b_hh = Array2::zeros((4 * hidden_size, 1));
77
78        LSTMCell { 
79            w_ih, 
80            w_hh, 
81            b_ih, 
82            b_hh, 
83            hidden_size,
84            input_dropout: None,
85            recurrent_dropout: None,
86            output_dropout: None,
87            zoneout: None,
88            is_training: true,
89        }
90    }
91
92    pub fn with_input_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
93        if variational {
94            self.input_dropout = Some(Dropout::variational(dropout_rate));
95        } else {
96            self.input_dropout = Some(Dropout::new(dropout_rate));
97        }
98        self
99    }
100
101    pub fn with_recurrent_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
102        if variational {
103            self.recurrent_dropout = Some(Dropout::variational(dropout_rate));
104        } else {
105            self.recurrent_dropout = Some(Dropout::new(dropout_rate));
106        }
107        self
108    }
109
110    pub fn with_output_dropout(mut self, dropout_rate: f64) -> Self {
111        self.output_dropout = Some(Dropout::new(dropout_rate));
112        self
113    }
114
115    pub fn with_zoneout(mut self, cell_zoneout_rate: f64, hidden_zoneout_rate: f64) -> Self {
116        self.zoneout = Some(Zoneout::new(cell_zoneout_rate, hidden_zoneout_rate));
117        self
118    }
119
120    pub fn train(&mut self) {
121        self.is_training = true;
122        if let Some(ref mut dropout) = self.input_dropout {
123            dropout.train();
124        }
125        if let Some(ref mut dropout) = self.recurrent_dropout {
126            dropout.train();
127        }
128        if let Some(ref mut dropout) = self.output_dropout {
129            dropout.train();
130        }
131        if let Some(ref mut zoneout) = self.zoneout {
132            zoneout.train();
133        }
134    }
135
136    pub fn eval(&mut self) {
137        self.is_training = false;
138        if let Some(ref mut dropout) = self.input_dropout {
139            dropout.eval();
140        }
141        if let Some(ref mut dropout) = self.recurrent_dropout {
142            dropout.eval();
143        }
144        if let Some(ref mut dropout) = self.output_dropout {
145            dropout.eval();
146        }
147        if let Some(ref mut zoneout) = self.zoneout {
148            zoneout.eval();
149        }
150    }
151
152    pub fn forward(&mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
153        let (hy, cy, _) = self.forward_with_cache(input, hx, cx);
154        (hy, cy)
155    }
156
157    pub fn forward_with_cache(&mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>, LSTMCellCache) {
158        let (input_dropped, input_mask) = if let Some(ref mut dropout) = self.input_dropout {
159            let dropped = dropout.forward(input);
160            let mask = dropout.get_last_mask().map(|m| m.clone());
161            (dropped, mask)
162        } else {
163            (input.clone(), None)
164        };
165
166        let (hx_dropped, recurrent_mask) = if let Some(ref mut dropout) = self.recurrent_dropout {
167            let dropped = dropout.forward(hx);
168            let mask = dropout.get_last_mask().map(|m| m.clone());
169            (dropped, mask)
170        } else {
171            (hx.clone(), None)
172        };
173
174        // Compute all gates in parallel: [input_gate, forget_gate, cell_gate, output_gate]
175        let gates = &self.w_ih.dot(&input_dropped) + &self.b_ih + &self.w_hh.dot(&hx_dropped) + &self.b_hh;
176
177        let input_gate = gates.slice(s![0..self.hidden_size, ..]).map(|&x| sigmoid(x));
178        let forget_gate = gates.slice(s![self.hidden_size..2*self.hidden_size, ..]).map(|&x| sigmoid(x));
179        let cell_gate = gates.slice(s![2*self.hidden_size..3*self.hidden_size, ..]).map(|&x| x.tanh());
180        let output_gate = gates.slice(s![3*self.hidden_size..4*self.hidden_size, ..]).map(|&x| sigmoid(x));
181
182        let mut cy = &forget_gate * cx + &input_gate * &cell_gate;
183
184        if let Some(ref zoneout) = self.zoneout {
185            cy = zoneout.apply_cell_zoneout(&cy, cx);
186        }
187
188        let mut hy = &output_gate * cy.map(|&x| x.tanh());
189
190        if let Some(ref zoneout) = self.zoneout {
191            hy = zoneout.apply_hidden_zoneout(&hy, hx);
192        }
193
194        let (hy_final, output_mask) = if let Some(ref mut dropout) = self.output_dropout {
195            let dropped = dropout.forward(&hy);
196            let mask = dropout.get_last_mask().map(|m| m.clone());
197            (dropped, mask)
198        } else {
199            (hy, None)
200        };
201
202        let cache = LSTMCellCache {
203            input: input.clone(),
204            hx: hx.clone(),
205            cx: cx.clone(),
206            gates: gates,
207            input_gate: input_gate.to_owned(),
208            forget_gate: forget_gate.to_owned(),
209            cell_gate: cell_gate.to_owned(),
210            output_gate: output_gate.to_owned(),
211            cy: cy.clone(),
212            hy: hy_final.clone(),
213            input_dropout_mask: input_mask,
214            recurrent_dropout_mask: recurrent_mask,
215            output_dropout_mask: output_mask,
216        };
217
218        (hy_final, cy, cache)
219    }
220
221    /// Batch forward pass for multiple sequences simultaneously
222    /// 
223    /// # Arguments
224    /// * `input` - Input tensor of shape (input_size, batch_size)
225    /// * `hx` - Hidden state tensor of shape (hidden_size, batch_size)
226    /// * `cx` - Cell state tensor of shape (hidden_size, batch_size)
227    /// 
228    /// # Returns
229    /// * Tuple of (new_hidden_state, new_cell_state) with same batch dimensions
230    pub fn forward_batch(&mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
231        let batch_size = input.ncols();
232        assert_eq!(hx.ncols(), batch_size, "Hidden state batch size must match input batch size");
233        assert_eq!(cx.ncols(), batch_size, "Cell state batch size must match input batch size");
234        assert_eq!(input.nrows(), self.w_ih.ncols(), "Input feature size must match weight matrix");
235        assert_eq!(hx.nrows(), self.hidden_size, "Hidden state size must match network hidden size");
236        assert_eq!(cx.nrows(), self.hidden_size, "Cell state size must match network hidden size");
237
238        // Apply input dropout across the entire batch
239        let (input_dropped, _input_mask) = if let Some(ref mut dropout) = self.input_dropout {
240            let dropped = dropout.forward(input);
241            let mask = dropout.get_last_mask().map(|m| m.clone());
242            (dropped, mask)
243        } else {
244            (input.clone(), None)
245        };
246
247        // Apply recurrent dropout across the entire batch
248        let (hx_dropped, _recurrent_mask) = if let Some(ref mut dropout) = self.recurrent_dropout {
249            let dropped = dropout.forward(hx);
250            let mask = dropout.get_last_mask().map(|m| m.clone());
251            (dropped, mask)
252        } else {
253            (hx.clone(), None)
254        };
255
256        // Compute all gates in parallel for the entire batch
257        // gates shape: (4 * hidden_size, batch_size)
258        let gates = &self.w_ih.dot(&input_dropped) + &self.b_ih.broadcast((4 * self.hidden_size, batch_size)).unwrap() 
259                  + &self.w_hh.dot(&hx_dropped) + &self.b_hh.broadcast((4 * self.hidden_size, batch_size)).unwrap();
260
261        // Extract and compute gate activations for the entire batch
262        let input_gate = gates.slice(s![0..self.hidden_size, ..]).map(|&x| sigmoid(x));
263        let forget_gate = gates.slice(s![self.hidden_size..2*self.hidden_size, ..]).map(|&x| sigmoid(x));
264        let cell_gate = gates.slice(s![2*self.hidden_size..3*self.hidden_size, ..]).map(|&x| x.tanh());
265        let output_gate = gates.slice(s![3*self.hidden_size..4*self.hidden_size, ..]).map(|&x| sigmoid(x));
266
267        // Update cell state for entire batch
268        let mut cy = &forget_gate * cx + &input_gate * &cell_gate;
269
270        // Apply zoneout to cell state if configured
271        if let Some(ref zoneout) = self.zoneout {
272            for col_idx in 0..batch_size {
273                let cy_col = cy.column(col_idx).to_owned().insert_axis(ndarray::Axis(1));
274                let cx_col = cx.column(col_idx).to_owned().insert_axis(ndarray::Axis(1));
275                let cy_zoneout = zoneout.apply_cell_zoneout(&cy_col, &cx_col);
276                cy.column_mut(col_idx).assign(&cy_zoneout.column(0));
277            }
278        }
279
280        // Compute hidden state for entire batch
281        let mut hy = &output_gate * cy.map(|&x| x.tanh());
282
283        // Apply zoneout to hidden state if configured
284        if let Some(ref zoneout) = self.zoneout {
285            for col_idx in 0..batch_size {
286                let hy_col = hy.column(col_idx).to_owned().insert_axis(ndarray::Axis(1));
287                let hx_col = hx.column(col_idx).to_owned().insert_axis(ndarray::Axis(1));
288                let hy_zoneout = zoneout.apply_hidden_zoneout(&hy_col, &hx_col);
289                hy.column_mut(col_idx).assign(&hy_zoneout.column(0));
290            }
291        }
292
293        // Apply output dropout to the entire batch
294        let hy_final = if let Some(ref mut dropout) = self.output_dropout {
295            dropout.forward(&hy)
296        } else {
297            hy
298        };
299
300        (hy_final, cy)
301    }
302
303    /// Batch forward pass with caching for training
304    /// 
305    /// Similar to forward_batch but caches intermediate values needed for backpropagation
306    pub fn forward_batch_with_cache(&mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>, LSTMCellBatchCache) {
307        let batch_size = input.ncols();
308
309        // Apply dropout and track masks
310        let (input_dropped, input_mask) = if let Some(ref mut dropout) = self.input_dropout {
311            let dropped = dropout.forward(input);
312            let mask = dropout.get_last_mask().map(|m| m.clone());
313            (dropped, mask)
314        } else {
315            (input.clone(), None)
316        };
317
318        let (hx_dropped, recurrent_mask) = if let Some(ref mut dropout) = self.recurrent_dropout {
319            let dropped = dropout.forward(hx);
320            let mask = dropout.get_last_mask().map(|m| m.clone());
321            (dropped, mask)
322        } else {
323            (hx.clone(), None)
324        };
325
326        // Compute gates for entire batch
327        let gates = &self.w_ih.dot(&input_dropped) + &self.b_ih.broadcast((4 * self.hidden_size, batch_size)).unwrap()
328                  + &self.w_hh.dot(&hx_dropped) + &self.b_hh.broadcast((4 * self.hidden_size, batch_size)).unwrap();
329
330        let input_gate = gates.slice(s![0..self.hidden_size, ..]).map(|&x| sigmoid(x));
331        let forget_gate = gates.slice(s![self.hidden_size..2*self.hidden_size, ..]).map(|&x| sigmoid(x));
332        let cell_gate = gates.slice(s![2*self.hidden_size..3*self.hidden_size, ..]).map(|&x| x.tanh());
333        let output_gate = gates.slice(s![3*self.hidden_size..4*self.hidden_size, ..]).map(|&x| sigmoid(x));
334
335        let mut cy = &forget_gate * cx + &input_gate * &cell_gate;
336
337        // Apply zoneout if configured
338        if let Some(ref zoneout) = self.zoneout {
339            for col_idx in 0..batch_size {
340                let cy_col = cy.column(col_idx).to_owned().insert_axis(ndarray::Axis(1));
341                let cx_col = cx.column(col_idx).to_owned().insert_axis(ndarray::Axis(1));
342                let cy_zoneout = zoneout.apply_cell_zoneout(&cy_col, &cx_col);
343                cy.column_mut(col_idx).assign(&cy_zoneout.column(0));
344            }
345        }
346
347        let mut hy = &output_gate * cy.map(|&x| x.tanh());
348
349        if let Some(ref zoneout) = self.zoneout {
350            for col_idx in 0..batch_size {
351                let hy_col = hy.column(col_idx).to_owned().insert_axis(ndarray::Axis(1));
352                let hx_col = hx.column(col_idx).to_owned().insert_axis(ndarray::Axis(1));
353                let hy_zoneout = zoneout.apply_hidden_zoneout(&hy_col, &hx_col);
354                hy.column_mut(col_idx).assign(&hy_zoneout.column(0));
355            }
356        }
357
358        let (hy_final, output_mask) = if let Some(ref mut dropout) = self.output_dropout {
359            let dropped = dropout.forward(&hy);
360            let mask = dropout.get_last_mask().map(|m| m.clone());
361            (dropped, mask)
362        } else {
363            (hy, None)
364        };
365
366        // Create cache for backpropagation
367        let cache = LSTMCellBatchCache {
368            input: input.clone(),
369            hx: hx.clone(),
370            cx: cx.clone(),
371            gates: gates.to_owned(),
372            input_gate: input_gate.to_owned(),
373            forget_gate: forget_gate.to_owned(),
374            cell_gate: cell_gate.to_owned(),
375            output_gate: output_gate.to_owned(),
376            cy: cy.clone(),
377            hy: hy_final.clone(),
378            input_dropout_mask: input_mask,
379            recurrent_dropout_mask: recurrent_mask,
380            output_dropout_mask: output_mask,
381            batch_size,
382        };
383
384        (hy_final, cy, cache)
385    }
386
387    /// Backward pass implementing LSTM gradient computation with dropout
388    /// 
389    /// Returns (parameter_gradients, input_gradient, hidden_gradient, cell_gradient)
390    pub fn backward(&self, dhy: &Array2<f64>, dcy: &Array2<f64>, cache: &LSTMCellCache) -> (LSTMCellGradients, Array2<f64>, Array2<f64>, Array2<f64>) {
391        let hidden_size = self.hidden_size;
392
393        // Apply output dropout backward pass using saved mask
394        let dhy_dropped = if let Some(ref mask) = cache.output_dropout_mask {
395            let keep_prob = if let Some(ref dropout) = self.output_dropout {
396                1.0 - dropout.dropout_rate
397            } else {
398                1.0
399            };
400            dhy * mask / keep_prob
401        } else {
402            dhy.clone()
403        };
404
405        // Output gate gradients: ∂L/∂o_t = ∂L/∂h_t ⊙ tanh(c_t)
406        let tanh_cy = cache.cy.map(|&x| x.tanh());
407        let do_t = &dhy_dropped * &tanh_cy;
408        let do_raw = &do_t * &cache.output_gate * (&cache.output_gate.map(|&x| 1.0 - x));
409
410        // Cell state gradients from both tanh and direct paths
411        let dcy_from_tanh = &dhy_dropped * &cache.output_gate * cache.cy.map(|&x| 1.0 - x.tanh().powi(2));
412        let dcy_total = dcy + dcy_from_tanh;
413
414        // Forget gate gradients: ∂L/∂f_t = ∂L/∂c_t ⊙ c_t-1
415        let df_t = &dcy_total * &cache.cx;
416        let df_raw = &df_t * &cache.forget_gate * cache.forget_gate.map(|&x| 1.0 - x);
417
418        // Input gate gradients: ∂L/∂i_t = ∂L/∂c_t ⊙ g_t
419        let di_t = &dcy_total * &cache.cell_gate;
420        let di_raw = &di_t * &cache.input_gate * cache.input_gate.map(|&x| 1.0 - x);
421
422        // Cell gate gradients: ∂L/∂g_t = ∂L/∂c_t ⊙ i_t
423        let dc_t = &dcy_total * &cache.input_gate;
424        let dc_raw = &dc_t * cache.cell_gate.map(|&x| 1.0 - x.powi(2));
425
426        // Concatenate gate gradients in the same order as forward pass
427        let mut dgates = Array2::zeros((4 * hidden_size, 1));
428        dgates.slice_mut(s![0..hidden_size, ..]).assign(&di_raw);
429        dgates.slice_mut(s![hidden_size..2*hidden_size, ..]).assign(&df_raw);
430        dgates.slice_mut(s![2*hidden_size..3*hidden_size, ..]).assign(&dc_raw);
431        dgates.slice_mut(s![3*hidden_size..4*hidden_size, ..]).assign(&do_raw);
432
433        // Parameter gradients using chain rule
434        let dw_ih = dgates.dot(&cache.input.t());
435        let dw_hh = dgates.dot(&cache.hx.t());
436        let db_ih = dgates.clone();
437        let db_hh = dgates.clone();
438
439        let gradients = LSTMCellGradients {
440            w_ih: dw_ih,
441            w_hh: dw_hh,
442            b_ih: db_ih,
443            b_hh: db_hh,
444        };
445
446        let mut dx = self.w_ih.t().dot(&dgates);
447        let mut dhx = self.w_hh.t().dot(&dgates);
448        let dcx = &dcy_total * &cache.forget_gate;
449
450        if let Some(ref mask) = cache.input_dropout_mask {
451            let keep_prob = if let Some(ref dropout) = self.input_dropout {
452                1.0 - dropout.dropout_rate
453            } else {
454                1.0
455            };
456            dx = dx * mask / keep_prob;
457        }
458
459        if let Some(ref mask) = cache.recurrent_dropout_mask {
460            let keep_prob = if let Some(ref dropout) = self.recurrent_dropout {
461                1.0 - dropout.dropout_rate
462            } else {
463                1.0
464            };
465            dhx = dhx * mask / keep_prob;
466        }
467
468        (gradients, dx, dhx, dcx)
469    }
470
471    /// Batch backward pass for training with multiple sequences
472    /// 
473    /// Computes gradients for an entire batch simultaneously
474    pub fn backward_batch(&self, dhy: &Array2<f64>, dcy: &Array2<f64>, cache: &LSTMCellBatchCache) -> (LSTMCellGradients, Array2<f64>, Array2<f64>, Array2<f64>) {
475        let batch_size = cache.batch_size;
476        let hidden_size = self.hidden_size;
477
478        // Apply output dropout backward pass using saved mask
479        let dhy_dropped = if let Some(ref mask) = cache.output_dropout_mask {
480            let keep_prob = if let Some(ref dropout) = self.output_dropout {
481                1.0 - dropout.dropout_rate
482            } else {
483                1.0
484            };
485            dhy * mask / keep_prob
486        } else {
487            dhy.clone()
488        };
489
490        // Output gate gradients for entire batch
491        let tanh_cy = cache.cy.map(|&x| x.tanh());
492        let do_t = &dhy_dropped * &tanh_cy;
493        let do_raw = &do_t * &cache.output_gate * &cache.output_gate.map(|&x| 1.0 - x);
494
495        // Cell state gradients from both tanh and direct paths
496        let dcy_from_tanh = &dhy_dropped * &cache.output_gate * cache.cy.map(|&x| 1.0 - x.tanh().powi(2));
497        let dcy_total = dcy + dcy_from_tanh;
498
499        // Gate gradients for entire batch
500        let df_t = &dcy_total * &cache.cx;
501        let df_raw = &df_t * &cache.forget_gate * cache.forget_gate.map(|&x| 1.0 - x);
502
503        let di_t = &dcy_total * &cache.cell_gate;
504        let di_raw = &di_t * &cache.input_gate * cache.input_gate.map(|&x| 1.0 - x);
505
506        let dc_t = &dcy_total * &cache.input_gate;
507        let dc_raw = &dc_t * cache.cell_gate.map(|&x| 1.0 - x.powi(2));
508
509        // Concatenate gate gradients
510        let mut dgates = Array2::zeros((4 * hidden_size, batch_size));
511        dgates.slice_mut(s![0..hidden_size, ..]).assign(&di_raw);
512        dgates.slice_mut(s![hidden_size..2*hidden_size, ..]).assign(&df_raw);
513        dgates.slice_mut(s![2*hidden_size..3*hidden_size, ..]).assign(&dc_raw);
514        dgates.slice_mut(s![3*hidden_size..4*hidden_size, ..]).assign(&do_raw);
515
516        // Parameter gradients - sum across batch dimension
517        let dw_ih = dgates.dot(&cache.input.t());
518        let dw_hh = dgates.dot(&cache.hx.t());
519        let db_ih = dgates.sum_axis(ndarray::Axis(1)).insert_axis(ndarray::Axis(1));
520        let db_hh = db_ih.clone();
521
522        let gradients = LSTMCellGradients {
523            w_ih: dw_ih,
524            w_hh: dw_hh,
525            b_ih: db_ih,
526            b_hh: db_hh,
527        };
528
529        // Input and hidden gradients for entire batch
530        let mut dx = self.w_ih.t().dot(&dgates);
531        let mut dhx = self.w_hh.t().dot(&dgates);
532        let dcx = &dcy_total * &cache.forget_gate;
533
534        // Apply dropout gradients if masks exist
535        if let Some(ref mask) = cache.input_dropout_mask {
536            let keep_prob = if let Some(ref dropout) = self.input_dropout {
537                1.0 - dropout.dropout_rate
538            } else {
539                1.0
540            };
541            dx = dx * mask / keep_prob;
542        }
543
544        if let Some(ref mask) = cache.recurrent_dropout_mask {
545            let keep_prob = if let Some(ref dropout) = self.recurrent_dropout {
546                1.0 - dropout.dropout_rate
547            } else {
548                1.0
549            };
550            dhx = dhx * mask / keep_prob;
551        }
552
553        (gradients, dx, dhx, dcx)
554    }
555
556    /// Initialize zero gradients for accumulation
557    pub fn zero_gradients(&self) -> LSTMCellGradients {
558        LSTMCellGradients {
559            w_ih: Array2::zeros(self.w_ih.raw_dim()),
560            w_hh: Array2::zeros(self.w_hh.raw_dim()),
561            b_ih: Array2::zeros(self.b_ih.raw_dim()),
562            b_hh: Array2::zeros(self.b_hh.raw_dim()),
563        }
564    }
565
566    /// Apply gradients using the provided optimizer
567    pub fn update_parameters<O: crate::optimizers::Optimizer>(&mut self, gradients: &LSTMCellGradients, optimizer: &mut O, prefix: &str) {
568        optimizer.update(&format!("{}_w_ih", prefix), &mut self.w_ih, &gradients.w_ih);
569        optimizer.update(&format!("{}_w_hh", prefix), &mut self.w_hh, &gradients.w_hh);
570        optimizer.update(&format!("{}_b_ih", prefix), &mut self.b_ih, &gradients.b_ih);
571        optimizer.update(&format!("{}_b_hh", prefix), &mut self.b_hh, &gradients.b_hh);
572    }
573}
574
575#[cfg(test)]
576mod tests {
577    use super::*;
578    use ndarray::arr2;
579
580    #[test]
581    fn test_lstm_cell_forward() {
582        let input_size = 3;
583        let hidden_size = 2;
584        let mut cell = LSTMCell::new(input_size, hidden_size);
585
586        let input = arr2(&[[0.5], [0.1], [-0.3]]);
587        let hx = arr2(&[[0.0], [0.0]]);
588        let cx = arr2(&[[0.0], [0.0]]);
589
590        let (hy, cy) = cell.forward(&input, &hx, &cx);
591
592        assert_eq!(hy.shape(), &[hidden_size, 1]);
593        assert_eq!(cy.shape(), &[hidden_size, 1]);
594    }
595
596    #[test]
597    fn test_lstm_cell_with_dropout() {
598        let input_size = 3;
599        let hidden_size = 2;
600        let mut cell = LSTMCell::new(input_size, hidden_size)
601            .with_input_dropout(0.2, false)
602            .with_recurrent_dropout(0.3, true)
603            .with_output_dropout(0.1)
604            .with_zoneout(0.1, 0.1);
605
606        let input = arr2(&[[0.5], [0.1], [-0.3]]);
607        let hx = arr2(&[[0.0], [0.0]]);
608        let cx = arr2(&[[0.0], [0.0]]);
609
610        // Test training mode
611        cell.train();
612        let (hy_train, cy_train) = cell.forward(&input, &hx, &cx);
613
614        // Test evaluation mode
615        cell.eval();
616        let (hy_eval, cy_eval) = cell.forward(&input, &hx, &cx);
617
618        assert_eq!(hy_train.shape(), &[hidden_size, 1]);
619        assert_eq!(cy_train.shape(), &[hidden_size, 1]);
620        assert_eq!(hy_eval.shape(), &[hidden_size, 1]);
621        assert_eq!(cy_eval.shape(), &[hidden_size, 1]);
622    }
623
624    #[test]
625    fn test_dropout_mask_backward_pass() {
626        let input_size = 2;
627        let hidden_size = 3;
628        let mut cell = LSTMCell::new(input_size, hidden_size)
629            .with_input_dropout(0.5, false)
630            .with_output_dropout(0.5);
631
632        let input = arr2(&[[1.0], [0.5]]);
633        let hx = arr2(&[[0.1], [0.2], [0.3]]);
634        let cx = arr2(&[[0.0], [0.0], [0.0]]);
635
636        cell.train();
637        let (_hy, _cy, cache) = cell.forward_with_cache(&input, &hx, &cx);
638
639        assert!(cache.input_dropout_mask.is_some());
640        assert!(cache.output_dropout_mask.is_some());
641
642        let dhy = arr2(&[[1.0], [1.0], [1.0]]);
643        let dcy = arr2(&[[0.0], [0.0], [0.0]]);
644        
645        let (gradients, dx, dhx, dcx) = cell.backward(&dhy, &dcy, &cache);
646
647        assert_eq!(gradients.w_ih.shape(), &[4 * hidden_size, input_size]);
648        assert_eq!(gradients.w_hh.shape(), &[4 * hidden_size, hidden_size]);
649        assert_eq!(dx.shape(), &[input_size, 1]);
650        assert_eq!(dhx.shape(), &[hidden_size, 1]);
651        assert_eq!(dcx.shape(), &[hidden_size, 1]);
652
653        cell.eval();
654        let (_, _, cache_eval) = cell.forward_with_cache(&input, &hx, &cx);
655        assert!(cache_eval.input_dropout_mask.is_none());
656        assert!(cache_eval.output_dropout_mask.is_none());
657    }
658}