scirs2_series/neural_forecasting/
lstm.rs1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use super::config::ActivationFunction;
11use crate::error::{Result, TimeSeriesError};
12
13#[derive(Debug, Clone)]
15pub struct LSTMState<F: Float> {
16 pub hidden: Array1<F>,
18 pub cell: Array1<F>,
20}
21
22#[derive(Debug)]
24pub struct LSTMCell<F: Float + Debug> {
25 #[allow(dead_code)]
27 input_size: usize,
28 #[allow(dead_code)]
30 hidden_size: usize,
31 #[allow(dead_code)]
33 w_forget: Array2<F>,
34 #[allow(dead_code)]
36 w_input: Array2<F>,
37 #[allow(dead_code)]
39 w_candidate: Array2<F>,
40 #[allow(dead_code)]
42 w_output: Array2<F>,
43 #[allow(dead_code)]
45 bias: Array1<F>,
46}
47
48impl<F: Float + Debug + Clone + FromPrimitive> LSTMCell<F> {
49 pub fn new(_input_size: usize, hiddensize: usize) -> Self {
51 let total_input_size = _input_size + hiddensize;
52
53 let scale = F::from(2.0).unwrap() / F::from(total_input_size).unwrap();
55 let std_dev = scale.sqrt();
56
57 Self {
58 input_size: _input_size,
59 hidden_size: hiddensize,
60 w_forget: Self::random_matrix(hiddensize, total_input_size, std_dev),
61 w_input: Self::random_matrix(hiddensize, total_input_size, std_dev),
62 w_candidate: Self::random_matrix(hiddensize, total_input_size, std_dev),
63 w_output: Self::random_matrix(hiddensize, total_input_size, std_dev),
64 bias: Array1::zeros(4 * hiddensize), }
66 }
67
68 pub fn random_matrix(_rows: usize, cols: usize, stddev: F) -> Array2<F> {
70 let mut matrix = Array2::zeros((_rows, cols));
71
72 let mut seed: u32 = 12345;
74 for i in 0.._rows {
75 for j in 0..cols {
76 seed = (seed.wrapping_mul(1103515245).wrapping_add(12345)) & 0x7fffffff;
78 let rand_val = F::from(seed as f64 / 2147483647.0).unwrap();
79 let normalized = (rand_val - F::from(0.5).unwrap()) * F::from(2.0).unwrap();
80 matrix[[i, j]] = normalized * stddev;
81 }
82 }
83
84 matrix
85 }
86
87 pub fn forward(&self, input: &Array1<F>, prevstate: &LSTMState<F>) -> Result<LSTMState<F>> {
89 if input.len() != self.input_size {
90 return Err(TimeSeriesError::DimensionMismatch {
91 expected: self.input_size,
92 actual: input.len(),
93 });
94 }
95
96 if prevstate.hidden.len() != self.hidden_size || prevstate.cell.len() != self.hidden_size {
97 return Err(TimeSeriesError::DimensionMismatch {
98 expected: self.hidden_size,
99 actual: prevstate.hidden.len(),
100 });
101 }
102
103 let mut combined_input = Array1::zeros(self.input_size + self.hidden_size);
105 for (i, &val) in input.iter().enumerate() {
106 combined_input[i] = val;
107 }
108 for (i, &val) in prevstate.hidden.iter().enumerate() {
109 combined_input[self.input_size + i] = val;
110 }
111
112 let forget_gate = self.compute_gate(&self.w_forget, &combined_input, 0);
114 let input_gate = self.compute_gate(&self.w_input, &combined_input, self.hidden_size);
115 let candidate_gate =
116 self.compute_gate(&self.w_candidate, &combined_input, 2 * self.hidden_size);
117 let output_gate = self.compute_gate(&self.w_output, &combined_input, 3 * self.hidden_size);
118
119 let forget_activated = forget_gate.mapv(|x| ActivationFunction::Sigmoid.apply(x));
121 let input_activated = input_gate.mapv(|x| ActivationFunction::Sigmoid.apply(x));
122 let candidate_activated = candidate_gate.mapv(|x| ActivationFunction::Tanh.apply(x));
123 let output_activated = output_gate.mapv(|x| ActivationFunction::Sigmoid.apply(x));
124
125 let mut new_cell = Array1::zeros(self.hidden_size);
127 for i in 0..self.hidden_size {
128 new_cell[i] = forget_activated[i] * prevstate.cell[i]
129 + input_activated[i] * candidate_activated[i];
130 }
131
132 let cell_tanh = new_cell.mapv(|x| x.tanh());
134 let mut new_hidden = Array1::zeros(self.hidden_size);
135 for i in 0..self.hidden_size {
136 new_hidden[i] = output_activated[i] * cell_tanh[i];
137 }
138
139 Ok(LSTMState {
140 hidden: new_hidden,
141 cell: new_cell,
142 })
143 }
144
145 fn compute_gate(
147 &self,
148 weights: &Array2<F>,
149 input: &Array1<F>,
150 bias_offset: usize,
151 ) -> Array1<F> {
152 let mut output = Array1::zeros(self.hidden_size);
153
154 for i in 0..self.hidden_size {
155 let mut sum = self.bias[bias_offset + i];
156 for j in 0..input.len() {
157 sum = sum + weights[[i, j]] * input[j];
158 }
159 output[i] = sum;
160 }
161
162 output
163 }
164
165 pub fn init_state(&self) -> LSTMState<F> {
167 LSTMState {
168 hidden: Array1::zeros(self.hidden_size),
169 cell: Array1::zeros(self.hidden_size),
170 }
171 }
172}
173
174#[derive(Debug)]
176pub struct LSTMNetwork<F: Float + Debug> {
177 #[allow(dead_code)]
179 layers: Vec<LSTMCell<F>>,
180 #[allow(dead_code)]
182 output_layer: Array2<F>,
183 #[allow(dead_code)]
185 output_bias: Array1<F>,
186 #[allow(dead_code)]
188 dropout_prob: F,
189}
190
191impl<F: Float + Debug + Clone + FromPrimitive> LSTMNetwork<F> {
192 pub fn new(
194 input_size: usize,
195 hidden_sizes: Vec<usize>,
196 output_size: usize,
197 dropout_prob: F,
198 ) -> Self {
199 let mut layers = Vec::new();
200
201 if !hidden_sizes.is_empty() {
203 layers.push(LSTMCell::new(input_size, hidden_sizes[0]));
204
205 for i in 1..hidden_sizes.len() {
207 layers.push(LSTMCell::new(hidden_sizes[i - 1], hidden_sizes[i]));
208 }
209 }
210
211 let final_hidden_size = hidden_sizes.last().copied().unwrap_or(input_size);
212
213 let output_scale = F::from(2.0).unwrap() / F::from(final_hidden_size).unwrap();
215 let output_std = output_scale.sqrt();
216 let output_layer = LSTMCell::random_matrix(output_size, final_hidden_size, output_std);
217
218 Self {
219 layers,
220 output_layer,
221 output_bias: Array1::zeros(output_size),
222 dropout_prob,
223 }
224 }
225
226 pub fn forward(&self, inputsequence: &Array2<F>) -> Result<Array2<F>> {
228 let (seqlen, _input_size) = inputsequence.dim();
229
230 if self.layers.is_empty() {
231 return Err(TimeSeriesError::InvalidModel(
232 "No LSTM layers defined".to_string(),
233 ));
234 }
235
236 let output_size = self.output_layer.nrows();
237 let mut outputs = Array2::zeros((seqlen, output_size));
238
239 let mut states: Vec<LSTMState<F>> =
241 self.layers.iter().map(|layer| layer.init_state()).collect();
242
243 for t in 0..seqlen {
245 let mut layer_input = inputsequence.row(t).to_owned();
246
247 for (i, layer) in self.layers.iter().enumerate() {
249 let new_state = layer.forward(&layer_input, &states[i])?;
250 layer_input = new_state.hidden.clone();
251 states[i] = new_state;
252 }
253
254 if self.dropout_prob > F::zero() {
256 let keep_prob = F::one() - self.dropout_prob;
257 layer_input = layer_input.mapv(|x| x * keep_prob);
258 }
259
260 let output = self.compute_output(&layer_input);
262 for (j, &val) in output.iter().enumerate() {
263 outputs[[t, j]] = val;
264 }
265 }
266
267 Ok(outputs)
268 }
269
270 fn compute_output(&self, hidden: &Array1<F>) -> Array1<F> {
272 let mut output = self.output_bias.clone();
273
274 for i in 0..self.output_layer.nrows() {
275 for j in 0..self.output_layer.ncols() {
276 output[i] = output[i] + self.output_layer[[i, j]] * hidden[j];
277 }
278 }
279
280 output
281 }
282
283 pub fn forecast(&self, input_sequence: &Array2<F>, forecaststeps: usize) -> Result<Array1<F>> {
285 let (seqlen, _) = input_sequence.dim();
286
287 let _ = self.forward(input_sequence)?;
289
290 let mut states: Vec<LSTMState<F>> =
292 self.layers.iter().map(|layer| layer.init_state()).collect();
293
294 for t in 0..seqlen {
296 let mut layer_input = input_sequence.row(t).to_owned();
297 for (i, layer) in self.layers.iter().enumerate() {
298 let new_state = layer.forward(&layer_input, &states[i])?;
299 layer_input = new_state.hidden.clone();
300 states[i] = new_state;
301 }
302 }
303
304 let mut forecasts = Array1::zeros(forecaststeps);
305 let mut last_output = input_sequence.row(seqlen - 1).to_owned();
306
307 for step in 0..forecaststeps {
309 let mut layer_input = last_output.clone();
310
311 for (i, layer) in self.layers.iter().enumerate() {
313 let new_state = layer.forward(&layer_input, &states[i])?;
314 layer_input = new_state.hidden.clone();
315 states[i] = new_state;
316 }
317
318 let output = self.compute_output(&layer_input);
320 forecasts[step] = output[0]; if last_output.len() == 1 {
324 last_output[0] = output[0];
325 } else {
326 last_output[0] = output[0];
328 }
329 }
330
331 Ok(forecasts)
332 }
333}