rust_lstm/models/
gru_network.rs

1use ndarray::Array2;
2use crate::layers::gru_cell::{GRUCell, GRUCellGradients, GRUCellCache};
3use crate::optimizers::Optimizer;
4
5/// Cache for GRU network forward pass
6#[derive(Clone)]
7pub struct GRUNetworkCache {
8    pub caches: Vec<GRUCellCache>,
9}
10
11/// Configuration for layer-specific dropout settings
12#[derive(Clone)]
13pub struct LayerDropoutConfig {
14    pub input_dropout_rate: f64,
15    pub input_variational: bool,
16    pub recurrent_dropout_rate: f64,
17    pub recurrent_variational: bool,
18    pub output_dropout_rate: f64,
19}
20
21impl LayerDropoutConfig {
22    pub fn new() -> Self {
23        LayerDropoutConfig {
24            input_dropout_rate: 0.0,
25            input_variational: false,
26            recurrent_dropout_rate: 0.0,
27            recurrent_variational: false,
28            output_dropout_rate: 0.0,
29        }
30    }
31
32    pub fn with_input_dropout(mut self, rate: f64, variational: bool) -> Self {
33        self.input_dropout_rate = rate;
34        self.input_variational = variational;
35        self
36    }
37
38    pub fn with_recurrent_dropout(mut self, rate: f64, variational: bool) -> Self {
39        self.recurrent_dropout_rate = rate;
40        self.recurrent_variational = variational;
41        self
42    }
43
44    pub fn with_output_dropout(mut self, rate: f64) -> Self {
45        self.output_dropout_rate = rate;
46        self
47    }
48}
49
50/// Multi-layer GRU network for sequence modeling
51#[derive(Clone)]
52pub struct GRUNetwork {
53    cells: Vec<GRUCell>,
54    pub input_size: usize,
55    pub hidden_size: usize,
56    pub num_layers: usize,
57    pub is_training: bool,
58}
59
60impl GRUNetwork {
61    /// Creates a new multi-layer GRU network
62    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
63        let mut cells = Vec::new();
64        
65        for i in 0..num_layers {
66            let layer_input_size = if i == 0 { input_size } else { hidden_size };
67            cells.push(GRUCell::new(layer_input_size, hidden_size));
68        }
69        
70        GRUNetwork {
71            cells,
72            input_size,
73            hidden_size,
74            num_layers,
75            is_training: true,
76        }
77    }
78
79    /// Apply uniform dropout across all layers
80    pub fn with_input_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
81        for cell in &mut self.cells {
82            *cell = cell.clone().with_input_dropout(dropout_rate, variational);
83        }
84        self
85    }
86
87    pub fn with_recurrent_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
88        for cell in &mut self.cells {
89            *cell = cell.clone().with_recurrent_dropout(dropout_rate, variational);
90        }
91        self
92    }
93
94    pub fn with_output_dropout(mut self, dropout_rate: f64) -> Self {
95        // Apply output dropout to all layers except the last
96        for (i, cell) in self.cells.iter_mut().enumerate() {
97            if i < self.num_layers - 1 {
98                *cell = cell.clone().with_output_dropout(dropout_rate);
99            }
100        }
101        self
102    }
103
104    /// Apply layer-specific dropout configuration
105    pub fn with_layer_dropout(mut self, configs: Vec<LayerDropoutConfig>) -> Self {
106        if configs.len() != self.num_layers {
107            panic!("Number of dropout configs must match number of layers");
108        }
109
110        for (i, config) in configs.into_iter().enumerate() {
111            if config.input_dropout_rate > 0.0 {
112                self.cells[i] = self.cells[i].clone()
113                    .with_input_dropout(config.input_dropout_rate, config.input_variational);
114            }
115            if config.recurrent_dropout_rate > 0.0 {
116                self.cells[i] = self.cells[i].clone()
117                    .with_recurrent_dropout(config.recurrent_dropout_rate, config.recurrent_variational);
118            }
119            if config.output_dropout_rate > 0.0 && i < self.num_layers - 1 {
120                self.cells[i] = self.cells[i].clone()
121                    .with_output_dropout(config.output_dropout_rate);
122            }
123        }
124        self
125    }
126
127    pub fn train(&mut self) {
128        self.is_training = true;
129        for cell in &mut self.cells {
130            cell.train();
131        }
132    }
133
134    pub fn eval(&mut self) {
135        self.is_training = false;
136        for cell in &mut self.cells {
137            cell.eval();
138        }
139    }
140
141    /// Forward pass for a single time step
142    pub fn forward(&mut self, input: &Array2<f64>, hx: &[Array2<f64>]) -> Vec<Array2<f64>> {
143        if hx.len() != self.num_layers {
144            panic!("Number of hidden states must match number of layers");
145        }
146
147        let mut layer_input = input.clone();
148        let mut outputs = Vec::new();
149
150        for (i, cell) in self.cells.iter_mut().enumerate() {
151            let hy = cell.forward(&layer_input, &hx[i]);
152            outputs.push(hy.clone());
153            layer_input = hy;
154        }
155
156        outputs
157    }
158
159    /// Forward pass for a sequence with caching for training
160    pub fn forward_sequence_with_cache(&mut self, sequence: &[Array2<f64>]) -> (Vec<(Array2<f64>, Vec<Array2<f64>>)>, Vec<GRUNetworkCache>) {
161        let mut all_outputs = Vec::new();
162        let mut all_caches = Vec::new();
163
164        // Initialize hidden states for all layers
165        let mut hidden_states: Vec<Array2<f64>> = (0..self.num_layers)
166            .map(|_| Array2::zeros((self.hidden_size, 1)))
167            .collect();
168
169        for input in sequence {
170            let mut layer_input = input.clone();
171            let mut step_outputs = Vec::new();
172            let mut step_caches = Vec::new();
173
174            for (i, cell) in self.cells.iter_mut().enumerate() {
175                let (hy, cache) = cell.forward_with_cache(&layer_input, &hidden_states[i]);
176                
177                hidden_states[i] = hy.clone();
178                step_outputs.push(hy.clone());
179                step_caches.push(cache);
180                layer_input = hy;
181            }
182
183            // The final output is from the last layer
184            let final_output = step_outputs.last().unwrap().clone();
185            all_outputs.push((final_output, step_outputs));
186            all_caches.push(GRUNetworkCache { caches: step_caches });
187        }
188
189        (all_outputs, all_caches)
190    }
191
192    /// Backward pass for training
193    pub fn backward(&self, dhy: &Array2<f64>, cache: &GRUNetworkCache) -> (Vec<GRUCellGradients>, Array2<f64>) {
194        let mut gradients = Vec::new();
195        let mut dhx = dhy.clone();
196
197        // Backward through layers in reverse order
198        for (i, cell) in self.cells.iter().enumerate().rev() {
199            let (cell_gradients, _, dhx_prev) = cell.backward(&dhx, &cache.caches[i]);
200            gradients.insert(0, cell_gradients);
201            dhx = dhx_prev;
202        }
203
204        (gradients, dhx)
205    }
206
207    /// Update parameters using optimizer
208    pub fn update_parameters<O: Optimizer>(&mut self, gradients: &[GRUCellGradients], optimizer: &mut O) {
209        for (i, (cell, grad)) in self.cells.iter_mut().zip(gradients.iter()).enumerate() {
210            cell.update_parameters(grad, optimizer, &format!("layer_{}", i));
211        }
212    }
213
214    /// Initialize zero gradients for all layers
215    pub fn zero_gradients(&self) -> Vec<GRUCellGradients> {
216        self.cells.iter().map(|cell| cell.zero_gradients()).collect()
217    }
218
219    /// Get references to cells for inspection
220    pub fn get_cells(&self) -> &[GRUCell] {
221        &self.cells
222    }
223
224    /// Get mutable references to cells
225    pub fn get_cells_mut(&mut self) -> &mut [GRUCell] {
226        &mut self.cells
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use ndarray::arr2;
234
235    #[test]
236    fn test_gru_network_creation() {
237        let network = GRUNetwork::new(3, 5, 2);
238        assert_eq!(network.input_size, 3);
239        assert_eq!(network.hidden_size, 5);
240        assert_eq!(network.num_layers, 2);
241        assert_eq!(network.cells.len(), 2);
242    }
243
244    #[test]
245    fn test_gru_network_forward() {
246        let mut network = GRUNetwork::new(2, 3, 2);
247        let input = arr2(&[[1.0], [0.5]]);
248        let hidden_states = vec![
249            arr2(&[[0.1], [0.2], [0.3]]),
250            arr2(&[[0.0], [0.1], [0.2]]),
251        ];
252
253        let outputs = network.forward(&input, &hidden_states);
254        assert_eq!(outputs.len(), 2);
255        assert_eq!(outputs[0].shape(), &[3, 1]);
256        assert_eq!(outputs[1].shape(), &[3, 1]);
257    }
258
259    #[test]
260    fn test_gru_network_sequence() {
261        let mut network = GRUNetwork::new(2, 3, 1);
262        let sequence = vec![
263            arr2(&[[1.0], [0.0]]),
264            arr2(&[[0.0], [1.0]]),
265            arr2(&[[-1.0], [0.5]]),
266        ];
267
268        let (outputs, caches) = network.forward_sequence_with_cache(&sequence);
269        
270        assert_eq!(outputs.len(), 3);
271        assert_eq!(caches.len(), 3);
272        
273        for (output, _) in &outputs {
274            assert_eq!(output.shape(), &[3, 1]);
275        }
276    }
277
278    #[test]
279    fn test_gru_network_with_dropout() {
280        let mut network = GRUNetwork::new(2, 3, 2)
281            .with_input_dropout(0.2, true)
282            .with_recurrent_dropout(0.3, false)
283            .with_output_dropout(0.1);
284
285        let input = arr2(&[[1.0], [0.5]]);
286        let hidden_states = vec![
287            arr2(&[[0.1], [0.2], [0.3]]),
288            arr2(&[[0.0], [0.1], [0.2]]),
289        ];
290
291        // Test training mode
292        network.train();
293        let outputs_train = network.forward(&input, &hidden_states);
294
295        // Test evaluation mode
296        network.eval();
297        let outputs_eval = network.forward(&input, &hidden_states);
298
299        assert_eq!(outputs_train.len(), 2);
300        assert_eq!(outputs_eval.len(), 2);
301    }
302
303    #[test]
304    fn test_gru_network_layer_dropout() {
305        let layer_configs = vec![
306            LayerDropoutConfig::new().with_input_dropout(0.1, false),
307            LayerDropoutConfig::new().with_recurrent_dropout(0.2, true),
308        ];
309
310        let network = GRUNetwork::new(2, 3, 2)
311            .with_layer_dropout(layer_configs);
312
313        assert_eq!(network.cells.len(), 2);
314    }
315}