Skip to main content

svod_tensor/nn/
rnn.rs

1//! Recurrent neural network layers (RNN, GRU, LSTM).
2
3use bon::bon;
4
5use crate::error::{NdimExactSnafu, ParamRangeSnafu};
6
7use super::*;
8
9/// Output of an RNN forward pass.
10pub struct RnnOutput {
11    /// All hidden states: `[seq_length, num_directions, batch, hidden_size]`
12    pub y: Tensor,
13    /// Final hidden state: `[num_directions, batch, hidden_size]`
14    pub y_h: Tensor,
15}
16
17/// Output of a GRU forward pass.
18pub struct GruOutput {
19    /// All hidden states: `[seq_length, num_directions, batch, hidden_size]`
20    pub y: Tensor,
21    /// Final hidden state: `[num_directions, batch, hidden_size]`
22    pub y_h: Tensor,
23}
24
25/// Output of an LSTM forward pass.
26pub struct LstmOutput {
27    /// All hidden states: `[seq_length, num_directions, batch, hidden_size]`
28    pub y: Tensor,
29    /// Final hidden state: `[num_directions, batch, hidden_size]`
30    pub y_h: Tensor,
31    /// Final cell state: `[num_directions, batch, hidden_size]`
32    pub y_c: Tensor,
33}
34
35#[bon]
36impl Tensor {
37    /// Simple RNN (Elman network).
38    ///
39    /// `H_t = tanh(X_t @ W^T + H_{t-1} @ R^T + Wb + Rb)`
40    ///
41    /// - `x`: input `[seq_length, batch_size, input_size]` (layout=0) or
42    ///         `[batch_size, seq_length, input_size]` (layout=1)
43    /// - `w`: input weights `[num_directions, hidden_size, input_size]`
44    /// - `r`: recurrence weights `[num_directions, hidden_size, hidden_size]`
45    /// - `bias`: optional bias `[num_directions, 2 * hidden_size]` (Wb ++ Rb)
46    /// - `initial_h`: optional initial hidden state `[num_directions, batch_size, hidden_size]`
47    /// - `layout`: 0 = seq-first (default), 1 = batch-first
48    ///
49    /// # Examples
50    ///
51    /// ```
52    /// # use svod_tensor::Tensor;
53    /// # use ndarray::{array, Array3};
54    /// // seq=2, batch=1, input=3
55    /// let x = Tensor::from_ndarray(&Array3::from_elem((2, 1, 3), 0.1f32));
56    /// let w = Tensor::from_ndarray(&Array3::from_elem((1, 4, 3), 0.1f32)); // [1, hidden=4, input=3]
57    /// let r = Tensor::from_ndarray(&Array3::from_elem((1, 4, 4), 0.1f32)); // [1, hidden=4, hidden=4]
58    /// let out = x.rnn().w(&w).r(&r).hidden_size(4).call().unwrap();
59    /// let y_shape: Vec<usize> = out.y.shape().unwrap().iter()
60    ///     .map(|d| d.as_const().unwrap()).collect();
61    /// assert_eq!(y_shape, vec![2, 1, 1, 4]); // [seq, num_directions, batch, hidden]
62    /// let yh_shape: Vec<usize> = out.y_h.shape().unwrap().iter()
63    ///     .map(|d| d.as_const().unwrap()).collect();
64    /// assert_eq!(yh_shape, vec![1, 1, 4]); // [num_directions, batch, hidden]
65    /// ```
66    #[builder]
67    pub fn rnn(
68        &self,
69        w: &Tensor,
70        r: &Tensor,
71        hidden_size: usize,
72        bias: Option<&Tensor>,
73        initial_h: Option<&Tensor>,
74        #[builder(default = 0)] layout: usize,
75    ) -> Result<RnnOutput> {
76        let ndim = self.ndim()?;
77        snafu::ensure!(ndim == 3, NdimExactSnafu { op: "rnn", expected: 3_usize, actual: ndim });
78        snafu::ensure!(
79            hidden_size > 0,
80            ParamRangeSnafu { op: "rnn", param: "hidden_size", value: hidden_size.to_string(), constraint: "> 0" }
81        );
82        let x = if layout != 0 { self.try_permute(&[1, 0, 2])? } else { self.clone() };
83        let x_shape = x.shape()?;
84        let seq_length = x_shape[0].as_const().expect("static seq_length");
85        let batch_size = x_shape[1].as_const().expect("static batch_size");
86        let input_size = x_shape[2].as_const().expect("static input_size");
87        let num_directions = w.shape()?[0].as_const().expect("static num_directions");
88        let dtype = x.uop().dtype();
89
90        snafu::ensure!(
91            num_directions == 1,
92            ParamRangeSnafu {
93                op: "rnn",
94                param: "num_directions",
95                value: num_directions.to_string(),
96                constraint: "== 1"
97            }
98        );
99
100        let w0 = w.try_squeeze(Some(0))?; // [hidden, input]
101        let r0 = r.try_squeeze(Some(0))?; // [hidden, hidden]
102        let wt = w0.try_permute(&[1, 0])?; // [input, hidden]
103        let rt = r0.try_permute(&[1, 0])?; // [hidden, hidden]
104
105        let combined_bias = if let Some(b) = bias {
106            let b0 = b.try_squeeze(Some(0))?; // [2*hidden]
107            let parts = b0.split(&[hidden_size, hidden_size], 0)?;
108            Some(parts[0].try_add(&parts[1])?) // [hidden]
109        } else {
110            None
111        };
112
113        let mut h_t = if let Some(h0) = initial_h {
114            h0.try_squeeze(Some(0))? // [batch, hidden]
115        } else {
116            Tensor::full(&[batch_size, hidden_size], 0.0f32, dtype)?
117        };
118
119        let mut h_list = Vec::with_capacity(seq_length);
120        for t in 0..seq_length {
121            let x_t =
122                x.try_shrink([(t as isize, t as isize + 1), (0, batch_size as isize), (0, input_size as isize)])?;
123            let x_t = x_t.try_squeeze(Some(0))?; // [batch, input]
124
125            let mut gate = x_t.matmul(&wt)?.try_add(&h_t.matmul(&rt)?)?;
126            if let Some(ref b) = combined_bias {
127                gate = gate.try_add(b)?;
128            }
129            h_t = gate.tanh()?;
130            h_list.push(h_t.clone());
131        }
132
133        let h_refs: Vec<&Tensor> = h_list.iter().collect();
134        let y_seq = Tensor::stack(&h_refs, 0)?; // [seq, batch, hidden]
135        let y = y_seq.try_unsqueeze(1)?; // [seq, 1, batch, hidden]
136
137        let y = if layout != 0 {
138            y.try_permute(&[2, 0, 1, 3])? // [batch, seq, 1, hidden]
139        } else {
140            y
141        };
142
143        let y_h = if layout != 0 {
144            h_t.try_unsqueeze(1)? // [batch, 1, hidden]
145        } else {
146            h_t.try_unsqueeze(0)? // [1, batch, hidden]
147        };
148
149        Ok(RnnOutput { y, y_h })
150    }
151
152    /// GRU (Gated Recurrent Unit).
153    ///
154    /// Gate order: `[z, r, h]` (update, reset, hidden).
155    ///
156    /// Equations (default, `linear_before_reset=0`):
157    /// - `z = sigmoid(X @ W_z^T + H @ R_z^T + w_bz + r_bz)`
158    /// - `r = sigmoid(X @ W_r^T + H @ R_r^T + w_br + r_br)`
159    /// - `h = tanh(X @ W_h^T + (r * H) @ R_h^T + w_bh + r_bh)`
160    /// - `H_new = (1 - z) * h + z * H_prev`
161    ///
162    /// When `linear_before_reset=1`:
163    /// - `h = tanh(X @ W_h^T + r * (H @ R_h^T + r_bh) + w_bh)`
164    ///
165    /// - `x`: input `[seq_length, batch_size, input_size]` (layout=0) or
166    ///         `[batch_size, seq_length, input_size]` (layout=1)
167    /// - `w`: input weights `[num_directions, 3*hidden_size, input_size]`
168    /// - `r_weights`: recurrence weights `[num_directions, 3*hidden_size, hidden_size]`
169    /// - `bias`: optional `[num_directions, 6*hidden_size]` (Wb ++ Rb)
170    /// - `initial_h`: optional `[num_directions, batch_size, hidden_size]`
171    /// - `linear_before_reset`: 0 (default) or 1
172    /// - `layout`: 0 = seq-first (default), 1 = batch-first
173    ///
174    /// # Examples
175    ///
176    /// ```
177    /// # use svod_tensor::Tensor;
178    /// # use ndarray::{array, Array3};
179    /// // seq=2, batch=1, input=3, hidden=4
180    /// let x = Tensor::from_ndarray(&Array3::from_elem((2, 1, 3), 0.1f32));
181    /// // GRU: w is [num_directions, 3*hidden_size, input_size]
182    /// let w = Tensor::from_ndarray(&Array3::from_elem((1, 12, 3), 0.1f32));
183    /// // GRU: r is [num_directions, 3*hidden_size, hidden_size]
184    /// let r = Tensor::from_ndarray(&Array3::from_elem((1, 12, 4), 0.1f32));
185    /// let out = x.gru().w(&w).r_weights(&r).hidden_size(4).call().unwrap();
186    /// let y_shape: Vec<usize> = out.y.shape().unwrap().iter()
187    ///     .map(|d| d.as_const().unwrap()).collect();
188    /// assert_eq!(y_shape, vec![2, 1, 1, 4]); // [seq, num_directions, batch, hidden]
189    /// ```
190    #[builder]
191    pub fn gru(
192        &self,
193        w: &Tensor,
194        r_weights: &Tensor,
195        hidden_size: usize,
196        bias: Option<&Tensor>,
197        initial_h: Option<&Tensor>,
198        #[builder(default = 0)] linear_before_reset: usize,
199        #[builder(default = 0)] layout: usize,
200    ) -> Result<GruOutput> {
201        let ndim = self.ndim()?;
202        snafu::ensure!(ndim == 3, NdimExactSnafu { op: "gru", expected: 3_usize, actual: ndim });
203        snafu::ensure!(
204            hidden_size > 0,
205            ParamRangeSnafu { op: "gru", param: "hidden_size", value: hidden_size.to_string(), constraint: "> 0" }
206        );
207        let x = if layout != 0 { self.try_permute(&[1, 0, 2])? } else { self.clone() };
208        let x_shape = x.shape()?;
209        let seq_length = x_shape[0].as_const().expect("static seq_length");
210        let batch_size = x_shape[1].as_const().expect("static batch_size");
211        let input_size = x_shape[2].as_const().expect("static input_size");
212        let num_directions = w.shape()?[0].as_const().expect("static num_directions");
213        let dtype = x.uop().dtype();
214
215        snafu::ensure!(
216            num_directions == 1,
217            ParamRangeSnafu {
218                op: "gru",
219                param: "num_directions",
220                value: num_directions.to_string(),
221                constraint: "== 1"
222            }
223        );
224
225        let w0 = w.try_squeeze(Some(0))?; // [3*hidden, input]
226        let r0 = r_weights.try_squeeze(Some(0))?; // [3*hidden, hidden]
227
228        // Split W into [W_z, W_r, W_h] and R into [R_z, R_r, R_h]
229        let w_parts = w0.split(&[hidden_size; 3], 0)?;
230        let r_parts = r0.split(&[hidden_size; 3], 0)?;
231
232        // Combine z,r weights for joint computation: gates_w = [W_z; W_r]^T
233        let gates_w = Tensor::cat(&[&w_parts[0], &w_parts[1]], 0)?.try_permute(&[1, 0])?;
234        let gates_r = Tensor::cat(&[&r_parts[0], &r_parts[1]], 0)?.try_permute(&[1, 0])?;
235
236        // W_h and R_h kept separate (reset gate interacts differently)
237        let w_h_t = w_parts[2].try_permute(&[1, 0])?; // [input, hidden]
238        let r_h_t = r_parts[2].try_permute(&[1, 0])?; // [hidden, hidden]
239
240        // Bias: [6*hidden] → [w_bz, w_br, w_bh, r_bz, r_br, r_bh]
241        let (gates_b, w_bh, r_bh) = if let Some(b) = bias {
242            let b0 = b.try_squeeze(Some(0))?;
243            let parts = b0.split(&[hidden_size; 6], 0)?;
244            // gates_b = (w_bz + r_bz) ++ (w_br + r_br)
245            let gates_b = Tensor::cat(&[&parts[0].try_add(&parts[3])?, &parts[1].try_add(&parts[4])?], 0)?;
246            (Some(gates_b), Some(parts[2].clone()), Some(parts[5].clone()))
247        } else {
248            (None, None, None)
249        };
250
251        let mut h_t = if let Some(h0) = initial_h {
252            h0.try_squeeze(Some(0))?
253        } else {
254            Tensor::full(&[batch_size, hidden_size], 0.0f32, dtype)?
255        };
256
257        let mut h_list = Vec::with_capacity(seq_length);
258        for t in 0..seq_length {
259            let x_t =
260                x.try_shrink([(t as isize, t as isize + 1), (0, batch_size as isize), (0, input_size as isize)])?;
261            let x_t = x_t.try_squeeze(Some(0))?; // [batch, input]
262
263            // z, r gates: combined matmul
264            let mut gates = x_t.matmul(&gates_w)?.try_add(&h_t.matmul(&gates_r)?)?;
265            if let Some(ref gb) = gates_b {
266                gates = gates.try_add(gb)?;
267            }
268            let zr = gates.split(&[hidden_size; 2], -1)?;
269            let z = zr[0].sigmoid()?;
270            let r = zr[1].sigmoid()?;
271
272            // Hidden candidate
273            let h_candidate = if linear_before_reset != 0 {
274                // h = tanh(x @ W_h^T + r * (H @ R_h^T + r_bh) + w_bh)
275                let mut rh = h_t.matmul(&r_h_t)?;
276                if let Some(ref rb) = r_bh {
277                    rh = rh.try_add(rb)?;
278                }
279                let mut h = x_t.matmul(&w_h_t)?.try_add(&r.try_mul(&rh)?)?;
280                if let Some(ref wb) = w_bh {
281                    h = h.try_add(wb)?;
282                }
283                h.tanh()?
284            } else {
285                // h = tanh(x @ W_h^T + (r * H) @ R_h^T + w_bh + r_bh)
286                let mut h = x_t.matmul(&w_h_t)?.try_add(&r.try_mul(&h_t)?.matmul(&r_h_t)?)?;
287                if let Some(ref wb) = w_bh {
288                    h = h.try_add(wb)?;
289                }
290                if let Some(ref rb) = r_bh {
291                    h = h.try_add(rb)?;
292                }
293                h.tanh()?
294            };
295
296            // H = (1 - z) * h_candidate + z * H_prev
297            let one = Tensor::full(&[1], 1.0f32, z.uop().dtype())?;
298            h_t = one.try_sub(&z)?.try_mul(&h_candidate)?.try_add(&z.try_mul(&h_t)?)?;
299            h_list.push(h_t.clone());
300        }
301
302        let h_refs: Vec<&Tensor> = h_list.iter().collect();
303        let y_seq = Tensor::stack(&h_refs, 0)?; // [seq, batch, hidden]
304        let y = y_seq.try_unsqueeze(1)?; // [seq, 1, batch, hidden]
305
306        let y = if layout != 0 {
307            y.try_permute(&[2, 0, 1, 3])? // [batch, seq, 1, hidden]
308        } else {
309            y
310        };
311
312        let y_h = if layout != 0 {
313            h_t.try_unsqueeze(1)? // [batch, 1, hidden]
314        } else {
315            h_t.try_unsqueeze(0)? // [1, batch, hidden]
316        };
317
318        Ok(GruOutput { y, y_h })
319    }
320
321    /// LSTM (Long Short-Term Memory).
322    ///
323    /// Gate order: `[i, o, f, c]` (input, output, forget, cell).
324    ///
325    /// - `x`: input `[seq_length, batch_size, input_size]` (layout=0) or
326    ///         `[batch_size, seq_length, input_size]` (layout=1)
327    /// - `w`: input weights `[num_directions, 4*hidden_size, input_size]`
328    /// - `r`: recurrence weights `[num_directions, 4*hidden_size, hidden_size]`
329    /// - `bias`: optional `[num_directions, 8*hidden_size]` (Wb ++ Rb)
330    /// - `initial_h`: optional `[num_directions, batch_size, hidden_size]`
331    /// - `initial_c`: optional `[num_directions, batch_size, hidden_size]`
332    /// - `peepholes`: optional `[num_directions, 3*hidden_size]` (p_i, p_o, p_f)
333    /// - `layout`: 0 = seq-first (default), 1 = batch-first
334    ///
335    /// # Examples
336    ///
337    /// ```
338    /// # use svod_tensor::Tensor;
339    /// # use ndarray::Array3;
340    /// // seq=2, batch=1, input=3, hidden=4
341    /// let x = Tensor::from_ndarray(&Array3::from_elem((2, 1, 3), 0.1f32));
342    /// // LSTM: w is [num_directions, 4*hidden_size, input_size]
343    /// let w = Tensor::from_ndarray(&Array3::from_elem((1, 16, 3), 0.1f32));
344    /// // LSTM: r is [num_directions, 4*hidden_size, hidden_size]
345    /// let r = Tensor::from_ndarray(&Array3::from_elem((1, 16, 4), 0.1f32));
346    /// let out = x.lstm().w(&w).r(&r).hidden_size(4).call().unwrap();
347    /// let y_shape: Vec<usize> = out.y.shape().unwrap().iter()
348    ///     .map(|d| d.as_const().unwrap()).collect();
349    /// assert_eq!(y_shape, vec![2, 1, 1, 4]); // [seq, num_directions, batch, hidden]
350    /// let yc_shape: Vec<usize> = out.y_c.shape().unwrap().iter()
351    ///     .map(|d| d.as_const().unwrap()).collect();
352    /// assert_eq!(yc_shape, vec![1, 1, 4]); // [num_directions, batch, hidden]
353    /// ```
354    #[builder]
355    pub fn lstm(
356        &self,
357        w: &Tensor,
358        r: &Tensor,
359        hidden_size: usize,
360        bias: Option<&Tensor>,
361        initial_h: Option<&Tensor>,
362        initial_c: Option<&Tensor>,
363        peepholes: Option<&Tensor>,
364        #[builder(default = 0)] layout: usize,
365    ) -> Result<LstmOutput> {
366        let ndim = self.ndim()?;
367        snafu::ensure!(ndim == 3, NdimExactSnafu { op: "lstm", expected: 3_usize, actual: ndim });
368        snafu::ensure!(
369            hidden_size > 0,
370            ParamRangeSnafu { op: "lstm", param: "hidden_size", value: hidden_size.to_string(), constraint: "> 0" }
371        );
372        let x = if layout != 0 {
373            self.try_permute(&[1, 0, 2])? // batch-first → seq-first
374        } else {
375            self.clone()
376        };
377        let x_shape = x.shape()?;
378        let seq_length = x_shape[0].as_const().expect("static seq_length");
379        let batch_size = x_shape[1].as_const().expect("static batch_size");
380        let input_size = x_shape[2].as_const().expect("static input_size");
381        let num_directions = w.shape()?[0].as_const().expect("static num_directions");
382        let dtype = x.uop().dtype();
383
384        snafu::ensure!(
385            num_directions == 1,
386            ParamRangeSnafu {
387                op: "lstm",
388                param: "num_directions",
389                value: num_directions.to_string(),
390                constraint: "== 1"
391            }
392        );
393
394        let w0 = w.try_squeeze(Some(0))?; // [4*hidden, input]
395        let r0 = r.try_squeeze(Some(0))?; // [4*hidden, hidden]
396        let wt = w0.try_permute(&[1, 0])?; // [input, 4*hidden]
397        let rt = r0.try_permute(&[1, 0])?; // [hidden, 4*hidden]
398
399        // Bias: [8*hidden] → split into Wb [4*hidden] and Rb [4*hidden], add together
400        let combined_bias = if let Some(b) = bias {
401            let b0 = b.try_squeeze(Some(0))?;
402            let hs4 = 4 * hidden_size;
403            let parts = b0.split(&[hs4, hs4], 0)?;
404            Some(parts[0].try_add(&parts[1])?)
405        } else {
406            None
407        };
408
409        // Peepholes: [3*hidden] → [p_i, p_o, p_f]
410        let (p_i, p_o, p_f) = if let Some(p) = peepholes {
411            let p0 = p.try_squeeze(Some(0))?;
412            let parts = p0.split(&[hidden_size, hidden_size, hidden_size], 0)?;
413            (Some(parts[0].clone()), Some(parts[1].clone()), Some(parts[2].clone()))
414        } else {
415            (None, None, None)
416        };
417
418        let mut h_t = if let Some(h0) = initial_h {
419            h0.try_squeeze(Some(0))?
420        } else {
421            Tensor::full(&[batch_size, hidden_size], 0.0f32, dtype.clone())?
422        };
423        let mut c_t = if let Some(c0) = initial_c {
424            c0.try_squeeze(Some(0))?
425        } else {
426            Tensor::full(&[batch_size, hidden_size], 0.0f32, dtype)?
427        };
428
429        let mut h_list = Vec::with_capacity(seq_length);
430        for t in 0..seq_length {
431            let x_t =
432                x.try_shrink([(t as isize, t as isize + 1), (0, batch_size as isize), (0, input_size as isize)])?;
433            let x_t = x_t.try_squeeze(Some(0))?; // [batch, input]
434
435            // gates = X_t @ W^T + H_{t-1} @ R^T + bias
436            let mut gates = x_t.matmul(&wt)?.try_add(&h_t.matmul(&rt)?)?;
437            if let Some(ref b) = combined_bias {
438                gates = gates.try_add(b)?;
439            }
440
441            // Split into [i, o, f, c] — each [batch, hidden]
442            let gate_parts = gates.split(&[hidden_size; 4], -1)?;
443            let (mut gi, mut go, mut gf, gc) =
444                (gate_parts[0].clone(), gate_parts[1].clone(), gate_parts[2].clone(), gate_parts[3].clone());
445
446            // Peephole connections: i and f use previous cell state
447            if let Some(ref pi) = p_i {
448                gi = gi.try_add(&c_t.try_mul(pi)?)?;
449            }
450            if let Some(ref pf) = p_f {
451                gf = gf.try_add(&c_t.try_mul(pf)?)?;
452            }
453
454            let i = gi.sigmoid()?;
455            let f = gf.sigmoid()?;
456            let c = gc.tanh()?;
457
458            // C = f * C_prev + i * c
459            c_t = f.try_mul(&c_t)?.try_add(&i.try_mul(&c)?)?;
460
461            // Peephole: o uses NEW cell state
462            if let Some(ref po) = p_o {
463                go = go.try_add(&c_t.try_mul(po)?)?;
464            }
465            let o = go.sigmoid()?;
466
467            // H = o * tanh(C)
468            h_t = o.try_mul(&c_t.tanh()?)?;
469            h_list.push(h_t.clone());
470        }
471
472        let h_refs: Vec<&Tensor> = h_list.iter().collect();
473        let y_seq = Tensor::stack(&h_refs, 0)?; // [seq, batch, hidden]
474        let y = y_seq.try_unsqueeze(1)?; // [seq, 1, batch, hidden]
475
476        // Apply layout transform to output
477        let y = if layout != 0 {
478            y.try_permute(&[2, 0, 1, 3])? // [batch, seq, 1, hidden]
479        } else {
480            y
481        };
482
483        let (y_h, y_c) = if layout != 0 {
484            // layout=1: Y_h/Y_c are [batch, num_directions, hidden]
485            (h_t.try_unsqueeze(1)?, c_t.try_unsqueeze(1)?)
486        } else {
487            // layout=0: Y_h/Y_c are [num_directions, batch, hidden]
488            (h_t.try_unsqueeze(0)?, c_t.try_unsqueeze(0)?)
489        };
490
491        Ok(LstmOutput { y, y_h, y_c })
492    }
493}