rust_lstm/layers/
bilstm_network.rs

1use ndarray::Array2;
2use crate::layers::lstm_cell::{LSTMCell, LSTMCellGradients, LSTMCellCache};
3use crate::optimizers::Optimizer;
4
5/// Cache for bidirectional LSTM forward pass
6#[derive(Clone)]
7pub struct BiLSTMNetworkCache {
8    pub forward_caches: Vec<LSTMCellCache>,
9    pub backward_caches: Vec<LSTMCellCache>,
10}
11
12/// Configuration for combining forward and backward outputs
13#[derive(Clone, Debug)]
14pub enum CombineMode {
15    Concat,
16    Sum,
17    Average,
18}
19
20/// Bidirectional LSTM network for sequence modeling
21#[derive(Clone)]
22pub struct BiLSTMNetwork {
23    forward_cells: Vec<LSTMCell>,
24    backward_cells: Vec<LSTMCell>,
25    pub input_size: usize,
26    pub hidden_size: usize,
27    pub num_layers: usize,
28    pub combine_mode: CombineMode,
29    pub is_training: bool,
30}
31
32impl BiLSTMNetwork {
33    /// Creates a new bidirectional LSTM network
34    /// 
35    /// # Arguments
36    /// * `input_size` - Size of input features
37    /// * `hidden_size` - Size of hidden state for each direction
38    /// * `num_layers` - Number of bidirectional layers
39    /// * `combine_mode` - How to combine forward and backward outputs
40    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize, combine_mode: CombineMode) -> Self {
41        let mut forward_cells = Vec::new();
42        let mut backward_cells = Vec::new();
43
44        for i in 0..num_layers {
45            let layer_input_size = if i == 0 {
46                input_size
47            } else {
48                match combine_mode {
49                    CombineMode::Concat => 2 * hidden_size,
50                    CombineMode::Sum | CombineMode::Average => hidden_size,
51                }
52            };
53
54            forward_cells.push(LSTMCell::new(layer_input_size, hidden_size));
55            backward_cells.push(LSTMCell::new(layer_input_size, hidden_size));
56        }
57        
58        BiLSTMNetwork { 
59            forward_cells,
60            backward_cells,
61            input_size,
62            hidden_size,
63            num_layers,
64            combine_mode,
65            is_training: true,
66        }
67    }
68
69    /// Create BiLSTM with concatenated outputs (most common)
70    pub fn new_concat(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
71        Self::new(input_size, hidden_size, num_layers, CombineMode::Concat)
72    }
73
74    /// Create BiLSTM with summed outputs
75    pub fn new_sum(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
76        Self::new(input_size, hidden_size, num_layers, CombineMode::Sum)
77    }
78
79    /// Create BiLSTM with averaged outputs
80    pub fn new_average(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
81        Self::new(input_size, hidden_size, num_layers, CombineMode::Average)
82    }
83
84    /// Get the output size based on combine mode
85    pub fn output_size(&self) -> usize {
86        match self.combine_mode {
87            CombineMode::Concat => 2 * self.hidden_size,
88            CombineMode::Sum | CombineMode::Average => self.hidden_size,
89        }
90    }
91
92    /// Apply dropout configuration to all cells
93    pub fn with_input_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
94        for cell in &mut self.forward_cells {
95            *cell = cell.clone().with_input_dropout(dropout_rate, variational);
96        }
97        for cell in &mut self.backward_cells {
98            *cell = cell.clone().with_input_dropout(dropout_rate, variational);
99        }
100        self
101    }
102
103    pub fn with_recurrent_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
104        for cell in &mut self.forward_cells {
105            *cell = cell.clone().with_recurrent_dropout(dropout_rate, variational);
106        }
107        for cell in &mut self.backward_cells {
108            *cell = cell.clone().with_recurrent_dropout(dropout_rate, variational);
109        }
110        self
111    }
112
113    pub fn with_output_dropout(mut self, dropout_rate: f64) -> Self {
114        // Apply output dropout to all layers except the last
115        for (i, cell) in self.forward_cells.iter_mut().enumerate() {
116            if i < self.num_layers - 1 {
117                *cell = cell.clone().with_output_dropout(dropout_rate);
118            }
119        }
120        for (i, cell) in self.backward_cells.iter_mut().enumerate() {
121            if i < self.num_layers - 1 {
122                *cell = cell.clone().with_output_dropout(dropout_rate);
123            }
124        }
125        self
126    }
127
128    pub fn with_zoneout(mut self, cell_zoneout_rate: f64, hidden_zoneout_rate: f64) -> Self {
129        for cell in &mut self.forward_cells {
130            *cell = cell.clone().with_zoneout(cell_zoneout_rate, hidden_zoneout_rate);
131        }
132        for cell in &mut self.backward_cells {
133            *cell = cell.clone().with_zoneout(cell_zoneout_rate, hidden_zoneout_rate);
134        }
135        self
136    }
137
138    pub fn train(&mut self) {
139        self.is_training = true;
140        for cell in &mut self.forward_cells {
141            cell.train();
142        }
143        for cell in &mut self.backward_cells {
144            cell.train();
145        }
146    }
147
148    pub fn eval(&mut self) {
149        self.is_training = false;
150        for cell in &mut self.forward_cells {
151            cell.eval();
152        }
153        for cell in &mut self.backward_cells {
154            cell.eval();
155        }
156    }
157
158    /// Combine forward and backward outputs according to the combine mode
159    fn combine_outputs(&self, forward: &Array2<f64>, backward: &Array2<f64>) -> Array2<f64> {
160        match self.combine_mode {
161            CombineMode::Concat => {
162                // Stack forward and backward outputs vertically
163                let mut combined = Array2::zeros((forward.nrows() + backward.nrows(), forward.ncols()));
164                combined.slice_mut(ndarray::s![..forward.nrows(), ..]).assign(forward);
165                combined.slice_mut(ndarray::s![forward.nrows().., ..]).assign(backward);
166                combined
167            },
168            CombineMode::Sum => forward + backward,
169            CombineMode::Average => (forward + backward) * 0.5,
170        }
171    }
172
173    /// Forward pass for a complete sequence
174    /// 
175    /// This is the main method for BiLSTM processing. It runs the forward direction
176    /// from start to end, backward direction from end to start, then combines outputs.
177    pub fn forward_sequence(&mut self, sequence: &[Array2<f64>]) -> Vec<Array2<f64>> {
178        let seq_len = sequence.len();
179        if seq_len == 0 {
180            return Vec::new();
181        }
182
183        // Process each layer sequentially
184        let mut layer_input_sequence = sequence.to_vec();
185
186        for layer_idx in 0..self.num_layers {
187            let mut forward_outputs = Vec::new();
188            let mut backward_outputs = Vec::new();
189
190            // Initialize states for this layer
191            let mut forward_hidden_state = Array2::zeros((self.hidden_size, 1));
192            let mut forward_cell_state = Array2::zeros((self.hidden_size, 1));
193            let mut backward_hidden_state = Array2::zeros((self.hidden_size, 1));
194            let mut backward_cell_state = Array2::zeros((self.hidden_size, 1));
195
196            // Forward direction
197            for t in 0..seq_len {
198                let (hy, cy) = self.forward_cells[layer_idx].forward(
199                    &layer_input_sequence[t],
200                    &forward_hidden_state,
201                    &forward_cell_state
202                );
203
204                forward_hidden_state = hy.clone();
205                forward_cell_state = cy;
206                forward_outputs.push(hy);
207            }
208
209            // Backward direction
210            for t in (0..seq_len).rev() {
211                let (hy, cy) = self.backward_cells[layer_idx].forward(
212                    &layer_input_sequence[t],
213                    &backward_hidden_state,
214                    &backward_cell_state
215                );
216
217                backward_hidden_state = hy.clone();
218                backward_cell_state = cy;
219                backward_outputs.push(hy);
220            }
221
222            // Reverse backward outputs to match forward sequence order
223            backward_outputs.reverse();
224
225            // Combine forward and backward outputs for this layer
226            let mut combined_outputs = Vec::new();
227            for (forward_out, backward_out) in forward_outputs.iter().zip(backward_outputs.iter()) {
228                combined_outputs.push(self.combine_outputs(forward_out, backward_out));
229            }
230
231            // Output of this layer becomes input to next layer
232            layer_input_sequence = combined_outputs;
233        }
234
235        layer_input_sequence
236    }
237
238    /// Forward pass with caching for training
239    pub fn forward_sequence_with_cache(&mut self, sequence: &[Array2<f64>]) -> (Vec<Array2<f64>>, BiLSTMNetworkCache) {
240        let seq_len = sequence.len();
241        if seq_len == 0 {
242            return (Vec::new(), BiLSTMNetworkCache {
243                forward_caches: Vec::new(),
244                backward_caches: Vec::new(),
245            });
246        }
247
248        let mut all_forward_caches = Vec::new();
249        let mut all_backward_caches = Vec::new();
250
251        // Process each layer sequentially
252        let mut layer_input_sequence = sequence.to_vec();
253
254        for layer_idx in 0..self.num_layers {
255            let mut forward_outputs = Vec::new();
256            let mut backward_outputs = Vec::new();
257            let mut forward_caches = Vec::new();
258            let mut backward_caches = Vec::new();
259
260            // Initialize states for this layer
261            let mut forward_hidden_state = Array2::zeros((self.hidden_size, 1));
262            let mut forward_cell_state = Array2::zeros((self.hidden_size, 1));
263            let mut backward_hidden_state = Array2::zeros((self.hidden_size, 1));
264            let mut backward_cell_state = Array2::zeros((self.hidden_size, 1));
265
266            // Forward direction with caching
267            for t in 0..seq_len {
268                let (hy, cy, cache) = self.forward_cells[layer_idx].forward_with_cache(
269                    &layer_input_sequence[t],
270                    &forward_hidden_state,
271                    &forward_cell_state
272                );
273
274                forward_hidden_state = hy.clone();
275                forward_cell_state = cy;
276                forward_outputs.push(hy);
277                forward_caches.push(cache);
278            }
279
280            // Backward direction with caching
281            for t in (0..seq_len).rev() {
282                let (hy, cy, cache) = self.backward_cells[layer_idx].forward_with_cache(
283                    &layer_input_sequence[t],
284                    &backward_hidden_state,
285                    &backward_cell_state
286                );
287
288                backward_hidden_state = hy.clone();
289                backward_cell_state = cy;
290                backward_outputs.push(hy);
291                backward_caches.push(cache);
292            }
293
294            // Reverse backward outputs and caches
295            backward_outputs.reverse();
296            backward_caches.reverse();
297
298            // Combine outputs for this layer
299            let mut combined_outputs = Vec::new();
300            for (forward_out, backward_out) in forward_outputs.iter().zip(backward_outputs.iter()) {
301                combined_outputs.push(self.combine_outputs(forward_out, backward_out));
302            }
303
304            // Store caches for this layer
305            all_forward_caches.extend(forward_caches);
306            all_backward_caches.extend(backward_caches);
307
308            // Output of this layer becomes input to next layer
309            layer_input_sequence = combined_outputs;
310        }
311
312        let cache = BiLSTMNetworkCache {
313            forward_caches: all_forward_caches,
314            backward_caches: all_backward_caches,
315        };
316
317        (layer_input_sequence, cache)
318    }
319
320    /// Get references to forward and backward cells for serialization
321    pub fn get_forward_cells(&self) -> &[LSTMCell] {
322        &self.forward_cells
323    }
324
325    pub fn get_backward_cells(&self) -> &[LSTMCell] {
326        &self.backward_cells
327    }
328
329    /// Get mutable references for training mode changes
330    pub fn get_forward_cells_mut(&mut self) -> &mut [LSTMCell] {
331        &mut self.forward_cells
332    }
333
334    pub fn get_backward_cells_mut(&mut self) -> &mut [LSTMCell] {
335        &mut self.backward_cells
336    }
337
338    /// Update parameters for both directions
339    pub fn update_parameters<O: Optimizer>(&mut self, 
340                                         forward_gradients: &[LSTMCellGradients], 
341                                         backward_gradients: &[LSTMCellGradients], 
342                                         optimizer: &mut O) {
343        // Update forward cells
344        for (i, (cell, gradients)) in self.forward_cells.iter_mut().zip(forward_gradients.iter()).enumerate() {
345            cell.update_parameters(gradients, optimizer, &format!("forward_layer_{}", i));
346        }
347
348        // Update backward cells
349        for (i, (cell, gradients)) in self.backward_cells.iter_mut().zip(backward_gradients.iter()).enumerate() {
350            cell.update_parameters(gradients, optimizer, &format!("backward_layer_{}", i));
351        }
352    }
353
354    /// Zero gradients for all cells
355    pub fn zero_gradients(&self) -> (Vec<LSTMCellGradients>, Vec<LSTMCellGradients>) {
356        let forward_gradients: Vec<_> = self.forward_cells.iter()
357            .map(|cell| cell.zero_gradients())
358            .collect();
359
360        let backward_gradients: Vec<_> = self.backward_cells.iter()
361            .map(|cell| cell.zero_gradients())
362            .collect();
363
364        (forward_gradients, backward_gradients)
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use ndarray::arr2;
372
373    #[test]
374    fn test_bilstm_creation() {
375        let network = BiLSTMNetwork::new_concat(3, 5, 2);
376        assert_eq!(network.input_size, 3);
377        assert_eq!(network.hidden_size, 5);
378        assert_eq!(network.num_layers, 2);
379        assert_eq!(network.output_size(), 10); // 2 * hidden_size for concat mode
380    }
381
382    #[test]
383    fn test_bilstm_combine_modes() {
384        let forward = arr2(&[[1.0], [2.0]]);
385        let backward = arr2(&[[3.0], [4.0]]);
386
387        let concat_network = BiLSTMNetwork::new_concat(2, 2, 1);
388        let concat_result = concat_network.combine_outputs(&forward, &backward);
389        assert_eq!(concat_result.shape(), &[4, 1]);
390        assert_eq!(concat_result[[0, 0]], 1.0);
391        assert_eq!(concat_result[[1, 0]], 2.0);
392        assert_eq!(concat_result[[2, 0]], 3.0);
393        assert_eq!(concat_result[[3, 0]], 4.0);
394
395        let sum_network = BiLSTMNetwork::new_sum(2, 2, 1);
396        let sum_result = sum_network.combine_outputs(&forward, &backward);
397        assert_eq!(sum_result.shape(), &[2, 1]);
398        assert_eq!(sum_result[[0, 0]], 4.0);
399        assert_eq!(sum_result[[1, 0]], 6.0);
400
401        let avg_network = BiLSTMNetwork::new_average(2, 2, 1);
402        let avg_result = avg_network.combine_outputs(&forward, &backward);
403        assert_eq!(avg_result.shape(), &[2, 1]);
404        assert_eq!(avg_result[[0, 0]], 2.0);
405        assert_eq!(avg_result[[1, 0]], 3.0);
406    }
407
408    #[test]
409    fn test_bilstm_forward_sequence() {
410        let mut network = BiLSTMNetwork::new_concat(2, 3, 1);
411        
412        let sequence = vec![
413            arr2(&[[1.0], [0.5]]),
414            arr2(&[[0.8], [0.2]]),
415            arr2(&[[0.3], [0.9]]),
416        ];
417
418        let outputs = network.forward_sequence(&sequence);
419        
420        assert_eq!(outputs.len(), 3);
421        for output in &outputs {
422            assert_eq!(output.shape(), &[6, 1]); // 2 * hidden_size for concat
423        }
424    }
425
426    #[test]
427    fn test_bilstm_training_mode() {
428        let mut network = BiLSTMNetwork::new_concat(2, 3, 1)
429            .with_input_dropout(0.1, false)
430            .with_recurrent_dropout(0.2, true);
431
432        // Test mode switching
433        network.train();
434        assert!(network.is_training);
435
436        network.eval();
437        assert!(!network.is_training);
438    }
439}