1use ndarray::Array2;
2use crate::layers::lstm_cell::{LSTMCell, LSTMCellGradients, LSTMCellCache, LSTMCellBatchCache};
3use crate::optimizers::Optimizer;
4
5#[derive(Clone)]
7pub struct LSTMNetworkCache {
8 pub cell_caches: Vec<LSTMCellCache>,
9}
10
11#[derive(Clone)]
13pub struct LSTMNetworkBatchCache {
14 pub cell_caches: Vec<LSTMCellBatchCache>,
15 pub batch_size: usize,
16}
17
18#[derive(Clone)]
24pub struct LSTMNetwork {
25 cells: Vec<LSTMCell>,
26 pub input_size: usize,
27 pub hidden_size: usize,
28 pub num_layers: usize,
29 pub is_training: bool,
30}
31
32impl LSTMNetwork {
33 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
38 let mut cells = Vec::new();
39
40 for i in 0..num_layers {
41 let layer_input_size = if i == 0 { input_size } else { hidden_size };
42 cells.push(LSTMCell::new(layer_input_size, hidden_size));
43 }
44
45 LSTMNetwork {
46 cells,
47 input_size,
48 hidden_size,
49 num_layers,
50 is_training: true,
51 }
52 }
53
54 pub fn with_input_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
55 for cell in &mut self.cells {
56 *cell = cell.clone().with_input_dropout(dropout_rate, variational);
57 }
58 self
59 }
60
61 pub fn with_recurrent_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
62 for cell in &mut self.cells {
63 *cell = cell.clone().with_recurrent_dropout(dropout_rate, variational);
64 }
65 self
66 }
67
68 pub fn with_output_dropout(mut self, dropout_rate: f64) -> Self {
69 for (i, cell) in self.cells.iter_mut().enumerate() {
70 if i < self.num_layers - 1 {
71 *cell = cell.clone().with_output_dropout(dropout_rate);
72 }
73 }
74 self
75 }
76
77 pub fn with_zoneout(mut self, cell_zoneout_rate: f64, hidden_zoneout_rate: f64) -> Self {
78 for cell in &mut self.cells {
79 *cell = cell.clone().with_zoneout(cell_zoneout_rate, hidden_zoneout_rate);
80 }
81 self
82 }
83
84 pub fn with_layer_dropout(mut self, layer_configs: Vec<LayerDropoutConfig>) -> Self {
85 for (i, config) in layer_configs.into_iter().enumerate() {
86 if i < self.cells.len() {
87 let mut cell = self.cells[i].clone();
88
89 if let Some((rate, variational)) = config.input_dropout {
90 cell = cell.with_input_dropout(rate, variational);
91 }
92 if let Some((rate, variational)) = config.recurrent_dropout {
93 cell = cell.with_recurrent_dropout(rate, variational);
94 }
95 if let Some(rate) = config.output_dropout {
96 cell = cell.with_output_dropout(rate);
97 }
98 if let Some((cell_rate, hidden_rate)) = config.zoneout {
99 cell = cell.with_zoneout(cell_rate, hidden_rate);
100 }
101
102 self.cells[i] = cell;
103 }
104 }
105 self
106 }
107
108 pub fn train(&mut self) {
109 self.is_training = true;
110 for cell in &mut self.cells {
111 cell.train();
112 }
113 }
114
115 pub fn eval(&mut self) {
116 self.is_training = false;
117 for cell in &mut self.cells {
118 cell.eval();
119 }
120 }
121
122 pub fn from_cells(cells: Vec<LSTMCell>, input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
124 LSTMNetwork {
125 cells,
126 input_size,
127 hidden_size,
128 num_layers,
129 is_training: true,
130 }
131 }
132
133 pub fn get_cells(&self) -> &[LSTMCell] {
135 &self.cells
136 }
137
138 pub fn get_cells_mut(&mut self) -> &mut [LSTMCell] {
140 &mut self.cells
141 }
142
143 pub fn forward(&mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
145 let (hy, cy, _) = self.forward_with_cache(input, hx, cx);
146 (hy, cy)
147 }
148
149 pub fn forward_with_cache(&mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>, LSTMNetworkCache) {
151 let mut current_input = input.clone();
152 let mut current_hx = hx.clone();
153 let mut current_cx = cx.clone();
154 let mut cell_caches = Vec::new();
155
156 for cell in &mut self.cells {
157 let (new_hx, new_cx, cache) = cell.forward_with_cache(¤t_input, ¤t_hx, ¤t_cx);
158 cell_caches.push(cache);
159
160 current_input = new_hx.clone();
161 current_hx = new_hx;
162 current_cx = new_cx;
163 }
164
165 let network_cache = LSTMNetworkCache { cell_caches };
166 (current_hx, current_cx, network_cache)
167 }
168
169 pub fn backward(&self, dhy: &Array2<f64>, dcy: &Array2<f64>, cache: &LSTMNetworkCache) -> (Vec<LSTMCellGradients>, Array2<f64>) {
174 let mut gradients = Vec::new();
175 let mut current_dhy = dhy.clone();
176 let mut current_dcy = dcy.clone();
177
178 for (i, cell) in self.cells.iter().enumerate().rev() {
179 let cell_cache = &cache.cell_caches[i];
180 let (cell_gradients, dx, _dhx_prev, dcx_prev) = cell.backward(¤t_dhy, ¤t_dcy, cell_cache);
181
182 gradients.push(cell_gradients);
183
184 if i > 0 {
185 current_dhy = dx;
186 current_dcy = dcx_prev;
187 }
188 }
189
190 gradients.reverse();
191
192 let dx_input = if !gradients.is_empty() {
193 let first_cell = &self.cells[0];
194 let first_cache = &cache.cell_caches[0];
195 let (_, dx_input, _, _) = first_cell.backward(dhy, dcy, first_cache);
196 dx_input
197 } else {
198 Array2::zeros(dhy.raw_dim())
199 };
200
201 (gradients, dx_input)
202 }
203
204 pub fn update_parameters<O: Optimizer>(&mut self, gradients: &[LSTMCellGradients], optimizer: &mut O) {
206 for (i, (cell, cell_gradients)) in self.cells.iter_mut().zip(gradients.iter()).enumerate() {
207 let prefix = format!("layer_{}", i);
208 cell.update_parameters(cell_gradients, optimizer, &prefix);
209 }
210 }
211
212 pub fn zero_gradients(&self) -> Vec<LSTMCellGradients> {
214 self.cells.iter().map(|cell| cell.zero_gradients()).collect()
215 }
216
217 pub fn forward_sequence_with_cache(&mut self, sequence: &[Array2<f64>]) -> (Vec<(Array2<f64>, Array2<f64>)>, Vec<LSTMNetworkCache>) {
222 let mut outputs = Vec::new();
223 let mut caches = Vec::new();
224 let mut hx = Array2::zeros((self.hidden_size, 1));
225 let mut cx = Array2::zeros((self.hidden_size, 1));
226
227 for input in sequence {
228 let (new_hx, new_cx, cache) = self.forward_with_cache(input, &hx, &cx);
229 outputs.push((new_hx.clone(), new_cx.clone()));
230 caches.push(cache);
231 hx = new_hx;
232 cx = new_cx;
233 }
234
235 (outputs, caches)
236 }
237
238 pub fn forward_batch_sequences(&mut self, batch_sequences: &[Vec<Array2<f64>>]) -> Vec<Vec<(Array2<f64>, Array2<f64>)>> {
247 let max_seq_len = batch_sequences.iter().map(|seq| seq.len()).max().unwrap_or(0);
249 let batch_size = batch_sequences.len();
250
251 if batch_size == 0 || max_seq_len == 0 {
252 return Vec::new();
253 }
254
255 let mut batch_outputs = vec![Vec::new(); batch_size];
256
257 let mut batch_hx = Array2::zeros((self.hidden_size, batch_size));
259 let mut batch_cx = Array2::zeros((self.hidden_size, batch_size));
260
261 for t in 0..max_seq_len {
263 let mut batch_input = Array2::zeros((self.input_size, batch_size));
265 let mut active_sequences = Vec::new();
266
267 for (batch_idx, sequence) in batch_sequences.iter().enumerate() {
268 if t < sequence.len() {
269 batch_input.column_mut(batch_idx).assign(&sequence[t].column(0));
271 active_sequences.push(batch_idx);
272 }
273 }
274
275 if active_sequences.is_empty() {
276 break; }
278
279 let (new_batch_hx, new_batch_cx) = self.forward_batch(&batch_input, &batch_hx, &batch_cx);
281
282 batch_hx = new_batch_hx.clone();
284 batch_cx = new_batch_cx.clone();
285
286 for &batch_idx in &active_sequences {
288 let hy = new_batch_hx.column(batch_idx).to_owned().insert_axis(ndarray::Axis(1));
289 let cy = new_batch_cx.column(batch_idx).to_owned().insert_axis(ndarray::Axis(1));
290 batch_outputs[batch_idx].push((hy, cy));
291 }
292 }
293
294 batch_outputs
295 }
296
297 pub fn forward_batch(&mut self, batch_input: &Array2<f64>, batch_hx: &Array2<f64>, batch_cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
307 let mut current_input = batch_input.clone();
308 let mut current_hx = batch_hx.clone();
309 let mut current_cx = batch_cx.clone();
310
311 for cell in &mut self.cells {
313 let (new_hx, new_cx) = cell.forward_batch(¤t_input, ¤t_hx, ¤t_cx);
314 current_input = new_hx.clone(); current_hx = new_hx;
316 current_cx = new_cx;
317 }
318
319 (current_hx, current_cx)
320 }
321
322 pub fn forward_batch_with_cache(&mut self, batch_input: &Array2<f64>, batch_hx: &Array2<f64>, batch_cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>, LSTMNetworkBatchCache) {
326 let mut current_input = batch_input.clone();
327 let mut current_hx = batch_hx.clone();
328 let mut current_cx = batch_cx.clone();
329 let mut cell_caches = Vec::new();
330
331 for cell in &mut self.cells {
333 let (new_hx, new_cx, cache) = cell.forward_batch_with_cache(¤t_input, ¤t_hx, ¤t_cx);
334 cell_caches.push(cache);
335
336 current_input = new_hx.clone();
337 current_hx = new_hx;
338 current_cx = new_cx;
339 }
340
341 let network_cache = LSTMNetworkBatchCache {
342 cell_caches,
343 batch_size: batch_input.ncols(),
344 };
345
346 (current_hx, current_cx, network_cache)
347 }
348
349 pub fn backward_batch(&self, dhy: &Array2<f64>, dcy: &Array2<f64>, cache: &LSTMNetworkBatchCache) -> (Vec<LSTMCellGradients>, Array2<f64>) {
353 let mut gradients = Vec::new();
354 let mut current_dhy = dhy.clone();
355 let mut current_dcy = dcy.clone();
356
357 for (i, cell) in self.cells.iter().enumerate().rev() {
359 let cell_cache = &cache.cell_caches[i];
360 let (cell_gradients, dx, _dhx_prev, dcx_prev) = cell.backward_batch(¤t_dhy, ¤t_dcy, cell_cache);
361
362 gradients.push(cell_gradients);
363
364 if i > 0 {
365 current_dhy = dx;
366 current_dcy = dcx_prev;
367 }
368 }
369
370 gradients.reverse();
371
372 let dx_input = if !gradients.is_empty() {
373 let first_cell = &self.cells[0];
374 let first_cache = &cache.cell_caches[0];
375 let (_, dx_input, _, _) = first_cell.backward_batch(dhy, dcy, first_cache);
376 dx_input
377 } else {
378 Array2::<f64>::zeros(dhy.raw_dim())
379 };
380
381 (gradients, dx_input)
382 }
383}
384
385#[derive(Clone, Debug)]
387pub struct LayerDropoutConfig {
388 pub input_dropout: Option<(f64, bool)>, pub recurrent_dropout: Option<(f64, bool)>, pub output_dropout: Option<f64>, pub zoneout: Option<(f64, f64)>, }
393
394impl LayerDropoutConfig {
395 pub fn new() -> Self {
396 LayerDropoutConfig {
397 input_dropout: None,
398 recurrent_dropout: None,
399 output_dropout: None,
400 zoneout: None,
401 }
402 }
403
404 pub fn with_input_dropout(mut self, rate: f64, variational: bool) -> Self {
405 self.input_dropout = Some((rate, variational));
406 self
407 }
408
409 pub fn with_recurrent_dropout(mut self, rate: f64, variational: bool) -> Self {
410 self.recurrent_dropout = Some((rate, variational));
411 self
412 }
413
414 pub fn with_output_dropout(mut self, rate: f64) -> Self {
415 self.output_dropout = Some(rate);
416 self
417 }
418
419 pub fn with_zoneout(mut self, cell_rate: f64, hidden_rate: f64) -> Self {
420 self.zoneout = Some((cell_rate, hidden_rate));
421 self
422 }
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428 use ndarray::arr2;
429
430 #[test]
431 fn test_lstm_network_forward() {
432 let input_size = 3;
433 let hidden_size = 2;
434 let num_layers = 2;
435 let mut network = LSTMNetwork::new(input_size, hidden_size, num_layers);
436
437 let input = arr2(&[[0.5], [0.1], [-0.3]]);
438 let hx = arr2(&[[0.0], [0.0]]);
439 let cx = arr2(&[[0.0], [0.0]]);
440
441 let (hy, cy) = network.forward(&input, &hx, &cx);
442
443 assert_eq!(hy.shape(), &[hidden_size, 1]);
444 assert_eq!(cy.shape(), &[hidden_size, 1]);
445 }
446
447 #[test]
448 fn test_lstm_network_with_dropout() {
449 let input_size = 3;
450 let hidden_size = 2;
451 let num_layers = 2;
452 let mut network = LSTMNetwork::new(input_size, hidden_size, num_layers)
453 .with_input_dropout(0.2, true) .with_recurrent_dropout(0.3, true) .with_output_dropout(0.1)
456 .with_zoneout(0.1, 0.1);
457
458 let input = arr2(&[[0.5], [0.1], [-0.3]]);
459 let hx = arr2(&[[0.0], [0.0]]);
460 let cx = arr2(&[[0.0], [0.0]]);
461
462 network.train();
464 let (hy_train, cy_train) = network.forward(&input, &hx, &cx);
465
466 network.eval();
468 let (hy_eval, cy_eval) = network.forward(&input, &hx, &cx);
469
470 assert_eq!(hy_train.shape(), &[hidden_size, 1]);
471 assert_eq!(cy_train.shape(), &[hidden_size, 1]);
472 assert_eq!(hy_eval.shape(), &[hidden_size, 1]);
473 assert_eq!(cy_eval.shape(), &[hidden_size, 1]);
474 }
475
476 #[test]
477 fn test_layer_specific_dropout() {
478 let input_size = 3;
479 let hidden_size = 2;
480 let num_layers = 2;
481
482 let layer_configs = vec![
483 LayerDropoutConfig::new()
484 .with_input_dropout(0.2, true)
485 .with_recurrent_dropout(0.3, true),
486 LayerDropoutConfig::new()
487 .with_output_dropout(0.1)
488 .with_zoneout(0.1, 0.1),
489 ];
490
491 let mut network = LSTMNetwork::new(input_size, hidden_size, num_layers)
492 .with_layer_dropout(layer_configs);
493
494 let input = arr2(&[[0.5], [0.1], [-0.3]]);
495 let hx = arr2(&[[0.0], [0.0]]);
496 let cx = arr2(&[[0.0], [0.0]]);
497
498 let (hy, cy) = network.forward(&input, &hx, &cx);
499
500 assert_eq!(hy.shape(), &[hidden_size, 1]);
501 assert_eq!(cy.shape(), &[hidden_size, 1]);
502 }
503}