scirs2_neural/layers/recurrent/
rnn.rs

1//! Basic Recurrent Neural Network (RNN) implementation
2
3use crate::error::{NeuralError, Result};
4use crate::layers::{Layer, ParamLayer};
5use scirs2_core::ndarray::{Array, ArrayView, Ix2, IxDyn, ScalarOperand};
6use scirs2_core::numeric::Float;
7use scirs2_core::random::Rng;
8use std::fmt::Debug;
9use std::sync::{Arc, RwLock};
10/// Activation function types for recurrent layers
11#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum RecurrentActivation {
13    /// Hyperbolic tangent (tanh) activation
14    Tanh,
15    /// Sigmoid activation
16    Sigmoid,
17    /// Rectified Linear Unit (ReLU)
18    ReLU,
19}
20/// Configuration for RNN layers
21#[derive(Debug, Clone)]
22pub struct RNNConfig {
23    /// Number of input features
24    pub input_size: usize,
25    /// Number of hidden units
26    pub hidden_size: usize,
27    /// Activation function
28    pub activation: RecurrentActivation,
29}
30
31impl RecurrentActivation {
32    /// Apply the activation function
33    pub fn apply<F: Float>(&self, x: F) -> F {
34        match self {
35            RecurrentActivation::Tanh => x.tanh(),
36            RecurrentActivation::Sigmoid => F::one() / (F::one() + (-x).exp()),
37            RecurrentActivation::ReLU => {
38                if x > F::zero() {
39                    x
40                } else {
41                    F::zero()
42                }
43            }
44        }
45    }
46    /// Apply the activation function to an array
47    #[allow(dead_code)]
48    pub fn apply_array<F: Float + ScalarOperand>(&self, x: &Array<F, IxDyn>) -> Array<F, IxDyn> {
49        match self {
50            RecurrentActivation::Tanh => x.mapv(|v| v.tanh()),
51            RecurrentActivation::Sigmoid => x.mapv(|v| F::one() / (F::one() + (-v).exp())),
52            RecurrentActivation::ReLU => x.mapv(|v| if v > F::zero() { v } else { F::zero() }),
53        }
54    }
55}
56/// Basic Recurrent Neural Network (RNN) layer
57///
58/// Implements a simple RNN layer with the following update rule:
59/// h_t = activation(W_ih * x_t + b_ih + W_hh * h_(t-1) + b_hh)
60/// # Examples
61/// ```
62/// use scirs2_neural::layers::{Layer, recurrent::{RNN, rnn::RecurrentActivation}};
63/// use scirs2_core::ndarray::{Array, Array3};
64/// use scirs2_core::random::rngs::SmallRng;
65/// use scirs2_core::random::SeedableRng;
66/// // Create an RNN layer with 10 input features and 20 hidden units
67/// let mut rng = scirs2_core::random::rng();
68/// let rnn = RNN::new(10, 20, RecurrentActivation::Tanh, &mut rng).unwrap();
69/// // Forward pass with a batch of 2 samples, sequence length 5, and 10 features
70/// let batch_size = 2;
71/// let seq_len = 5;
72/// let input_size = 10;
73/// let input = Array3::<f64>::from_elem((batch_size, seq_len, input_size), 0.1).into_dyn();
74/// let output = rnn.forward(&input).unwrap();
75/// // Output should have dimensions [batch_size, seq_len, hidden_size]
76/// assert_eq!(output.shape(), &[batch_size, seq_len, 20]);
77pub struct RNN<F: Float + Debug + Send + Sync> {
78    /// Input size (number of input features)
79    input_size: usize,
80    /// Hidden size (number of hidden units)
81    hidden_size: usize,
82    activation: RecurrentActivation,
83    /// Input-to-hidden weights
84    weight_ih: Array<F, IxDyn>,
85    /// Hidden-to-hidden weights
86    weight_hh: Array<F, IxDyn>,
87    /// Input-to-hidden bias
88    bias_ih: Array<F, IxDyn>,
89    /// Hidden-to-hidden bias
90    bias_hh: Array<F, IxDyn>,
91    /// Gradient of input-to-hidden weights
92    dweight_ih: Array<F, IxDyn>,
93    /// Gradient of hidden-to-hidden weights
94    dweight_hh: Array<F, IxDyn>,
95    /// Gradient of input-to-hidden bias
96    dbias_ih: Array<F, IxDyn>,
97    /// Gradient of hidden-to-hidden bias
98    dbias_hh: Array<F, IxDyn>,
99    /// Input cache for backward pass
100    input_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
101    /// Hidden states cache for backward pass
102    hidden_states_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
103}
104
105impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> RNN<F> {
106    /// Create a new RNN layer
107    ///
108    /// # Arguments
109    /// * `input_size` - Number of input features
110    /// * `hidden_size` - Number of hidden units
111    /// * `activation` - Activation function
112    /// * `rng` - Random number generator for weight initialization
113    /// # Returns
114    /// * A new RNN layer
115    pub fn new<R: Rng>(
116        input_size: usize,
117        hidden_size: usize,
118        activation: RecurrentActivation,
119        rng: &mut R,
120    ) -> Result<Self> {
121        // Validate parameters
122        if input_size == 0 || hidden_size == 0 {
123            return Err(NeuralError::InvalidArchitecture(
124                "Input _size and hidden _size must be positive".to_string(),
125            ));
126        }
127        // Initialize weights with Xavier/Glorot initialization
128        let scale_ih = F::from(1.0 / (input_size as f64).sqrt()).ok_or_else(|| {
129            NeuralError::InvalidArchitecture("Failed to convert scale factor".to_string())
130        })?;
131        let scale_hh = F::from(1.0 / (hidden_size as f64).sqrt()).ok_or_else(|| {
132            NeuralError::InvalidArchitecture("Failed to convert hidden _size scale".to_string())
133        })?;
134        // Initialize input-to-hidden weights
135        let mut weight_ih_vec: Vec<F> = Vec::with_capacity(hidden_size * input_size);
136        for _ in 0..(hidden_size * input_size) {
137            let rand_val = rng.gen_range(-1.0..1.0);
138            let val = F::from(rand_val).ok_or_else(|| {
139                NeuralError::InvalidArchitecture("Failed to convert random value".to_string())
140            })?;
141            weight_ih_vec.push(val * scale_ih);
142        }
143        let weight_ih = Array::from_shape_vec(IxDyn(&[hidden_size, input_size]), weight_ih_vec)
144            .map_err(|e| {
145                NeuralError::InvalidArchitecture(format!("Failed to create weights array: {e}"))
146            })?;
147        // Initialize hidden-to-hidden weights
148        let mut weight_hh_vec: Vec<F> = Vec::with_capacity(hidden_size * hidden_size);
149        for _ in 0..(hidden_size * hidden_size) {
150            let rand_val = rng.gen_range(-1.0..1.0);
151            let val = F::from(rand_val).ok_or_else(|| {
152                NeuralError::InvalidArchitecture("Failed to convert random value".to_string())
153            })?;
154            weight_hh_vec.push(val * scale_hh);
155        }
156        let weight_hh = Array::from_shape_vec(IxDyn(&[hidden_size, hidden_size]), weight_hh_vec)
157            .map_err(|e| {
158                NeuralError::InvalidArchitecture(format!("Failed to create weights array: {e}"))
159            })?;
160        // Initialize biases
161        let bias_ih = Array::zeros(IxDyn(&[hidden_size]));
162        let bias_hh = Array::zeros(IxDyn(&[hidden_size]));
163        // Initialize gradients
164        let dweight_ih = Array::zeros(weight_ih.dim());
165        let dweight_hh = Array::zeros(weight_hh.dim());
166        let dbias_ih = Array::zeros(bias_ih.dim());
167        let dbias_hh = Array::zeros(bias_hh.dim());
168        Ok(Self {
169            input_size,
170            hidden_size,
171            activation,
172            weight_ih,
173            weight_hh,
174            bias_ih,
175            bias_hh,
176            dweight_ih,
177            dweight_hh,
178            dbias_ih,
179            dbias_hh,
180            input_cache: Arc::new(RwLock::new(None)),
181            hidden_states_cache: Arc::new(RwLock::new(None)),
182        })
183    }
184    /// Helper method to compute one step of the RNN
185    /// * `x` - Input tensor of shape [batch_size, input_size]
186    /// * `h` - Previous hidden state of shape [batch_size, hidden_size]
187    /// * New hidden state of shape [batch_size, hidden_size]
188    fn step(&self, x: &ArrayView<F, IxDyn>, h: &ArrayView<F, IxDyn>) -> Result<Array<F, IxDyn>> {
189        let xshape = x.shape();
190        let hshape = h.shape();
191        let batch_size = xshape[0];
192        // Validate shapes
193        if xshape[1] != self.input_size {
194            return Err(NeuralError::InferenceError(format!(
195                "Input feature dimension mismatch: expected {}, got {}",
196                self.input_size, xshape[1]
197            )));
198        }
199        if hshape[1] != self.hidden_size {
200            return Err(NeuralError::InferenceError(format!(
201                "Hidden state dimension mismatch: expected {}, got {}",
202                self.hidden_size, hshape[1]
203            )));
204        }
205        if xshape[0] != hshape[0] {
206            return Err(NeuralError::InferenceError(format!(
207                "Batch size mismatch: input has {}, hidden state has {}",
208                xshape[0], hshape[0]
209            )));
210        }
211        // Initialize output
212        let mut new_h = Array::zeros((batch_size, self.hidden_size));
213        // Compute h_t = activation(W_ih * x_t + b_ih + W_hh * h_(t-1) + b_hh)
214        for b in 0..batch_size {
215            for i in 0..self.hidden_size {
216                // Input-to-hidden contribution: W_ih * x_t + b_ih
217                let mut ih_sum = self.bias_ih[i];
218                for j in 0..self.input_size {
219                    ih_sum = ih_sum + self.weight_ih[[i, j]] * x[[b, j]];
220                }
221                // Hidden-to-hidden contribution: W_hh * h_(t-1) + b_hh
222                let mut hh_sum = self.bias_hh[i];
223                for j in 0..self.hidden_size {
224                    hh_sum = hh_sum + self.weight_hh[[i, j]] * h[[b, j]];
225                }
226                // Apply activation
227                new_h[[b, i]] = self.activation.apply(ih_sum + hh_sum);
228            }
229        }
230        // Convert to IxDyn dimension
231        let new_h_dyn = new_h.into_dyn();
232        Ok(new_h_dyn)
233    }
234}
235
236impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for RNN<F> {
237    fn as_any(&self) -> &dyn std::any::Any {
238        self
239    }
240
241    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
242        self
243    }
244    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
245        // Cache input for backward pass
246        if let Ok(mut cache) = self.input_cache.write() {
247            *cache = Some(input.to_owned());
248        } else {
249            return Err(NeuralError::InferenceError(
250                "Failed to acquire write lock on input cache".to_string(),
251            ));
252        }
253        // Validate input shape
254        let inputshape = input.shape();
255        if inputshape.len() != 3 {
256            return Err(NeuralError::InferenceError(format!(
257                "Expected 3D input [batch_size, seq_len, features], got {inputshape:?}"
258            )));
259        }
260        let batch_size = inputshape[0];
261        let seq_len = inputshape[1];
262        let features = inputshape[2];
263        if features != self.input_size {
264            return Err(NeuralError::InferenceError(format!(
265                "Input features dimension mismatch: expected {}, got {}",
266                self.input_size, features
267            )));
268        }
269        // Initialize hidden state to zeros
270        let mut h = Array::zeros((batch_size, self.hidden_size));
271        // Initialize output array to store all hidden states
272        let mut all_hidden_states = Array::zeros((batch_size, seq_len, self.hidden_size));
273        // Process each time step
274        for t in 0..seq_len {
275            // Extract input at time t
276            let x_t = input.slice(scirs2_core::ndarray::s![.., t, ..]);
277            // Process one step
278            let x_t_view = x_t.view().into_dyn();
279            let h_view = h.view().into_dyn();
280            h = self
281                .step(&x_t_view, &h_view)?
282                .into_dimensionality::<Ix2>()
283                .unwrap();
284            // Store hidden state
285            for b in 0..batch_size {
286                for i in 0..self.hidden_size {
287                    all_hidden_states[[b, t, i]] = h[[b, i]];
288                }
289            }
290        }
291        // Cache all hidden states for backward pass
292        if let Ok(mut cache) = self.hidden_states_cache.write() {
293            *cache = Some(all_hidden_states.to_owned().into_dyn());
294        } else {
295            return Err(NeuralError::InferenceError(
296                "Failed to acquire write lock on hidden states cache".to_string(),
297            ));
298        }
299        // Return all hidden states
300        Ok(all_hidden_states.into_dyn())
301    }
302
303    fn backward(
304        &self,
305        input: &Array<F, IxDyn>,
306        _grad_output: &Array<F, IxDyn>,
307    ) -> Result<Array<F, IxDyn>> {
308        // Retrieve cached values
309        let input_ref = match self.input_cache.read() {
310            Ok(guard) => guard,
311            Err(_) => {
312                return Err(NeuralError::InferenceError(
313                    "Failed to acquire read lock on input cache".to_string(),
314                ))
315            }
316        };
317        let hidden_states_ref = match self.hidden_states_cache.read() {
318            Ok(guard) => guard,
319            Err(_) => {
320                return Err(NeuralError::InferenceError(
321                    "Failed to acquire read lock on hidden states cache".to_string(),
322                ))
323            }
324        };
325        if input_ref.is_none() || hidden_states_ref.is_none() {
326            return Err(NeuralError::InferenceError(
327                "No cached values for backward pass. Call forward() first.".to_string(),
328            ));
329        }
330        // In a real implementation, we would compute gradients for all parameters
331        // and return the gradient with respect to the input
332        // Here we're providing a simplified version that returns a gradient of zeros
333        // with the correct shape
334        let grad_input = Array::zeros(input.dim());
335        Ok(grad_input)
336    }
337
338    fn update(&mut self, learningrate: F) -> Result<()> {
339        // Apply a small update to parameters (placeholder)
340        let small_change = F::from(0.001).unwrap();
341        let lr = small_change * learningrate;
342        // Update weights and biases
343        for w in self.weight_ih.iter_mut() {
344            *w = *w - lr;
345        }
346        for w in self.weight_hh.iter_mut() {
347            *w = *w - lr;
348        }
349        for b in self.bias_ih.iter_mut() {
350            *b = *b - lr;
351        }
352        for b in self.bias_hh.iter_mut() {
353            *b = *b - lr;
354        }
355        Ok(())
356    }
357}
358
359impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> ParamLayer<F> for RNN<F> {
360    fn get_parameters(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
361        vec![
362            self.weight_ih.clone(),
363            self.weight_hh.clone(),
364            self.bias_ih.clone(),
365            self.bias_hh.clone(),
366        ]
367    }
368
369    fn get_gradients(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
370        vec![
371            self.dweight_ih.clone(),
372            self.dweight_hh.clone(),
373            self.dbias_ih.clone(),
374            self.dbias_hh.clone(),
375        ]
376    }
377    fn set_parameters(&mut self, params: Vec<Array<F, scirs2_core::ndarray::IxDyn>>) -> Result<()> {
378        if params.len() != 4 {
379            return Err(NeuralError::InvalidArchitecture(format!(
380                "Expected 4 parameters, got {}",
381                params.len()
382            )));
383        }
384
385        // Check shapes
386        if params[0].shape() != self.weight_ih.shape() {
387            return Err(NeuralError::InvalidArchitecture(format!(
388                "Weight_ih shape mismatch: expected {:?}, got {:?}",
389                self.weight_ih.shape(),
390                params[0].shape()
391            )));
392        }
393        if params[1].shape() != self.weight_hh.shape() {
394            return Err(NeuralError::InvalidArchitecture(format!(
395                "Weight_hh shape mismatch: expected {:?}, got {:?}",
396                self.weight_hh.shape(),
397                params[1].shape()
398            )));
399        }
400        if params[2].shape() != self.bias_ih.shape() {
401            return Err(NeuralError::InvalidArchitecture(format!(
402                "Bias_ih shape mismatch: expected {:?}, got {:?}",
403                self.bias_ih.shape(),
404                params[2].shape()
405            )));
406        }
407        if params[3].shape() != self.bias_hh.shape() {
408            return Err(NeuralError::InvalidArchitecture(format!(
409                "Bias_hh shape mismatch: expected {:?}, got {:?}",
410                self.bias_hh.shape(),
411                params[3].shape()
412            )));
413        }
414
415        self.weight_ih = params[0].clone();
416        self.weight_hh = params[1].clone();
417        self.bias_ih = params[2].clone();
418        self.bias_hh = params[3].clone();
419
420        Ok(())
421    }
422}
423#[cfg(test)]
424mod tests {
425    use super::*;
426    use scirs2_core::ndarray::Array3;
427    use scirs2_core::random::SeedableRng;
428    #[test]
429    fn test_rnnshape() {
430        // Create an RNN layer
431        let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
432        let rnn = RNN::<f64>::new(
433            10,                        // input_size
434            20,                        // hidden_size
435            RecurrentActivation::Tanh, // activation
436            &mut rng,
437        )
438        .unwrap();
439        // Create a batch of input data
440        let batch_size = 2;
441        let seq_len = 5;
442        let input_size = 10;
443        let input = Array3::<f64>::from_elem((batch_size, seq_len, input_size), 0.1).into_dyn();
444        // Forward pass
445        let output = rnn.forward(&input).unwrap();
446        // Check output shape
447        assert_eq!(output.shape(), &[batch_size, seq_len, 20]);
448    }
449
450    #[test]
451    fn test_recurrent_activations() {
452        // Test each activation function
453        let tanh = RecurrentActivation::Tanh;
454        let sigmoid = RecurrentActivation::Sigmoid;
455        let relu = RecurrentActivation::ReLU;
456        // Test tanh
457        assert_eq!(tanh.apply(0.0f64), 0.0f64.tanh());
458        assert_eq!(tanh.apply(1.0f64), 1.0f64.tanh());
459        assert_eq!(tanh.apply(-1.0f64), (-1.0f64).tanh());
460        // Test sigmoid
461        assert_eq!(sigmoid.apply(0.0f64), 0.5f64);
462        assert!((sigmoid.apply(10.0f64) - 1.0).abs() < 1e-4);
463        assert!(sigmoid.apply(-10.0f64).abs() < 1e-4);
464        // Test ReLU
465        assert_eq!(relu.apply(1.0f64), 1.0f64);
466        assert_eq!(relu.apply(-1.0f64), 0.0f64);
467        assert_eq!(relu.apply(0.0f64), 0.0f64);
468    }
469}