rust_lstm/layers/
peephole_lstm_cell.rs

1use ndarray::Array2;
2use rand_distr::{Distribution, Normal};
3
4/// Peephole LSTM cell with direct connections from cell state to gates
5pub struct PeepholeLSTMCell {
6    // Input gate
7    pub w_xi: Array2<f64>,
8    pub w_hi: Array2<f64>,
9    pub b_i:  Array2<f64>,
10    pub w_ci: Array2<f64>,
11
12    // Forget gate
13    pub w_xf: Array2<f64>,
14    pub w_hf: Array2<f64>,
15    pub b_f:  Array2<f64>,
16    pub w_cf: Array2<f64>,
17
18    // Cell update
19    pub w_xc: Array2<f64>,
20    pub w_hc: Array2<f64>,
21    pub b_c:  Array2<f64>,
22
23    // Output gate
24    pub w_xo: Array2<f64>,
25    pub w_ho: Array2<f64>,
26    pub b_o:  Array2<f64>,
27    pub w_co: Array2<f64>,
28}
29
30impl PeepholeLSTMCell {
31    /// Create new peephole LSTM cell with Gaussian weight initialization
32    pub fn new(input_size: usize, hidden_size: usize) -> Self {
33        let dist = Normal::new(0.0, 0.1).unwrap();
34        let mut rng = rand::thread_rng();
35
36        let w_xi = Self::random_matrix(&dist, &mut rng, hidden_size, input_size);
37        let w_hi = Self::random_matrix(&dist, &mut rng, hidden_size, hidden_size);
38        let b_i = Self::random_vector_2d(&dist, &mut rng, hidden_size);
39        let w_ci = Self::random_vector_2d(&dist, &mut rng, hidden_size);
40
41        let w_xf = Self::random_matrix(&dist, &mut rng, hidden_size, input_size);
42        let w_hf = Self::random_matrix(&dist, &mut rng, hidden_size, hidden_size);
43        let b_f = Self::random_vector_2d(&dist, &mut rng, hidden_size);
44        let w_cf = Self::random_vector_2d(&dist, &mut rng, hidden_size);
45
46        let w_xc = Self::random_matrix(&dist, &mut rng, hidden_size, input_size);
47        let w_hc = Self::random_matrix(&dist, &mut rng, hidden_size, hidden_size);
48        let b_c = Self::random_vector_2d(&dist, &mut rng, hidden_size);
49
50        let w_xo = Self::random_matrix(&dist, &mut rng, hidden_size, input_size);
51        let w_ho = Self::random_matrix(&dist, &mut rng, hidden_size, hidden_size);
52        let b_o = Self::random_vector_2d(&dist, &mut rng, hidden_size);
53        let w_co = Self::random_vector_2d(&dist, &mut rng, hidden_size);
54
55        Self {
56            w_xi, w_hi, b_i, w_ci,
57            w_xf, w_hf, b_f, w_cf,
58            w_xc, w_hc, b_c,
59            w_xo, w_ho, b_o, w_co,
60        }
61    }
62
63    fn random_matrix(dist: &Normal<f64>, rng: &mut impl rand::Rng, rows: usize, cols: usize) -> Array2<f64> {
64        let mut arr = Array2::<f64>::zeros((rows, cols));
65        for val in arr.iter_mut() {
66            *val = dist.sample(rng);
67        }
68        arr
69    }
70
71    fn random_vector_2d(dist: &Normal<f64>, rng: &mut impl rand::Rng, len: usize) -> Array2<f64> {
72        let mut arr = Array2::<f64>::zeros((len, 1));
73        for val in arr.iter_mut() {
74            *val = dist.sample(rng);
75        }
76        arr
77    }
78
79    /// Forward pass implementing peephole LSTM equations
80    pub fn forward(
81        &self,
82        input: &Array2<f64>,
83        h_prev: &Array2<f64>,
84        c_prev: &Array2<f64>,
85    ) -> (Array2<f64>, Array2<f64>) {
86        let i_t = &self.w_xi.dot(input)
87            + &self.w_hi.dot(h_prev)
88            + &self.b_i
89            + &(&self.w_ci * c_prev);
90        let i_t = i_t.map(|&x| sigmoid(x));
91
92        let f_t = &self.w_xf.dot(input)
93            + &self.w_hf.dot(h_prev)
94            + &self.b_f
95            + &(&self.w_cf * c_prev);
96        let f_t = f_t.map(|&x| sigmoid(x));
97
98        let g_t = (&self.w_xc.dot(input) + &self.w_hc.dot(h_prev) + &self.b_c)
99            .map(|&x| x.tanh());
100
101        let c_t = f_t * c_prev + i_t * g_t;
102
103        let o_t = &self.w_xo.dot(input)
104            + &self.w_ho.dot(h_prev)
105            + &self.b_o
106            + &(&self.w_co * &c_t);
107        let o_t = o_t.map(|&x| sigmoid(x));
108
109        let h_t = o_t * c_t.map(|&x| x.tanh());
110
111        (h_t, c_t)
112    }
113}
114
115fn sigmoid(x: f64) -> f64 {
116    1.0 / (1.0 + (-x).exp())
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use ndarray::{arr2, Array2};
123
124    #[test]
125    fn test_forward_shape() {
126        let input_size = 3;
127        let hidden_size = 2;
128        let cell = PeepholeLSTMCell::new(input_size, hidden_size);
129
130        let input = arr2(&[[0.5], [0.1], [-0.3]]);
131        let h_prev = Array2::zeros((hidden_size, 1));
132        let c_prev = Array2::zeros((hidden_size, 1));
133
134        let (h_t, c_t) = cell.forward(&input, &h_prev, &c_prev);
135        assert_eq!(h_t.shape(), &[hidden_size, 1]);
136        assert_eq!(c_t.shape(), &[hidden_size, 1]);
137    }
138
139    #[test]
140    fn test_multiple_timesteps() {
141        let input_size = 3;
142        let hidden_size = 2;
143        let cell = PeepholeLSTMCell::new(input_size, hidden_size);
144
145        let sequence = vec![
146            arr2(&[[0.5], [0.1], [-0.3]]),
147            arr2(&[[0.2], [0.8], [0.05]]),
148            arr2(&[[0.0], [-0.1], [0.3]]),
149        ];
150
151        let mut h_prev = Array2::zeros((hidden_size, 1));
152        let mut c_prev = Array2::zeros((hidden_size, 1));
153
154        for (t, x_t) in sequence.iter().enumerate() {
155            let (h_t, c_t) = cell.forward(x_t, &h_prev, &c_prev);
156
157            assert_eq!(h_t.shape(), &[hidden_size, 1], "h_t shape mismatch at timestep {}", t);
158            assert_eq!(c_t.shape(), &[hidden_size, 1], "c_t shape mismatch at timestep {}", t);
159
160            h_prev = h_t;
161            c_prev = c_t;
162        }
163    }
164}