svod_tensor/nn/rnn.rs
1//! Recurrent neural network layers (RNN, GRU, LSTM).
2
3use bon::bon;
4
5use crate::error::{NdimExactSnafu, ParamRangeSnafu};
6
7use super::*;
8
9/// Output of an RNN forward pass.
10pub struct RnnOutput {
11 /// All hidden states: `[seq_length, num_directions, batch, hidden_size]`
12 pub y: Tensor,
13 /// Final hidden state: `[num_directions, batch, hidden_size]`
14 pub y_h: Tensor,
15}
16
17/// Output of a GRU forward pass.
18pub struct GruOutput {
19 /// All hidden states: `[seq_length, num_directions, batch, hidden_size]`
20 pub y: Tensor,
21 /// Final hidden state: `[num_directions, batch, hidden_size]`
22 pub y_h: Tensor,
23}
24
25/// Output of an LSTM forward pass.
26pub struct LstmOutput {
27 /// All hidden states: `[seq_length, num_directions, batch, hidden_size]`
28 pub y: Tensor,
29 /// Final hidden state: `[num_directions, batch, hidden_size]`
30 pub y_h: Tensor,
31 /// Final cell state: `[num_directions, batch, hidden_size]`
32 pub y_c: Tensor,
33}
34
35#[bon]
36impl Tensor {
37 /// Simple RNN (Elman network).
38 ///
39 /// `H_t = tanh(X_t @ W^T + H_{t-1} @ R^T + Wb + Rb)`
40 ///
41 /// - `x`: input `[seq_length, batch_size, input_size]` (layout=0) or
42 /// `[batch_size, seq_length, input_size]` (layout=1)
43 /// - `w`: input weights `[num_directions, hidden_size, input_size]`
44 /// - `r`: recurrence weights `[num_directions, hidden_size, hidden_size]`
45 /// - `bias`: optional bias `[num_directions, 2 * hidden_size]` (Wb ++ Rb)
46 /// - `initial_h`: optional initial hidden state `[num_directions, batch_size, hidden_size]`
47 /// - `layout`: 0 = seq-first (default), 1 = batch-first
48 ///
49 /// # Examples
50 ///
51 /// ```
52 /// # use svod_tensor::Tensor;
53 /// # use ndarray::{array, Array3};
54 /// // seq=2, batch=1, input=3
55 /// let x = Tensor::from_ndarray(&Array3::from_elem((2, 1, 3), 0.1f32));
56 /// let w = Tensor::from_ndarray(&Array3::from_elem((1, 4, 3), 0.1f32)); // [1, hidden=4, input=3]
57 /// let r = Tensor::from_ndarray(&Array3::from_elem((1, 4, 4), 0.1f32)); // [1, hidden=4, hidden=4]
58 /// let out = x.rnn().w(&w).r(&r).hidden_size(4).call().unwrap();
59 /// let y_shape: Vec<usize> = out.y.shape().unwrap().iter()
60 /// .map(|d| d.as_const().unwrap()).collect();
61 /// assert_eq!(y_shape, vec![2, 1, 1, 4]); // [seq, num_directions, batch, hidden]
62 /// let yh_shape: Vec<usize> = out.y_h.shape().unwrap().iter()
63 /// .map(|d| d.as_const().unwrap()).collect();
64 /// assert_eq!(yh_shape, vec![1, 1, 4]); // [num_directions, batch, hidden]
65 /// ```
66 #[builder]
67 pub fn rnn(
68 &self,
69 w: &Tensor,
70 r: &Tensor,
71 hidden_size: usize,
72 bias: Option<&Tensor>,
73 initial_h: Option<&Tensor>,
74 #[builder(default = 0)] layout: usize,
75 ) -> Result<RnnOutput> {
76 let ndim = self.ndim()?;
77 snafu::ensure!(ndim == 3, NdimExactSnafu { op: "rnn", expected: 3_usize, actual: ndim });
78 snafu::ensure!(
79 hidden_size > 0,
80 ParamRangeSnafu { op: "rnn", param: "hidden_size", value: hidden_size.to_string(), constraint: "> 0" }
81 );
82 let x = if layout != 0 { self.try_permute(&[1, 0, 2])? } else { self.clone() };
83 let x_shape = x.shape()?;
84 let seq_length = x_shape[0].as_const().expect("static seq_length");
85 let batch_size = x_shape[1].as_const().expect("static batch_size");
86 let input_size = x_shape[2].as_const().expect("static input_size");
87 let num_directions = w.shape()?[0].as_const().expect("static num_directions");
88 let dtype = x.uop().dtype();
89
90 snafu::ensure!(
91 num_directions == 1,
92 ParamRangeSnafu {
93 op: "rnn",
94 param: "num_directions",
95 value: num_directions.to_string(),
96 constraint: "== 1"
97 }
98 );
99
100 let w0 = w.try_squeeze(Some(0))?; // [hidden, input]
101 let r0 = r.try_squeeze(Some(0))?; // [hidden, hidden]
102 let wt = w0.try_permute(&[1, 0])?; // [input, hidden]
103 let rt = r0.try_permute(&[1, 0])?; // [hidden, hidden]
104
105 let combined_bias = if let Some(b) = bias {
106 let b0 = b.try_squeeze(Some(0))?; // [2*hidden]
107 let parts = b0.split(&[hidden_size, hidden_size], 0)?;
108 Some(parts[0].try_add(&parts[1])?) // [hidden]
109 } else {
110 None
111 };
112
113 let mut h_t = if let Some(h0) = initial_h {
114 h0.try_squeeze(Some(0))? // [batch, hidden]
115 } else {
116 Tensor::full(&[batch_size, hidden_size], 0.0f32, dtype)?
117 };
118
119 let mut h_list = Vec::with_capacity(seq_length);
120 for t in 0..seq_length {
121 let x_t =
122 x.try_shrink([(t as isize, t as isize + 1), (0, batch_size as isize), (0, input_size as isize)])?;
123 let x_t = x_t.try_squeeze(Some(0))?; // [batch, input]
124
125 let mut gate = x_t.matmul(&wt)?.try_add(&h_t.matmul(&rt)?)?;
126 if let Some(ref b) = combined_bias {
127 gate = gate.try_add(b)?;
128 }
129 h_t = gate.tanh()?;
130 h_list.push(h_t.clone());
131 }
132
133 let h_refs: Vec<&Tensor> = h_list.iter().collect();
134 let y_seq = Tensor::stack(&h_refs, 0)?; // [seq, batch, hidden]
135 let y = y_seq.try_unsqueeze(1)?; // [seq, 1, batch, hidden]
136
137 let y = if layout != 0 {
138 y.try_permute(&[2, 0, 1, 3])? // [batch, seq, 1, hidden]
139 } else {
140 y
141 };
142
143 let y_h = if layout != 0 {
144 h_t.try_unsqueeze(1)? // [batch, 1, hidden]
145 } else {
146 h_t.try_unsqueeze(0)? // [1, batch, hidden]
147 };
148
149 Ok(RnnOutput { y, y_h })
150 }
151
152 /// GRU (Gated Recurrent Unit).
153 ///
154 /// Gate order: `[z, r, h]` (update, reset, hidden).
155 ///
156 /// Equations (default, `linear_before_reset=0`):
157 /// - `z = sigmoid(X @ W_z^T + H @ R_z^T + w_bz + r_bz)`
158 /// - `r = sigmoid(X @ W_r^T + H @ R_r^T + w_br + r_br)`
159 /// - `h = tanh(X @ W_h^T + (r * H) @ R_h^T + w_bh + r_bh)`
160 /// - `H_new = (1 - z) * h + z * H_prev`
161 ///
162 /// When `linear_before_reset=1`:
163 /// - `h = tanh(X @ W_h^T + r * (H @ R_h^T + r_bh) + w_bh)`
164 ///
165 /// - `x`: input `[seq_length, batch_size, input_size]` (layout=0) or
166 /// `[batch_size, seq_length, input_size]` (layout=1)
167 /// - `w`: input weights `[num_directions, 3*hidden_size, input_size]`
168 /// - `r_weights`: recurrence weights `[num_directions, 3*hidden_size, hidden_size]`
169 /// - `bias`: optional `[num_directions, 6*hidden_size]` (Wb ++ Rb)
170 /// - `initial_h`: optional `[num_directions, batch_size, hidden_size]`
171 /// - `linear_before_reset`: 0 (default) or 1
172 /// - `layout`: 0 = seq-first (default), 1 = batch-first
173 ///
174 /// # Examples
175 ///
176 /// ```
177 /// # use svod_tensor::Tensor;
178 /// # use ndarray::{array, Array3};
179 /// // seq=2, batch=1, input=3, hidden=4
180 /// let x = Tensor::from_ndarray(&Array3::from_elem((2, 1, 3), 0.1f32));
181 /// // GRU: w is [num_directions, 3*hidden_size, input_size]
182 /// let w = Tensor::from_ndarray(&Array3::from_elem((1, 12, 3), 0.1f32));
183 /// // GRU: r is [num_directions, 3*hidden_size, hidden_size]
184 /// let r = Tensor::from_ndarray(&Array3::from_elem((1, 12, 4), 0.1f32));
185 /// let out = x.gru().w(&w).r_weights(&r).hidden_size(4).call().unwrap();
186 /// let y_shape: Vec<usize> = out.y.shape().unwrap().iter()
187 /// .map(|d| d.as_const().unwrap()).collect();
188 /// assert_eq!(y_shape, vec![2, 1, 1, 4]); // [seq, num_directions, batch, hidden]
189 /// ```
190 #[builder]
191 pub fn gru(
192 &self,
193 w: &Tensor,
194 r_weights: &Tensor,
195 hidden_size: usize,
196 bias: Option<&Tensor>,
197 initial_h: Option<&Tensor>,
198 #[builder(default = 0)] linear_before_reset: usize,
199 #[builder(default = 0)] layout: usize,
200 ) -> Result<GruOutput> {
201 let ndim = self.ndim()?;
202 snafu::ensure!(ndim == 3, NdimExactSnafu { op: "gru", expected: 3_usize, actual: ndim });
203 snafu::ensure!(
204 hidden_size > 0,
205 ParamRangeSnafu { op: "gru", param: "hidden_size", value: hidden_size.to_string(), constraint: "> 0" }
206 );
207 let x = if layout != 0 { self.try_permute(&[1, 0, 2])? } else { self.clone() };
208 let x_shape = x.shape()?;
209 let seq_length = x_shape[0].as_const().expect("static seq_length");
210 let batch_size = x_shape[1].as_const().expect("static batch_size");
211 let input_size = x_shape[2].as_const().expect("static input_size");
212 let num_directions = w.shape()?[0].as_const().expect("static num_directions");
213 let dtype = x.uop().dtype();
214
215 snafu::ensure!(
216 num_directions == 1,
217 ParamRangeSnafu {
218 op: "gru",
219 param: "num_directions",
220 value: num_directions.to_string(),
221 constraint: "== 1"
222 }
223 );
224
225 let w0 = w.try_squeeze(Some(0))?; // [3*hidden, input]
226 let r0 = r_weights.try_squeeze(Some(0))?; // [3*hidden, hidden]
227
228 // Split W into [W_z, W_r, W_h] and R into [R_z, R_r, R_h]
229 let w_parts = w0.split(&[hidden_size; 3], 0)?;
230 let r_parts = r0.split(&[hidden_size; 3], 0)?;
231
232 // Combine z,r weights for joint computation: gates_w = [W_z; W_r]^T
233 let gates_w = Tensor::cat(&[&w_parts[0], &w_parts[1]], 0)?.try_permute(&[1, 0])?;
234 let gates_r = Tensor::cat(&[&r_parts[0], &r_parts[1]], 0)?.try_permute(&[1, 0])?;
235
236 // W_h and R_h kept separate (reset gate interacts differently)
237 let w_h_t = w_parts[2].try_permute(&[1, 0])?; // [input, hidden]
238 let r_h_t = r_parts[2].try_permute(&[1, 0])?; // [hidden, hidden]
239
240 // Bias: [6*hidden] → [w_bz, w_br, w_bh, r_bz, r_br, r_bh]
241 let (gates_b, w_bh, r_bh) = if let Some(b) = bias {
242 let b0 = b.try_squeeze(Some(0))?;
243 let parts = b0.split(&[hidden_size; 6], 0)?;
244 // gates_b = (w_bz + r_bz) ++ (w_br + r_br)
245 let gates_b = Tensor::cat(&[&parts[0].try_add(&parts[3])?, &parts[1].try_add(&parts[4])?], 0)?;
246 (Some(gates_b), Some(parts[2].clone()), Some(parts[5].clone()))
247 } else {
248 (None, None, None)
249 };
250
251 let mut h_t = if let Some(h0) = initial_h {
252 h0.try_squeeze(Some(0))?
253 } else {
254 Tensor::full(&[batch_size, hidden_size], 0.0f32, dtype)?
255 };
256
257 let mut h_list = Vec::with_capacity(seq_length);
258 for t in 0..seq_length {
259 let x_t =
260 x.try_shrink([(t as isize, t as isize + 1), (0, batch_size as isize), (0, input_size as isize)])?;
261 let x_t = x_t.try_squeeze(Some(0))?; // [batch, input]
262
263 // z, r gates: combined matmul
264 let mut gates = x_t.matmul(&gates_w)?.try_add(&h_t.matmul(&gates_r)?)?;
265 if let Some(ref gb) = gates_b {
266 gates = gates.try_add(gb)?;
267 }
268 let zr = gates.split(&[hidden_size; 2], -1)?;
269 let z = zr[0].sigmoid()?;
270 let r = zr[1].sigmoid()?;
271
272 // Hidden candidate
273 let h_candidate = if linear_before_reset != 0 {
274 // h = tanh(x @ W_h^T + r * (H @ R_h^T + r_bh) + w_bh)
275 let mut rh = h_t.matmul(&r_h_t)?;
276 if let Some(ref rb) = r_bh {
277 rh = rh.try_add(rb)?;
278 }
279 let mut h = x_t.matmul(&w_h_t)?.try_add(&r.try_mul(&rh)?)?;
280 if let Some(ref wb) = w_bh {
281 h = h.try_add(wb)?;
282 }
283 h.tanh()?
284 } else {
285 // h = tanh(x @ W_h^T + (r * H) @ R_h^T + w_bh + r_bh)
286 let mut h = x_t.matmul(&w_h_t)?.try_add(&r.try_mul(&h_t)?.matmul(&r_h_t)?)?;
287 if let Some(ref wb) = w_bh {
288 h = h.try_add(wb)?;
289 }
290 if let Some(ref rb) = r_bh {
291 h = h.try_add(rb)?;
292 }
293 h.tanh()?
294 };
295
296 // H = (1 - z) * h_candidate + z * H_prev
297 let one = Tensor::full(&[1], 1.0f32, z.uop().dtype())?;
298 h_t = one.try_sub(&z)?.try_mul(&h_candidate)?.try_add(&z.try_mul(&h_t)?)?;
299 h_list.push(h_t.clone());
300 }
301
302 let h_refs: Vec<&Tensor> = h_list.iter().collect();
303 let y_seq = Tensor::stack(&h_refs, 0)?; // [seq, batch, hidden]
304 let y = y_seq.try_unsqueeze(1)?; // [seq, 1, batch, hidden]
305
306 let y = if layout != 0 {
307 y.try_permute(&[2, 0, 1, 3])? // [batch, seq, 1, hidden]
308 } else {
309 y
310 };
311
312 let y_h = if layout != 0 {
313 h_t.try_unsqueeze(1)? // [batch, 1, hidden]
314 } else {
315 h_t.try_unsqueeze(0)? // [1, batch, hidden]
316 };
317
318 Ok(GruOutput { y, y_h })
319 }
320
321 /// LSTM (Long Short-Term Memory).
322 ///
323 /// Gate order: `[i, o, f, c]` (input, output, forget, cell).
324 ///
325 /// - `x`: input `[seq_length, batch_size, input_size]` (layout=0) or
326 /// `[batch_size, seq_length, input_size]` (layout=1)
327 /// - `w`: input weights `[num_directions, 4*hidden_size, input_size]`
328 /// - `r`: recurrence weights `[num_directions, 4*hidden_size, hidden_size]`
329 /// - `bias`: optional `[num_directions, 8*hidden_size]` (Wb ++ Rb)
330 /// - `initial_h`: optional `[num_directions, batch_size, hidden_size]`
331 /// - `initial_c`: optional `[num_directions, batch_size, hidden_size]`
332 /// - `peepholes`: optional `[num_directions, 3*hidden_size]` (p_i, p_o, p_f)
333 /// - `layout`: 0 = seq-first (default), 1 = batch-first
334 ///
335 /// # Examples
336 ///
337 /// ```
338 /// # use svod_tensor::Tensor;
339 /// # use ndarray::Array3;
340 /// // seq=2, batch=1, input=3, hidden=4
341 /// let x = Tensor::from_ndarray(&Array3::from_elem((2, 1, 3), 0.1f32));
342 /// // LSTM: w is [num_directions, 4*hidden_size, input_size]
343 /// let w = Tensor::from_ndarray(&Array3::from_elem((1, 16, 3), 0.1f32));
344 /// // LSTM: r is [num_directions, 4*hidden_size, hidden_size]
345 /// let r = Tensor::from_ndarray(&Array3::from_elem((1, 16, 4), 0.1f32));
346 /// let out = x.lstm().w(&w).r(&r).hidden_size(4).call().unwrap();
347 /// let y_shape: Vec<usize> = out.y.shape().unwrap().iter()
348 /// .map(|d| d.as_const().unwrap()).collect();
349 /// assert_eq!(y_shape, vec![2, 1, 1, 4]); // [seq, num_directions, batch, hidden]
350 /// let yc_shape: Vec<usize> = out.y_c.shape().unwrap().iter()
351 /// .map(|d| d.as_const().unwrap()).collect();
352 /// assert_eq!(yc_shape, vec![1, 1, 4]); // [num_directions, batch, hidden]
353 /// ```
354 #[builder]
355 pub fn lstm(
356 &self,
357 w: &Tensor,
358 r: &Tensor,
359 hidden_size: usize,
360 bias: Option<&Tensor>,
361 initial_h: Option<&Tensor>,
362 initial_c: Option<&Tensor>,
363 peepholes: Option<&Tensor>,
364 #[builder(default = 0)] layout: usize,
365 ) -> Result<LstmOutput> {
366 let ndim = self.ndim()?;
367 snafu::ensure!(ndim == 3, NdimExactSnafu { op: "lstm", expected: 3_usize, actual: ndim });
368 snafu::ensure!(
369 hidden_size > 0,
370 ParamRangeSnafu { op: "lstm", param: "hidden_size", value: hidden_size.to_string(), constraint: "> 0" }
371 );
372 let x = if layout != 0 {
373 self.try_permute(&[1, 0, 2])? // batch-first → seq-first
374 } else {
375 self.clone()
376 };
377 let x_shape = x.shape()?;
378 let seq_length = x_shape[0].as_const().expect("static seq_length");
379 let batch_size = x_shape[1].as_const().expect("static batch_size");
380 let input_size = x_shape[2].as_const().expect("static input_size");
381 let num_directions = w.shape()?[0].as_const().expect("static num_directions");
382 let dtype = x.uop().dtype();
383
384 snafu::ensure!(
385 num_directions == 1,
386 ParamRangeSnafu {
387 op: "lstm",
388 param: "num_directions",
389 value: num_directions.to_string(),
390 constraint: "== 1"
391 }
392 );
393
394 let w0 = w.try_squeeze(Some(0))?; // [4*hidden, input]
395 let r0 = r.try_squeeze(Some(0))?; // [4*hidden, hidden]
396 let wt = w0.try_permute(&[1, 0])?; // [input, 4*hidden]
397 let rt = r0.try_permute(&[1, 0])?; // [hidden, 4*hidden]
398
399 // Bias: [8*hidden] → split into Wb [4*hidden] and Rb [4*hidden], add together
400 let combined_bias = if let Some(b) = bias {
401 let b0 = b.try_squeeze(Some(0))?;
402 let hs4 = 4 * hidden_size;
403 let parts = b0.split(&[hs4, hs4], 0)?;
404 Some(parts[0].try_add(&parts[1])?)
405 } else {
406 None
407 };
408
409 // Peepholes: [3*hidden] → [p_i, p_o, p_f]
410 let (p_i, p_o, p_f) = if let Some(p) = peepholes {
411 let p0 = p.try_squeeze(Some(0))?;
412 let parts = p0.split(&[hidden_size, hidden_size, hidden_size], 0)?;
413 (Some(parts[0].clone()), Some(parts[1].clone()), Some(parts[2].clone()))
414 } else {
415 (None, None, None)
416 };
417
418 let mut h_t = if let Some(h0) = initial_h {
419 h0.try_squeeze(Some(0))?
420 } else {
421 Tensor::full(&[batch_size, hidden_size], 0.0f32, dtype.clone())?
422 };
423 let mut c_t = if let Some(c0) = initial_c {
424 c0.try_squeeze(Some(0))?
425 } else {
426 Tensor::full(&[batch_size, hidden_size], 0.0f32, dtype)?
427 };
428
429 let mut h_list = Vec::with_capacity(seq_length);
430 for t in 0..seq_length {
431 let x_t =
432 x.try_shrink([(t as isize, t as isize + 1), (0, batch_size as isize), (0, input_size as isize)])?;
433 let x_t = x_t.try_squeeze(Some(0))?; // [batch, input]
434
435 // gates = X_t @ W^T + H_{t-1} @ R^T + bias
436 let mut gates = x_t.matmul(&wt)?.try_add(&h_t.matmul(&rt)?)?;
437 if let Some(ref b) = combined_bias {
438 gates = gates.try_add(b)?;
439 }
440
441 // Split into [i, o, f, c] — each [batch, hidden]
442 let gate_parts = gates.split(&[hidden_size; 4], -1)?;
443 let (mut gi, mut go, mut gf, gc) =
444 (gate_parts[0].clone(), gate_parts[1].clone(), gate_parts[2].clone(), gate_parts[3].clone());
445
446 // Peephole connections: i and f use previous cell state
447 if let Some(ref pi) = p_i {
448 gi = gi.try_add(&c_t.try_mul(pi)?)?;
449 }
450 if let Some(ref pf) = p_f {
451 gf = gf.try_add(&c_t.try_mul(pf)?)?;
452 }
453
454 let i = gi.sigmoid()?;
455 let f = gf.sigmoid()?;
456 let c = gc.tanh()?;
457
458 // C = f * C_prev + i * c
459 c_t = f.try_mul(&c_t)?.try_add(&i.try_mul(&c)?)?;
460
461 // Peephole: o uses NEW cell state
462 if let Some(ref po) = p_o {
463 go = go.try_add(&c_t.try_mul(po)?)?;
464 }
465 let o = go.sigmoid()?;
466
467 // H = o * tanh(C)
468 h_t = o.try_mul(&c_t.tanh()?)?;
469 h_list.push(h_t.clone());
470 }
471
472 let h_refs: Vec<&Tensor> = h_list.iter().collect();
473 let y_seq = Tensor::stack(&h_refs, 0)?; // [seq, batch, hidden]
474 let y = y_seq.try_unsqueeze(1)?; // [seq, 1, batch, hidden]
475
476 // Apply layout transform to output
477 let y = if layout != 0 {
478 y.try_permute(&[2, 0, 1, 3])? // [batch, seq, 1, hidden]
479 } else {
480 y
481 };
482
483 let (y_h, y_c) = if layout != 0 {
484 // layout=1: Y_h/Y_c are [batch, num_directions, hidden]
485 (h_t.try_unsqueeze(1)?, c_t.try_unsqueeze(1)?)
486 } else {
487 // layout=0: Y_h/Y_c are [num_directions, batch, hidden]
488 (h_t.try_unsqueeze(0)?, c_t.try_unsqueeze(0)?)
489 };
490
491 Ok(LstmOutput { y, y_h, y_c })
492 }
493}