scirs2_series/neural_forecasting/
nbeats.rs1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use super::config::ActivationFunction;
11use super::lstm::LSTMCell;
12use crate::error::{Result, TimeSeriesError}; #[derive(Debug, Clone)]
16pub enum NBeatsBlockType {
17 Generic,
19 Trend,
21 Seasonality,
23}
24
25#[derive(Debug)]
27pub struct NBeatsBlock<F: Float + Debug> {
28 #[allow(dead_code)]
30 block_type: NBeatsBlockType,
31 #[allow(dead_code)]
33 input_size: usize,
34 #[allow(dead_code)]
36 output_size: usize,
37 #[allow(dead_code)]
39 num_layers: usize,
40 #[allow(dead_code)]
42 layer_widths: Vec<usize>,
43 #[allow(dead_code)]
45 weights: Vec<Array2<F>>,
46 #[allow(dead_code)]
48 biases: Vec<Array1<F>>,
49 #[allow(dead_code)]
51 theta_weights: Array2<F>,
52 #[allow(dead_code)]
54 theta_bias: Array1<F>,
55}
56
57impl<F: Float + Debug + Clone + FromPrimitive> NBeatsBlock<F> {
58 pub fn new(
60 block_type: NBeatsBlockType,
61 input_size: usize,
62 output_size: usize,
63 layer_widths: Vec<usize>,
64 ) -> Self {
65 let num_layers = layer_widths.len();
66 let mut weights = Vec::new();
67 let mut biases = Vec::new();
68
69 let mut prev_width = input_size;
71 for &width in &layer_widths {
72 let scale = F::from(2.0).unwrap() / F::from(prev_width).unwrap();
73 let std_dev = scale.sqrt();
74 weights.push(LSTMCell::random_matrix(width, prev_width, std_dev));
75 biases.push(Array1::zeros(width));
76 prev_width = width;
77 }
78
79 let theta_size = match block_type {
81 NBeatsBlockType::Generic => output_size + input_size,
82 NBeatsBlockType::Trend => 3, NBeatsBlockType::Seasonality => output_size / 2, };
85
86 let theta_scale = F::from(2.0).unwrap() / F::from(prev_width).unwrap();
87 let theta_std = theta_scale.sqrt();
88
89 Self {
90 block_type,
91 input_size,
92 output_size,
93 num_layers,
94 layer_widths,
95 weights,
96 biases,
97 theta_weights: LSTMCell::random_matrix(theta_size, prev_width, theta_std),
98 theta_bias: Array1::zeros(theta_size),
99 }
100 }
101
102 pub fn forward(&self, input: &Array1<F>) -> Result<(Array1<F>, Array1<F>)> {
104 if input.len() != self.input_size {
105 return Err(TimeSeriesError::DimensionMismatch {
106 expected: self.input_size,
107 actual: input.len(),
108 });
109 }
110
111 let backcast = Array1::zeros(self.input_size);
113 let forecast = Array1::zeros(self.output_size);
114 Ok((backcast, forecast))
115 }
116}
117
118#[derive(Debug, Clone)]
120pub enum NBeatsStackType {
121 Generic,
123 Trend,
125 Seasonality,
127}
128
129#[derive(Debug)]
131pub struct NBeatsStack<F: Float + Debug> {
132 #[allow(dead_code)]
134 stack_type: NBeatsStackType,
135 #[allow(dead_code)]
137 blocks: Vec<NBeatsBlock<F>>,
138}
139
140impl<F: Float + Debug + Clone + FromPrimitive> NBeatsStack<F> {
141 pub fn new(
143 stack_type: NBeatsStackType,
144 input_size: usize,
145 output_size: usize,
146 num_blocks: usize,
147 layer_widths: Vec<usize>,
148 ) -> Self {
149 let mut blocks = Vec::new();
150
151 let block_type = match stack_type {
152 NBeatsStackType::Generic => NBeatsBlockType::Generic,
153 NBeatsStackType::Trend => NBeatsBlockType::Trend,
154 NBeatsStackType::Seasonality => NBeatsBlockType::Seasonality,
155 };
156
157 for _ in 0..num_blocks {
158 blocks.push(NBeatsBlock::new(
159 block_type.clone(),
160 input_size,
161 output_size,
162 layer_widths.clone(),
163 ));
164 }
165
166 Self { stack_type, blocks }
167 }
168
169 pub fn forward(&self, input: &Array1<F>) -> Result<(Array1<F>, Array1<F>)> {
171 let residual = input.clone();
173 let forecast = Array1::zeros(0); Ok((residual, forecast))
175 }
176}
177
178#[derive(Debug)]
180pub struct NBeatsModel<F: Float + Debug> {
181 #[allow(dead_code)]
183 stacks: Vec<NBeatsStack<F>>,
184 #[allow(dead_code)]
186 input_size: usize,
187 #[allow(dead_code)]
189 output_size: usize,
190}
191
192impl<F: Float + Debug + Clone + FromPrimitive> NBeatsModel<F> {
193 pub fn new(
195 input_size: usize,
196 output_size: usize,
197 stack_configs: Vec<(NBeatsStackType, usize, Vec<usize>)>, ) -> Self {
199 let mut stacks = Vec::new();
200
201 for (stack_type, num_blocks, layer_widths) in stack_configs {
202 stacks.push(NBeatsStack::new(
203 stack_type,
204 input_size,
205 output_size,
206 num_blocks,
207 layer_widths,
208 ));
209 }
210
211 Self {
212 stacks,
213 input_size,
214 output_size,
215 }
216 }
217
218 pub fn forward(&self, input: &Array1<F>) -> Result<Array1<F>> {
220 if input.len() != self.input_size {
221 return Err(TimeSeriesError::DimensionMismatch {
222 expected: self.input_size,
223 actual: input.len(),
224 });
225 }
226
227 Ok(Array1::zeros(self.output_size))
229 }
230
231 pub fn forecast(&self, input: &Array1<F>) -> Result<Array1<F>> {
233 self.forward(input)
234 }
235}