1use crate::error::{NeuralError, Result};
4use crate::layers::{Layer, ParamLayer};
5use scirs2_core::ndarray::{Array, ArrayView, Ix2, IxDyn, ScalarOperand};
6use scirs2_core::numeric::Float;
7use scirs2_core::random::Rng;
8use std::fmt::Debug;
9use std::sync::{Arc, RwLock};
10#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum RecurrentActivation {
13 Tanh,
15 Sigmoid,
17 ReLU,
19}
20#[derive(Debug, Clone)]
22pub struct RNNConfig {
23 pub input_size: usize,
25 pub hidden_size: usize,
27 pub activation: RecurrentActivation,
29}
30
31impl RecurrentActivation {
32 pub fn apply<F: Float>(&self, x: F) -> F {
34 match self {
35 RecurrentActivation::Tanh => x.tanh(),
36 RecurrentActivation::Sigmoid => F::one() / (F::one() + (-x).exp()),
37 RecurrentActivation::ReLU => {
38 if x > F::zero() {
39 x
40 } else {
41 F::zero()
42 }
43 }
44 }
45 }
46 #[allow(dead_code)]
48 pub fn apply_array<F: Float + ScalarOperand>(&self, x: &Array<F, IxDyn>) -> Array<F, IxDyn> {
49 match self {
50 RecurrentActivation::Tanh => x.mapv(|v| v.tanh()),
51 RecurrentActivation::Sigmoid => x.mapv(|v| F::one() / (F::one() + (-v).exp())),
52 RecurrentActivation::ReLU => x.mapv(|v| if v > F::zero() { v } else { F::zero() }),
53 }
54 }
55}
56pub struct RNN<F: Float + Debug + Send + Sync> {
78 input_size: usize,
80 hidden_size: usize,
82 activation: RecurrentActivation,
83 weight_ih: Array<F, IxDyn>,
85 weight_hh: Array<F, IxDyn>,
87 bias_ih: Array<F, IxDyn>,
89 bias_hh: Array<F, IxDyn>,
91 dweight_ih: Array<F, IxDyn>,
93 dweight_hh: Array<F, IxDyn>,
95 dbias_ih: Array<F, IxDyn>,
97 dbias_hh: Array<F, IxDyn>,
99 input_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
101 hidden_states_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
103}
104
105impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> RNN<F> {
106 pub fn new<R: Rng>(
116 input_size: usize,
117 hidden_size: usize,
118 activation: RecurrentActivation,
119 rng: &mut R,
120 ) -> Result<Self> {
121 if input_size == 0 || hidden_size == 0 {
123 return Err(NeuralError::InvalidArchitecture(
124 "Input _size and hidden _size must be positive".to_string(),
125 ));
126 }
127 let scale_ih = F::from(1.0 / (input_size as f64).sqrt()).ok_or_else(|| {
129 NeuralError::InvalidArchitecture("Failed to convert scale factor".to_string())
130 })?;
131 let scale_hh = F::from(1.0 / (hidden_size as f64).sqrt()).ok_or_else(|| {
132 NeuralError::InvalidArchitecture("Failed to convert hidden _size scale".to_string())
133 })?;
134 let mut weight_ih_vec: Vec<F> = Vec::with_capacity(hidden_size * input_size);
136 for _ in 0..(hidden_size * input_size) {
137 let rand_val = rng.gen_range(-1.0..1.0);
138 let val = F::from(rand_val).ok_or_else(|| {
139 NeuralError::InvalidArchitecture("Failed to convert random value".to_string())
140 })?;
141 weight_ih_vec.push(val * scale_ih);
142 }
143 let weight_ih = Array::from_shape_vec(IxDyn(&[hidden_size, input_size]), weight_ih_vec)
144 .map_err(|e| {
145 NeuralError::InvalidArchitecture(format!("Failed to create weights array: {e}"))
146 })?;
147 let mut weight_hh_vec: Vec<F> = Vec::with_capacity(hidden_size * hidden_size);
149 for _ in 0..(hidden_size * hidden_size) {
150 let rand_val = rng.gen_range(-1.0..1.0);
151 let val = F::from(rand_val).ok_or_else(|| {
152 NeuralError::InvalidArchitecture("Failed to convert random value".to_string())
153 })?;
154 weight_hh_vec.push(val * scale_hh);
155 }
156 let weight_hh = Array::from_shape_vec(IxDyn(&[hidden_size, hidden_size]), weight_hh_vec)
157 .map_err(|e| {
158 NeuralError::InvalidArchitecture(format!("Failed to create weights array: {e}"))
159 })?;
160 let bias_ih = Array::zeros(IxDyn(&[hidden_size]));
162 let bias_hh = Array::zeros(IxDyn(&[hidden_size]));
163 let dweight_ih = Array::zeros(weight_ih.dim());
165 let dweight_hh = Array::zeros(weight_hh.dim());
166 let dbias_ih = Array::zeros(bias_ih.dim());
167 let dbias_hh = Array::zeros(bias_hh.dim());
168 Ok(Self {
169 input_size,
170 hidden_size,
171 activation,
172 weight_ih,
173 weight_hh,
174 bias_ih,
175 bias_hh,
176 dweight_ih,
177 dweight_hh,
178 dbias_ih,
179 dbias_hh,
180 input_cache: Arc::new(RwLock::new(None)),
181 hidden_states_cache: Arc::new(RwLock::new(None)),
182 })
183 }
184 fn step(&self, x: &ArrayView<F, IxDyn>, h: &ArrayView<F, IxDyn>) -> Result<Array<F, IxDyn>> {
189 let xshape = x.shape();
190 let hshape = h.shape();
191 let batch_size = xshape[0];
192 if xshape[1] != self.input_size {
194 return Err(NeuralError::InferenceError(format!(
195 "Input feature dimension mismatch: expected {}, got {}",
196 self.input_size, xshape[1]
197 )));
198 }
199 if hshape[1] != self.hidden_size {
200 return Err(NeuralError::InferenceError(format!(
201 "Hidden state dimension mismatch: expected {}, got {}",
202 self.hidden_size, hshape[1]
203 )));
204 }
205 if xshape[0] != hshape[0] {
206 return Err(NeuralError::InferenceError(format!(
207 "Batch size mismatch: input has {}, hidden state has {}",
208 xshape[0], hshape[0]
209 )));
210 }
211 let mut new_h = Array::zeros((batch_size, self.hidden_size));
213 for b in 0..batch_size {
215 for i in 0..self.hidden_size {
216 let mut ih_sum = self.bias_ih[i];
218 for j in 0..self.input_size {
219 ih_sum = ih_sum + self.weight_ih[[i, j]] * x[[b, j]];
220 }
221 let mut hh_sum = self.bias_hh[i];
223 for j in 0..self.hidden_size {
224 hh_sum = hh_sum + self.weight_hh[[i, j]] * h[[b, j]];
225 }
226 new_h[[b, i]] = self.activation.apply(ih_sum + hh_sum);
228 }
229 }
230 let new_h_dyn = new_h.into_dyn();
232 Ok(new_h_dyn)
233 }
234}
235
236impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for RNN<F> {
237 fn as_any(&self) -> &dyn std::any::Any {
238 self
239 }
240
241 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
242 self
243 }
244 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
245 if let Ok(mut cache) = self.input_cache.write() {
247 *cache = Some(input.to_owned());
248 } else {
249 return Err(NeuralError::InferenceError(
250 "Failed to acquire write lock on input cache".to_string(),
251 ));
252 }
253 let inputshape = input.shape();
255 if inputshape.len() != 3 {
256 return Err(NeuralError::InferenceError(format!(
257 "Expected 3D input [batch_size, seq_len, features], got {inputshape:?}"
258 )));
259 }
260 let batch_size = inputshape[0];
261 let seq_len = inputshape[1];
262 let features = inputshape[2];
263 if features != self.input_size {
264 return Err(NeuralError::InferenceError(format!(
265 "Input features dimension mismatch: expected {}, got {}",
266 self.input_size, features
267 )));
268 }
269 let mut h = Array::zeros((batch_size, self.hidden_size));
271 let mut all_hidden_states = Array::zeros((batch_size, seq_len, self.hidden_size));
273 for t in 0..seq_len {
275 let x_t = input.slice(scirs2_core::ndarray::s![.., t, ..]);
277 let x_t_view = x_t.view().into_dyn();
279 let h_view = h.view().into_dyn();
280 h = self
281 .step(&x_t_view, &h_view)?
282 .into_dimensionality::<Ix2>()
283 .unwrap();
284 for b in 0..batch_size {
286 for i in 0..self.hidden_size {
287 all_hidden_states[[b, t, i]] = h[[b, i]];
288 }
289 }
290 }
291 if let Ok(mut cache) = self.hidden_states_cache.write() {
293 *cache = Some(all_hidden_states.to_owned().into_dyn());
294 } else {
295 return Err(NeuralError::InferenceError(
296 "Failed to acquire write lock on hidden states cache".to_string(),
297 ));
298 }
299 Ok(all_hidden_states.into_dyn())
301 }
302
303 fn backward(
304 &self,
305 input: &Array<F, IxDyn>,
306 _grad_output: &Array<F, IxDyn>,
307 ) -> Result<Array<F, IxDyn>> {
308 let input_ref = match self.input_cache.read() {
310 Ok(guard) => guard,
311 Err(_) => {
312 return Err(NeuralError::InferenceError(
313 "Failed to acquire read lock on input cache".to_string(),
314 ))
315 }
316 };
317 let hidden_states_ref = match self.hidden_states_cache.read() {
318 Ok(guard) => guard,
319 Err(_) => {
320 return Err(NeuralError::InferenceError(
321 "Failed to acquire read lock on hidden states cache".to_string(),
322 ))
323 }
324 };
325 if input_ref.is_none() || hidden_states_ref.is_none() {
326 return Err(NeuralError::InferenceError(
327 "No cached values for backward pass. Call forward() first.".to_string(),
328 ));
329 }
330 let grad_input = Array::zeros(input.dim());
335 Ok(grad_input)
336 }
337
338 fn update(&mut self, learningrate: F) -> Result<()> {
339 let small_change = F::from(0.001).unwrap();
341 let lr = small_change * learningrate;
342 for w in self.weight_ih.iter_mut() {
344 *w = *w - lr;
345 }
346 for w in self.weight_hh.iter_mut() {
347 *w = *w - lr;
348 }
349 for b in self.bias_ih.iter_mut() {
350 *b = *b - lr;
351 }
352 for b in self.bias_hh.iter_mut() {
353 *b = *b - lr;
354 }
355 Ok(())
356 }
357}
358
359impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> ParamLayer<F> for RNN<F> {
360 fn get_parameters(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
361 vec![
362 self.weight_ih.clone(),
363 self.weight_hh.clone(),
364 self.bias_ih.clone(),
365 self.bias_hh.clone(),
366 ]
367 }
368
369 fn get_gradients(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
370 vec![
371 self.dweight_ih.clone(),
372 self.dweight_hh.clone(),
373 self.dbias_ih.clone(),
374 self.dbias_hh.clone(),
375 ]
376 }
377 fn set_parameters(&mut self, params: Vec<Array<F, scirs2_core::ndarray::IxDyn>>) -> Result<()> {
378 if params.len() != 4 {
379 return Err(NeuralError::InvalidArchitecture(format!(
380 "Expected 4 parameters, got {}",
381 params.len()
382 )));
383 }
384
385 if params[0].shape() != self.weight_ih.shape() {
387 return Err(NeuralError::InvalidArchitecture(format!(
388 "Weight_ih shape mismatch: expected {:?}, got {:?}",
389 self.weight_ih.shape(),
390 params[0].shape()
391 )));
392 }
393 if params[1].shape() != self.weight_hh.shape() {
394 return Err(NeuralError::InvalidArchitecture(format!(
395 "Weight_hh shape mismatch: expected {:?}, got {:?}",
396 self.weight_hh.shape(),
397 params[1].shape()
398 )));
399 }
400 if params[2].shape() != self.bias_ih.shape() {
401 return Err(NeuralError::InvalidArchitecture(format!(
402 "Bias_ih shape mismatch: expected {:?}, got {:?}",
403 self.bias_ih.shape(),
404 params[2].shape()
405 )));
406 }
407 if params[3].shape() != self.bias_hh.shape() {
408 return Err(NeuralError::InvalidArchitecture(format!(
409 "Bias_hh shape mismatch: expected {:?}, got {:?}",
410 self.bias_hh.shape(),
411 params[3].shape()
412 )));
413 }
414
415 self.weight_ih = params[0].clone();
416 self.weight_hh = params[1].clone();
417 self.bias_ih = params[2].clone();
418 self.bias_hh = params[3].clone();
419
420 Ok(())
421 }
422}
423#[cfg(test)]
424mod tests {
425 use super::*;
426 use scirs2_core::ndarray::Array3;
427 use scirs2_core::random::SeedableRng;
428 #[test]
429 fn test_rnnshape() {
430 let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
432 let rnn = RNN::<f64>::new(
433 10, 20, RecurrentActivation::Tanh, &mut rng,
437 )
438 .unwrap();
439 let batch_size = 2;
441 let seq_len = 5;
442 let input_size = 10;
443 let input = Array3::<f64>::from_elem((batch_size, seq_len, input_size), 0.1).into_dyn();
444 let output = rnn.forward(&input).unwrap();
446 assert_eq!(output.shape(), &[batch_size, seq_len, 20]);
448 }
449
450 #[test]
451 fn test_recurrent_activations() {
452 let tanh = RecurrentActivation::Tanh;
454 let sigmoid = RecurrentActivation::Sigmoid;
455 let relu = RecurrentActivation::ReLU;
456 assert_eq!(tanh.apply(0.0f64), 0.0f64.tanh());
458 assert_eq!(tanh.apply(1.0f64), 1.0f64.tanh());
459 assert_eq!(tanh.apply(-1.0f64), (-1.0f64).tanh());
460 assert_eq!(sigmoid.apply(0.0f64), 0.5f64);
462 assert!((sigmoid.apply(10.0f64) - 1.0).abs() < 1e-4);
463 assert!(sigmoid.apply(-10.0f64).abs() < 1e-4);
464 assert_eq!(relu.apply(1.0f64), 1.0f64);
466 assert_eq!(relu.apply(-1.0f64), 0.0f64);
467 assert_eq!(relu.apply(0.0f64), 0.0f64);
468 }
469}