Skip to main content

tensorlogic_scirs_backend/
recurrent.rs

1//! Recurrent neural network cells: RNN, LSTM, GRU.
2//!
3//! Provides cell-level and sequence-level forward pass implementations for the
4//! three most common recurrent architectures.  All weights are stored as plain
5//! `ndarray` arrays so they can be loaded from external checkpoints or
6//! initialised with the built-in deterministic LCG scheme.
7//!
8//! ## Cell types
9//! - [`RnnCell`]  – vanilla tanh-RNN
10//! - [`LstmCell`] – Long Short-Term Memory (LSTM)
11//! - [`GruCell`]  – Gated Recurrent Unit (GRU)
12//!
13//! ## Sequence helpers
14//! - [`rnn_sequence`]  – run `RnnCell` over a slice of inputs
15//! - [`lstm_sequence`] – run `LstmCell` over a slice of inputs
16//! - [`gru_sequence`]  – run `GruCell` over a slice of inputs
17
18use scirs2_core::ndarray::{Array1, Array2};
19
20// ─────────────────────────────────────────────────────────────────────────────
21// Error type
22// ─────────────────────────────────────────────────────────────────────────────
23
24/// Errors that can arise from recurrent cell operations.
25#[derive(Debug, Clone)]
26pub enum RecurrentError {
27    /// A matrix or vector had the wrong shape.
28    ShapeMismatch {
29        /// The shape that was expected.
30        expected: Vec<usize>,
31        /// The shape that was actually provided.
32        got: Vec<usize>,
33    },
34    /// `hidden_size` was zero or otherwise invalid.
35    InvalidHiddenSize(usize),
36    /// `input_size` was zero or otherwise invalid.
37    InvalidInputSize(usize),
38    /// The input sequence had length zero.
39    EmptySequence,
40    /// The input sequence length was invalid for some reason.
41    InvalidSequenceLength {
42        /// The problematic length that was encountered.
43        got: usize,
44    },
45}
46
47impl std::fmt::Display for RecurrentError {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        match self {
50            RecurrentError::ShapeMismatch { expected, got } => {
51                write!(f, "shape mismatch: expected {:?}, got {:?}", expected, got)
52            }
53            RecurrentError::InvalidHiddenSize(s) => {
54                write!(f, "invalid hidden_size: {s}")
55            }
56            RecurrentError::InvalidInputSize(s) => {
57                write!(f, "invalid input_size: {s}")
58            }
59            RecurrentError::EmptySequence => {
60                write!(f, "input sequence must not be empty")
61            }
62            RecurrentError::InvalidSequenceLength { got } => {
63                write!(f, "invalid sequence length: {got}")
64            }
65        }
66    }
67}
68
69impl std::error::Error for RecurrentError {}
70
71// ─────────────────────────────────────────────────────────────────────────────
72// Internal helpers
73// ─────────────────────────────────────────────────────────────────────────────
74
75/// Sigmoid activation: σ(x) = 1 / (1 + e^{-x}).
76#[inline]
77fn sigmoid(x: f64) -> f64 {
78    1.0 / (1.0 + (-x).exp())
79}
80
81/// Deterministic LCG-based pseudo-random value in [-scale, scale].
82///
83/// Uses the same constants as the rest of the crate so that different modules
84/// produce consistent-looking weight matrices.
85#[inline]
86fn lcg_value(state: &mut u64, scale: f64) -> f64 {
87    *state = state
88        .wrapping_mul(6364136223846793005_u64)
89        .wrapping_add(1442695040888963407_u64);
90    // Map [0, 2^64) → [-scale, scale]
91    let normalised = (*state as f64) / (u64::MAX as f64); // [0, 1]
92    (normalised * 2.0 - 1.0) * scale
93}
94
95/// Fill an `Array2<f64>` with LCG-generated values in [-scale, scale].
96fn lcg_fill_2d(rows: usize, cols: usize, scale: f64, state: &mut u64) -> Array2<f64> {
97    let data: Vec<f64> = (0..rows * cols).map(|_| lcg_value(state, scale)).collect();
98    // unwrap-free: we construct the vec with exactly rows*cols elements
99    Array2::from_shape_vec((rows, cols), data).unwrap_or_else(|_| Array2::zeros((rows, cols)))
100}
101
102/// Fill an `Array1<f64>` with LCG-generated values in [-scale, scale].
103fn lcg_fill_1d(len: usize, scale: f64, state: &mut u64) -> Array1<f64> {
104    let data: Vec<f64> = (0..len).map(|_| lcg_value(state, scale)).collect();
105    Array1::from_vec(data)
106}
107
108// ─────────────────────────────────────────────────────────────────────────────
109// RnnCell
110// ─────────────────────────────────────────────────────────────────────────────
111
112/// Vanilla RNN cell.
113///
114/// Computes: `h_t = tanh(W_ih @ x_t + b_ih + W_hh @ h_{t-1} + b_hh)`
115#[derive(Debug, Clone)]
116pub struct RnnCell {
117    /// Number of input features.
118    pub input_size: usize,
119    /// Number of hidden units.
120    pub hidden_size: usize,
121    /// Input-hidden weight matrix `[hidden_size, input_size]`.
122    pub w_ih: Array2<f64>,
123    /// Hidden-hidden weight matrix `[hidden_size, hidden_size]`.
124    pub w_hh: Array2<f64>,
125    /// Input-hidden bias `[hidden_size]`.
126    pub b_ih: Array1<f64>,
127    /// Hidden-hidden bias `[hidden_size]`.
128    pub b_hh: Array1<f64>,
129}
130
131impl RnnCell {
132    /// Construct a new RNN cell with small deterministic random weights.
133    pub fn new(input_size: usize, hidden_size: usize) -> Result<Self, RecurrentError> {
134        if input_size == 0 {
135            return Err(RecurrentError::InvalidInputSize(input_size));
136        }
137        if hidden_size == 0 {
138            return Err(RecurrentError::InvalidHiddenSize(hidden_size));
139        }
140        let scale = 0.1_f64;
141        let mut state: u64 = 0xdeadbeef_12345678_u64;
142        let w_ih = lcg_fill_2d(hidden_size, input_size, scale, &mut state);
143        let w_hh = lcg_fill_2d(hidden_size, hidden_size, scale, &mut state);
144        let b_ih = lcg_fill_1d(hidden_size, scale, &mut state);
145        let b_hh = lcg_fill_1d(hidden_size, scale, &mut state);
146        Ok(Self {
147            input_size,
148            hidden_size,
149            w_ih,
150            w_hh,
151            b_ih,
152            b_hh,
153        })
154    }
155
156    /// Construct an RNN cell from pre-existing weight arrays.
157    pub fn from_weights(
158        w_ih: Array2<f64>,
159        w_hh: Array2<f64>,
160        b_ih: Array1<f64>,
161        b_hh: Array1<f64>,
162    ) -> Result<Self, RecurrentError> {
163        let hidden_size = w_ih.nrows();
164        let input_size = w_ih.ncols();
165        if hidden_size == 0 {
166            return Err(RecurrentError::InvalidHiddenSize(hidden_size));
167        }
168        if input_size == 0 {
169            return Err(RecurrentError::InvalidInputSize(input_size));
170        }
171        // Validate w_hh shape
172        if w_hh.nrows() != hidden_size || w_hh.ncols() != hidden_size {
173            return Err(RecurrentError::ShapeMismatch {
174                expected: vec![hidden_size, hidden_size],
175                got: vec![w_hh.nrows(), w_hh.ncols()],
176            });
177        }
178        if b_ih.len() != hidden_size {
179            return Err(RecurrentError::ShapeMismatch {
180                expected: vec![hidden_size],
181                got: vec![b_ih.len()],
182            });
183        }
184        if b_hh.len() != hidden_size {
185            return Err(RecurrentError::ShapeMismatch {
186                expected: vec![hidden_size],
187                got: vec![b_hh.len()],
188            });
189        }
190        Ok(Self {
191            input_size,
192            hidden_size,
193            w_ih,
194            w_hh,
195            b_ih,
196            b_hh,
197        })
198    }
199
200    /// Run one step forward.
201    ///
202    /// # Arguments
203    /// * `input`  – shape `[input_size]`
204    /// * `hidden` – shape `[hidden_size]`
205    ///
206    /// # Returns
207    /// New hidden state, shape `[hidden_size]`.
208    pub fn forward(
209        &self,
210        input: &Array1<f64>,
211        hidden: &Array1<f64>,
212    ) -> Result<Array1<f64>, RecurrentError> {
213        if input.len() != self.input_size {
214            return Err(RecurrentError::ShapeMismatch {
215                expected: vec![self.input_size],
216                got: vec![input.len()],
217            });
218        }
219        if hidden.len() != self.hidden_size {
220            return Err(RecurrentError::ShapeMismatch {
221                expected: vec![self.hidden_size],
222                got: vec![hidden.len()],
223            });
224        }
225        // h_t = tanh(W_ih @ x + b_ih + W_hh @ h + b_hh)
226        let pre_act = self.w_ih.dot(input) + &self.b_ih + self.w_hh.dot(hidden) + &self.b_hh;
227        Ok(pre_act.mapv(f64::tanh))
228    }
229
230    /// Return an all-zeros initial hidden state.
231    pub fn init_hidden(&self) -> Array1<f64> {
232        Array1::zeros(self.hidden_size)
233    }
234
235    /// Total number of learnable parameters.
236    pub fn num_parameters(&self) -> usize {
237        self.hidden_size * self.input_size    // w_ih
238            + self.hidden_size * self.hidden_size // w_hh
239            + self.hidden_size                    // b_ih
240            + self.hidden_size // b_hh
241    }
242}
243
244// ─────────────────────────────────────────────────────────────────────────────
245// LstmState
246// ─────────────────────────────────────────────────────────────────────────────
247
248/// Combined hidden and cell state for an LSTM.
249#[derive(Debug, Clone)]
250pub struct LstmState {
251    /// Hidden state `h`, shape `[hidden_size]`.
252    pub h: Array1<f64>,
253    /// Cell state `c`, shape `[hidden_size]`.
254    pub c: Array1<f64>,
255}
256
257impl LstmState {
258    /// Create an all-zeros initial state.
259    pub fn zeros(hidden_size: usize) -> Self {
260        Self {
261            h: Array1::zeros(hidden_size),
262            c: Array1::zeros(hidden_size),
263        }
264    }
265}
266
267// ─────────────────────────────────────────────────────────────────────────────
268// LstmCell
269// ─────────────────────────────────────────────────────────────────────────────
270
271/// LSTM cell with combined gate weight matrices.
272///
273/// Weight row ordering: `[input_gate, forget_gate, cell_gate, output_gate]`.
274///
275/// Forward pass:
276/// ```text
277/// i = σ(W_ii @ x + b_ii + W_hi @ h + b_hi)   [rows 0   .. h]
278/// f = σ(W_if @ x + b_if + W_hf @ h + b_hf)   [rows h   .. 2h]
279/// g = tanh(W_ig @ x + b_ig + W_hg @ h + b_hg) [rows 2h .. 3h]
280/// o = σ(W_io @ x + b_io + W_ho @ h + b_ho)   [rows 3h .. 4h]
281/// c' = f ⊙ c + i ⊙ g
282/// h' = o ⊙ tanh(c')
283/// ```
284#[derive(Debug, Clone)]
285pub struct LstmCell {
286    /// Number of input features.
287    pub input_size: usize,
288    /// Number of hidden units.
289    pub hidden_size: usize,
290    /// Combined input-hidden weight matrix `[4*hidden_size, input_size]`.
291    pub w_ih: Array2<f64>,
292    /// Combined hidden-hidden weight matrix `[4*hidden_size, hidden_size]`.
293    pub w_hh: Array2<f64>,
294    /// Combined input-hidden bias `[4*hidden_size]`.
295    pub b_ih: Array1<f64>,
296    /// Combined hidden-hidden bias `[4*hidden_size]`.
297    pub b_hh: Array1<f64>,
298}
299
300impl LstmCell {
301    /// Construct a new LSTM cell with small deterministic random weights.
302    pub fn new(input_size: usize, hidden_size: usize) -> Result<Self, RecurrentError> {
303        if input_size == 0 {
304            return Err(RecurrentError::InvalidInputSize(input_size));
305        }
306        if hidden_size == 0 {
307            return Err(RecurrentError::InvalidHiddenSize(hidden_size));
308        }
309        let scale = 0.1_f64;
310        let mut state: u64 = 0xfeedface_abcd1234_u64;
311        let gates = 4;
312        let w_ih = lcg_fill_2d(gates * hidden_size, input_size, scale, &mut state);
313        let w_hh = lcg_fill_2d(gates * hidden_size, hidden_size, scale, &mut state);
314        let b_ih = lcg_fill_1d(gates * hidden_size, scale, &mut state);
315        let b_hh = lcg_fill_1d(gates * hidden_size, scale, &mut state);
316        Ok(Self {
317            input_size,
318            hidden_size,
319            w_ih,
320            w_hh,
321            b_ih,
322            b_hh,
323        })
324    }
325
326    /// Construct an LSTM cell from pre-existing weight arrays.
327    pub fn from_weights(
328        w_ih: Array2<f64>,
329        w_hh: Array2<f64>,
330        b_ih: Array1<f64>,
331        b_hh: Array1<f64>,
332    ) -> Result<Self, RecurrentError> {
333        let input_size = w_ih.ncols();
334        if input_size == 0 {
335            return Err(RecurrentError::InvalidInputSize(input_size));
336        }
337        let combined_rows = w_ih.nrows();
338        if combined_rows == 0 || !combined_rows.is_multiple_of(4) {
339            return Err(RecurrentError::ShapeMismatch {
340                expected: vec![0 /* 4*h */, input_size],
341                got: vec![combined_rows, input_size],
342            });
343        }
344        let hidden_size = combined_rows / 4;
345        // Validate all other tensors
346        if w_hh.nrows() != combined_rows || w_hh.ncols() != hidden_size {
347            return Err(RecurrentError::ShapeMismatch {
348                expected: vec![combined_rows, hidden_size],
349                got: vec![w_hh.nrows(), w_hh.ncols()],
350            });
351        }
352        if b_ih.len() != combined_rows {
353            return Err(RecurrentError::ShapeMismatch {
354                expected: vec![combined_rows],
355                got: vec![b_ih.len()],
356            });
357        }
358        if b_hh.len() != combined_rows {
359            return Err(RecurrentError::ShapeMismatch {
360                expected: vec![combined_rows],
361                got: vec![b_hh.len()],
362            });
363        }
364        Ok(Self {
365            input_size,
366            hidden_size,
367            w_ih,
368            w_hh,
369            b_ih,
370            b_hh,
371        })
372    }
373
374    /// Run one LSTM step.
375    ///
376    /// # Arguments
377    /// * `input` – shape `[input_size]`
378    /// * `state` – current `(h, c)` state
379    ///
380    /// # Returns
381    /// Updated `LstmState`.
382    pub fn forward(
383        &self,
384        input: &Array1<f64>,
385        state: &LstmState,
386    ) -> Result<LstmState, RecurrentError> {
387        if input.len() != self.input_size {
388            return Err(RecurrentError::ShapeMismatch {
389                expected: vec![self.input_size],
390                got: vec![input.len()],
391            });
392        }
393        if state.h.len() != self.hidden_size {
394            return Err(RecurrentError::ShapeMismatch {
395                expected: vec![self.hidden_size],
396                got: vec![state.h.len()],
397            });
398        }
399        if state.c.len() != self.hidden_size {
400            return Err(RecurrentError::ShapeMismatch {
401                expected: vec![self.hidden_size],
402                got: vec![state.c.len()],
403            });
404        }
405
406        // Combined pre-activations: shape [4*hidden]
407        let gates_pre = self.w_ih.dot(input) + &self.b_ih + self.w_hh.dot(&state.h) + &self.b_hh;
408
409        let h = self.hidden_size;
410
411        // Slice into per-gate vectors (views → owned)
412        let i_pre = gates_pre.slice(scirs2_core::ndarray::s![..h]).to_owned();
413        let f_pre = gates_pre
414            .slice(scirs2_core::ndarray::s![h..2 * h])
415            .to_owned();
416        let g_pre = gates_pre
417            .slice(scirs2_core::ndarray::s![2 * h..3 * h])
418            .to_owned();
419        let o_pre = gates_pre
420            .slice(scirs2_core::ndarray::s![3 * h..])
421            .to_owned();
422
423        let i_gate = i_pre.mapv(sigmoid);
424        let f_gate = f_pre.mapv(sigmoid);
425        let g_gate = g_pre.mapv(f64::tanh);
426        let o_gate = o_pre.mapv(sigmoid);
427
428        // c' = f ⊙ c + i ⊙ g
429        let new_c = &f_gate * &state.c + &i_gate * &g_gate;
430        // h' = o ⊙ tanh(c')
431        let new_h = &o_gate * new_c.mapv(f64::tanh);
432
433        Ok(LstmState { h: new_h, c: new_c })
434    }
435
436    /// Return an all-zeros initial state.
437    pub fn init_state(&self) -> LstmState {
438        LstmState::zeros(self.hidden_size)
439    }
440
441    /// Total number of learnable parameters.
442    pub fn num_parameters(&self) -> usize {
443        let gates = 4;
444        gates * self.hidden_size * self.input_size    // w_ih
445            + gates * self.hidden_size * self.hidden_size // w_hh
446            + gates * self.hidden_size                    // b_ih
447            + gates * self.hidden_size // b_hh
448    }
449}
450
451// ─────────────────────────────────────────────────────────────────────────────
452// GruCell
453// ─────────────────────────────────────────────────────────────────────────────
454
455/// GRU cell with combined gate weight matrices.
456///
457/// Weight row ordering: `[reset_gate, update_gate, new_gate]`.
458///
459/// Forward pass:
460/// ```text
461/// r = σ(W_ir @ x + b_ir + W_hr @ h + b_hr)           [rows 0   .. h]
462/// z = σ(W_iz @ x + b_iz + W_hz @ h + b_hz)           [rows h   .. 2h]
463/// n = tanh(W_in @ x + b_in + r ⊙ (W_hn @ h + b_hn)) [rows 2h .. 3h]
464/// h' = (1 - z) ⊙ n + z ⊙ h
465/// ```
466#[derive(Debug, Clone)]
467pub struct GruCell {
468    /// Number of input features.
469    pub input_size: usize,
470    /// Number of hidden units.
471    pub hidden_size: usize,
472    /// Combined input-hidden weight matrix `[3*hidden_size, input_size]`.
473    pub w_ih: Array2<f64>,
474    /// Combined hidden-hidden weight matrix `[3*hidden_size, hidden_size]`.
475    pub w_hh: Array2<f64>,
476    /// Combined input-hidden bias `[3*hidden_size]`.
477    pub b_ih: Array1<f64>,
478    /// Combined hidden-hidden bias `[3*hidden_size]`.
479    pub b_hh: Array1<f64>,
480}
481
482impl GruCell {
483    /// Construct a new GRU cell with small deterministic random weights.
484    pub fn new(input_size: usize, hidden_size: usize) -> Result<Self, RecurrentError> {
485        if input_size == 0 {
486            return Err(RecurrentError::InvalidInputSize(input_size));
487        }
488        if hidden_size == 0 {
489            return Err(RecurrentError::InvalidHiddenSize(hidden_size));
490        }
491        let scale = 0.1_f64;
492        let mut state: u64 = 0xc0ffee00_87654321_u64;
493        let gates = 3;
494        let w_ih = lcg_fill_2d(gates * hidden_size, input_size, scale, &mut state);
495        let w_hh = lcg_fill_2d(gates * hidden_size, hidden_size, scale, &mut state);
496        let b_ih = lcg_fill_1d(gates * hidden_size, scale, &mut state);
497        let b_hh = lcg_fill_1d(gates * hidden_size, scale, &mut state);
498        Ok(Self {
499            input_size,
500            hidden_size,
501            w_ih,
502            w_hh,
503            b_ih,
504            b_hh,
505        })
506    }
507
508    /// Construct a GRU cell from pre-existing weight arrays.
509    pub fn from_weights(
510        w_ih: Array2<f64>,
511        w_hh: Array2<f64>,
512        b_ih: Array1<f64>,
513        b_hh: Array1<f64>,
514    ) -> Result<Self, RecurrentError> {
515        let input_size = w_ih.ncols();
516        if input_size == 0 {
517            return Err(RecurrentError::InvalidInputSize(input_size));
518        }
519        let combined_rows = w_ih.nrows();
520        if combined_rows == 0 || !combined_rows.is_multiple_of(3) {
521            return Err(RecurrentError::ShapeMismatch {
522                expected: vec![0 /* 3*h */, input_size],
523                got: vec![combined_rows, input_size],
524            });
525        }
526        let hidden_size = combined_rows / 3;
527        if w_hh.nrows() != combined_rows || w_hh.ncols() != hidden_size {
528            return Err(RecurrentError::ShapeMismatch {
529                expected: vec![combined_rows, hidden_size],
530                got: vec![w_hh.nrows(), w_hh.ncols()],
531            });
532        }
533        if b_ih.len() != combined_rows {
534            return Err(RecurrentError::ShapeMismatch {
535                expected: vec![combined_rows],
536                got: vec![b_ih.len()],
537            });
538        }
539        if b_hh.len() != combined_rows {
540            return Err(RecurrentError::ShapeMismatch {
541                expected: vec![combined_rows],
542                got: vec![b_hh.len()],
543            });
544        }
545        Ok(Self {
546            input_size,
547            hidden_size,
548            w_ih,
549            w_hh,
550            b_ih,
551            b_hh,
552        })
553    }
554
555    /// Run one GRU step.
556    ///
557    /// # Arguments
558    /// * `input`  – shape `[input_size]`
559    /// * `hidden` – shape `[hidden_size]`
560    ///
561    /// # Returns
562    /// New hidden state, shape `[hidden_size]`.
563    pub fn forward(
564        &self,
565        input: &Array1<f64>,
566        hidden: &Array1<f64>,
567    ) -> Result<Array1<f64>, RecurrentError> {
568        if input.len() != self.input_size {
569            return Err(RecurrentError::ShapeMismatch {
570                expected: vec![self.input_size],
571                got: vec![input.len()],
572            });
573        }
574        if hidden.len() != self.hidden_size {
575            return Err(RecurrentError::ShapeMismatch {
576                expected: vec![self.hidden_size],
577                got: vec![hidden.len()],
578            });
579        }
580
581        let h = self.hidden_size;
582
583        // Input-side pre-activations: [3h]
584        let x_pre = self.w_ih.dot(input) + &self.b_ih;
585        // Hidden-side pre-activations: [3h]
586        let h_pre = self.w_hh.dot(hidden) + &self.b_hh;
587
588        // Reset and update gates use the sum of both sides
589        let r_pre = x_pre.slice(scirs2_core::ndarray::s![..h]).to_owned()
590            + h_pre.slice(scirs2_core::ndarray::s![..h]).to_owned();
591        let z_pre = x_pre.slice(scirs2_core::ndarray::s![h..2 * h]).to_owned()
592            + h_pre.slice(scirs2_core::ndarray::s![h..2 * h]).to_owned();
593
594        let r_gate = r_pre.mapv(sigmoid);
595        let z_gate = z_pre.mapv(sigmoid);
596
597        // New gate: x part + r ⊙ h part
598        let n_x = x_pre.slice(scirs2_core::ndarray::s![2 * h..]).to_owned();
599        let n_h = h_pre.slice(scirs2_core::ndarray::s![2 * h..]).to_owned();
600        let n_pre = n_x + &r_gate * n_h;
601        let n_gate = n_pre.mapv(f64::tanh);
602
603        // h' = (1 - z) ⊙ n + z ⊙ h
604        let ones = Array1::<f64>::ones(h);
605        let new_h = (&ones - &z_gate) * &n_gate + &z_gate * hidden;
606        Ok(new_h)
607    }
608
609    /// Return an all-zeros initial hidden state.
610    pub fn init_hidden(&self) -> Array1<f64> {
611        Array1::zeros(self.hidden_size)
612    }
613
614    /// Total number of learnable parameters.
615    pub fn num_parameters(&self) -> usize {
616        let gates = 3;
617        gates * self.hidden_size * self.input_size    // w_ih
618            + gates * self.hidden_size * self.hidden_size // w_hh
619            + gates * self.hidden_size                    // b_ih
620            + gates * self.hidden_size // b_hh
621    }
622}
623
624// ─────────────────────────────────────────────────────────────────────────────
625// Sequence helpers
626// ─────────────────────────────────────────────────────────────────────────────
627
628/// Run an [`RnnCell`] over a sequence of inputs.
629///
630/// # Arguments
631/// * `cell`   – configured RNN cell
632/// * `inputs` – slice of length `T`, each element shape `[input_size]`
633///
634/// # Returns
635/// `Vec` of `T` hidden states, each shape `[hidden_size]`.
636pub fn rnn_sequence(
637    cell: &RnnCell,
638    inputs: &[Array1<f64>],
639) -> Result<Vec<Array1<f64>>, RecurrentError> {
640    if inputs.is_empty() {
641        return Err(RecurrentError::EmptySequence);
642    }
643    let mut hidden = cell.init_hidden();
644    let mut outputs = Vec::with_capacity(inputs.len());
645    for x in inputs {
646        hidden = cell.forward(x, &hidden)?;
647        outputs.push(hidden.clone());
648    }
649    Ok(outputs)
650}
651
652/// Run an [`LstmCell`] over a sequence of inputs.
653///
654/// # Arguments
655/// * `cell`   – configured LSTM cell
656/// * `inputs` – slice of length `T`, each element shape `[input_size]`
657///
658/// # Returns
659/// `(all_hidden_states, final_state)` where `all_hidden_states` has length `T`.
660pub fn lstm_sequence(
661    cell: &LstmCell,
662    inputs: &[Array1<f64>],
663) -> Result<(Vec<Array1<f64>>, LstmState), RecurrentError> {
664    if inputs.is_empty() {
665        return Err(RecurrentError::EmptySequence);
666    }
667    let mut state = cell.init_state();
668    let mut hidden_states = Vec::with_capacity(inputs.len());
669    for x in inputs {
670        state = cell.forward(x, &state)?;
671        hidden_states.push(state.h.clone());
672    }
673    Ok((hidden_states, state))
674}
675
676/// Run a [`GruCell`] over a sequence of inputs.
677///
678/// # Arguments
679/// * `cell`   – configured GRU cell
680/// * `inputs` – slice of length `T`, each element shape `[input_size]`
681///
682/// # Returns
683/// `Vec` of `T` hidden states, each shape `[hidden_size]`.
684pub fn gru_sequence(
685    cell: &GruCell,
686    inputs: &[Array1<f64>],
687) -> Result<Vec<Array1<f64>>, RecurrentError> {
688    if inputs.is_empty() {
689        return Err(RecurrentError::EmptySequence);
690    }
691    let mut hidden = cell.init_hidden();
692    let mut outputs = Vec::with_capacity(inputs.len());
693    for x in inputs {
694        hidden = cell.forward(x, &hidden)?;
695        outputs.push(hidden.clone());
696    }
697    Ok(outputs)
698}
699
700// ─────────────────────────────────────────────────────────────────────────────
701// RecurrentStats
702// ─────────────────────────────────────────────────────────────────────────────
703
704/// Diagnostic statistics for a recurrent cell / sequence run.
705#[derive(Debug, Clone)]
706pub struct RecurrentStats {
707    /// Human-readable cell type label (e.g. `"RNN"`, `"LSTM"`, `"GRU"`).
708    pub cell_type: String,
709    /// Number of input features.
710    pub input_size: usize,
711    /// Number of hidden units.
712    pub hidden_size: usize,
713    /// Total number of learnable parameters.
714    pub num_parameters: usize,
715    /// Length of the most recently processed sequence, if applicable.
716    pub sequence_length: Option<usize>,
717}
718
719impl RecurrentStats {
720    /// Return a single-line human-readable summary.
721    pub fn summary(&self) -> String {
722        let seq = match self.sequence_length {
723            Some(t) => format!("seq_len={t}"),
724            None => "seq_len=n/a".to_string(),
725        };
726        format!(
727            "{} | input={} hidden={} params={} {}",
728            self.cell_type, self.input_size, self.hidden_size, self.num_parameters, seq
729        )
730    }
731}
732
733// ─────────────────────────────────────────────────────────────────────────────
734// Tests
735// ─────────────────────────────────────────────────────────────────────────────
736
737#[cfg(test)]
738mod tests {
739    use super::*;
740    use scirs2_core::ndarray::Array1;
741
742    // ── RnnCell ────────────────────────────────────────────────────────────
743
744    #[test]
745    fn test_rnn_cell_new() {
746        let cell = RnnCell::new(4, 8);
747        assert!(cell.is_ok(), "RnnCell::new should succeed");
748    }
749
750    #[test]
751    fn test_rnn_cell_forward_shape() {
752        let cell = RnnCell::new(4, 8).expect("construct rnn");
753        let x = Array1::zeros(4);
754        let h = cell.init_hidden();
755        let h_new = cell.forward(&x, &h).expect("rnn forward");
756        assert_eq!(h_new.len(), 8);
757    }
758
759    #[test]
760    fn test_rnn_cell_init_hidden() {
761        let cell = RnnCell::new(3, 5).expect("construct rnn");
762        let h = cell.init_hidden();
763        assert_eq!(h.len(), 5);
764        assert!(h.iter().all(|&v| v == 0.0), "init hidden should be zeros");
765    }
766
767    #[test]
768    fn test_rnn_cell_num_parameters() {
769        let input_size = 4;
770        let hidden_size = 8;
771        let cell = RnnCell::new(input_size, hidden_size).expect("construct rnn");
772        // w_ih: 8*4, w_hh: 8*8, b_ih: 8, b_hh: 8  = 32+64+8+8 = 112
773        let expected =
774            hidden_size * input_size + hidden_size * hidden_size + hidden_size + hidden_size;
775        assert_eq!(cell.num_parameters(), expected);
776    }
777
778    // ── LstmCell ───────────────────────────────────────────────────────────
779
780    #[test]
781    fn test_lstm_cell_new() {
782        let cell = LstmCell::new(4, 8);
783        assert!(cell.is_ok(), "LstmCell::new should succeed");
784    }
785
786    #[test]
787    fn test_lstm_cell_forward_shape() {
788        let cell = LstmCell::new(4, 8).expect("construct lstm");
789        let x = Array1::zeros(4);
790        let state = cell.init_state();
791        let new_state = cell.forward(&x, &state).expect("lstm forward");
792        assert_eq!(new_state.h.len(), 8);
793        assert_eq!(new_state.c.len(), 8);
794    }
795
796    #[test]
797    fn test_lstm_cell_init_state() {
798        let cell = LstmCell::new(3, 6).expect("construct lstm");
799        let state = cell.init_state();
800        assert_eq!(state.h.len(), 6);
801        assert_eq!(state.c.len(), 6);
802        assert!(state.h.iter().all(|&v| v == 0.0));
803        assert!(state.c.iter().all(|&v| v == 0.0));
804    }
805
806    #[test]
807    fn test_lstm_cell_gate_bounds() {
808        let cell = LstmCell::new(4, 8).expect("construct lstm");
809        let x = Array1::from_elem(4, 0.5);
810        let state = cell.init_state();
811        let new_state = cell.forward(&x, &state).expect("lstm forward");
812        // h' = o ⊙ tanh(c') so each element is in (-1, 1)
813        for &v in new_state.h.iter() {
814            assert!(v > -1.0 && v < 1.0, "h element out of (-1,1): {v}");
815        }
816    }
817
818    #[test]
819    fn test_lstm_cell_num_parameters() {
820        let input_size = 4;
821        let hidden_size = 8;
822        let cell = LstmCell::new(input_size, hidden_size).expect("construct lstm");
823        let gates = 4;
824        let expected = gates * hidden_size * input_size
825            + gates * hidden_size * hidden_size
826            + gates * hidden_size
827            + gates * hidden_size;
828        assert_eq!(cell.num_parameters(), expected);
829    }
830
831    // ── GruCell ────────────────────────────────────────────────────────────
832
833    #[test]
834    fn test_gru_cell_new() {
835        let cell = GruCell::new(4, 8);
836        assert!(cell.is_ok(), "GruCell::new should succeed");
837    }
838
839    #[test]
840    fn test_gru_cell_forward_shape() {
841        let cell = GruCell::new(4, 8).expect("construct gru");
842        let x = Array1::zeros(4);
843        let h = cell.init_hidden();
844        let h_new = cell.forward(&x, &h).expect("gru forward");
845        assert_eq!(h_new.len(), 8);
846    }
847
848    #[test]
849    fn test_gru_cell_hidden_init_zeros() {
850        let cell = GruCell::new(3, 5).expect("construct gru");
851        let h = cell.init_hidden();
852        assert_eq!(h.len(), 5);
853        assert!(h.iter().all(|&v| v == 0.0));
854    }
855
856    #[test]
857    fn test_gru_cell_num_parameters() {
858        let input_size = 4;
859        let hidden_size = 8;
860        let cell = GruCell::new(input_size, hidden_size).expect("construct gru");
861        let gates = 3;
862        let expected = gates * hidden_size * input_size
863            + gates * hidden_size * hidden_size
864            + gates * hidden_size
865            + gates * hidden_size;
866        assert_eq!(cell.num_parameters(), expected);
867    }
868
869    // ── Sequence helpers ───────────────────────────────────────────────────
870
871    #[test]
872    fn test_rnn_sequence_length() {
873        let cell = RnnCell::new(4, 8).expect("rnn");
874        let inputs: Vec<Array1<f64>> = (0..7).map(|_| Array1::zeros(4)).collect();
875        let out = rnn_sequence(&cell, &inputs).expect("rnn sequence");
876        assert_eq!(out.len(), 7, "T inputs → T outputs");
877    }
878
879    #[test]
880    fn test_rnn_sequence_empty_error() {
881        let cell = RnnCell::new(4, 8).expect("rnn");
882        let result = rnn_sequence(&cell, &[]);
883        assert!(
884            matches!(result, Err(RecurrentError::EmptySequence)),
885            "expected EmptySequence error"
886        );
887    }
888
889    #[test]
890    fn test_lstm_sequence_length() {
891        let cell = LstmCell::new(4, 8).expect("lstm");
892        let inputs: Vec<Array1<f64>> = (0..5).map(|_| Array1::zeros(4)).collect();
893        let (hidden_states, _) = lstm_sequence(&cell, &inputs).expect("lstm sequence");
894        assert_eq!(hidden_states.len(), 5);
895    }
896
897    #[test]
898    fn test_lstm_sequence_final_state_nonzero() {
899        let cell = LstmCell::new(4, 8).expect("lstm");
900        // Non-zero inputs so that the state is driven away from zero
901        let inputs: Vec<Array1<f64>> = (0..3).map(|_| Array1::from_elem(4, 1.0)).collect();
902        let (_, final_state) = lstm_sequence(&cell, &inputs).expect("lstm sequence");
903        let h_norm: f64 = final_state.h.iter().map(|v| v * v).sum::<f64>().sqrt();
904        assert!(
905            h_norm > 1e-12,
906            "final h should be non-zero for non-zero inputs"
907        );
908    }
909
910    #[test]
911    fn test_gru_sequence_length() {
912        let cell = GruCell::new(4, 8).expect("gru");
913        let inputs: Vec<Array1<f64>> = (0..6).map(|_| Array1::zeros(4)).collect();
914        let out = gru_sequence(&cell, &inputs).expect("gru sequence");
915        assert_eq!(out.len(), 6);
916    }
917
918    // ── RecurrentStats ─────────────────────────────────────────────────────
919
920    #[test]
921    fn test_recurrent_stats_summary_nonempty() {
922        let stats = RecurrentStats {
923            cell_type: "LSTM".to_string(),
924            input_size: 4,
925            hidden_size: 8,
926            num_parameters: 416,
927            sequence_length: Some(10),
928        };
929        let s = stats.summary();
930        assert!(!s.is_empty(), "summary should not be empty");
931        assert!(s.contains("LSTM"));
932        assert!(s.contains("416"));
933    }
934
935    // ── from_weights shape mismatch ─────────────────────────────────────────
936
937    #[test]
938    fn test_lstm_cell_from_weights_shape_mismatch() {
939        use scirs2_core::ndarray::Array2;
940        // w_ih has 4*hidden=8 rows but w_hh has wrong ncols
941        let w_ih = Array2::zeros((8, 4));
942        let w_hh = Array2::zeros((8, 3)); // should be (8, 2) for hidden=2
943        let b_ih = Array1::zeros(8);
944        let b_hh = Array1::zeros(8);
945        let result = LstmCell::from_weights(w_ih, w_hh, b_ih, b_hh);
946        assert!(result.is_err(), "should fail due to w_hh shape mismatch");
947    }
948}