Skip to main content

svod_tensor/nn/
lstm_cell.rs

1use svod_dtype::DType;
2
3use crate::Tensor;
4
5type Result<T> = crate::Result<T>;
6
7/// LSTM cell with PyTorch's `[i, f, g, o]` gate order.
8///
9/// `weight_ih` shape: `[4*hidden, input]`; `weight_hh` shape: `[4*hidden, hidden]`.
10/// `bias_ih` and `bias_hh` both `[4*hidden]` — summed in [`Self::step`] to match
11/// `nn.LSTM`'s packing, so PyTorch checkpoints load without remapping.
12///
13/// Not a [`Layer`](crate::nn::Layer) — cells take `(x, h, c)`, not a single tensor.
14#[derive(Clone)]
15pub struct LSTMCell {
16    pub weight_ih: Tensor,
17    pub weight_hh: Tensor,
18    pub bias_ih: Tensor,
19    pub bias_hh: Tensor,
20    hidden_size: usize,
21}
22
23impl LSTMCell {
24    /// Create an LSTM cell from existing weight/bias tensors. `hidden_size` is
25    /// derived from `weight_ih.shape()[0] / 4`.
26    pub fn new(weight_ih: Tensor, weight_hh: Tensor, bias_ih: Tensor, bias_hh: Tensor) -> Self {
27        let shape = weight_ih.shape().expect("lstm_cell: weight_ih shape");
28        let four_hidden = shape[0].as_const().expect("lstm_cell: 4*hidden must be concrete");
29        Self { weight_ih, weight_hh, bias_ih, bias_hh, hidden_size: four_hidden / 4 }
30    }
31
32    /// Create an LSTM cell with deterministic `sin()` initialization, zero biases.
33    pub fn with_dims(input_size: usize, hidden_size: usize, dtype: DType) -> Self {
34        let four_hidden = 4 * hidden_size;
35        let w_ih_data: Vec<f32> = (0..four_hidden * input_size).map(|i| ((i as f32) * 0.1).sin() * 0.1).collect();
36        let weight_ih = Tensor::from_slice(&w_ih_data)
37            .try_reshape([four_hidden as isize, input_size as isize])
38            .expect("lstm_cell weight_ih reshape failed");
39        let w_hh_data: Vec<f32> = (0..four_hidden * hidden_size).map(|i| ((i as f32) * 0.1).sin() * 0.1).collect();
40        let weight_hh = Tensor::from_slice(&w_hh_data)
41            .try_reshape([four_hidden as isize, hidden_size as isize])
42            .expect("lstm_cell weight_hh reshape failed");
43        let bias_ih = Tensor::full(&[four_hidden], 0.0, dtype.clone()).expect("lstm_cell bias_ih creation");
44        let bias_hh = Tensor::full(&[four_hidden], 0.0, dtype).expect("lstm_cell bias_hh creation");
45        Self { weight_ih, weight_hh, bias_ih, bias_hh, hidden_size }
46    }
47
48    pub fn hidden_size(&self) -> usize {
49        self.hidden_size
50    }
51
52    /// One LSTM step. Returns `(h_next, c_next)`.
53    ///
54    /// Shapes: `x: [B, input]`, `h, c: [B, hidden]`.
55    pub fn step(&self, x: &Tensor, h: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
56        let gates_x = x.linear().weight(&self.weight_ih).bias(&self.bias_ih).call()?;
57        let gates_h = h.linear().weight(&self.weight_hh).bias(&self.bias_hh).call()?;
58        let gates = gates_x.try_add(&gates_h)?;
59
60        let h_sz = self.hidden_size;
61        let parts = gates.split(&[h_sz, h_sz, h_sz, h_sz], 1)?;
62        let i = parts[0].sigmoid()?;
63        let f = parts[1].sigmoid()?;
64        let g = parts[2].tanh()?;
65        let o = parts[3].sigmoid()?;
66
67        let new_c = f.try_mul(c)?.try_add(&i.try_mul(&g)?)?;
68        let new_h = o.try_mul(&new_c.tanh()?)?;
69        Ok((new_h, new_c))
70    }
71}