rust_lstm/layers/
gru_cell.rs

1use ndarray::Array2;
2use ndarray_rand::RandomExt;
3use ndarray_rand::rand_distr::Uniform;
4use crate::utils::sigmoid;
5use crate::layers::dropout::Dropout;
6
7/// Holds gradients for all GRU cell parameters during backpropagation
8#[derive(Clone)]
9pub struct GRUCellGradients {
10    pub w_ir: Array2<f64>,
11    pub w_hr: Array2<f64>,
12    pub b_ir: Array2<f64>,
13    pub b_hr: Array2<f64>,
14    pub w_iz: Array2<f64>,
15    pub w_hz: Array2<f64>,
16    pub b_iz: Array2<f64>,
17    pub b_hz: Array2<f64>,
18    pub w_ih: Array2<f64>,
19    pub w_hh: Array2<f64>,
20    pub b_ih: Array2<f64>,
21    pub b_hh: Array2<f64>,
22}
23
24/// Caches intermediate values during forward pass for efficient backward computation
25#[derive(Clone)]
26pub struct GRUCellCache {
27    pub input: Array2<f64>,
28    pub hx: Array2<f64>,
29    pub reset_gate: Array2<f64>,
30    pub update_gate: Array2<f64>,
31    pub new_gate: Array2<f64>,
32    pub reset_hidden: Array2<f64>,
33    pub hy: Array2<f64>,
34    pub input_dropout_mask: Option<Array2<f64>>,
35    pub recurrent_dropout_mask: Option<Array2<f64>>,
36    pub output_dropout_mask: Option<Array2<f64>>,
37}
38
39/// GRU cell with trainable parameters and dropout support
40#[derive(Clone)]
41pub struct GRUCell {
42    // Reset gate parameters
43    pub w_ir: Array2<f64>,
44    pub w_hr: Array2<f64>,
45    pub b_ir: Array2<f64>,
46    pub b_hr: Array2<f64>,
47    
48    // Update gate parameters
49    pub w_iz: Array2<f64>,
50    pub w_hz: Array2<f64>,
51    pub b_iz: Array2<f64>,
52    pub b_hz: Array2<f64>,
53    
54    // New gate parameters
55    pub w_ih: Array2<f64>,
56    pub w_hh: Array2<f64>,
57    pub b_ih: Array2<f64>,
58    pub b_hh: Array2<f64>,
59    
60    pub hidden_size: usize,
61    pub input_dropout: Option<Dropout>,
62    pub recurrent_dropout: Option<Dropout>,
63    pub output_dropout: Option<Dropout>,
64    pub is_training: bool,
65}
66
67impl GRUCell {
68    /// Creates new GRU cell with Xavier-uniform weight initialization
69    pub fn new(input_size: usize, hidden_size: usize) -> Self {
70        let dist = Uniform::new(-0.1, 0.1);
71
72        // Reset gate weights
73        let w_ir = Array2::random((hidden_size, input_size), dist);
74        let w_hr = Array2::random((hidden_size, hidden_size), dist);
75        let b_ir = Array2::zeros((hidden_size, 1));
76        let b_hr = Array2::zeros((hidden_size, 1));
77        
78        // Update gate weights
79        let w_iz = Array2::random((hidden_size, input_size), dist);
80        let w_hz = Array2::random((hidden_size, hidden_size), dist);
81        let b_iz = Array2::zeros((hidden_size, 1));
82        let b_hz = Array2::zeros((hidden_size, 1));
83        
84        // New gate weights
85        let w_ih = Array2::random((hidden_size, input_size), dist);
86        let w_hh = Array2::random((hidden_size, hidden_size), dist);
87        let b_ih = Array2::zeros((hidden_size, 1));
88        let b_hh = Array2::zeros((hidden_size, 1));
89
90        GRUCell { 
91            w_ir, w_hr, b_ir, b_hr,
92            w_iz, w_hz, b_iz, b_hz,
93            w_ih, w_hh, b_ih, b_hh,
94            hidden_size,
95            input_dropout: None,
96            recurrent_dropout: None,
97            output_dropout: None,
98            is_training: true,
99        }
100    }
101
102    pub fn with_input_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
103        if variational {
104            self.input_dropout = Some(Dropout::variational(dropout_rate));
105        } else {
106            self.input_dropout = Some(Dropout::new(dropout_rate));
107        }
108        self
109    }
110
111    pub fn with_recurrent_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
112        if variational {
113            self.recurrent_dropout = Some(Dropout::variational(dropout_rate));
114        } else {
115            self.recurrent_dropout = Some(Dropout::new(dropout_rate));
116        }
117        self
118    }
119
120    pub fn with_output_dropout(mut self, dropout_rate: f64) -> Self {
121        self.output_dropout = Some(Dropout::new(dropout_rate));
122        self
123    }
124
125    pub fn train(&mut self) {
126        self.is_training = true;
127        if let Some(ref mut dropout) = self.input_dropout {
128            dropout.train();
129        }
130        if let Some(ref mut dropout) = self.recurrent_dropout {
131            dropout.train();
132        }
133        if let Some(ref mut dropout) = self.output_dropout {
134            dropout.train();
135        }
136    }
137
138    pub fn eval(&mut self) {
139        self.is_training = false;
140        if let Some(ref mut dropout) = self.input_dropout {
141            dropout.eval();
142        }
143        if let Some(ref mut dropout) = self.recurrent_dropout {
144            dropout.eval();
145        }
146        if let Some(ref mut dropout) = self.output_dropout {
147            dropout.eval();
148        }
149    }
150
151    pub fn forward(&mut self, input: &Array2<f64>, hx: &Array2<f64>) -> Array2<f64> {
152        let (hy, _) = self.forward_with_cache(input, hx);
153        hy
154    }
155
156    pub fn forward_with_cache(&mut self, input: &Array2<f64>, hx: &Array2<f64>) -> (Array2<f64>, GRUCellCache) {
157        // Apply input dropout
158        let (input_dropped, input_mask) = if let Some(ref mut dropout) = self.input_dropout {
159            let dropped = dropout.forward(input);
160            let mask = dropout.get_last_mask().map(|m| m.clone());
161            (dropped, mask)
162        } else {
163            (input.clone(), None)
164        };
165
166        // Apply recurrent dropout to hidden state
167        let (hx_dropped, recurrent_mask) = if let Some(ref mut dropout) = self.recurrent_dropout {
168            let dropped = dropout.forward(hx);
169            let mask = dropout.get_last_mask().map(|m| m.clone());
170            (dropped, mask)
171        } else {
172            (hx.clone(), None)
173        };
174
175        // Reset gate: r_t = σ(W_ir * x_t + b_ir + W_hr * h_{t-1} + b_hr)
176        let reset_gate = (&self.w_ir.dot(&input_dropped) + &self.b_ir + &self.w_hr.dot(&hx_dropped) + &self.b_hr)
177            .map(|&x| sigmoid(x));
178
179        // Update gate: z_t = σ(W_iz * x_t + b_iz + W_hz * h_{t-1} + b_hz)
180        let update_gate = (&self.w_iz.dot(&input_dropped) + &self.b_iz + &self.w_hz.dot(&hx_dropped) + &self.b_hz)
181            .map(|&x| sigmoid(x));
182
183        // Reset hidden state: reset_hidden = r_t ⊙ h_{t-1}
184        let reset_hidden = &reset_gate * &hx_dropped;
185
186        // New gate: h_tilde_t = tanh(W_ih * x_t + b_ih + W_hh * reset_hidden + b_hh)
187        let new_gate = (&self.w_ih.dot(&input_dropped) + &self.b_ih + &self.w_hh.dot(&reset_hidden) + &self.b_hh)
188            .map(|&x| x.tanh());
189
190        // Output: h_t = (1 - z_t) ⊙ h_{t-1} + z_t ⊙ h_tilde_t
191        let hy = &update_gate.map(|&x| 1.0 - x) * &hx_dropped + &update_gate * &new_gate;
192
193        // Apply output dropout
194        let (hy_final, output_mask) = if let Some(ref mut dropout) = self.output_dropout {
195            let dropped = dropout.forward(&hy);
196            let mask = dropout.get_last_mask().map(|m| m.clone());
197            (dropped, mask)
198        } else {
199            (hy, None)
200        };
201
202        let cache = GRUCellCache {
203            input: input.clone(),
204            hx: hx.clone(),
205            reset_gate: reset_gate.clone(),
206            update_gate: update_gate.clone(),
207            new_gate: new_gate.clone(),
208            reset_hidden: reset_hidden,
209            hy: hy_final.clone(),
210            input_dropout_mask: input_mask,
211            recurrent_dropout_mask: recurrent_mask,
212            output_dropout_mask: output_mask,
213        };
214
215        (hy_final, cache)
216    }
217
218    /// Backward pass implementing GRU gradient computation with dropout
219    /// 
220    /// Returns (parameter_gradients, input_gradient, hidden_gradient)
221    pub fn backward(&self, dhy: &Array2<f64>, cache: &GRUCellCache) -> (GRUCellGradients, Array2<f64>, Array2<f64>) {
222        // Apply output dropout backward pass using saved mask
223        let dhy_dropped = if let Some(ref mask) = cache.output_dropout_mask {
224            let keep_prob = if let Some(ref dropout) = self.output_dropout {
225                1.0 - dropout.dropout_rate
226            } else {
227                1.0
228            };
229            dhy * mask / keep_prob
230        } else {
231            dhy.clone()
232        };
233
234        // Gradients for output computation: h_t = (1 - z_t) ⊙ h_{t-1} + z_t ⊙ h_tilde_t
235        let d_update_gate = &dhy_dropped * (&cache.new_gate - &cache.hx);
236        let d_new_gate = &dhy_dropped * &cache.update_gate;
237        let dhx_from_output = &dhy_dropped * cache.update_gate.map(|&x| 1.0 - x);
238
239        // Gradients for new gate: h_tilde_t = tanh(W_ih * x_t + b_ih + W_hh * reset_hidden + b_hh)
240        let d_new_gate_raw = &d_new_gate * cache.new_gate.map(|&x| 1.0 - x.powi(2));
241        
242        // Gradients for reset hidden: reset_hidden = r_t ⊙ h_{t-1}
243        let d_reset_hidden = self.w_hh.t().dot(&d_new_gate_raw);
244        let d_reset_gate = &d_reset_hidden * &cache.hx;
245        let dhx_from_reset = &d_reset_hidden * &cache.reset_gate;
246
247        // Gradients for reset gate: r_t = σ(W_ir * x_t + b_ir + W_hr * h_{t-1} + b_hr)
248        let d_reset_gate_raw = &d_reset_gate * &cache.reset_gate * cache.reset_gate.map(|&x| 1.0 - x);
249
250        // Gradients for update gate: z_t = σ(W_iz * x_t + b_iz + W_hz * h_{t-1} + b_hz)
251        let d_update_gate_raw = &d_update_gate * &cache.update_gate * cache.update_gate.map(|&x| 1.0 - x);
252
253        // Parameter gradients
254        let dw_ir = d_reset_gate_raw.dot(&cache.input.t());
255        let dw_hr = d_reset_gate_raw.dot(&cache.hx.t());
256        let db_ir = d_reset_gate_raw.clone();
257        let db_hr = d_reset_gate_raw.clone();
258
259        let dw_iz = d_update_gate_raw.dot(&cache.input.t());
260        let dw_hz = d_update_gate_raw.dot(&cache.hx.t());
261        let db_iz = d_update_gate_raw.clone();
262        let db_hz = d_update_gate_raw.clone();
263
264        let dw_ih = d_new_gate_raw.dot(&cache.input.t());
265        let dw_hh = d_new_gate_raw.dot(&cache.reset_hidden.t());
266        let db_ih = d_new_gate_raw.clone();
267        let db_hh = d_new_gate_raw.clone();
268
269        let gradients = GRUCellGradients {
270            w_ir: dw_ir, w_hr: dw_hr, b_ir: db_ir, b_hr: db_hr,
271            w_iz: dw_iz, w_hz: dw_hz, b_iz: db_iz, b_hz: db_hz,
272            w_ih: dw_ih, w_hh: dw_hh, b_ih: db_ih, b_hh: db_hh,
273        };
274
275        // Input and hidden gradients
276        let mut dx = self.w_ir.t().dot(&d_reset_gate_raw) + 
277                     self.w_iz.t().dot(&d_update_gate_raw) + 
278                     self.w_ih.t().dot(&d_new_gate_raw);
279        
280        let mut dhx = dhx_from_output + dhx_from_reset + 
281                      self.w_hr.t().dot(&d_reset_gate_raw) + 
282                      self.w_hz.t().dot(&d_update_gate_raw);
283
284        // Apply dropout gradients
285        if let Some(ref mask) = cache.input_dropout_mask {
286            let keep_prob = if let Some(ref dropout) = self.input_dropout {
287                1.0 - dropout.dropout_rate
288            } else {
289                1.0
290            };
291            dx = dx * mask / keep_prob;
292        }
293
294        if let Some(ref mask) = cache.recurrent_dropout_mask {
295            let keep_prob = if let Some(ref dropout) = self.recurrent_dropout {
296                1.0 - dropout.dropout_rate
297            } else {
298                1.0
299            };
300            dhx = dhx * mask / keep_prob;
301        }
302
303        (gradients, dx, dhx)
304    }
305
306    /// Initialize zero gradients for accumulation
307    pub fn zero_gradients(&self) -> GRUCellGradients {
308        GRUCellGradients {
309            w_ir: Array2::zeros(self.w_ir.raw_dim()),
310            w_hr: Array2::zeros(self.w_hr.raw_dim()),
311            b_ir: Array2::zeros(self.b_ir.raw_dim()),
312            b_hr: Array2::zeros(self.b_hr.raw_dim()),
313            w_iz: Array2::zeros(self.w_iz.raw_dim()),
314            w_hz: Array2::zeros(self.w_hz.raw_dim()),
315            b_iz: Array2::zeros(self.b_iz.raw_dim()),
316            b_hz: Array2::zeros(self.b_hz.raw_dim()),
317            w_ih: Array2::zeros(self.w_ih.raw_dim()),
318            w_hh: Array2::zeros(self.w_hh.raw_dim()),
319            b_ih: Array2::zeros(self.b_ih.raw_dim()),
320            b_hh: Array2::zeros(self.b_hh.raw_dim()),
321        }
322    }
323
324    /// Apply gradients using the provided optimizer
325    pub fn update_parameters<O: crate::optimizers::Optimizer>(&mut self, gradients: &GRUCellGradients, optimizer: &mut O, prefix: &str) {
326        optimizer.update(&format!("{}_w_ir", prefix), &mut self.w_ir, &gradients.w_ir);
327        optimizer.update(&format!("{}_w_hr", prefix), &mut self.w_hr, &gradients.w_hr);
328        optimizer.update(&format!("{}_b_ir", prefix), &mut self.b_ir, &gradients.b_ir);
329        optimizer.update(&format!("{}_b_hr", prefix), &mut self.b_hr, &gradients.b_hr);
330        optimizer.update(&format!("{}_w_iz", prefix), &mut self.w_iz, &gradients.w_iz);
331        optimizer.update(&format!("{}_w_hz", prefix), &mut self.w_hz, &gradients.w_hz);
332        optimizer.update(&format!("{}_b_iz", prefix), &mut self.b_iz, &gradients.b_iz);
333        optimizer.update(&format!("{}_b_hz", prefix), &mut self.b_hz, &gradients.b_hz);
334        optimizer.update(&format!("{}_w_ih", prefix), &mut self.w_ih, &gradients.w_ih);
335        optimizer.update(&format!("{}_w_hh", prefix), &mut self.w_hh, &gradients.w_hh);
336        optimizer.update(&format!("{}_b_ih", prefix), &mut self.b_ih, &gradients.b_ih);
337        optimizer.update(&format!("{}_b_hh", prefix), &mut self.b_hh, &gradients.b_hh);
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344    use ndarray::arr2;
345
346    #[test]
347    fn test_gru_cell_forward() {
348        let input_size = 3;
349        let hidden_size = 2;
350        let mut cell = GRUCell::new(input_size, hidden_size);
351
352        let input = arr2(&[[0.5], [0.1], [-0.3]]);
353        let hx = arr2(&[[0.1], [0.2]]);
354
355        let hy = cell.forward(&input, &hx);
356
357        assert_eq!(hy.shape(), &[hidden_size, 1]);
358    }
359
360    #[test]
361    fn test_gru_cell_with_dropout() {
362        let input_size = 3;
363        let hidden_size = 2;
364        let mut cell = GRUCell::new(input_size, hidden_size)
365            .with_input_dropout(0.2, false)
366            .with_recurrent_dropout(0.3, true)
367            .with_output_dropout(0.1);
368
369        let input = arr2(&[[0.5], [0.1], [-0.3]]);
370        let hx = arr2(&[[0.1], [0.2]]);
371
372        // Test training mode
373        cell.train();
374        let hy_train = cell.forward(&input, &hx);
375
376        // Test evaluation mode
377        cell.eval();
378        let hy_eval = cell.forward(&input, &hx);
379
380        assert_eq!(hy_train.shape(), &[hidden_size, 1]);
381        assert_eq!(hy_eval.shape(), &[hidden_size, 1]);
382    }
383
384    #[test]
385    fn test_gru_backward_pass() {
386        let input_size = 2;
387        let hidden_size = 3;
388        let mut cell = GRUCell::new(input_size, hidden_size);
389
390        let input = arr2(&[[1.0], [0.5]]);
391        let hx = arr2(&[[0.1], [0.2], [0.3]]);
392
393        let (_hy, cache) = cell.forward_with_cache(&input, &hx);
394        
395        let dhy = arr2(&[[1.0], [1.0], [1.0]]);
396        let (gradients, dx, dhx) = cell.backward(&dhy, &cache);
397
398        assert_eq!(gradients.w_ir.shape(), &[hidden_size, input_size]);
399        assert_eq!(gradients.w_hr.shape(), &[hidden_size, hidden_size]);
400        assert_eq!(dx.shape(), &[input_size, 1]);
401        assert_eq!(dhx.shape(), &[hidden_size, 1]);
402    }
403}