1use ndarray::Array2;
2use crate::layers::lstm_cell::{LSTMCell, LSTMCellGradients, LSTMCellCache};
3use crate::optimizers::Optimizer;
4
5#[derive(Clone)]
7pub struct BiLSTMNetworkCache {
8 pub forward_caches: Vec<LSTMCellCache>,
9 pub backward_caches: Vec<LSTMCellCache>,
10}
11
12#[derive(Clone, Debug)]
14pub enum CombineMode {
15 Concat,
16 Sum,
17 Average,
18}
19
20#[derive(Clone)]
22pub struct BiLSTMNetwork {
23 forward_cells: Vec<LSTMCell>,
24 backward_cells: Vec<LSTMCell>,
25 pub input_size: usize,
26 pub hidden_size: usize,
27 pub num_layers: usize,
28 pub combine_mode: CombineMode,
29 pub is_training: bool,
30}
31
32impl BiLSTMNetwork {
33 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize, combine_mode: CombineMode) -> Self {
41 let mut forward_cells = Vec::new();
42 let mut backward_cells = Vec::new();
43
44 for i in 0..num_layers {
45 let layer_input_size = if i == 0 {
46 input_size
47 } else {
48 match combine_mode {
49 CombineMode::Concat => 2 * hidden_size,
50 CombineMode::Sum | CombineMode::Average => hidden_size,
51 }
52 };
53
54 forward_cells.push(LSTMCell::new(layer_input_size, hidden_size));
55 backward_cells.push(LSTMCell::new(layer_input_size, hidden_size));
56 }
57
58 BiLSTMNetwork {
59 forward_cells,
60 backward_cells,
61 input_size,
62 hidden_size,
63 num_layers,
64 combine_mode,
65 is_training: true,
66 }
67 }
68
69 pub fn new_concat(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
71 Self::new(input_size, hidden_size, num_layers, CombineMode::Concat)
72 }
73
74 pub fn new_sum(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
76 Self::new(input_size, hidden_size, num_layers, CombineMode::Sum)
77 }
78
79 pub fn new_average(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
81 Self::new(input_size, hidden_size, num_layers, CombineMode::Average)
82 }
83
84 pub fn output_size(&self) -> usize {
86 match self.combine_mode {
87 CombineMode::Concat => 2 * self.hidden_size,
88 CombineMode::Sum | CombineMode::Average => self.hidden_size,
89 }
90 }
91
92 pub fn with_input_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
94 for cell in &mut self.forward_cells {
95 *cell = cell.clone().with_input_dropout(dropout_rate, variational);
96 }
97 for cell in &mut self.backward_cells {
98 *cell = cell.clone().with_input_dropout(dropout_rate, variational);
99 }
100 self
101 }
102
103 pub fn with_recurrent_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
104 for cell in &mut self.forward_cells {
105 *cell = cell.clone().with_recurrent_dropout(dropout_rate, variational);
106 }
107 for cell in &mut self.backward_cells {
108 *cell = cell.clone().with_recurrent_dropout(dropout_rate, variational);
109 }
110 self
111 }
112
113 pub fn with_output_dropout(mut self, dropout_rate: f64) -> Self {
114 for (i, cell) in self.forward_cells.iter_mut().enumerate() {
116 if i < self.num_layers - 1 {
117 *cell = cell.clone().with_output_dropout(dropout_rate);
118 }
119 }
120 for (i, cell) in self.backward_cells.iter_mut().enumerate() {
121 if i < self.num_layers - 1 {
122 *cell = cell.clone().with_output_dropout(dropout_rate);
123 }
124 }
125 self
126 }
127
128 pub fn with_zoneout(mut self, cell_zoneout_rate: f64, hidden_zoneout_rate: f64) -> Self {
129 for cell in &mut self.forward_cells {
130 *cell = cell.clone().with_zoneout(cell_zoneout_rate, hidden_zoneout_rate);
131 }
132 for cell in &mut self.backward_cells {
133 *cell = cell.clone().with_zoneout(cell_zoneout_rate, hidden_zoneout_rate);
134 }
135 self
136 }
137
138 pub fn train(&mut self) {
139 self.is_training = true;
140 for cell in &mut self.forward_cells {
141 cell.train();
142 }
143 for cell in &mut self.backward_cells {
144 cell.train();
145 }
146 }
147
148 pub fn eval(&mut self) {
149 self.is_training = false;
150 for cell in &mut self.forward_cells {
151 cell.eval();
152 }
153 for cell in &mut self.backward_cells {
154 cell.eval();
155 }
156 }
157
158 fn combine_outputs(&self, forward: &Array2<f64>, backward: &Array2<f64>) -> Array2<f64> {
160 match self.combine_mode {
161 CombineMode::Concat => {
162 let mut combined = Array2::zeros((forward.nrows() + backward.nrows(), forward.ncols()));
164 combined.slice_mut(ndarray::s![..forward.nrows(), ..]).assign(forward);
165 combined.slice_mut(ndarray::s![forward.nrows().., ..]).assign(backward);
166 combined
167 },
168 CombineMode::Sum => forward + backward,
169 CombineMode::Average => (forward + backward) * 0.5,
170 }
171 }
172
173 pub fn forward_sequence(&mut self, sequence: &[Array2<f64>]) -> Vec<Array2<f64>> {
178 let seq_len = sequence.len();
179 if seq_len == 0 {
180 return Vec::new();
181 }
182
183 let mut layer_input_sequence = sequence.to_vec();
185
186 for layer_idx in 0..self.num_layers {
187 let mut forward_outputs = Vec::new();
188 let mut backward_outputs = Vec::new();
189
190 let mut forward_hidden_state = Array2::zeros((self.hidden_size, 1));
192 let mut forward_cell_state = Array2::zeros((self.hidden_size, 1));
193 let mut backward_hidden_state = Array2::zeros((self.hidden_size, 1));
194 let mut backward_cell_state = Array2::zeros((self.hidden_size, 1));
195
196 for t in 0..seq_len {
198 let (hy, cy) = self.forward_cells[layer_idx].forward(
199 &layer_input_sequence[t],
200 &forward_hidden_state,
201 &forward_cell_state
202 );
203
204 forward_hidden_state = hy.clone();
205 forward_cell_state = cy;
206 forward_outputs.push(hy);
207 }
208
209 for t in (0..seq_len).rev() {
211 let (hy, cy) = self.backward_cells[layer_idx].forward(
212 &layer_input_sequence[t],
213 &backward_hidden_state,
214 &backward_cell_state
215 );
216
217 backward_hidden_state = hy.clone();
218 backward_cell_state = cy;
219 backward_outputs.push(hy);
220 }
221
222 backward_outputs.reverse();
224
225 let mut combined_outputs = Vec::new();
227 for (forward_out, backward_out) in forward_outputs.iter().zip(backward_outputs.iter()) {
228 combined_outputs.push(self.combine_outputs(forward_out, backward_out));
229 }
230
231 layer_input_sequence = combined_outputs;
233 }
234
235 layer_input_sequence
236 }
237
238 pub fn forward_sequence_with_cache(&mut self, sequence: &[Array2<f64>]) -> (Vec<Array2<f64>>, BiLSTMNetworkCache) {
240 let seq_len = sequence.len();
241 if seq_len == 0 {
242 return (Vec::new(), BiLSTMNetworkCache {
243 forward_caches: Vec::new(),
244 backward_caches: Vec::new(),
245 });
246 }
247
248 let mut all_forward_caches = Vec::new();
249 let mut all_backward_caches = Vec::new();
250
251 let mut layer_input_sequence = sequence.to_vec();
253
254 for layer_idx in 0..self.num_layers {
255 let mut forward_outputs = Vec::new();
256 let mut backward_outputs = Vec::new();
257 let mut forward_caches = Vec::new();
258 let mut backward_caches = Vec::new();
259
260 let mut forward_hidden_state = Array2::zeros((self.hidden_size, 1));
262 let mut forward_cell_state = Array2::zeros((self.hidden_size, 1));
263 let mut backward_hidden_state = Array2::zeros((self.hidden_size, 1));
264 let mut backward_cell_state = Array2::zeros((self.hidden_size, 1));
265
266 for t in 0..seq_len {
268 let (hy, cy, cache) = self.forward_cells[layer_idx].forward_with_cache(
269 &layer_input_sequence[t],
270 &forward_hidden_state,
271 &forward_cell_state
272 );
273
274 forward_hidden_state = hy.clone();
275 forward_cell_state = cy;
276 forward_outputs.push(hy);
277 forward_caches.push(cache);
278 }
279
280 for t in (0..seq_len).rev() {
282 let (hy, cy, cache) = self.backward_cells[layer_idx].forward_with_cache(
283 &layer_input_sequence[t],
284 &backward_hidden_state,
285 &backward_cell_state
286 );
287
288 backward_hidden_state = hy.clone();
289 backward_cell_state = cy;
290 backward_outputs.push(hy);
291 backward_caches.push(cache);
292 }
293
294 backward_outputs.reverse();
296 backward_caches.reverse();
297
298 let mut combined_outputs = Vec::new();
300 for (forward_out, backward_out) in forward_outputs.iter().zip(backward_outputs.iter()) {
301 combined_outputs.push(self.combine_outputs(forward_out, backward_out));
302 }
303
304 all_forward_caches.extend(forward_caches);
306 all_backward_caches.extend(backward_caches);
307
308 layer_input_sequence = combined_outputs;
310 }
311
312 let cache = BiLSTMNetworkCache {
313 forward_caches: all_forward_caches,
314 backward_caches: all_backward_caches,
315 };
316
317 (layer_input_sequence, cache)
318 }
319
320 pub fn get_forward_cells(&self) -> &[LSTMCell] {
322 &self.forward_cells
323 }
324
325 pub fn get_backward_cells(&self) -> &[LSTMCell] {
326 &self.backward_cells
327 }
328
329 pub fn get_forward_cells_mut(&mut self) -> &mut [LSTMCell] {
331 &mut self.forward_cells
332 }
333
334 pub fn get_backward_cells_mut(&mut self) -> &mut [LSTMCell] {
335 &mut self.backward_cells
336 }
337
338 pub fn update_parameters<O: Optimizer>(&mut self,
340 forward_gradients: &[LSTMCellGradients],
341 backward_gradients: &[LSTMCellGradients],
342 optimizer: &mut O) {
343 for (i, (cell, gradients)) in self.forward_cells.iter_mut().zip(forward_gradients.iter()).enumerate() {
345 cell.update_parameters(gradients, optimizer, &format!("forward_layer_{}", i));
346 }
347
348 for (i, (cell, gradients)) in self.backward_cells.iter_mut().zip(backward_gradients.iter()).enumerate() {
350 cell.update_parameters(gradients, optimizer, &format!("backward_layer_{}", i));
351 }
352 }
353
354 pub fn zero_gradients(&self) -> (Vec<LSTMCellGradients>, Vec<LSTMCellGradients>) {
356 let forward_gradients: Vec<_> = self.forward_cells.iter()
357 .map(|cell| cell.zero_gradients())
358 .collect();
359
360 let backward_gradients: Vec<_> = self.backward_cells.iter()
361 .map(|cell| cell.zero_gradients())
362 .collect();
363
364 (forward_gradients, backward_gradients)
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use ndarray::arr2;
372
373 #[test]
374 fn test_bilstm_creation() {
375 let network = BiLSTMNetwork::new_concat(3, 5, 2);
376 assert_eq!(network.input_size, 3);
377 assert_eq!(network.hidden_size, 5);
378 assert_eq!(network.num_layers, 2);
379 assert_eq!(network.output_size(), 10); }
381
382 #[test]
383 fn test_bilstm_combine_modes() {
384 let forward = arr2(&[[1.0], [2.0]]);
385 let backward = arr2(&[[3.0], [4.0]]);
386
387 let concat_network = BiLSTMNetwork::new_concat(2, 2, 1);
388 let concat_result = concat_network.combine_outputs(&forward, &backward);
389 assert_eq!(concat_result.shape(), &[4, 1]);
390 assert_eq!(concat_result[[0, 0]], 1.0);
391 assert_eq!(concat_result[[1, 0]], 2.0);
392 assert_eq!(concat_result[[2, 0]], 3.0);
393 assert_eq!(concat_result[[3, 0]], 4.0);
394
395 let sum_network = BiLSTMNetwork::new_sum(2, 2, 1);
396 let sum_result = sum_network.combine_outputs(&forward, &backward);
397 assert_eq!(sum_result.shape(), &[2, 1]);
398 assert_eq!(sum_result[[0, 0]], 4.0);
399 assert_eq!(sum_result[[1, 0]], 6.0);
400
401 let avg_network = BiLSTMNetwork::new_average(2, 2, 1);
402 let avg_result = avg_network.combine_outputs(&forward, &backward);
403 assert_eq!(avg_result.shape(), &[2, 1]);
404 assert_eq!(avg_result[[0, 0]], 2.0);
405 assert_eq!(avg_result[[1, 0]], 3.0);
406 }
407
408 #[test]
409 fn test_bilstm_forward_sequence() {
410 let mut network = BiLSTMNetwork::new_concat(2, 3, 1);
411
412 let sequence = vec![
413 arr2(&[[1.0], [0.5]]),
414 arr2(&[[0.8], [0.2]]),
415 arr2(&[[0.3], [0.9]]),
416 ];
417
418 let outputs = network.forward_sequence(&sequence);
419
420 assert_eq!(outputs.len(), 3);
421 for output in &outputs {
422 assert_eq!(output.shape(), &[6, 1]); }
424 }
425
426 #[test]
427 fn test_bilstm_training_mode() {
428 let mut network = BiLSTMNetwork::new_concat(2, 3, 1)
429 .with_input_dropout(0.1, false)
430 .with_recurrent_dropout(0.2, true);
431
432 network.train();
434 assert!(network.is_training);
435
436 network.eval();
437 assert!(!network.is_training);
438 }
439}