rust_lstm/models/
lstm_network.rs

1use ndarray::Array2;
2use crate::layers::lstm_cell::{LSTMCell, LSTMCellGradients, LSTMCellCache};
3use crate::optimizers::Optimizer;
4
5/// Holds cached values for all layers during network forward pass
6#[derive(Clone)]
7pub struct LSTMNetworkCache {
8    pub cell_caches: Vec<LSTMCellCache>,
9}
10
11/// Multi-layer LSTM network for sequence modeling with dropout support
12/// 
13/// Stacks multiple LSTM cells where the output of layer i becomes 
14/// the input to layer i+1. Supports both inference and training with
15/// configurable dropout regularization.
16#[derive(Clone)]
17pub struct LSTMNetwork {
18    cells: Vec<LSTMCell>,
19    pub input_size: usize,
20    pub hidden_size: usize,
21    pub num_layers: usize,
22    pub is_training: bool,
23}
24
25impl LSTMNetwork {
26    /// Creates a new multi-layer LSTM network
27    /// 
28    /// First layer accepts `input_size` dimensions, subsequent layers 
29    /// accept `hidden_size` dimensions from the previous layer.
30    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
31        let mut cells = Vec::new();
32
33        for i in 0..num_layers {
34            let layer_input_size = if i == 0 { input_size } else { hidden_size };
35            cells.push(LSTMCell::new(layer_input_size, hidden_size));
36        }
37        
38        LSTMNetwork { 
39            cells,
40            input_size,
41            hidden_size,
42            num_layers,
43            is_training: true,
44        }
45    }
46
47    pub fn with_input_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
48        for cell in &mut self.cells {
49            *cell = cell.clone().with_input_dropout(dropout_rate, variational);
50        }
51        self
52    }
53
54    pub fn with_recurrent_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
55        for cell in &mut self.cells {
56            *cell = cell.clone().with_recurrent_dropout(dropout_rate, variational);
57        }
58        self
59    }
60
61    pub fn with_output_dropout(mut self, dropout_rate: f64) -> Self {
62        for (i, cell) in self.cells.iter_mut().enumerate() {
63            if i < self.num_layers - 1 {
64                *cell = cell.clone().with_output_dropout(dropout_rate);
65            }
66        }
67        self
68    }
69
70    pub fn with_zoneout(mut self, cell_zoneout_rate: f64, hidden_zoneout_rate: f64) -> Self {
71        for cell in &mut self.cells {
72            *cell = cell.clone().with_zoneout(cell_zoneout_rate, hidden_zoneout_rate);
73        }
74        self
75    }
76
77    pub fn with_layer_dropout(mut self, layer_configs: Vec<LayerDropoutConfig>) -> Self {
78        for (i, config) in layer_configs.into_iter().enumerate() {
79            if i < self.cells.len() {
80                let mut cell = self.cells[i].clone();
81                
82                if let Some((rate, variational)) = config.input_dropout {
83                    cell = cell.with_input_dropout(rate, variational);
84                }
85                if let Some((rate, variational)) = config.recurrent_dropout {
86                    cell = cell.with_recurrent_dropout(rate, variational);
87                }
88                if let Some(rate) = config.output_dropout {
89                    cell = cell.with_output_dropout(rate);
90                }
91                if let Some((cell_rate, hidden_rate)) = config.zoneout {
92                    cell = cell.with_zoneout(cell_rate, hidden_rate);
93                }
94                
95                self.cells[i] = cell;
96            }
97        }
98        self
99    }
100
101    pub fn train(&mut self) {
102        self.is_training = true;
103        for cell in &mut self.cells {
104            cell.train();
105        }
106    }
107
108    pub fn eval(&mut self) {
109        self.is_training = false;
110        for cell in &mut self.cells {
111            cell.eval();
112        }
113    }
114
115    /// Creates a network from existing cells (used for deserialization)
116    pub fn from_cells(cells: Vec<LSTMCell>, input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
117        LSTMNetwork {
118            cells,
119            input_size,
120            hidden_size,
121            num_layers,
122            is_training: true,
123        }
124    }
125
126    /// Get reference to the cells (used for serialization)
127    pub fn get_cells(&self) -> &[LSTMCell] {
128        &self.cells
129    }
130
131    /// Get mutable reference to the cells (for training mode changes)
132    pub fn get_cells_mut(&mut self) -> &mut [LSTMCell] {
133        &mut self.cells
134    }
135
136    /// Forward pass for inference (no caching)
137    pub fn forward(&mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
138        let (hy, cy, _) = self.forward_with_cache(input, hx, cx);
139        (hy, cy)
140    }
141
142    /// Forward pass with caching for training
143    pub fn forward_with_cache(&mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>, LSTMNetworkCache) {
144        let mut current_input = input.clone();
145        let mut current_hx = hx.clone();
146        let mut current_cx = cx.clone();
147        let mut cell_caches = Vec::new();
148
149        for cell in &mut self.cells {
150            let (new_hx, new_cx, cache) = cell.forward_with_cache(&current_input, &current_hx, &current_cx);
151            cell_caches.push(cache);
152
153            current_input = new_hx.clone();
154            current_hx = new_hx;
155            current_cx = new_cx;
156        }
157
158        let network_cache = LSTMNetworkCache { cell_caches };
159        (current_hx, current_cx, network_cache)
160    }
161
162    /// Backward pass through all layers (reverse order)
163    /// 
164    /// Implements backpropagation through the multi-layer stack.
165    /// Returns gradients for each layer and input gradients.
166    pub fn backward(&self, dhy: &Array2<f64>, dcy: &Array2<f64>, cache: &LSTMNetworkCache) -> (Vec<LSTMCellGradients>, Array2<f64>) {
167        let mut gradients = Vec::new();
168        let mut current_dhy = dhy.clone();
169        let mut current_dcy = dcy.clone();
170
171        for (i, cell) in self.cells.iter().enumerate().rev() {
172            let cell_cache = &cache.cell_caches[i];
173            let (cell_gradients, dx, _dhx_prev, dcx_prev) = cell.backward(&current_dhy, &current_dcy, cell_cache);
174            
175            gradients.push(cell_gradients);
176
177            if i > 0 {
178                current_dhy = dx;
179                current_dcy = dcx_prev;
180            }
181        }
182
183        gradients.reverse();
184        
185        let dx_input = if !gradients.is_empty() {
186            let first_cell = &self.cells[0];
187            let first_cache = &cache.cell_caches[0];
188            let (_, dx_input, _, _) = first_cell.backward(dhy, dcy, first_cache);
189            dx_input
190        } else {
191            Array2::zeros(dhy.raw_dim())
192        };
193
194        (gradients, dx_input)
195    }
196
197    /// Update parameters for all layers using computed gradients
198    pub fn update_parameters<O: Optimizer>(&mut self, gradients: &[LSTMCellGradients], optimizer: &mut O) {
199        for (i, (cell, cell_gradients)) in self.cells.iter_mut().zip(gradients.iter()).enumerate() {
200            let prefix = format!("layer_{}", i);
201            cell.update_parameters(cell_gradients, optimizer, &prefix);
202        }
203    }
204
205    /// Initialize zero gradients for all layers
206    pub fn zero_gradients(&self) -> Vec<LSTMCellGradients> {
207        self.cells.iter().map(|cell| cell.zero_gradients()).collect()
208    }
209
210    /// Process an entire sequence with caching for training
211    /// 
212    /// Maintains hidden/cell state across time steps within the sequence.
213    /// Returns outputs and caches for each time step.
214    pub fn forward_sequence_with_cache(&mut self, sequence: &[Array2<f64>]) -> (Vec<(Array2<f64>, Array2<f64>)>, Vec<LSTMNetworkCache>) {
215        let mut outputs = Vec::new();
216        let mut caches = Vec::new();
217        let mut hx = Array2::zeros((self.hidden_size, 1));
218        let mut cx = Array2::zeros((self.hidden_size, 1));
219
220        for input in sequence {
221            let (new_hx, new_cx, cache) = self.forward_with_cache(input, &hx, &cx);
222            outputs.push((new_hx.clone(), new_cx.clone()));
223            caches.push(cache);
224            hx = new_hx;
225            cx = new_cx;
226        }
227
228        (outputs, caches)
229    }
230}
231
232/// Configuration for layer-specific dropout settings
233#[derive(Clone, Debug)]
234pub struct LayerDropoutConfig {
235    pub input_dropout: Option<(f64, bool)>,     // (rate, variational)
236    pub recurrent_dropout: Option<(f64, bool)>, // (rate, variational)
237    pub output_dropout: Option<f64>,            // rate
238    pub zoneout: Option<(f64, f64)>,           // (cell_rate, hidden_rate)
239}
240
241impl LayerDropoutConfig {
242    pub fn new() -> Self {
243        LayerDropoutConfig {
244            input_dropout: None,
245            recurrent_dropout: None,
246            output_dropout: None,
247            zoneout: None,
248        }
249    }
250
251    pub fn with_input_dropout(mut self, rate: f64, variational: bool) -> Self {
252        self.input_dropout = Some((rate, variational));
253        self
254    }
255
256    pub fn with_recurrent_dropout(mut self, rate: f64, variational: bool) -> Self {
257        self.recurrent_dropout = Some((rate, variational));
258        self
259    }
260
261    pub fn with_output_dropout(mut self, rate: f64) -> Self {
262        self.output_dropout = Some(rate);
263        self
264    }
265
266    pub fn with_zoneout(mut self, cell_rate: f64, hidden_rate: f64) -> Self {
267        self.zoneout = Some((cell_rate, hidden_rate));
268        self
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use ndarray::arr2;
276
277    #[test]
278    fn test_lstm_network_forward() {
279        let input_size = 3;
280        let hidden_size = 2;
281        let num_layers = 2;
282        let mut network = LSTMNetwork::new(input_size, hidden_size, num_layers);
283
284        let input = arr2(&[[0.5], [0.1], [-0.3]]);
285        let hx = arr2(&[[0.0], [0.0]]);
286        let cx = arr2(&[[0.0], [0.0]]);
287
288        let (hy, cy) = network.forward(&input, &hx, &cx);
289
290        assert_eq!(hy.shape(), &[hidden_size, 1]);
291        assert_eq!(cy.shape(), &[hidden_size, 1]);
292    }
293
294    #[test]
295    fn test_lstm_network_with_dropout() {
296        let input_size = 3;
297        let hidden_size = 2;
298        let num_layers = 2;
299        let mut network = LSTMNetwork::new(input_size, hidden_size, num_layers)
300            .with_input_dropout(0.2, true)  // Variational input dropout
301            .with_recurrent_dropout(0.3, true)  // Variational recurrent dropout
302            .with_output_dropout(0.1)
303            .with_zoneout(0.1, 0.1);
304
305        let input = arr2(&[[0.5], [0.1], [-0.3]]);
306        let hx = arr2(&[[0.0], [0.0]]);
307        let cx = arr2(&[[0.0], [0.0]]);
308
309        // Test training mode
310        network.train();
311        let (hy_train, cy_train) = network.forward(&input, &hx, &cx);
312
313        // Test evaluation mode
314        network.eval();
315        let (hy_eval, cy_eval) = network.forward(&input, &hx, &cx);
316
317        assert_eq!(hy_train.shape(), &[hidden_size, 1]);
318        assert_eq!(cy_train.shape(), &[hidden_size, 1]);
319        assert_eq!(hy_eval.shape(), &[hidden_size, 1]);
320        assert_eq!(cy_eval.shape(), &[hidden_size, 1]);
321    }
322
323    #[test]
324    fn test_layer_specific_dropout() {
325        let input_size = 3;
326        let hidden_size = 2;
327        let num_layers = 2;
328
329        let layer_configs = vec![
330            LayerDropoutConfig::new()
331                .with_input_dropout(0.2, true)
332                .with_recurrent_dropout(0.3, true),
333            LayerDropoutConfig::new()
334                .with_output_dropout(0.1)
335                .with_zoneout(0.1, 0.1),
336        ];
337
338        let mut network = LSTMNetwork::new(input_size, hidden_size, num_layers)
339            .with_layer_dropout(layer_configs);
340
341        let input = arr2(&[[0.5], [0.1], [-0.3]]);
342        let hx = arr2(&[[0.0], [0.0]]);
343        let cx = arr2(&[[0.0], [0.0]]);
344
345        let (hy, cy) = network.forward(&input, &hx, &cx);
346
347        assert_eq!(hy.shape(), &[hidden_size, 1]);
348        assert_eq!(cy.shape(), &[hidden_size, 1]);
349    }
350}