1use scirs2_core::ndarray::{Array1, Array2};
19
20#[derive(Debug, Clone)]
26pub enum RecurrentError {
27 ShapeMismatch {
29 expected: Vec<usize>,
31 got: Vec<usize>,
33 },
34 InvalidHiddenSize(usize),
36 InvalidInputSize(usize),
38 EmptySequence,
40 InvalidSequenceLength {
42 got: usize,
44 },
45}
46
47impl std::fmt::Display for RecurrentError {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 match self {
50 RecurrentError::ShapeMismatch { expected, got } => {
51 write!(f, "shape mismatch: expected {:?}, got {:?}", expected, got)
52 }
53 RecurrentError::InvalidHiddenSize(s) => {
54 write!(f, "invalid hidden_size: {s}")
55 }
56 RecurrentError::InvalidInputSize(s) => {
57 write!(f, "invalid input_size: {s}")
58 }
59 RecurrentError::EmptySequence => {
60 write!(f, "input sequence must not be empty")
61 }
62 RecurrentError::InvalidSequenceLength { got } => {
63 write!(f, "invalid sequence length: {got}")
64 }
65 }
66 }
67}
68
69impl std::error::Error for RecurrentError {}
70
71#[inline]
77fn sigmoid(x: f64) -> f64 {
78 1.0 / (1.0 + (-x).exp())
79}
80
81#[inline]
86fn lcg_value(state: &mut u64, scale: f64) -> f64 {
87 *state = state
88 .wrapping_mul(6364136223846793005_u64)
89 .wrapping_add(1442695040888963407_u64);
90 let normalised = (*state as f64) / (u64::MAX as f64); (normalised * 2.0 - 1.0) * scale
93}
94
95fn lcg_fill_2d(rows: usize, cols: usize, scale: f64, state: &mut u64) -> Array2<f64> {
97 let data: Vec<f64> = (0..rows * cols).map(|_| lcg_value(state, scale)).collect();
98 Array2::from_shape_vec((rows, cols), data).unwrap_or_else(|_| Array2::zeros((rows, cols)))
100}
101
102fn lcg_fill_1d(len: usize, scale: f64, state: &mut u64) -> Array1<f64> {
104 let data: Vec<f64> = (0..len).map(|_| lcg_value(state, scale)).collect();
105 Array1::from_vec(data)
106}
107
108#[derive(Debug, Clone)]
116pub struct RnnCell {
117 pub input_size: usize,
119 pub hidden_size: usize,
121 pub w_ih: Array2<f64>,
123 pub w_hh: Array2<f64>,
125 pub b_ih: Array1<f64>,
127 pub b_hh: Array1<f64>,
129}
130
131impl RnnCell {
132 pub fn new(input_size: usize, hidden_size: usize) -> Result<Self, RecurrentError> {
134 if input_size == 0 {
135 return Err(RecurrentError::InvalidInputSize(input_size));
136 }
137 if hidden_size == 0 {
138 return Err(RecurrentError::InvalidHiddenSize(hidden_size));
139 }
140 let scale = 0.1_f64;
141 let mut state: u64 = 0xdeadbeef_12345678_u64;
142 let w_ih = lcg_fill_2d(hidden_size, input_size, scale, &mut state);
143 let w_hh = lcg_fill_2d(hidden_size, hidden_size, scale, &mut state);
144 let b_ih = lcg_fill_1d(hidden_size, scale, &mut state);
145 let b_hh = lcg_fill_1d(hidden_size, scale, &mut state);
146 Ok(Self {
147 input_size,
148 hidden_size,
149 w_ih,
150 w_hh,
151 b_ih,
152 b_hh,
153 })
154 }
155
156 pub fn from_weights(
158 w_ih: Array2<f64>,
159 w_hh: Array2<f64>,
160 b_ih: Array1<f64>,
161 b_hh: Array1<f64>,
162 ) -> Result<Self, RecurrentError> {
163 let hidden_size = w_ih.nrows();
164 let input_size = w_ih.ncols();
165 if hidden_size == 0 {
166 return Err(RecurrentError::InvalidHiddenSize(hidden_size));
167 }
168 if input_size == 0 {
169 return Err(RecurrentError::InvalidInputSize(input_size));
170 }
171 if w_hh.nrows() != hidden_size || w_hh.ncols() != hidden_size {
173 return Err(RecurrentError::ShapeMismatch {
174 expected: vec![hidden_size, hidden_size],
175 got: vec![w_hh.nrows(), w_hh.ncols()],
176 });
177 }
178 if b_ih.len() != hidden_size {
179 return Err(RecurrentError::ShapeMismatch {
180 expected: vec![hidden_size],
181 got: vec![b_ih.len()],
182 });
183 }
184 if b_hh.len() != hidden_size {
185 return Err(RecurrentError::ShapeMismatch {
186 expected: vec![hidden_size],
187 got: vec![b_hh.len()],
188 });
189 }
190 Ok(Self {
191 input_size,
192 hidden_size,
193 w_ih,
194 w_hh,
195 b_ih,
196 b_hh,
197 })
198 }
199
200 pub fn forward(
209 &self,
210 input: &Array1<f64>,
211 hidden: &Array1<f64>,
212 ) -> Result<Array1<f64>, RecurrentError> {
213 if input.len() != self.input_size {
214 return Err(RecurrentError::ShapeMismatch {
215 expected: vec![self.input_size],
216 got: vec![input.len()],
217 });
218 }
219 if hidden.len() != self.hidden_size {
220 return Err(RecurrentError::ShapeMismatch {
221 expected: vec![self.hidden_size],
222 got: vec![hidden.len()],
223 });
224 }
225 let pre_act = self.w_ih.dot(input) + &self.b_ih + self.w_hh.dot(hidden) + &self.b_hh;
227 Ok(pre_act.mapv(f64::tanh))
228 }
229
230 pub fn init_hidden(&self) -> Array1<f64> {
232 Array1::zeros(self.hidden_size)
233 }
234
235 pub fn num_parameters(&self) -> usize {
237 self.hidden_size * self.input_size + self.hidden_size * self.hidden_size + self.hidden_size + self.hidden_size }
242}
243
244#[derive(Debug, Clone)]
250pub struct LstmState {
251 pub h: Array1<f64>,
253 pub c: Array1<f64>,
255}
256
257impl LstmState {
258 pub fn zeros(hidden_size: usize) -> Self {
260 Self {
261 h: Array1::zeros(hidden_size),
262 c: Array1::zeros(hidden_size),
263 }
264 }
265}
266
267#[derive(Debug, Clone)]
285pub struct LstmCell {
286 pub input_size: usize,
288 pub hidden_size: usize,
290 pub w_ih: Array2<f64>,
292 pub w_hh: Array2<f64>,
294 pub b_ih: Array1<f64>,
296 pub b_hh: Array1<f64>,
298}
299
300impl LstmCell {
301 pub fn new(input_size: usize, hidden_size: usize) -> Result<Self, RecurrentError> {
303 if input_size == 0 {
304 return Err(RecurrentError::InvalidInputSize(input_size));
305 }
306 if hidden_size == 0 {
307 return Err(RecurrentError::InvalidHiddenSize(hidden_size));
308 }
309 let scale = 0.1_f64;
310 let mut state: u64 = 0xfeedface_abcd1234_u64;
311 let gates = 4;
312 let w_ih = lcg_fill_2d(gates * hidden_size, input_size, scale, &mut state);
313 let w_hh = lcg_fill_2d(gates * hidden_size, hidden_size, scale, &mut state);
314 let b_ih = lcg_fill_1d(gates * hidden_size, scale, &mut state);
315 let b_hh = lcg_fill_1d(gates * hidden_size, scale, &mut state);
316 Ok(Self {
317 input_size,
318 hidden_size,
319 w_ih,
320 w_hh,
321 b_ih,
322 b_hh,
323 })
324 }
325
326 pub fn from_weights(
328 w_ih: Array2<f64>,
329 w_hh: Array2<f64>,
330 b_ih: Array1<f64>,
331 b_hh: Array1<f64>,
332 ) -> Result<Self, RecurrentError> {
333 let input_size = w_ih.ncols();
334 if input_size == 0 {
335 return Err(RecurrentError::InvalidInputSize(input_size));
336 }
337 let combined_rows = w_ih.nrows();
338 if combined_rows == 0 || !combined_rows.is_multiple_of(4) {
339 return Err(RecurrentError::ShapeMismatch {
340 expected: vec![0 , input_size],
341 got: vec![combined_rows, input_size],
342 });
343 }
344 let hidden_size = combined_rows / 4;
345 if w_hh.nrows() != combined_rows || w_hh.ncols() != hidden_size {
347 return Err(RecurrentError::ShapeMismatch {
348 expected: vec![combined_rows, hidden_size],
349 got: vec![w_hh.nrows(), w_hh.ncols()],
350 });
351 }
352 if b_ih.len() != combined_rows {
353 return Err(RecurrentError::ShapeMismatch {
354 expected: vec![combined_rows],
355 got: vec![b_ih.len()],
356 });
357 }
358 if b_hh.len() != combined_rows {
359 return Err(RecurrentError::ShapeMismatch {
360 expected: vec![combined_rows],
361 got: vec![b_hh.len()],
362 });
363 }
364 Ok(Self {
365 input_size,
366 hidden_size,
367 w_ih,
368 w_hh,
369 b_ih,
370 b_hh,
371 })
372 }
373
374 pub fn forward(
383 &self,
384 input: &Array1<f64>,
385 state: &LstmState,
386 ) -> Result<LstmState, RecurrentError> {
387 if input.len() != self.input_size {
388 return Err(RecurrentError::ShapeMismatch {
389 expected: vec![self.input_size],
390 got: vec![input.len()],
391 });
392 }
393 if state.h.len() != self.hidden_size {
394 return Err(RecurrentError::ShapeMismatch {
395 expected: vec![self.hidden_size],
396 got: vec![state.h.len()],
397 });
398 }
399 if state.c.len() != self.hidden_size {
400 return Err(RecurrentError::ShapeMismatch {
401 expected: vec![self.hidden_size],
402 got: vec![state.c.len()],
403 });
404 }
405
406 let gates_pre = self.w_ih.dot(input) + &self.b_ih + self.w_hh.dot(&state.h) + &self.b_hh;
408
409 let h = self.hidden_size;
410
411 let i_pre = gates_pre.slice(scirs2_core::ndarray::s![..h]).to_owned();
413 let f_pre = gates_pre
414 .slice(scirs2_core::ndarray::s![h..2 * h])
415 .to_owned();
416 let g_pre = gates_pre
417 .slice(scirs2_core::ndarray::s![2 * h..3 * h])
418 .to_owned();
419 let o_pre = gates_pre
420 .slice(scirs2_core::ndarray::s![3 * h..])
421 .to_owned();
422
423 let i_gate = i_pre.mapv(sigmoid);
424 let f_gate = f_pre.mapv(sigmoid);
425 let g_gate = g_pre.mapv(f64::tanh);
426 let o_gate = o_pre.mapv(sigmoid);
427
428 let new_c = &f_gate * &state.c + &i_gate * &g_gate;
430 let new_h = &o_gate * new_c.mapv(f64::tanh);
432
433 Ok(LstmState { h: new_h, c: new_c })
434 }
435
436 pub fn init_state(&self) -> LstmState {
438 LstmState::zeros(self.hidden_size)
439 }
440
441 pub fn num_parameters(&self) -> usize {
443 let gates = 4;
444 gates * self.hidden_size * self.input_size + gates * self.hidden_size * self.hidden_size + gates * self.hidden_size + gates * self.hidden_size }
449}
450
451#[derive(Debug, Clone)]
467pub struct GruCell {
468 pub input_size: usize,
470 pub hidden_size: usize,
472 pub w_ih: Array2<f64>,
474 pub w_hh: Array2<f64>,
476 pub b_ih: Array1<f64>,
478 pub b_hh: Array1<f64>,
480}
481
482impl GruCell {
483 pub fn new(input_size: usize, hidden_size: usize) -> Result<Self, RecurrentError> {
485 if input_size == 0 {
486 return Err(RecurrentError::InvalidInputSize(input_size));
487 }
488 if hidden_size == 0 {
489 return Err(RecurrentError::InvalidHiddenSize(hidden_size));
490 }
491 let scale = 0.1_f64;
492 let mut state: u64 = 0xc0ffee00_87654321_u64;
493 let gates = 3;
494 let w_ih = lcg_fill_2d(gates * hidden_size, input_size, scale, &mut state);
495 let w_hh = lcg_fill_2d(gates * hidden_size, hidden_size, scale, &mut state);
496 let b_ih = lcg_fill_1d(gates * hidden_size, scale, &mut state);
497 let b_hh = lcg_fill_1d(gates * hidden_size, scale, &mut state);
498 Ok(Self {
499 input_size,
500 hidden_size,
501 w_ih,
502 w_hh,
503 b_ih,
504 b_hh,
505 })
506 }
507
508 pub fn from_weights(
510 w_ih: Array2<f64>,
511 w_hh: Array2<f64>,
512 b_ih: Array1<f64>,
513 b_hh: Array1<f64>,
514 ) -> Result<Self, RecurrentError> {
515 let input_size = w_ih.ncols();
516 if input_size == 0 {
517 return Err(RecurrentError::InvalidInputSize(input_size));
518 }
519 let combined_rows = w_ih.nrows();
520 if combined_rows == 0 || !combined_rows.is_multiple_of(3) {
521 return Err(RecurrentError::ShapeMismatch {
522 expected: vec![0 , input_size],
523 got: vec![combined_rows, input_size],
524 });
525 }
526 let hidden_size = combined_rows / 3;
527 if w_hh.nrows() != combined_rows || w_hh.ncols() != hidden_size {
528 return Err(RecurrentError::ShapeMismatch {
529 expected: vec![combined_rows, hidden_size],
530 got: vec![w_hh.nrows(), w_hh.ncols()],
531 });
532 }
533 if b_ih.len() != combined_rows {
534 return Err(RecurrentError::ShapeMismatch {
535 expected: vec![combined_rows],
536 got: vec![b_ih.len()],
537 });
538 }
539 if b_hh.len() != combined_rows {
540 return Err(RecurrentError::ShapeMismatch {
541 expected: vec![combined_rows],
542 got: vec![b_hh.len()],
543 });
544 }
545 Ok(Self {
546 input_size,
547 hidden_size,
548 w_ih,
549 w_hh,
550 b_ih,
551 b_hh,
552 })
553 }
554
555 pub fn forward(
564 &self,
565 input: &Array1<f64>,
566 hidden: &Array1<f64>,
567 ) -> Result<Array1<f64>, RecurrentError> {
568 if input.len() != self.input_size {
569 return Err(RecurrentError::ShapeMismatch {
570 expected: vec![self.input_size],
571 got: vec![input.len()],
572 });
573 }
574 if hidden.len() != self.hidden_size {
575 return Err(RecurrentError::ShapeMismatch {
576 expected: vec![self.hidden_size],
577 got: vec![hidden.len()],
578 });
579 }
580
581 let h = self.hidden_size;
582
583 let x_pre = self.w_ih.dot(input) + &self.b_ih;
585 let h_pre = self.w_hh.dot(hidden) + &self.b_hh;
587
588 let r_pre = x_pre.slice(scirs2_core::ndarray::s![..h]).to_owned()
590 + h_pre.slice(scirs2_core::ndarray::s![..h]).to_owned();
591 let z_pre = x_pre.slice(scirs2_core::ndarray::s![h..2 * h]).to_owned()
592 + h_pre.slice(scirs2_core::ndarray::s![h..2 * h]).to_owned();
593
594 let r_gate = r_pre.mapv(sigmoid);
595 let z_gate = z_pre.mapv(sigmoid);
596
597 let n_x = x_pre.slice(scirs2_core::ndarray::s![2 * h..]).to_owned();
599 let n_h = h_pre.slice(scirs2_core::ndarray::s![2 * h..]).to_owned();
600 let n_pre = n_x + &r_gate * n_h;
601 let n_gate = n_pre.mapv(f64::tanh);
602
603 let ones = Array1::<f64>::ones(h);
605 let new_h = (&ones - &z_gate) * &n_gate + &z_gate * hidden;
606 Ok(new_h)
607 }
608
609 pub fn init_hidden(&self) -> Array1<f64> {
611 Array1::zeros(self.hidden_size)
612 }
613
614 pub fn num_parameters(&self) -> usize {
616 let gates = 3;
617 gates * self.hidden_size * self.input_size + gates * self.hidden_size * self.hidden_size + gates * self.hidden_size + gates * self.hidden_size }
622}
623
624pub fn rnn_sequence(
637 cell: &RnnCell,
638 inputs: &[Array1<f64>],
639) -> Result<Vec<Array1<f64>>, RecurrentError> {
640 if inputs.is_empty() {
641 return Err(RecurrentError::EmptySequence);
642 }
643 let mut hidden = cell.init_hidden();
644 let mut outputs = Vec::with_capacity(inputs.len());
645 for x in inputs {
646 hidden = cell.forward(x, &hidden)?;
647 outputs.push(hidden.clone());
648 }
649 Ok(outputs)
650}
651
652pub fn lstm_sequence(
661 cell: &LstmCell,
662 inputs: &[Array1<f64>],
663) -> Result<(Vec<Array1<f64>>, LstmState), RecurrentError> {
664 if inputs.is_empty() {
665 return Err(RecurrentError::EmptySequence);
666 }
667 let mut state = cell.init_state();
668 let mut hidden_states = Vec::with_capacity(inputs.len());
669 for x in inputs {
670 state = cell.forward(x, &state)?;
671 hidden_states.push(state.h.clone());
672 }
673 Ok((hidden_states, state))
674}
675
676pub fn gru_sequence(
685 cell: &GruCell,
686 inputs: &[Array1<f64>],
687) -> Result<Vec<Array1<f64>>, RecurrentError> {
688 if inputs.is_empty() {
689 return Err(RecurrentError::EmptySequence);
690 }
691 let mut hidden = cell.init_hidden();
692 let mut outputs = Vec::with_capacity(inputs.len());
693 for x in inputs {
694 hidden = cell.forward(x, &hidden)?;
695 outputs.push(hidden.clone());
696 }
697 Ok(outputs)
698}
699
700#[derive(Debug, Clone)]
706pub struct RecurrentStats {
707 pub cell_type: String,
709 pub input_size: usize,
711 pub hidden_size: usize,
713 pub num_parameters: usize,
715 pub sequence_length: Option<usize>,
717}
718
719impl RecurrentStats {
720 pub fn summary(&self) -> String {
722 let seq = match self.sequence_length {
723 Some(t) => format!("seq_len={t}"),
724 None => "seq_len=n/a".to_string(),
725 };
726 format!(
727 "{} | input={} hidden={} params={} {}",
728 self.cell_type, self.input_size, self.hidden_size, self.num_parameters, seq
729 )
730 }
731}
732
733#[cfg(test)]
738mod tests {
739 use super::*;
740 use scirs2_core::ndarray::Array1;
741
742 #[test]
745 fn test_rnn_cell_new() {
746 let cell = RnnCell::new(4, 8);
747 assert!(cell.is_ok(), "RnnCell::new should succeed");
748 }
749
750 #[test]
751 fn test_rnn_cell_forward_shape() {
752 let cell = RnnCell::new(4, 8).expect("construct rnn");
753 let x = Array1::zeros(4);
754 let h = cell.init_hidden();
755 let h_new = cell.forward(&x, &h).expect("rnn forward");
756 assert_eq!(h_new.len(), 8);
757 }
758
759 #[test]
760 fn test_rnn_cell_init_hidden() {
761 let cell = RnnCell::new(3, 5).expect("construct rnn");
762 let h = cell.init_hidden();
763 assert_eq!(h.len(), 5);
764 assert!(h.iter().all(|&v| v == 0.0), "init hidden should be zeros");
765 }
766
767 #[test]
768 fn test_rnn_cell_num_parameters() {
769 let input_size = 4;
770 let hidden_size = 8;
771 let cell = RnnCell::new(input_size, hidden_size).expect("construct rnn");
772 let expected =
774 hidden_size * input_size + hidden_size * hidden_size + hidden_size + hidden_size;
775 assert_eq!(cell.num_parameters(), expected);
776 }
777
778 #[test]
781 fn test_lstm_cell_new() {
782 let cell = LstmCell::new(4, 8);
783 assert!(cell.is_ok(), "LstmCell::new should succeed");
784 }
785
786 #[test]
787 fn test_lstm_cell_forward_shape() {
788 let cell = LstmCell::new(4, 8).expect("construct lstm");
789 let x = Array1::zeros(4);
790 let state = cell.init_state();
791 let new_state = cell.forward(&x, &state).expect("lstm forward");
792 assert_eq!(new_state.h.len(), 8);
793 assert_eq!(new_state.c.len(), 8);
794 }
795
796 #[test]
797 fn test_lstm_cell_init_state() {
798 let cell = LstmCell::new(3, 6).expect("construct lstm");
799 let state = cell.init_state();
800 assert_eq!(state.h.len(), 6);
801 assert_eq!(state.c.len(), 6);
802 assert!(state.h.iter().all(|&v| v == 0.0));
803 assert!(state.c.iter().all(|&v| v == 0.0));
804 }
805
806 #[test]
807 fn test_lstm_cell_gate_bounds() {
808 let cell = LstmCell::new(4, 8).expect("construct lstm");
809 let x = Array1::from_elem(4, 0.5);
810 let state = cell.init_state();
811 let new_state = cell.forward(&x, &state).expect("lstm forward");
812 for &v in new_state.h.iter() {
814 assert!(v > -1.0 && v < 1.0, "h element out of (-1,1): {v}");
815 }
816 }
817
818 #[test]
819 fn test_lstm_cell_num_parameters() {
820 let input_size = 4;
821 let hidden_size = 8;
822 let cell = LstmCell::new(input_size, hidden_size).expect("construct lstm");
823 let gates = 4;
824 let expected = gates * hidden_size * input_size
825 + gates * hidden_size * hidden_size
826 + gates * hidden_size
827 + gates * hidden_size;
828 assert_eq!(cell.num_parameters(), expected);
829 }
830
831 #[test]
834 fn test_gru_cell_new() {
835 let cell = GruCell::new(4, 8);
836 assert!(cell.is_ok(), "GruCell::new should succeed");
837 }
838
839 #[test]
840 fn test_gru_cell_forward_shape() {
841 let cell = GruCell::new(4, 8).expect("construct gru");
842 let x = Array1::zeros(4);
843 let h = cell.init_hidden();
844 let h_new = cell.forward(&x, &h).expect("gru forward");
845 assert_eq!(h_new.len(), 8);
846 }
847
848 #[test]
849 fn test_gru_cell_hidden_init_zeros() {
850 let cell = GruCell::new(3, 5).expect("construct gru");
851 let h = cell.init_hidden();
852 assert_eq!(h.len(), 5);
853 assert!(h.iter().all(|&v| v == 0.0));
854 }
855
856 #[test]
857 fn test_gru_cell_num_parameters() {
858 let input_size = 4;
859 let hidden_size = 8;
860 let cell = GruCell::new(input_size, hidden_size).expect("construct gru");
861 let gates = 3;
862 let expected = gates * hidden_size * input_size
863 + gates * hidden_size * hidden_size
864 + gates * hidden_size
865 + gates * hidden_size;
866 assert_eq!(cell.num_parameters(), expected);
867 }
868
869 #[test]
872 fn test_rnn_sequence_length() {
873 let cell = RnnCell::new(4, 8).expect("rnn");
874 let inputs: Vec<Array1<f64>> = (0..7).map(|_| Array1::zeros(4)).collect();
875 let out = rnn_sequence(&cell, &inputs).expect("rnn sequence");
876 assert_eq!(out.len(), 7, "T inputs → T outputs");
877 }
878
879 #[test]
880 fn test_rnn_sequence_empty_error() {
881 let cell = RnnCell::new(4, 8).expect("rnn");
882 let result = rnn_sequence(&cell, &[]);
883 assert!(
884 matches!(result, Err(RecurrentError::EmptySequence)),
885 "expected EmptySequence error"
886 );
887 }
888
889 #[test]
890 fn test_lstm_sequence_length() {
891 let cell = LstmCell::new(4, 8).expect("lstm");
892 let inputs: Vec<Array1<f64>> = (0..5).map(|_| Array1::zeros(4)).collect();
893 let (hidden_states, _) = lstm_sequence(&cell, &inputs).expect("lstm sequence");
894 assert_eq!(hidden_states.len(), 5);
895 }
896
897 #[test]
898 fn test_lstm_sequence_final_state_nonzero() {
899 let cell = LstmCell::new(4, 8).expect("lstm");
900 let inputs: Vec<Array1<f64>> = (0..3).map(|_| Array1::from_elem(4, 1.0)).collect();
902 let (_, final_state) = lstm_sequence(&cell, &inputs).expect("lstm sequence");
903 let h_norm: f64 = final_state.h.iter().map(|v| v * v).sum::<f64>().sqrt();
904 assert!(
905 h_norm > 1e-12,
906 "final h should be non-zero for non-zero inputs"
907 );
908 }
909
910 #[test]
911 fn test_gru_sequence_length() {
912 let cell = GruCell::new(4, 8).expect("gru");
913 let inputs: Vec<Array1<f64>> = (0..6).map(|_| Array1::zeros(4)).collect();
914 let out = gru_sequence(&cell, &inputs).expect("gru sequence");
915 assert_eq!(out.len(), 6);
916 }
917
918 #[test]
921 fn test_recurrent_stats_summary_nonempty() {
922 let stats = RecurrentStats {
923 cell_type: "LSTM".to_string(),
924 input_size: 4,
925 hidden_size: 8,
926 num_parameters: 416,
927 sequence_length: Some(10),
928 };
929 let s = stats.summary();
930 assert!(!s.is_empty(), "summary should not be empty");
931 assert!(s.contains("LSTM"));
932 assert!(s.contains("416"));
933 }
934
935 #[test]
938 fn test_lstm_cell_from_weights_shape_mismatch() {
939 use scirs2_core::ndarray::Array2;
940 let w_ih = Array2::zeros((8, 4));
942 let w_hh = Array2::zeros((8, 3)); let b_ih = Array1::zeros(8);
944 let b_hh = Array1::zeros(8);
945 let result = LstmCell::from_weights(w_ih, w_hh, b_ih, b_hh);
946 assert!(result.is_err(), "should fail due to w_hh shape mismatch");
947 }
948}