1use ndarray::Array2;
2use rand_distr::{Distribution, Normal};
3
4pub struct PeepholeLSTMCell {
6 pub w_xi: Array2<f64>,
8 pub w_hi: Array2<f64>,
9 pub b_i: Array2<f64>,
10 pub w_ci: Array2<f64>,
11
12 pub w_xf: Array2<f64>,
14 pub w_hf: Array2<f64>,
15 pub b_f: Array2<f64>,
16 pub w_cf: Array2<f64>,
17
18 pub w_xc: Array2<f64>,
20 pub w_hc: Array2<f64>,
21 pub b_c: Array2<f64>,
22
23 pub w_xo: Array2<f64>,
25 pub w_ho: Array2<f64>,
26 pub b_o: Array2<f64>,
27 pub w_co: Array2<f64>,
28}
29
30impl PeepholeLSTMCell {
31 pub fn new(input_size: usize, hidden_size: usize) -> Self {
33 let dist = Normal::new(0.0, 0.1).unwrap();
34 let mut rng = rand::thread_rng();
35
36 let w_xi = Self::random_matrix(&dist, &mut rng, hidden_size, input_size);
37 let w_hi = Self::random_matrix(&dist, &mut rng, hidden_size, hidden_size);
38 let b_i = Self::random_vector_2d(&dist, &mut rng, hidden_size);
39 let w_ci = Self::random_vector_2d(&dist, &mut rng, hidden_size);
40
41 let w_xf = Self::random_matrix(&dist, &mut rng, hidden_size, input_size);
42 let w_hf = Self::random_matrix(&dist, &mut rng, hidden_size, hidden_size);
43 let b_f = Self::random_vector_2d(&dist, &mut rng, hidden_size);
44 let w_cf = Self::random_vector_2d(&dist, &mut rng, hidden_size);
45
46 let w_xc = Self::random_matrix(&dist, &mut rng, hidden_size, input_size);
47 let w_hc = Self::random_matrix(&dist, &mut rng, hidden_size, hidden_size);
48 let b_c = Self::random_vector_2d(&dist, &mut rng, hidden_size);
49
50 let w_xo = Self::random_matrix(&dist, &mut rng, hidden_size, input_size);
51 let w_ho = Self::random_matrix(&dist, &mut rng, hidden_size, hidden_size);
52 let b_o = Self::random_vector_2d(&dist, &mut rng, hidden_size);
53 let w_co = Self::random_vector_2d(&dist, &mut rng, hidden_size);
54
55 Self {
56 w_xi, w_hi, b_i, w_ci,
57 w_xf, w_hf, b_f, w_cf,
58 w_xc, w_hc, b_c,
59 w_xo, w_ho, b_o, w_co,
60 }
61 }
62
63 fn random_matrix(dist: &Normal<f64>, rng: &mut impl rand::Rng, rows: usize, cols: usize) -> Array2<f64> {
64 let mut arr = Array2::<f64>::zeros((rows, cols));
65 for val in arr.iter_mut() {
66 *val = dist.sample(rng);
67 }
68 arr
69 }
70
71 fn random_vector_2d(dist: &Normal<f64>, rng: &mut impl rand::Rng, len: usize) -> Array2<f64> {
72 let mut arr = Array2::<f64>::zeros((len, 1));
73 for val in arr.iter_mut() {
74 *val = dist.sample(rng);
75 }
76 arr
77 }
78
79 pub fn forward(
81 &self,
82 input: &Array2<f64>,
83 h_prev: &Array2<f64>,
84 c_prev: &Array2<f64>,
85 ) -> (Array2<f64>, Array2<f64>) {
86 let i_t = &self.w_xi.dot(input)
87 + &self.w_hi.dot(h_prev)
88 + &self.b_i
89 + &(&self.w_ci * c_prev);
90 let i_t = i_t.map(|&x| sigmoid(x));
91
92 let f_t = &self.w_xf.dot(input)
93 + &self.w_hf.dot(h_prev)
94 + &self.b_f
95 + &(&self.w_cf * c_prev);
96 let f_t = f_t.map(|&x| sigmoid(x));
97
98 let g_t = (&self.w_xc.dot(input) + &self.w_hc.dot(h_prev) + &self.b_c)
99 .map(|&x| x.tanh());
100
101 let c_t = f_t * c_prev + i_t * g_t;
102
103 let o_t = &self.w_xo.dot(input)
104 + &self.w_ho.dot(h_prev)
105 + &self.b_o
106 + &(&self.w_co * &c_t);
107 let o_t = o_t.map(|&x| sigmoid(x));
108
109 let h_t = o_t * c_t.map(|&x| x.tanh());
110
111 (h_t, c_t)
112 }
113}
114
115fn sigmoid(x: f64) -> f64 {
116 1.0 / (1.0 + (-x).exp())
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use ndarray::{arr2, Array2};
123
124 #[test]
125 fn test_forward_shape() {
126 let input_size = 3;
127 let hidden_size = 2;
128 let cell = PeepholeLSTMCell::new(input_size, hidden_size);
129
130 let input = arr2(&[[0.5], [0.1], [-0.3]]);
131 let h_prev = Array2::zeros((hidden_size, 1));
132 let c_prev = Array2::zeros((hidden_size, 1));
133
134 let (h_t, c_t) = cell.forward(&input, &h_prev, &c_prev);
135 assert_eq!(h_t.shape(), &[hidden_size, 1]);
136 assert_eq!(c_t.shape(), &[hidden_size, 1]);
137 }
138
139 #[test]
140 fn test_multiple_timesteps() {
141 let input_size = 3;
142 let hidden_size = 2;
143 let cell = PeepholeLSTMCell::new(input_size, hidden_size);
144
145 let sequence = vec![
146 arr2(&[[0.5], [0.1], [-0.3]]),
147 arr2(&[[0.2], [0.8], [0.05]]),
148 arr2(&[[0.0], [-0.1], [0.3]]),
149 ];
150
151 let mut h_prev = Array2::zeros((hidden_size, 1));
152 let mut c_prev = Array2::zeros((hidden_size, 1));
153
154 for (t, x_t) in sequence.iter().enumerate() {
155 let (h_t, c_t) = cell.forward(x_t, &h_prev, &c_prev);
156
157 assert_eq!(h_t.shape(), &[hidden_size, 1], "h_t shape mismatch at timestep {}", t);
158 assert_eq!(c_t.shape(), &[hidden_size, 1], "c_t shape mismatch at timestep {}", t);
159
160 h_prev = h_t;
161 c_prev = c_t;
162 }
163 }
164}