1use ndarray::Array2;
2use crate::layers::lstm_cell::{LSTMCell, LSTMCellGradients, LSTMCellCache};
3use crate::optimizers::Optimizer;
4
5#[derive(Clone)]
7pub struct LSTMNetworkCache {
8 pub cell_caches: Vec<LSTMCellCache>,
9}
10
11#[derive(Clone)]
17pub struct LSTMNetwork {
18 cells: Vec<LSTMCell>,
19 pub input_size: usize,
20 pub hidden_size: usize,
21 pub num_layers: usize,
22 pub is_training: bool,
23}
24
25impl LSTMNetwork {
26 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
31 let mut cells = Vec::new();
32
33 for i in 0..num_layers {
34 let layer_input_size = if i == 0 { input_size } else { hidden_size };
35 cells.push(LSTMCell::new(layer_input_size, hidden_size));
36 }
37
38 LSTMNetwork {
39 cells,
40 input_size,
41 hidden_size,
42 num_layers,
43 is_training: true,
44 }
45 }
46
47 pub fn with_input_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
48 for cell in &mut self.cells {
49 *cell = cell.clone().with_input_dropout(dropout_rate, variational);
50 }
51 self
52 }
53
54 pub fn with_recurrent_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
55 for cell in &mut self.cells {
56 *cell = cell.clone().with_recurrent_dropout(dropout_rate, variational);
57 }
58 self
59 }
60
61 pub fn with_output_dropout(mut self, dropout_rate: f64) -> Self {
62 for (i, cell) in self.cells.iter_mut().enumerate() {
63 if i < self.num_layers - 1 {
64 *cell = cell.clone().with_output_dropout(dropout_rate);
65 }
66 }
67 self
68 }
69
70 pub fn with_zoneout(mut self, cell_zoneout_rate: f64, hidden_zoneout_rate: f64) -> Self {
71 for cell in &mut self.cells {
72 *cell = cell.clone().with_zoneout(cell_zoneout_rate, hidden_zoneout_rate);
73 }
74 self
75 }
76
77 pub fn with_layer_dropout(mut self, layer_configs: Vec<LayerDropoutConfig>) -> Self {
78 for (i, config) in layer_configs.into_iter().enumerate() {
79 if i < self.cells.len() {
80 let mut cell = self.cells[i].clone();
81
82 if let Some((rate, variational)) = config.input_dropout {
83 cell = cell.with_input_dropout(rate, variational);
84 }
85 if let Some((rate, variational)) = config.recurrent_dropout {
86 cell = cell.with_recurrent_dropout(rate, variational);
87 }
88 if let Some(rate) = config.output_dropout {
89 cell = cell.with_output_dropout(rate);
90 }
91 if let Some((cell_rate, hidden_rate)) = config.zoneout {
92 cell = cell.with_zoneout(cell_rate, hidden_rate);
93 }
94
95 self.cells[i] = cell;
96 }
97 }
98 self
99 }
100
101 pub fn train(&mut self) {
102 self.is_training = true;
103 for cell in &mut self.cells {
104 cell.train();
105 }
106 }
107
108 pub fn eval(&mut self) {
109 self.is_training = false;
110 for cell in &mut self.cells {
111 cell.eval();
112 }
113 }
114
115 pub fn from_cells(cells: Vec<LSTMCell>, input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
117 LSTMNetwork {
118 cells,
119 input_size,
120 hidden_size,
121 num_layers,
122 is_training: true,
123 }
124 }
125
126 pub fn get_cells(&self) -> &[LSTMCell] {
128 &self.cells
129 }
130
131 pub fn get_cells_mut(&mut self) -> &mut [LSTMCell] {
133 &mut self.cells
134 }
135
136 pub fn forward(&mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
138 let (hy, cy, _) = self.forward_with_cache(input, hx, cx);
139 (hy, cy)
140 }
141
142 pub fn forward_with_cache(&mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>) -> (Array2<f64>, Array2<f64>, LSTMNetworkCache) {
144 let mut current_input = input.clone();
145 let mut current_hx = hx.clone();
146 let mut current_cx = cx.clone();
147 let mut cell_caches = Vec::new();
148
149 for cell in &mut self.cells {
150 let (new_hx, new_cx, cache) = cell.forward_with_cache(¤t_input, ¤t_hx, ¤t_cx);
151 cell_caches.push(cache);
152
153 current_input = new_hx.clone();
154 current_hx = new_hx;
155 current_cx = new_cx;
156 }
157
158 let network_cache = LSTMNetworkCache { cell_caches };
159 (current_hx, current_cx, network_cache)
160 }
161
162 pub fn backward(&self, dhy: &Array2<f64>, dcy: &Array2<f64>, cache: &LSTMNetworkCache) -> (Vec<LSTMCellGradients>, Array2<f64>) {
167 let mut gradients = Vec::new();
168 let mut current_dhy = dhy.clone();
169 let mut current_dcy = dcy.clone();
170
171 for (i, cell) in self.cells.iter().enumerate().rev() {
172 let cell_cache = &cache.cell_caches[i];
173 let (cell_gradients, dx, _dhx_prev, dcx_prev) = cell.backward(¤t_dhy, ¤t_dcy, cell_cache);
174
175 gradients.push(cell_gradients);
176
177 if i > 0 {
178 current_dhy = dx;
179 current_dcy = dcx_prev;
180 }
181 }
182
183 gradients.reverse();
184
185 let dx_input = if !gradients.is_empty() {
186 let first_cell = &self.cells[0];
187 let first_cache = &cache.cell_caches[0];
188 let (_, dx_input, _, _) = first_cell.backward(dhy, dcy, first_cache);
189 dx_input
190 } else {
191 Array2::zeros(dhy.raw_dim())
192 };
193
194 (gradients, dx_input)
195 }
196
197 pub fn update_parameters<O: Optimizer>(&mut self, gradients: &[LSTMCellGradients], optimizer: &mut O) {
199 for (i, (cell, cell_gradients)) in self.cells.iter_mut().zip(gradients.iter()).enumerate() {
200 let prefix = format!("layer_{}", i);
201 cell.update_parameters(cell_gradients, optimizer, &prefix);
202 }
203 }
204
205 pub fn zero_gradients(&self) -> Vec<LSTMCellGradients> {
207 self.cells.iter().map(|cell| cell.zero_gradients()).collect()
208 }
209
210 pub fn forward_sequence_with_cache(&mut self, sequence: &[Array2<f64>]) -> (Vec<(Array2<f64>, Array2<f64>)>, Vec<LSTMNetworkCache>) {
215 let mut outputs = Vec::new();
216 let mut caches = Vec::new();
217 let mut hx = Array2::zeros((self.hidden_size, 1));
218 let mut cx = Array2::zeros((self.hidden_size, 1));
219
220 for input in sequence {
221 let (new_hx, new_cx, cache) = self.forward_with_cache(input, &hx, &cx);
222 outputs.push((new_hx.clone(), new_cx.clone()));
223 caches.push(cache);
224 hx = new_hx;
225 cx = new_cx;
226 }
227
228 (outputs, caches)
229 }
230}
231
232#[derive(Clone, Debug)]
234pub struct LayerDropoutConfig {
235 pub input_dropout: Option<(f64, bool)>, pub recurrent_dropout: Option<(f64, bool)>, pub output_dropout: Option<f64>, pub zoneout: Option<(f64, f64)>, }
240
241impl LayerDropoutConfig {
242 pub fn new() -> Self {
243 LayerDropoutConfig {
244 input_dropout: None,
245 recurrent_dropout: None,
246 output_dropout: None,
247 zoneout: None,
248 }
249 }
250
251 pub fn with_input_dropout(mut self, rate: f64, variational: bool) -> Self {
252 self.input_dropout = Some((rate, variational));
253 self
254 }
255
256 pub fn with_recurrent_dropout(mut self, rate: f64, variational: bool) -> Self {
257 self.recurrent_dropout = Some((rate, variational));
258 self
259 }
260
261 pub fn with_output_dropout(mut self, rate: f64) -> Self {
262 self.output_dropout = Some(rate);
263 self
264 }
265
266 pub fn with_zoneout(mut self, cell_rate: f64, hidden_rate: f64) -> Self {
267 self.zoneout = Some((cell_rate, hidden_rate));
268 self
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use ndarray::arr2;
276
277 #[test]
278 fn test_lstm_network_forward() {
279 let input_size = 3;
280 let hidden_size = 2;
281 let num_layers = 2;
282 let mut network = LSTMNetwork::new(input_size, hidden_size, num_layers);
283
284 let input = arr2(&[[0.5], [0.1], [-0.3]]);
285 let hx = arr2(&[[0.0], [0.0]]);
286 let cx = arr2(&[[0.0], [0.0]]);
287
288 let (hy, cy) = network.forward(&input, &hx, &cx);
289
290 assert_eq!(hy.shape(), &[hidden_size, 1]);
291 assert_eq!(cy.shape(), &[hidden_size, 1]);
292 }
293
294 #[test]
295 fn test_lstm_network_with_dropout() {
296 let input_size = 3;
297 let hidden_size = 2;
298 let num_layers = 2;
299 let mut network = LSTMNetwork::new(input_size, hidden_size, num_layers)
300 .with_input_dropout(0.2, true) .with_recurrent_dropout(0.3, true) .with_output_dropout(0.1)
303 .with_zoneout(0.1, 0.1);
304
305 let input = arr2(&[[0.5], [0.1], [-0.3]]);
306 let hx = arr2(&[[0.0], [0.0]]);
307 let cx = arr2(&[[0.0], [0.0]]);
308
309 network.train();
311 let (hy_train, cy_train) = network.forward(&input, &hx, &cx);
312
313 network.eval();
315 let (hy_eval, cy_eval) = network.forward(&input, &hx, &cx);
316
317 assert_eq!(hy_train.shape(), &[hidden_size, 1]);
318 assert_eq!(cy_train.shape(), &[hidden_size, 1]);
319 assert_eq!(hy_eval.shape(), &[hidden_size, 1]);
320 assert_eq!(cy_eval.shape(), &[hidden_size, 1]);
321 }
322
323 #[test]
324 fn test_layer_specific_dropout() {
325 let input_size = 3;
326 let hidden_size = 2;
327 let num_layers = 2;
328
329 let layer_configs = vec![
330 LayerDropoutConfig::new()
331 .with_input_dropout(0.2, true)
332 .with_recurrent_dropout(0.3, true),
333 LayerDropoutConfig::new()
334 .with_output_dropout(0.1)
335 .with_zoneout(0.1, 0.1),
336 ];
337
338 let mut network = LSTMNetwork::new(input_size, hidden_size, num_layers)
339 .with_layer_dropout(layer_configs);
340
341 let input = arr2(&[[0.5], [0.1], [-0.3]]);
342 let hx = arr2(&[[0.0], [0.0]]);
343 let cx = arr2(&[[0.0], [0.0]]);
344
345 let (hy, cy) = network.forward(&input, &hx, &cx);
346
347 assert_eq!(hy.shape(), &[hidden_size, 1]);
348 assert_eq!(cy.shape(), &[hidden_size, 1]);
349 }
350}