1use ndarray::Array2;
2use crate::layers::gru_cell::{GRUCell, GRUCellGradients, GRUCellCache};
3use crate::optimizers::Optimizer;
4
5#[derive(Clone)]
7pub struct GRUNetworkCache {
8 pub caches: Vec<GRUCellCache>,
9}
10
11#[derive(Clone)]
13pub struct LayerDropoutConfig {
14 pub input_dropout_rate: f64,
15 pub input_variational: bool,
16 pub recurrent_dropout_rate: f64,
17 pub recurrent_variational: bool,
18 pub output_dropout_rate: f64,
19}
20
21impl LayerDropoutConfig {
22 pub fn new() -> Self {
23 LayerDropoutConfig {
24 input_dropout_rate: 0.0,
25 input_variational: false,
26 recurrent_dropout_rate: 0.0,
27 recurrent_variational: false,
28 output_dropout_rate: 0.0,
29 }
30 }
31
32 pub fn with_input_dropout(mut self, rate: f64, variational: bool) -> Self {
33 self.input_dropout_rate = rate;
34 self.input_variational = variational;
35 self
36 }
37
38 pub fn with_recurrent_dropout(mut self, rate: f64, variational: bool) -> Self {
39 self.recurrent_dropout_rate = rate;
40 self.recurrent_variational = variational;
41 self
42 }
43
44 pub fn with_output_dropout(mut self, rate: f64) -> Self {
45 self.output_dropout_rate = rate;
46 self
47 }
48}
49
50#[derive(Clone)]
52pub struct GRUNetwork {
53 cells: Vec<GRUCell>,
54 pub input_size: usize,
55 pub hidden_size: usize,
56 pub num_layers: usize,
57 pub is_training: bool,
58}
59
60impl GRUNetwork {
61 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
63 let mut cells = Vec::new();
64
65 for i in 0..num_layers {
66 let layer_input_size = if i == 0 { input_size } else { hidden_size };
67 cells.push(GRUCell::new(layer_input_size, hidden_size));
68 }
69
70 GRUNetwork {
71 cells,
72 input_size,
73 hidden_size,
74 num_layers,
75 is_training: true,
76 }
77 }
78
79 pub fn with_input_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
81 for cell in &mut self.cells {
82 *cell = cell.clone().with_input_dropout(dropout_rate, variational);
83 }
84 self
85 }
86
87 pub fn with_recurrent_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
88 for cell in &mut self.cells {
89 *cell = cell.clone().with_recurrent_dropout(dropout_rate, variational);
90 }
91 self
92 }
93
94 pub fn with_output_dropout(mut self, dropout_rate: f64) -> Self {
95 for (i, cell) in self.cells.iter_mut().enumerate() {
97 if i < self.num_layers - 1 {
98 *cell = cell.clone().with_output_dropout(dropout_rate);
99 }
100 }
101 self
102 }
103
104 pub fn with_layer_dropout(mut self, configs: Vec<LayerDropoutConfig>) -> Self {
106 if configs.len() != self.num_layers {
107 panic!("Number of dropout configs must match number of layers");
108 }
109
110 for (i, config) in configs.into_iter().enumerate() {
111 if config.input_dropout_rate > 0.0 {
112 self.cells[i] = self.cells[i].clone()
113 .with_input_dropout(config.input_dropout_rate, config.input_variational);
114 }
115 if config.recurrent_dropout_rate > 0.0 {
116 self.cells[i] = self.cells[i].clone()
117 .with_recurrent_dropout(config.recurrent_dropout_rate, config.recurrent_variational);
118 }
119 if config.output_dropout_rate > 0.0 && i < self.num_layers - 1 {
120 self.cells[i] = self.cells[i].clone()
121 .with_output_dropout(config.output_dropout_rate);
122 }
123 }
124 self
125 }
126
127 pub fn train(&mut self) {
128 self.is_training = true;
129 for cell in &mut self.cells {
130 cell.train();
131 }
132 }
133
134 pub fn eval(&mut self) {
135 self.is_training = false;
136 for cell in &mut self.cells {
137 cell.eval();
138 }
139 }
140
141 pub fn forward(&mut self, input: &Array2<f64>, hx: &[Array2<f64>]) -> Vec<Array2<f64>> {
143 if hx.len() != self.num_layers {
144 panic!("Number of hidden states must match number of layers");
145 }
146
147 let mut layer_input = input.clone();
148 let mut outputs = Vec::new();
149
150 for (i, cell) in self.cells.iter_mut().enumerate() {
151 let hy = cell.forward(&layer_input, &hx[i]);
152 outputs.push(hy.clone());
153 layer_input = hy;
154 }
155
156 outputs
157 }
158
159 pub fn forward_sequence_with_cache(&mut self, sequence: &[Array2<f64>]) -> (Vec<(Array2<f64>, Vec<Array2<f64>>)>, Vec<GRUNetworkCache>) {
161 let mut all_outputs = Vec::new();
162 let mut all_caches = Vec::new();
163
164 let mut hidden_states: Vec<Array2<f64>> = (0..self.num_layers)
166 .map(|_| Array2::zeros((self.hidden_size, 1)))
167 .collect();
168
169 for input in sequence {
170 let mut layer_input = input.clone();
171 let mut step_outputs = Vec::new();
172 let mut step_caches = Vec::new();
173
174 for (i, cell) in self.cells.iter_mut().enumerate() {
175 let (hy, cache) = cell.forward_with_cache(&layer_input, &hidden_states[i]);
176
177 hidden_states[i] = hy.clone();
178 step_outputs.push(hy.clone());
179 step_caches.push(cache);
180 layer_input = hy;
181 }
182
183 let final_output = step_outputs.last().unwrap().clone();
185 all_outputs.push((final_output, step_outputs));
186 all_caches.push(GRUNetworkCache { caches: step_caches });
187 }
188
189 (all_outputs, all_caches)
190 }
191
192 pub fn backward(&self, dhy: &Array2<f64>, cache: &GRUNetworkCache) -> (Vec<GRUCellGradients>, Array2<f64>) {
194 let mut gradients = Vec::new();
195 let mut dhx = dhy.clone();
196
197 for (i, cell) in self.cells.iter().enumerate().rev() {
199 let (cell_gradients, _, dhx_prev) = cell.backward(&dhx, &cache.caches[i]);
200 gradients.insert(0, cell_gradients);
201 dhx = dhx_prev;
202 }
203
204 (gradients, dhx)
205 }
206
207 pub fn update_parameters<O: Optimizer>(&mut self, gradients: &[GRUCellGradients], optimizer: &mut O) {
209 for (i, (cell, grad)) in self.cells.iter_mut().zip(gradients.iter()).enumerate() {
210 cell.update_parameters(grad, optimizer, &format!("layer_{}", i));
211 }
212 }
213
214 pub fn zero_gradients(&self) -> Vec<GRUCellGradients> {
216 self.cells.iter().map(|cell| cell.zero_gradients()).collect()
217 }
218
219 pub fn get_cells(&self) -> &[GRUCell] {
221 &self.cells
222 }
223
224 pub fn get_cells_mut(&mut self) -> &mut [GRUCell] {
226 &mut self.cells
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use ndarray::arr2;
234
235 #[test]
236 fn test_gru_network_creation() {
237 let network = GRUNetwork::new(3, 5, 2);
238 assert_eq!(network.input_size, 3);
239 assert_eq!(network.hidden_size, 5);
240 assert_eq!(network.num_layers, 2);
241 assert_eq!(network.cells.len(), 2);
242 }
243
244 #[test]
245 fn test_gru_network_forward() {
246 let mut network = GRUNetwork::new(2, 3, 2);
247 let input = arr2(&[[1.0], [0.5]]);
248 let hidden_states = vec![
249 arr2(&[[0.1], [0.2], [0.3]]),
250 arr2(&[[0.0], [0.1], [0.2]]),
251 ];
252
253 let outputs = network.forward(&input, &hidden_states);
254 assert_eq!(outputs.len(), 2);
255 assert_eq!(outputs[0].shape(), &[3, 1]);
256 assert_eq!(outputs[1].shape(), &[3, 1]);
257 }
258
259 #[test]
260 fn test_gru_network_sequence() {
261 let mut network = GRUNetwork::new(2, 3, 1);
262 let sequence = vec![
263 arr2(&[[1.0], [0.0]]),
264 arr2(&[[0.0], [1.0]]),
265 arr2(&[[-1.0], [0.5]]),
266 ];
267
268 let (outputs, caches) = network.forward_sequence_with_cache(&sequence);
269
270 assert_eq!(outputs.len(), 3);
271 assert_eq!(caches.len(), 3);
272
273 for (output, _) in &outputs {
274 assert_eq!(output.shape(), &[3, 1]);
275 }
276 }
277
278 #[test]
279 fn test_gru_network_with_dropout() {
280 let mut network = GRUNetwork::new(2, 3, 2)
281 .with_input_dropout(0.2, true)
282 .with_recurrent_dropout(0.3, false)
283 .with_output_dropout(0.1);
284
285 let input = arr2(&[[1.0], [0.5]]);
286 let hidden_states = vec![
287 arr2(&[[0.1], [0.2], [0.3]]),
288 arr2(&[[0.0], [0.1], [0.2]]),
289 ];
290
291 network.train();
293 let outputs_train = network.forward(&input, &hidden_states);
294
295 network.eval();
297 let outputs_eval = network.forward(&input, &hidden_states);
298
299 assert_eq!(outputs_train.len(), 2);
300 assert_eq!(outputs_eval.len(), 2);
301 }
302
303 #[test]
304 fn test_gru_network_layer_dropout() {
305 let layer_configs = vec![
306 LayerDropoutConfig::new().with_input_dropout(0.1, false),
307 LayerDropoutConfig::new().with_recurrent_dropout(0.2, true),
308 ];
309
310 let network = GRUNetwork::new(2, 3, 2)
311 .with_layer_dropout(layer_configs);
312
313 assert_eq!(network.cells.len(), 2);
314 }
315}