svod_tensor/nn/
lstm_cell.rs1use svod_dtype::DType;
2
3use crate::Tensor;
4
5type Result<T> = crate::Result<T>;
6
7#[derive(Clone)]
15pub struct LSTMCell {
16 pub weight_ih: Tensor,
17 pub weight_hh: Tensor,
18 pub bias_ih: Tensor,
19 pub bias_hh: Tensor,
20 hidden_size: usize,
21}
22
23impl LSTMCell {
24 pub fn new(weight_ih: Tensor, weight_hh: Tensor, bias_ih: Tensor, bias_hh: Tensor) -> Self {
27 let shape = weight_ih.shape().expect("lstm_cell: weight_ih shape");
28 let four_hidden = shape[0].as_const().expect("lstm_cell: 4*hidden must be concrete");
29 Self { weight_ih, weight_hh, bias_ih, bias_hh, hidden_size: four_hidden / 4 }
30 }
31
32 pub fn with_dims(input_size: usize, hidden_size: usize, dtype: DType) -> Self {
34 let four_hidden = 4 * hidden_size;
35 let w_ih_data: Vec<f32> = (0..four_hidden * input_size).map(|i| ((i as f32) * 0.1).sin() * 0.1).collect();
36 let weight_ih = Tensor::from_slice(&w_ih_data)
37 .try_reshape([four_hidden as isize, input_size as isize])
38 .expect("lstm_cell weight_ih reshape failed");
39 let w_hh_data: Vec<f32> = (0..four_hidden * hidden_size).map(|i| ((i as f32) * 0.1).sin() * 0.1).collect();
40 let weight_hh = Tensor::from_slice(&w_hh_data)
41 .try_reshape([four_hidden as isize, hidden_size as isize])
42 .expect("lstm_cell weight_hh reshape failed");
43 let bias_ih = Tensor::full(&[four_hidden], 0.0, dtype.clone()).expect("lstm_cell bias_ih creation");
44 let bias_hh = Tensor::full(&[four_hidden], 0.0, dtype).expect("lstm_cell bias_hh creation");
45 Self { weight_ih, weight_hh, bias_ih, bias_hh, hidden_size }
46 }
47
48 pub fn hidden_size(&self) -> usize {
49 self.hidden_size
50 }
51
52 pub fn step(&self, x: &Tensor, h: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
56 let gates_x = x.linear().weight(&self.weight_ih).bias(&self.bias_ih).call()?;
57 let gates_h = h.linear().weight(&self.weight_hh).bias(&self.bias_hh).call()?;
58 let gates = gates_x.try_add(&gates_h)?;
59
60 let h_sz = self.hidden_size;
61 let parts = gates.split(&[h_sz, h_sz, h_sz, h_sz], 1)?;
62 let i = parts[0].sigmoid()?;
63 let f = parts[1].sigmoid()?;
64 let g = parts[2].tanh()?;
65 let o = parts[3].sigmoid()?;
66
67 let new_c = f.try_mul(c)?.try_add(&i.try_mul(&g)?)?;
68 let new_h = o.try_mul(&new_c.tanh()?)?;
69 Ok((new_h, new_c))
70 }
71}