scirs2_neural/layers/recurrent/
gru.rs

1//! Gated Recurrent Unit (GRU) implementation
2
3use crate::error::{NeuralError, Result};
4use crate::layers::recurrent::{GruForwardOutput, GruGateCache};
5use crate::layers::{Layer, ParamLayer};
6use scirs2_core::ndarray::{Array, ArrayView, Ix2, IxDyn, ScalarOperand};
7use scirs2_core::numeric::Float;
8use scirs2_core::random::{Distribution, Uniform};
9use std::fmt::Debug;
10use std::sync::{Arc, RwLock};
11/// Configuration for GRU layers
12#[derive(Debug, Clone)]
13pub struct GRUConfig {
14    /// Number of input features
15    pub input_size: usize,
16    /// Number of hidden units
17    pub hidden_size: usize,
18}
19/// Gated Recurrent Unit (GRU) layer
20///
21/// Implements a GRU layer with the following update rules:
22/// r_t = sigmoid(W_ir * x_t + b_ir + W_hr * h_(t-1) + b_hr)  # reset gate
23/// z_t = sigmoid(W_iz * x_t + b_iz + W_hz * h_(t-1) + b_hz)  # update gate
24/// n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_(t-1) + b_hn))  # new gate
25/// h_t = (1 - z_t) * n_t + z_t * h_(t-1)  # hidden state
26/// # Examples
27/// ```
28/// use scirs2_neural::layers::{Layer, recurrent::GRU};
29/// use scirs2_core::ndarray::{Array, Array3};
30/// use scirs2_core::random::rngs::StdRng;
31/// use scirs2_core::random::SeedableRng;
32/// // Create a GRU layer with 10 input features and 20 hidden units
33/// let mut rng = StdRng::seed_from_u64(42);
34/// let gru = GRU::new(10, 20, &mut rng).unwrap();
35/// // Forward pass with a batch of 2 samples, sequence length 5, and 10 features
36/// let batch_size = 2;
37/// let seq_len = 5;
38/// let input_size = 10;
39/// let input = Array3::<f64>::from_elem((batch_size, seq_len, input_size), 0.1).into_dyn();
40/// let output = gru.forward(&input).unwrap();
41/// // Output should have dimensions [batch_size, seq_len, hidden_size]
42/// assert_eq!(output.shape(), &[batch_size, seq_len, 20]);
43pub struct GRU<F: Float + Debug> {
44    /// Input size (number of input features)
45    input_size: usize,
46    /// Hidden size (number of hidden units)
47    hidden_size: usize,
48    /// Input-to-hidden weights for reset gate
49    weight_ir: Array<F, IxDyn>,
50    /// Hidden-to-hidden weights for reset gate
51    weight_hr: Array<F, IxDyn>,
52    /// Input-to-hidden bias for reset gate
53    bias_ir: Array<F, IxDyn>,
54    /// Hidden-to-hidden bias for reset gate
55    bias_hr: Array<F, IxDyn>,
56    /// Input-to-hidden weights for update gate
57    weight_iz: Array<F, IxDyn>,
58    /// Hidden-to-hidden weights for update gate
59    weight_hz: Array<F, IxDyn>,
60    /// Input-to-hidden bias for update gate
61    bias_iz: Array<F, IxDyn>,
62    /// Hidden-to-hidden bias for update gate
63    bias_hz: Array<F, IxDyn>,
64    /// Input-to-hidden weights for new gate
65    weight_in: Array<F, IxDyn>,
66    /// Hidden-to-hidden weights for new gate
67    weight_hn: Array<F, IxDyn>,
68    /// Input-to-hidden bias for new gate
69    bias_in: Array<F, IxDyn>,
70    /// Hidden-to-hidden bias for new gate
71    bias_hn: Array<F, IxDyn>,
72    /// Gradients for all parameters (kept simple here)
73    #[allow(dead_code)]
74    gradients: RwLock<Vec<Array<F, IxDyn>>>,
75    /// Input cache for backward pass
76    input_cache: RwLock<Option<Array<F, IxDyn>>>,
77    /// Hidden states cache for backward pass
78    hidden_states_cache: RwLock<Option<Array<F, IxDyn>>>,
79    /// Gate values cache for backward pass
80    #[allow(dead_code)]
81    gate_cache: GruGateCache<F>,
82}
83
84impl<F: Float + Debug + ScalarOperand + 'static> GRU<F> {
85    /// Create a new GRU layer
86    ///
87    /// # Arguments
88    /// * `input_size` - Number of input features
89    /// * `hidden_size` - Number of hidden units
90    /// * `rng` - Random number generator for weight initialization
91    /// # Returns
92    /// * A new GRU layer
93    pub fn new<R: scirs2_core::random::Rng + scirs2_core::random::RngCore>(
94        input_size: usize,
95        hidden_size: usize,
96        rng: &mut R,
97    ) -> Result<Self> {
98        // Validate parameters
99        if input_size == 0 || hidden_size == 0 {
100            return Err(NeuralError::InvalidArchitecture(
101                "Input _size and hidden _size must be positive".to_string(),
102            ));
103        }
104        // Initialize weights with Xavier/Glorot initialization
105        let scale_ih = F::from(1.0 / (input_size as f64).sqrt()).ok_or_else(|| {
106            NeuralError::InvalidArchitecture("Failed to convert scale factor".to_string())
107        })?;
108        let scale_hh = F::from(1.0 / (hidden_size as f64).sqrt()).ok_or_else(|| {
109            NeuralError::InvalidArchitecture("Failed to convert hidden _size scale".to_string())
110        })?;
111
112        // Helper function to create weight matrices
113        let mut create_weight_matrix = |rows: usize,
114                                        cols: usize,
115                                        scale: F|
116         -> Result<Array<F, IxDyn>> {
117            let mut weights_vec: Vec<F> = Vec::with_capacity(rows * cols);
118            let uniform = Uniform::new(-1.0, 1.0).map_err(|e| {
119                NeuralError::InvalidArchitecture(format!(
120                    "Failed to create uniform distribution: {e}"
121                ))
122            })?;
123            for _ in 0..(rows * cols) {
124                let rand_val = uniform.sample(rng);
125                let val = F::from(rand_val).ok_or_else(|| {
126                    NeuralError::InvalidArchitecture("Failed to convert random value".to_string())
127                })?;
128                weights_vec.push(val * scale);
129            }
130            Array::from_shape_vec(IxDyn(&[rows, cols]), weights_vec).map_err(|e| {
131                NeuralError::InvalidArchitecture(format!("Failed to create weights array: {e}"))
132            })
133        };
134        // Initialize all weights and biases
135        let weight_ir = create_weight_matrix(hidden_size, input_size, scale_ih)?;
136        let weight_hr = create_weight_matrix(hidden_size, hidden_size, scale_hh)?;
137        let bias_ir: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
138        let bias_hr: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
139        let weight_iz = create_weight_matrix(hidden_size, input_size, scale_ih)?;
140        let weight_hz = create_weight_matrix(hidden_size, hidden_size, scale_hh)?;
141        let bias_iz: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
142        let bias_hz: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
143        let weight_in = create_weight_matrix(hidden_size, input_size, scale_ih)?;
144        let weight_hn = create_weight_matrix(hidden_size, hidden_size, scale_hh)?;
145        let bias_in: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
146        let bias_hn: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
147        // Initialize gradients
148        let gradients = vec![
149            Array::zeros(weight_ir.dim()),
150            Array::zeros(weight_hr.dim()),
151            Array::zeros(bias_ir.dim()),
152            Array::zeros(bias_hr.dim()),
153            Array::zeros(weight_iz.dim()),
154            Array::zeros(weight_hz.dim()),
155            Array::zeros(bias_iz.dim()),
156            Array::zeros(bias_hz.dim()),
157            Array::zeros(weight_in.dim()),
158            Array::zeros(weight_hn.dim()),
159            Array::zeros(bias_in.dim()),
160            Array::zeros(bias_hn.dim()),
161        ];
162        Ok(Self {
163            input_size,
164            hidden_size,
165            weight_ir,
166            weight_hr,
167            bias_ir,
168            bias_hr,
169            weight_iz,
170            weight_hz,
171            bias_iz,
172            bias_hz,
173            weight_in,
174            weight_hn,
175            bias_in,
176            bias_hn,
177            gradients: RwLock::new(gradients),
178            input_cache: RwLock::new(None),
179            hidden_states_cache: RwLock::new(None),
180            gate_cache: Arc::new(RwLock::new(None)),
181        })
182    }
183    /// Helper method to compute one step of the GRU
184    /// * `x` - Input tensor of shape [batch_size, input_size]
185    /// * `h` - Previous hidden state of shape [batch_size, hidden_size]
186    /// * (new_h, gates) where:
187    ///   - new_h: New hidden state of shape [batch_size, hidden_size]
188    ///   - gates: (reset_gate, update_gate, new_gate)
189    fn step(
190        &self,
191        x: &ArrayView<F, IxDyn>,
192        h: &ArrayView<F, IxDyn>,
193    ) -> Result<GruForwardOutput<F>> {
194        let xshape = x.shape();
195        let hshape = h.shape();
196        let batch_size = xshape[0];
197        // Validate shapes
198        if xshape[1] != self.input_size {
199            return Err(NeuralError::InferenceError(format!(
200                "Input feature dimension mismatch: expected {}, got {}",
201                self.input_size, xshape[1]
202            )));
203        }
204        if hshape[1] != self.hidden_size {
205            return Err(NeuralError::InferenceError(format!(
206                "Hidden state dimension mismatch: expected {}, got {}",
207                self.hidden_size, hshape[1]
208            )));
209        }
210        if xshape[0] != hshape[0] {
211            return Err(NeuralError::InferenceError(format!(
212                "Batch size mismatch: input has {}, hidden state has {}",
213                xshape[0], hshape[0]
214            )));
215        }
216        // Initialize gates
217        let mut r_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
218        let mut z_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
219        let mut n_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
220        // Initialize new hidden state
221        let mut new_h: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
222        // Compute gates for each batch item
223        for b in 0..batch_size {
224            for i in 0..self.hidden_size {
225                // Reset gate (r_t)
226                let mut r_sum = self.bias_ir[i] + self.bias_hr[i];
227                for j in 0..self.input_size {
228                    r_sum = r_sum + self.weight_ir[[i, j]] * x[[b, j]];
229                }
230                for j in 0..self.hidden_size {
231                    r_sum = r_sum + self.weight_hr[[i, j]] * h[[b, j]];
232                }
233                r_gate[[b, i]] = F::one() / (F::one() + (-r_sum).exp()); // sigmoid
234
235                // Update gate (z_t)
236                let mut z_sum = self.bias_iz[i] + self.bias_hz[i];
237                for j in 0..self.input_size {
238                    z_sum = z_sum + self.weight_iz[[i, j]] * x[[b, j]];
239                }
240                for j in 0..self.hidden_size {
241                    z_sum = z_sum + self.weight_hz[[i, j]] * h[[b, j]];
242                }
243                z_gate[[b, i]] = F::one() / (F::one() + (-z_sum).exp()); // sigmoid
244
245                // New gate (n_t)
246                let mut n_sum = self.bias_in[i];
247                for j in 0..self.input_size {
248                    n_sum = n_sum + self.weight_in[[i, j]] * x[[b, j]];
249                }
250                // Reset gate applied to hidden state
251                let mut hn_sum = self.bias_hn[i];
252                for j in 0..self.hidden_size {
253                    hn_sum = hn_sum + self.weight_hn[[i, j]] * h[[b, j]];
254                }
255                n_gate[[b, i]] = (n_sum + r_gate[[b, i]] * hn_sum).tanh(); // tanh
256
257                // New hidden state (h_t)
258                new_h[[b, i]] =
259                    (F::one() - z_gate[[b, i]]) * n_gate[[b, i]] + z_gate[[b, i]] * h[[b, i]];
260            }
261        }
262        // Convert all to dynamic dimension
263        let new_h_dyn = new_h.into_dyn();
264        let r_gate_dyn = r_gate.into_dyn();
265        let z_gate_dyn = z_gate.into_dyn();
266        let n_gate_dyn = n_gate.into_dyn();
267        Ok((new_h_dyn, (r_gate_dyn, z_gate_dyn, n_gate_dyn)))
268    }
269}
270
271impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for GRU<F> {
272    fn as_any(&self) -> &dyn std::any::Any {
273        self
274    }
275
276    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
277        self
278    }
279
280    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
281        // Cache input for backward pass
282        *self.input_cache.write().unwrap() = Some(input.clone());
283        // Validate input shape
284        let inputshape = input.shape();
285        if inputshape.len() != 3 {
286            return Err(NeuralError::InferenceError(format!(
287                "Expected 3D input [batch_size, seq_len, features], got {inputshape:?}"
288            )));
289        }
290        let batch_size = inputshape[0];
291        let seq_len = inputshape[1];
292        let features = inputshape[2];
293        if features != self.input_size {
294            return Err(NeuralError::InferenceError(format!(
295                "Input features dimension mismatch: expected {}, got {}",
296                self.input_size, features
297            )));
298        }
299        // Initialize hidden state to zeros
300        let mut h = Array::zeros((batch_size, self.hidden_size));
301        // Initialize output array to store all hidden states
302        let mut all_hidden_states = Array::zeros((batch_size, seq_len, self.hidden_size));
303        let mut all_gates = Vec::with_capacity(seq_len);
304        // Process each time step
305        for t in 0..seq_len {
306            // Extract input at time t
307            let x_t = input.slice(scirs2_core::ndarray::s![.., t, ..]);
308            // Process one step - converting views to dynamic dimension
309            let x_t_view = x_t.view().into_dyn();
310            let h_view = h.view().into_dyn();
311            let step_result = self.step(&x_t_view, &h_view)?;
312            let new_h = step_result.0;
313            let gates = step_result.1;
314            // Convert back from dynamic dimension
315            h = new_h.into_dimensionality::<Ix2>().unwrap();
316            all_gates.push(gates);
317            // Store hidden state
318            for b in 0..batch_size {
319                for i in 0..self.hidden_size {
320                    all_hidden_states[[b, t, i]] = h[[b, i]];
321                }
322            }
323        }
324        // Cache hidden states for backward pass
325        *self.hidden_states_cache.write().unwrap() = Some(all_hidden_states.clone().into_dyn());
326        // Return with correct dynamic dimension
327        Ok(all_hidden_states.into_dyn())
328    }
329
330    fn backward(
331        &self,
332        input: &Array<F, IxDyn>,
333        _grad_output: &Array<F, IxDyn>,
334    ) -> Result<Array<F, IxDyn>> {
335        // Retrieve cached values
336        let input_ref = self.input_cache.read().map_err(|_| {
337            NeuralError::InferenceError("Failed to acquire read lock on input cache".to_string())
338        })?;
339        let hidden_states_ref = self.hidden_states_cache.read().map_err(|_| {
340            NeuralError::InferenceError(
341                "Failed to acquire read lock on hidden states cache".to_string(),
342            )
343        })?;
344        if input_ref.is_none() || hidden_states_ref.is_none() {
345            return Err(NeuralError::InferenceError(
346                "No cached values for backward pass. Call forward() first.".to_string(),
347            ));
348        }
349        // In a real implementation, we would compute gradients for all parameters
350        // and return the gradient with respect to the input
351        // Here we're providing a simplified version that returns a gradient of zeros
352        // with the correct shape
353        let grad_input = Array::zeros(input.dim());
354        Ok(grad_input)
355    }
356
357    fn update(&mut self, learningrate: F) -> Result<()> {
358        // Apply a small update to parameters (placeholder)
359        let small_change = F::from(0.001).unwrap();
360        let lr = small_change * learningrate;
361        // Helper function to update a parameter
362        let update_param = |param: &mut Array<F, IxDyn>| {
363            for w in param.iter_mut() {
364                *w = *w - lr;
365            }
366        };
367        // Update all parameters
368        update_param(&mut self.weight_ir);
369        update_param(&mut self.weight_hr);
370        update_param(&mut self.bias_ir);
371        update_param(&mut self.bias_hr);
372        update_param(&mut self.weight_iz);
373        update_param(&mut self.weight_hz);
374        update_param(&mut self.bias_iz);
375        update_param(&mut self.bias_hz);
376        update_param(&mut self.weight_in);
377        update_param(&mut self.weight_hn);
378        update_param(&mut self.bias_in);
379        update_param(&mut self.bias_hn);
380        Ok(())
381    }
382}
383
384impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> ParamLayer<F> for GRU<F> {
385    fn get_parameters(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
386        vec![
387            self.weight_ir.clone(),
388            self.weight_hr.clone(),
389            self.bias_ir.clone(),
390            self.bias_hr.clone(),
391            self.weight_iz.clone(),
392            self.weight_hz.clone(),
393            self.bias_iz.clone(),
394            self.bias_hz.clone(),
395            self.weight_in.clone(),
396            self.weight_hn.clone(),
397            self.bias_in.clone(),
398            self.bias_hn.clone(),
399        ]
400    }
401
402    fn get_gradients(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
403        // This is a placeholder implementation until proper gradient access is implemented
404        // Return an empty vector as we can't get references to the gradients inside the RwLock
405        // The actual gradient update logic is handled in the backward method
406        Vec::new()
407    }
408
409    fn set_parameters(&mut self, params: Vec<Array<F, scirs2_core::ndarray::IxDyn>>) -> Result<()> {
410        if params.len() != 12 {
411            return Err(NeuralError::InvalidArchitecture(format!(
412                "Expected 12 parameters, got {}",
413                params.len()
414            )));
415        }
416
417        let expectedshapes = [
418            self.weight_ir.shape(),
419            self.weight_hr.shape(),
420            self.bias_ir.shape(),
421            self.bias_hr.shape(),
422            self.weight_iz.shape(),
423            self.weight_hz.shape(),
424            self.bias_iz.shape(),
425            self.bias_hz.shape(),
426            self.weight_in.shape(),
427            self.weight_hn.shape(),
428            self.bias_in.shape(),
429            self.bias_hn.shape(),
430        ];
431
432        for (i, (param, expected)) in params.iter().zip(expectedshapes.iter()).enumerate() {
433            if param.shape() != *expected {
434                return Err(NeuralError::InvalidArchitecture(format!(
435                    "Parameter {} shape mismatch: expected {:?}, got {:?}",
436                    i,
437                    expected,
438                    param.shape()
439                )));
440            }
441        }
442
443        // Set parameters
444        self.weight_ir = params[0].clone();
445        self.weight_hr = params[1].clone();
446        self.bias_ir = params[2].clone();
447        self.bias_hr = params[3].clone();
448        self.weight_iz = params[4].clone();
449        self.weight_hz = params[5].clone();
450        self.bias_iz = params[6].clone();
451        self.bias_hz = params[7].clone();
452        self.weight_in = params[8].clone();
453        self.weight_hn = params[9].clone();
454        self.bias_in = params[10].clone();
455        self.bias_hn = params[11].clone();
456
457        Ok(())
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464    use scirs2_core::ndarray::Array3;
465    use scirs2_core::random::rngs::SmallRng;
466    use scirs2_core::random::SeedableRng;
467
468    #[test]
469    fn test_grushape() {
470        // Create a GRU layer
471        let mut rng = SmallRng::from_seed([42; 32]);
472        let gru = GRU::<f64>::new(
473            10, // input_size
474            20, // hidden_size
475            &mut rng,
476        )
477        .unwrap();
478
479        // Create a batch of input data
480        let batch_size = 2;
481        let seq_len = 5;
482        let input_size = 10;
483        let input = Array3::<f64>::from_elem((batch_size, seq_len, input_size), 0.1).into_dyn();
484        // Forward pass
485        let output = gru.forward(&input).unwrap();
486        // Check output shape
487        assert_eq!(output.shape(), &[batch_size, seq_len, 20]);
488    }
489}