rust_lstm/models/
lstm_network.rs

1use ndarray::Array2;
2use crate::layers::lstm_cell::{LSTMCell, LSTMCellGradients, LSTMCellCache, LSTMCellBatchCache};
3use crate::optimizers::Optimizer;
4
5/// Holds cached values for all layers during network forward pass
6#[derive(Clone)]
7pub struct LSTMNetworkCache {
8    pub cell_caches: Vec<LSTMCellCache>,
9}
10
11/// Holds cached values for batch processing during network forward pass
12#[derive(Clone)]
13pub struct LSTMNetworkBatchCache {
14    pub cell_caches: Vec<LSTMCellBatchCache>,
15    pub batch_size: usize,
16}
17
18/// Multi-layer LSTM network for sequence modeling with dropout support
19/// 
20/// Stacks multiple LSTM cells where the output of layer i becomes 
21/// the input to layer i+1. Supports both inference and training with
22/// configurable dropout regularization.
23#[derive(Clone)]
24pub struct LSTMNetwork {
25    cells: Vec<LSTMCell>,
26    pub input_size: usize,
27    pub hidden_size: usize,
28    pub num_layers: usize,
29    pub is_training: bool,
30}
31
32impl LSTMNetwork {
33    /// Creates a new multi-layer LSTM network
34    /// 
35    /// First layer accepts `input_size` dimensions, subsequent layers 
36    /// accept `hidden_size` dimensions from the previous layer.
37    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
38        let mut cells = Vec::new();
39
40        for i in 0..num_layers {
41            let layer_input_size = if i == 0 { input_size } else { hidden_size };
42            cells.push(LSTMCell::new(layer_input_size, hidden_size));
43        }
44        
45        LSTMNetwork { 
46            cells,
47            input_size,
48            hidden_size,
49            num_layers,
50            is_training: true,
51        }
52    }
53
54    pub fn with_input_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
55        for cell in &mut self.cells {
56            *cell = cell.clone().with_input_dropout(dropout_rate, variational);
57        }
58        self
59    }
60
61    pub fn with_recurrent_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
62        for cell in &mut self.cells {
63            *cell = cell.clone().with_recurrent_dropout(dropout_rate, variational);
64        }
65        self
66    }
67
68    pub fn with_output_dropout(mut self, dropout_rate: f64) -> Self {
69        for (i, cell) in self.cells.iter_mut().enumerate() {
70            if i < self.num_layers - 1 {
71                *cell = cell.clone().with_output_dropout(dropout_rate);
72            }
73        }
74        self
75    }
76
77    pub fn with_zoneout(mut self, cell_zoneout_rate: f64, hidden_zoneout_rate: f64) -> Self {
78        for cell in &mut self.cells {
79            *cell = cell.clone().with_zoneout(cell_zoneout_rate, hidden_zoneout_rate);
80        }
81        self
82    }
83
84    pub fn with_layer_dropout(mut self, layer_configs: Vec<LayerDropoutConfig>) -> Self {
85        for (i, config) in layer_configs.into_iter().enumerate() {
86            if i < self.cells.len() {
87                let mut cell = self.cells[i].clone();
88                
89                if let Some((rate, variational)) = config.input_dropout {
90                    cell = cell.with_input_dropout(rate, variational);
91                }
92                if let Some((rate, variational)) = config.recurrent_dropout {
93                    cell = cell.with_recurrent_dropout(rate, variational);
94                }
95                if let Some(rate) = config.output_dropout {
96                    cell = cell.with_output_dropout(rate);
97                }
98                if let Some((cell_rate, hidden_rate)) = config.zoneout {
99                    cell = cell.with_zoneout(cell_rate, hidden_rate);
100                }
101                
102                self.cells[i] = cell;
103            }
104        }
105        self
106    }
107
108    pub fn train(&mut self) {
109        self.is_training = true;
110        for cell in &mut self.cells {
111            cell.train();
112        }
113    }
114
115    pub fn eval(&mut self) {
116        self.is_training = false;
117        for cell in &mut self.cells {
118            cell.eval();
119        }
120    }
121
122    /// Creates a network from existing cells (used for deserialization)
123    pub fn from_cells(cells: Vec<LSTMCell>, input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
124        LSTMNetwork {
125            cells,
126            input_size,
127            hidden_size,
128            num_layers,
129            is_training: true,
130        }
131    }
132
133    /// Get reference to the cells (used for serialization)
134    pub fn get_cells(&self) -> &[LSTMCell] {
135        &self.cells
136    }
137
138    /// Get mutable reference to the cells (for training mode changes)
139    pub fn get_cells_mut(&mut self) -> &mut [LSTMCell] {
140        &mut self.cells
141    }
142
143    /// Forward pass for inference (no caching)
144    pub fn forward(&mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
145        let (hy, cy, _) = self.forward_with_cache(input, hx, cx);
146        (hy, cy)
147    }
148
149    /// Forward pass with caching for training
150    pub fn forward_with_cache(&mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>, LSTMNetworkCache) {
151        let mut current_input = input.clone();
152        let mut current_hx = hx.clone();
153        let mut current_cx = cx.clone();
154        let mut cell_caches = Vec::new();
155
156        for cell in &mut self.cells {
157            let (new_hx, new_cx, cache) = cell.forward_with_cache(&current_input, &current_hx, &current_cx);
158            cell_caches.push(cache);
159
160            current_input = new_hx.clone();
161            current_hx = new_hx;
162            current_cx = new_cx;
163        }
164
165        let network_cache = LSTMNetworkCache { cell_caches };
166        (current_hx, current_cx, network_cache)
167    }
168
169    /// Backward pass through all layers (reverse order)
170    /// 
171    /// Implements backpropagation through the multi-layer stack.
172    /// Returns gradients for each layer and input gradients.
173    pub fn backward(&self, dhy: &Array2<f64>, dcy: &Array2<f64>, cache: &LSTMNetworkCache) -> (Vec<LSTMCellGradients>, Array2<f64>) {
174        let mut gradients = Vec::new();
175        let mut current_dhy = dhy.clone();
176        let mut current_dcy = dcy.clone();
177
178        for (i, cell) in self.cells.iter().enumerate().rev() {
179            let cell_cache = &cache.cell_caches[i];
180            let (cell_gradients, dx, _dhx_prev, dcx_prev) = cell.backward(&current_dhy, &current_dcy, cell_cache);
181            
182            gradients.push(cell_gradients);
183
184            if i > 0 {
185                current_dhy = dx;
186                current_dcy = dcx_prev;
187            }
188        }
189
190        gradients.reverse();
191        
192        let dx_input = if !gradients.is_empty() {
193            let first_cell = &self.cells[0];
194            let first_cache = &cache.cell_caches[0];
195            let (_, dx_input, _, _) = first_cell.backward(dhy, dcy, first_cache);
196            dx_input
197        } else {
198            Array2::zeros(dhy.raw_dim())
199        };
200
201        (gradients, dx_input)
202    }
203
204    /// Update parameters for all layers using computed gradients
205    pub fn update_parameters<O: Optimizer>(&mut self, gradients: &[LSTMCellGradients], optimizer: &mut O) {
206        for (i, (cell, cell_gradients)) in self.cells.iter_mut().zip(gradients.iter()).enumerate() {
207            let prefix = format!("layer_{}", i);
208            cell.update_parameters(cell_gradients, optimizer, &prefix);
209        }
210    }
211
212    /// Initialize zero gradients for all layers
213    pub fn zero_gradients(&self) -> Vec<LSTMCellGradients> {
214        self.cells.iter().map(|cell| cell.zero_gradients()).collect()
215    }
216
217    /// Process an entire sequence with caching for training
218    /// 
219    /// Maintains hidden/cell state across time steps within the sequence.
220    /// Returns outputs and caches for each time step.
221    pub fn forward_sequence_with_cache(&mut self, sequence: &[Array2<f64>]) -> (Vec<(Array2<f64>, Array2<f64>)>, Vec<LSTMNetworkCache>) {
222        let mut outputs = Vec::new();
223        let mut caches = Vec::new();
224        let mut hx = Array2::zeros((self.hidden_size, 1));
225        let mut cx = Array2::zeros((self.hidden_size, 1));
226
227        for input in sequence {
228            let (new_hx, new_cx, cache) = self.forward_with_cache(input, &hx, &cx);
229            outputs.push((new_hx.clone(), new_cx.clone()));
230            caches.push(cache);
231            hx = new_hx;
232            cx = new_cx;
233        }
234
235        (outputs, caches)
236    }
237
238    /// Process multiple sequences in a batch
239    /// 
240    /// # Arguments
241    /// * `batch_sequences` - Vector of sequences, each sequence is a Vec<Array2<f64>>
242    ///   where each Array2 has shape (input_size, 1) for single sequences
243    /// 
244    /// # Returns
245    /// * Vector of sequence outputs, where each sequence output is Vec<(Array2<f64>, Array2<f64>)>
246    pub fn forward_batch_sequences(&mut self, batch_sequences: &[Vec<Array2<f64>>]) -> Vec<Vec<(Array2<f64>, Array2<f64>)>> {
247        // Find the maximum sequence length for padding
248        let max_seq_len = batch_sequences.iter().map(|seq| seq.len()).max().unwrap_or(0);
249        let batch_size = batch_sequences.len();
250        
251        if batch_size == 0 || max_seq_len == 0 {
252            return Vec::new();
253        }
254
255        let mut batch_outputs = vec![Vec::new(); batch_size];
256        
257        // Initialize batch hidden and cell states
258        let mut batch_hx = Array2::zeros((self.hidden_size, batch_size));
259        let mut batch_cx = Array2::zeros((self.hidden_size, batch_size));
260
261        // Process each time step across all sequences in the batch
262        for t in 0..max_seq_len {
263            // Prepare batch input for current time step
264            let mut batch_input = Array2::zeros((self.input_size, batch_size));
265            let mut active_sequences = Vec::new();
266            
267            for (batch_idx, sequence) in batch_sequences.iter().enumerate() {
268                if t < sequence.len() {
269                    // Copy input for this sequence at time step t
270                    batch_input.column_mut(batch_idx).assign(&sequence[t].column(0));
271                    active_sequences.push(batch_idx);
272                }
273            }
274
275            if active_sequences.is_empty() {
276                break; // No more active sequences
277            }
278
279            // Forward pass for this time step across the batch
280            let (new_batch_hx, new_batch_cx) = self.forward_batch(&batch_input, &batch_hx, &batch_cx);
281            
282            // Update states and collect outputs for active sequences
283            batch_hx = new_batch_hx.clone();
284            batch_cx = new_batch_cx.clone();
285
286            // Store outputs for each active sequence
287            for &batch_idx in &active_sequences {
288                let hy = new_batch_hx.column(batch_idx).to_owned().insert_axis(ndarray::Axis(1));
289                let cy = new_batch_cx.column(batch_idx).to_owned().insert_axis(ndarray::Axis(1));
290                batch_outputs[batch_idx].push((hy, cy));
291            }
292        }
293
294        batch_outputs
295    }
296
297    /// Batch forward pass for single time step across multiple sequences
298    /// 
299    /// # Arguments
300    /// * `batch_input` - Input tensor of shape (input_size, batch_size)
301    /// * `batch_hx` - Hidden states tensor of shape (hidden_size, batch_size)  
302    /// * `batch_cx` - Cell states tensor of shape (hidden_size, batch_size)
303    /// 
304    /// # Returns
305    /// * Tuple of (new_hidden_states, new_cell_states) with same batch dimensions
306    pub fn forward_batch(&mut self, batch_input: &Array2<f64>, batch_hx: &Array2<f64>, batch_cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
307        let mut current_input = batch_input.clone();
308        let mut current_hx = batch_hx.clone();
309        let mut current_cx = batch_cx.clone();
310
311        // Process through each layer
312        for cell in &mut self.cells {
313            let (new_hx, new_cx) = cell.forward_batch(&current_input, &current_hx, &current_cx);
314            current_input = new_hx.clone(); // Output of layer i becomes input to layer i+1
315            current_hx = new_hx;
316            current_cx = new_cx;
317        }
318
319        (current_hx, current_cx)
320    }
321
322    /// Batch forward pass with caching for training
323    /// 
324    /// Similar to forward_batch but caches intermediate values needed for backpropagation
325    pub fn forward_batch_with_cache(&mut self, batch_input: &Array2<f64>, batch_hx: &Array2<f64>, batch_cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>, LSTMNetworkBatchCache) {
326        let mut current_input = batch_input.clone();
327        let mut current_hx = batch_hx.clone();
328        let mut current_cx = batch_cx.clone();
329        let mut cell_caches = Vec::new();
330
331        // Process through each layer with caching
332        for cell in &mut self.cells {
333            let (new_hx, new_cx, cache) = cell.forward_batch_with_cache(&current_input, &current_hx, &current_cx);
334            cell_caches.push(cache);
335
336            current_input = new_hx.clone();
337            current_hx = new_hx;
338            current_cx = new_cx;
339        }
340
341        let network_cache = LSTMNetworkBatchCache { 
342            cell_caches,
343            batch_size: batch_input.ncols(),
344        };
345        
346        (current_hx, current_cx, network_cache)
347    }
348
349    /// Batch backward pass for training
350    /// 
351    /// Computes gradients for an entire batch simultaneously
352    pub fn backward_batch(&self, dhy: &Array2<f64>, dcy: &Array2<f64>, cache: &LSTMNetworkBatchCache) -> (Vec<LSTMCellGradients>, Array2<f64>) {
353        let mut gradients = Vec::new();
354        let mut current_dhy = dhy.clone();
355        let mut current_dcy = dcy.clone();
356
357        // Backward through layers in reverse order
358        for (i, cell) in self.cells.iter().enumerate().rev() {
359            let cell_cache = &cache.cell_caches[i];
360            let (cell_gradients, dx, _dhx_prev, dcx_prev) = cell.backward_batch(&current_dhy, &current_dcy, cell_cache);
361            
362            gradients.push(cell_gradients);
363
364            if i > 0 {
365                current_dhy = dx;
366                current_dcy = dcx_prev;
367            }
368        }
369
370        gradients.reverse();
371        
372        let dx_input = if !gradients.is_empty() {
373            let first_cell = &self.cells[0];
374            let first_cache = &cache.cell_caches[0];
375            let (_, dx_input, _, _) = first_cell.backward_batch(dhy, dcy, first_cache);
376            dx_input
377        } else {
378            Array2::<f64>::zeros(dhy.raw_dim())
379        };
380
381        (gradients, dx_input)
382    }
383}
384
385/// Configuration for layer-specific dropout settings
386#[derive(Clone, Debug)]
387pub struct LayerDropoutConfig {
388    pub input_dropout: Option<(f64, bool)>,     // (rate, variational)
389    pub recurrent_dropout: Option<(f64, bool)>, // (rate, variational)
390    pub output_dropout: Option<f64>,            // rate
391    pub zoneout: Option<(f64, f64)>,           // (cell_rate, hidden_rate)
392}
393
394impl LayerDropoutConfig {
395    pub fn new() -> Self {
396        LayerDropoutConfig {
397            input_dropout: None,
398            recurrent_dropout: None,
399            output_dropout: None,
400            zoneout: None,
401        }
402    }
403
404    pub fn with_input_dropout(mut self, rate: f64, variational: bool) -> Self {
405        self.input_dropout = Some((rate, variational));
406        self
407    }
408
409    pub fn with_recurrent_dropout(mut self, rate: f64, variational: bool) -> Self {
410        self.recurrent_dropout = Some((rate, variational));
411        self
412    }
413
414    pub fn with_output_dropout(mut self, rate: f64) -> Self {
415        self.output_dropout = Some(rate);
416        self
417    }
418
419    pub fn with_zoneout(mut self, cell_rate: f64, hidden_rate: f64) -> Self {
420        self.zoneout = Some((cell_rate, hidden_rate));
421        self
422    }
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428    use ndarray::arr2;
429
430    #[test]
431    fn test_lstm_network_forward() {
432        let input_size = 3;
433        let hidden_size = 2;
434        let num_layers = 2;
435        let mut network = LSTMNetwork::new(input_size, hidden_size, num_layers);
436
437        let input = arr2(&[[0.5], [0.1], [-0.3]]);
438        let hx = arr2(&[[0.0], [0.0]]);
439        let cx = arr2(&[[0.0], [0.0]]);
440
441        let (hy, cy) = network.forward(&input, &hx, &cx);
442
443        assert_eq!(hy.shape(), &[hidden_size, 1]);
444        assert_eq!(cy.shape(), &[hidden_size, 1]);
445    }
446
447    #[test]
448    fn test_lstm_network_with_dropout() {
449        let input_size = 3;
450        let hidden_size = 2;
451        let num_layers = 2;
452        let mut network = LSTMNetwork::new(input_size, hidden_size, num_layers)
453            .with_input_dropout(0.2, true)  // Variational input dropout
454            .with_recurrent_dropout(0.3, true)  // Variational recurrent dropout
455            .with_output_dropout(0.1)
456            .with_zoneout(0.1, 0.1);
457
458        let input = arr2(&[[0.5], [0.1], [-0.3]]);
459        let hx = arr2(&[[0.0], [0.0]]);
460        let cx = arr2(&[[0.0], [0.0]]);
461
462        // Test training mode
463        network.train();
464        let (hy_train, cy_train) = network.forward(&input, &hx, &cx);
465
466        // Test evaluation mode
467        network.eval();
468        let (hy_eval, cy_eval) = network.forward(&input, &hx, &cx);
469
470        assert_eq!(hy_train.shape(), &[hidden_size, 1]);
471        assert_eq!(cy_train.shape(), &[hidden_size, 1]);
472        assert_eq!(hy_eval.shape(), &[hidden_size, 1]);
473        assert_eq!(cy_eval.shape(), &[hidden_size, 1]);
474    }
475
476    #[test]
477    fn test_layer_specific_dropout() {
478        let input_size = 3;
479        let hidden_size = 2;
480        let num_layers = 2;
481
482        let layer_configs = vec![
483            LayerDropoutConfig::new()
484                .with_input_dropout(0.2, true)
485                .with_recurrent_dropout(0.3, true),
486            LayerDropoutConfig::new()
487                .with_output_dropout(0.1)
488                .with_zoneout(0.1, 0.1),
489        ];
490
491        let mut network = LSTMNetwork::new(input_size, hidden_size, num_layers)
492            .with_layer_dropout(layer_configs);
493
494        let input = arr2(&[[0.5], [0.1], [-0.3]]);
495        let hx = arr2(&[[0.0], [0.0]]);
496        let cx = arr2(&[[0.0], [0.0]]);
497
498        let (hy, cy) = network.forward(&input, &hx, &cx);
499
500        assert_eq!(hy.shape(), &[hidden_size, 1]);
501        assert_eq!(cy.shape(), &[hidden_size, 1]);
502    }
503}