scirs2_series/neural_forecasting/
transformer.rs1use scirs2_core::ndarray::{Array1, Array2, Array3};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use super::config::ActivationFunction;
11use super::lstm::LSTMCell;
12use crate::error::{Result, TimeSeriesError}; #[derive(Debug)]
16pub struct MultiHeadAttention<F: Float + Debug> {
17 #[allow(dead_code)]
19 numheads: usize,
20 #[allow(dead_code)]
22 _model_dim: usize,
23 #[allow(dead_code)]
25 head_dim: usize,
26 #[allow(dead_code)]
28 w_query: Array2<F>,
29 #[allow(dead_code)]
31 w_key: Array2<F>,
32 #[allow(dead_code)]
34 w_value: Array2<F>,
35 w_output: Array2<F>,
37}
38
39impl<F: Float + Debug + Clone + FromPrimitive> MultiHeadAttention<F> {
40 pub fn new(_model_dim: usize, numheads: usize) -> Result<Self> {
42 if !_model_dim.is_multiple_of(numheads) {
43 return Err(TimeSeriesError::InvalidInput(
44 "Model dimension must be divisible by number of heads".to_string(),
45 ));
46 }
47
48 let head_dim = _model_dim / numheads;
49 let scale = F::from(2.0).unwrap() / F::from(_model_dim).unwrap();
50 let std_dev = scale.sqrt();
51
52 Ok(Self {
53 numheads,
54 _model_dim,
55 head_dim,
56 w_query: LSTMCell::random_matrix(_model_dim, _model_dim, std_dev),
57 w_key: LSTMCell::random_matrix(_model_dim, _model_dim, std_dev),
58 w_value: LSTMCell::random_matrix(_model_dim, _model_dim, std_dev),
59 w_output: LSTMCell::random_matrix(_model_dim, _model_dim, std_dev),
60 })
61 }
62
63 pub fn forward(&self, input: &Array2<F>) -> Result<Array2<F>> {
65 let (seqlen, _model_dim) = input.dim();
66
67 if _model_dim != self._model_dim {
68 return Err(TimeSeriesError::DimensionMismatch {
69 expected: self._model_dim,
70 actual: _model_dim,
71 });
72 }
73
74 Ok(input.clone())
77 }
78}
79
80#[derive(Debug)]
82pub struct FeedForwardNetwork<F: Float + Debug> {
83 #[allow(dead_code)]
85 w1: Array2<F>,
86 #[allow(dead_code)]
88 w2: Array2<F>,
89 #[allow(dead_code)]
91 b1: Array1<F>,
92 #[allow(dead_code)]
94 b2: Array1<F>,
95 #[allow(dead_code)]
97 activation: ActivationFunction,
98}
99
100impl<F: Float + Debug + Clone + FromPrimitive> FeedForwardNetwork<F> {
101 pub fn new(input_dim: usize, hidden_dim: usize, activation: ActivationFunction) -> Self {
103 let scale1 = F::from(2.0).unwrap() / F::from(input_dim).unwrap();
104 let std_dev1 = scale1.sqrt();
105 let scale2 = F::from(2.0).unwrap() / F::from(hidden_dim).unwrap();
106 let std_dev2 = scale2.sqrt();
107
108 Self {
109 w1: LSTMCell::random_matrix(hidden_dim, input_dim, std_dev1),
110 w2: LSTMCell::random_matrix(input_dim, hidden_dim, std_dev2),
111 b1: Array1::zeros(hidden_dim),
112 b2: Array1::zeros(input_dim),
113 activation,
114 }
115 }
116
117 pub fn forward(&self, input: &Array2<F>) -> Array2<F> {
119 input.clone()
121 }
122}
123
124#[derive(Debug)]
126pub struct TransformerBlock<F: Float + Debug> {
127 #[allow(dead_code)]
129 attention: MultiHeadAttention<F>,
130 #[allow(dead_code)]
132 ffn: FeedForwardNetwork<F>,
133 #[allow(dead_code)]
135 ln1_gamma: Array1<F>,
136 #[allow(dead_code)]
137 ln1_beta: Array1<F>,
138 #[allow(dead_code)]
139 ln2_gamma: Array1<F>,
140 #[allow(dead_code)]
141 ln2_beta: Array1<F>,
142}
143
144impl<F: Float + Debug + Clone + FromPrimitive> TransformerBlock<F> {
145 pub fn new(model_dim: usize, num_heads: usize, ffn_hidden_dim: usize) -> Result<Self> {
147 let attention = MultiHeadAttention::new(model_dim, num_heads)?;
148 let ffn = FeedForwardNetwork::new(model_dim, ffn_hidden_dim, ActivationFunction::ReLU);
149
150 Ok(Self {
151 attention,
152 ffn,
153 ln1_gamma: Array1::ones(model_dim),
154 ln1_beta: Array1::zeros(model_dim),
155 ln2_gamma: Array1::ones(model_dim),
156 ln2_beta: Array1::zeros(model_dim),
157 })
158 }
159
160 pub fn forward(&self, input: &Array2<F>) -> Result<Array2<F>> {
162 Ok(input.clone())
164 }
165}
166
167#[derive(Debug)]
169pub struct TransformerForecaster<F: Float + Debug> {
170 #[allow(dead_code)]
172 blocks: Vec<TransformerBlock<F>>,
173 #[allow(dead_code)]
175 input_embedding: Array2<F>,
176 #[allow(dead_code)]
178 positional_encoding: Array2<F>,
179 #[allow(dead_code)]
181 output_projection: Array2<F>,
182}
183
184impl<F: Float + Debug + Clone + FromPrimitive> TransformerForecaster<F> {
185 pub fn new(
187 input_dim: usize,
188 model_dim: usize,
189 num_layers: usize,
190 num_heads: usize,
191 ffn_hidden_dim: usize,
192 max_seq_len: usize,
193 output_dim: usize,
194 ) -> Result<Self> {
195 let mut blocks = Vec::new();
196 for _ in 0..num_layers {
197 blocks.push(TransformerBlock::new(model_dim, num_heads, ffn_hidden_dim)?);
198 }
199
200 let scale = F::from(2.0).unwrap() / F::from(input_dim).unwrap();
201 let std_dev = scale.sqrt();
202
203 Ok(Self {
204 blocks,
205 input_embedding: LSTMCell::random_matrix(model_dim, input_dim, std_dev),
206 positional_encoding: Array2::zeros((max_seq_len, model_dim)),
207 output_projection: LSTMCell::random_matrix(output_dim, model_dim, std_dev),
208 })
209 }
210
211 pub fn forward(&self, input: &Array2<F>) -> Result<Array2<F>> {
213 Ok(input.clone())
215 }
216
217 pub fn forecast(&self, input: &Array2<F>, forecast_steps: usize) -> Result<Array1<F>> {
219 Ok(Array1::zeros(forecast_steps))
221 }
222}