scirs2_series/neural_forecasting/
lstm.rs

1//! LSTM Network Components for Time Series Forecasting
2//!
3//! This module provides Long Short-Term Memory (LSTM) network implementations
4//! for time series forecasting, including LSTM cells, states, and multi-layer networks.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use super::config::ActivationFunction;
11use crate::error::{Result, TimeSeriesError};
12
13/// LSTM cell state and hidden state
14#[derive(Debug, Clone)]
15pub struct LSTMState<F: Float> {
16    /// Hidden state
17    pub hidden: Array1<F>,
18    /// Cell state
19    pub cell: Array1<F>,
20}
21
22/// LSTM cell implementation
23#[derive(Debug)]
24pub struct LSTMCell<F: Float + Debug> {
25    /// Input size
26    #[allow(dead_code)]
27    input_size: usize,
28    /// Hidden size
29    #[allow(dead_code)]
30    hidden_size: usize,
31    /// Forget gate weights
32    #[allow(dead_code)]
33    w_forget: Array2<F>,
34    /// Input gate weights
35    #[allow(dead_code)]
36    w_input: Array2<F>,
37    /// Candidate gate weights
38    #[allow(dead_code)]
39    w_candidate: Array2<F>,
40    /// Output gate weights
41    #[allow(dead_code)]
42    w_output: Array2<F>,
43    /// Bias terms
44    #[allow(dead_code)]
45    bias: Array1<F>,
46}
47
48impl<F: Float + Debug + Clone + FromPrimitive> LSTMCell<F> {
49    /// Create new LSTM cell with random initialization
50    pub fn new(_input_size: usize, hiddensize: usize) -> Self {
51        let total_input_size = _input_size + hiddensize;
52
53        // Initialize weights with Xavier/Glorot initialization
54        let scale = F::from(2.0).unwrap() / F::from(total_input_size).unwrap();
55        let std_dev = scale.sqrt();
56
57        Self {
58            input_size: _input_size,
59            hidden_size: hiddensize,
60            w_forget: Self::random_matrix(hiddensize, total_input_size, std_dev),
61            w_input: Self::random_matrix(hiddensize, total_input_size, std_dev),
62            w_candidate: Self::random_matrix(hiddensize, total_input_size, std_dev),
63            w_output: Self::random_matrix(hiddensize, total_input_size, std_dev),
64            bias: Array1::zeros(4 * hiddensize), // Bias for all gates
65        }
66    }
67
68    /// Initialize random matrix with given standard deviation
69    pub fn random_matrix(_rows: usize, cols: usize, stddev: F) -> Array2<F> {
70        let mut matrix = Array2::zeros((_rows, cols));
71
72        // Simple pseudo-random initialization (for production, use proper RNG)
73        let mut seed: u32 = 12345;
74        for i in 0.._rows {
75            for j in 0..cols {
76                // Linear congruential generator
77                seed = (seed.wrapping_mul(1103515245).wrapping_add(12345)) & 0x7fffffff;
78                let rand_val = F::from(seed as f64 / 2147483647.0).unwrap();
79                let normalized = (rand_val - F::from(0.5).unwrap()) * F::from(2.0).unwrap();
80                matrix[[i, j]] = normalized * stddev;
81            }
82        }
83
84        matrix
85    }
86
87    /// Forward pass through LSTM cell
88    pub fn forward(&self, input: &Array1<F>, prevstate: &LSTMState<F>) -> Result<LSTMState<F>> {
89        if input.len() != self.input_size {
90            return Err(TimeSeriesError::DimensionMismatch {
91                expected: self.input_size,
92                actual: input.len(),
93            });
94        }
95
96        if prevstate.hidden.len() != self.hidden_size || prevstate.cell.len() != self.hidden_size {
97            return Err(TimeSeriesError::DimensionMismatch {
98                expected: self.hidden_size,
99                actual: prevstate.hidden.len(),
100            });
101        }
102
103        // Concatenate input and previous hidden _state
104        let mut combined_input = Array1::zeros(self.input_size + self.hidden_size);
105        for (i, &val) in input.iter().enumerate() {
106            combined_input[i] = val;
107        }
108        for (i, &val) in prevstate.hidden.iter().enumerate() {
109            combined_input[self.input_size + i] = val;
110        }
111
112        // Compute gate values
113        let forget_gate = self.compute_gate(&self.w_forget, &combined_input, 0);
114        let input_gate = self.compute_gate(&self.w_input, &combined_input, self.hidden_size);
115        let candidate_gate =
116            self.compute_gate(&self.w_candidate, &combined_input, 2 * self.hidden_size);
117        let output_gate = self.compute_gate(&self.w_output, &combined_input, 3 * self.hidden_size);
118
119        // Apply activations
120        let forget_activated = forget_gate.mapv(|x| ActivationFunction::Sigmoid.apply(x));
121        let input_activated = input_gate.mapv(|x| ActivationFunction::Sigmoid.apply(x));
122        let candidate_activated = candidate_gate.mapv(|x| ActivationFunction::Tanh.apply(x));
123        let output_activated = output_gate.mapv(|x| ActivationFunction::Sigmoid.apply(x));
124
125        // Update cell _state
126        let mut new_cell = Array1::zeros(self.hidden_size);
127        for i in 0..self.hidden_size {
128            new_cell[i] = forget_activated[i] * prevstate.cell[i]
129                + input_activated[i] * candidate_activated[i];
130        }
131
132        // Update hidden _state
133        let cell_tanh = new_cell.mapv(|x| x.tanh());
134        let mut new_hidden = Array1::zeros(self.hidden_size);
135        for i in 0..self.hidden_size {
136            new_hidden[i] = output_activated[i] * cell_tanh[i];
137        }
138
139        Ok(LSTMState {
140            hidden: new_hidden,
141            cell: new_cell,
142        })
143    }
144
145    /// Compute gate output (linear transformation)
146    fn compute_gate(
147        &self,
148        weights: &Array2<F>,
149        input: &Array1<F>,
150        bias_offset: usize,
151    ) -> Array1<F> {
152        let mut output = Array1::zeros(self.hidden_size);
153
154        for i in 0..self.hidden_size {
155            let mut sum = self.bias[bias_offset + i];
156            for j in 0..input.len() {
157                sum = sum + weights[[i, j]] * input[j];
158            }
159            output[i] = sum;
160        }
161
162        output
163    }
164
165    /// Initialize zero state
166    pub fn init_state(&self) -> LSTMState<F> {
167        LSTMState {
168            hidden: Array1::zeros(self.hidden_size),
169            cell: Array1::zeros(self.hidden_size),
170        }
171    }
172}
173
174/// Multi-layer LSTM network
175#[derive(Debug)]
176pub struct LSTMNetwork<F: Float + Debug> {
177    /// LSTM layers
178    #[allow(dead_code)]
179    layers: Vec<LSTMCell<F>>,
180    /// Output projection layer
181    #[allow(dead_code)]
182    output_layer: Array2<F>,
183    /// Output bias
184    #[allow(dead_code)]
185    output_bias: Array1<F>,
186    /// Dropout probability
187    #[allow(dead_code)]
188    dropout_prob: F,
189}
190
191impl<F: Float + Debug + Clone + FromPrimitive> LSTMNetwork<F> {
192    /// Create new multi-layer LSTM network
193    pub fn new(
194        input_size: usize,
195        hidden_sizes: Vec<usize>,
196        output_size: usize,
197        dropout_prob: F,
198    ) -> Self {
199        let mut layers = Vec::new();
200
201        // First layer
202        if !hidden_sizes.is_empty() {
203            layers.push(LSTMCell::new(input_size, hidden_sizes[0]));
204
205            // Additional layers
206            for i in 1..hidden_sizes.len() {
207                layers.push(LSTMCell::new(hidden_sizes[i - 1], hidden_sizes[i]));
208            }
209        }
210
211        let final_hidden_size = hidden_sizes.last().copied().unwrap_or(input_size);
212
213        // Output layer initialization
214        let output_scale = F::from(2.0).unwrap() / F::from(final_hidden_size).unwrap();
215        let output_std = output_scale.sqrt();
216        let output_layer = LSTMCell::random_matrix(output_size, final_hidden_size, output_std);
217
218        Self {
219            layers,
220            output_layer,
221            output_bias: Array1::zeros(output_size),
222            dropout_prob,
223        }
224    }
225
226    /// Forward pass through the network
227    pub fn forward(&self, inputsequence: &Array2<F>) -> Result<Array2<F>> {
228        let (seqlen, _input_size) = inputsequence.dim();
229
230        if self.layers.is_empty() {
231            return Err(TimeSeriesError::InvalidModel(
232                "No LSTM layers defined".to_string(),
233            ));
234        }
235
236        let output_size = self.output_layer.nrows();
237        let mut outputs = Array2::zeros((seqlen, output_size));
238
239        // Initialize states for all layers
240        let mut states: Vec<LSTMState<F>> =
241            self.layers.iter().map(|layer| layer.init_state()).collect();
242
243        // Process each time step
244        for t in 0..seqlen {
245            let mut layer_input = inputsequence.row(t).to_owned();
246
247            // Forward through LSTM layers
248            for (i, layer) in self.layers.iter().enumerate() {
249                let new_state = layer.forward(&layer_input, &states[i])?;
250                layer_input = new_state.hidden.clone();
251                states[i] = new_state;
252            }
253
254            // Apply dropout (simplified - just scaling)
255            if self.dropout_prob > F::zero() {
256                let keep_prob = F::one() - self.dropout_prob;
257                layer_input = layer_input.mapv(|x| x * keep_prob);
258            }
259
260            // Output projection
261            let output = self.compute_output(&layer_input);
262            for (j, &val) in output.iter().enumerate() {
263                outputs[[t, j]] = val;
264            }
265        }
266
267        Ok(outputs)
268    }
269
270    /// Compute final output projection
271    fn compute_output(&self, hidden: &Array1<F>) -> Array1<F> {
272        let mut output = self.output_bias.clone();
273
274        for i in 0..self.output_layer.nrows() {
275            for j in 0..self.output_layer.ncols() {
276                output[i] = output[i] + self.output_layer[[i, j]] * hidden[j];
277            }
278        }
279
280        output
281    }
282
283    /// Generate forecast for multiple steps
284    pub fn forecast(&self, input_sequence: &Array2<F>, forecaststeps: usize) -> Result<Array1<F>> {
285        let (seqlen, _) = input_sequence.dim();
286
287        // Get the last hidden states from input _sequence
288        let _ = self.forward(input_sequence)?;
289
290        // Initialize states for forecasting
291        let mut states: Vec<LSTMState<F>> =
292            self.layers.iter().map(|layer| layer.init_state()).collect();
293
294        // Re-run forward pass to get final states
295        for t in 0..seqlen {
296            let mut layer_input = input_sequence.row(t).to_owned();
297            for (i, layer) in self.layers.iter().enumerate() {
298                let new_state = layer.forward(&layer_input, &states[i])?;
299                layer_input = new_state.hidden.clone();
300                states[i] = new_state;
301            }
302        }
303
304        let mut forecasts = Array1::zeros(forecaststeps);
305        let mut last_output = input_sequence.row(seqlen - 1).to_owned();
306
307        // Generate forecasts step by step
308        for step in 0..forecaststeps {
309            let mut layer_input = last_output.clone();
310
311            // Forward through LSTM layers
312            for (i, layer) in self.layers.iter().enumerate() {
313                let new_state = layer.forward(&layer_input, &states[i])?;
314                layer_input = new_state.hidden.clone();
315                states[i] = new_state;
316            }
317
318            // Compute output
319            let output = self.compute_output(&layer_input);
320            forecasts[step] = output[0]; // Assuming single output for forecasting
321
322            // Use forecast as input for next step (assuming univariate)
323            if last_output.len() == 1 {
324                last_output[0] = output[0];
325            } else {
326                // For multivariate, use the forecast as the first feature
327                last_output[0] = output[0];
328            }
329        }
330
331        Ok(forecasts)
332    }
333}