scirs2_series/neural_forecasting/
temporal_fusion.rs1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use super::lstm::LSTMCell;
11use crate::error::Result; #[derive(Debug)]
15pub struct TemporalFusionTransformer<F: Float + Debug> {
16 #[allow(dead_code)]
18 model_dim: usize,
19 #[allow(dead_code)]
21 variable_selection: VariableSelectionNetwork<F>,
22 #[allow(dead_code)]
24 grn_layers: Vec<GatedResidualNetwork<F>>,
25}
26
27impl<F: Float + Debug + Clone + FromPrimitive> TemporalFusionTransformer<F> {
28 pub fn new(input_dim: usize, model_dim: usize, num_layers: usize) -> Self {
30 let variable_selection = VariableSelectionNetwork::new(input_dim, model_dim);
31 let mut grn_layers = Vec::new();
32
33 for _ in 0..num_layers {
34 grn_layers.push(GatedResidualNetwork::new(model_dim));
35 }
36
37 Self {
38 model_dim,
39 variable_selection,
40 grn_layers,
41 }
42 }
43
44 pub fn forward(&self, input: &Array2<F>) -> Result<Array2<F>> {
46 Ok(input.clone())
48 }
49}
50
51#[derive(Debug)]
53pub struct VariableSelectionNetwork<F: Float + Debug> {
54 #[allow(dead_code)]
56 selection_weights: Array2<F>,
57 #[allow(dead_code)]
59 context_vectors: Array2<F>,
60}
61
62impl<F: Float + Debug + Clone + FromPrimitive> VariableSelectionNetwork<F> {
63 pub fn new(input_dim: usize, output_dim: usize) -> Self {
65 let scale = F::from(2.0).unwrap() / F::from(input_dim).unwrap();
66 let std_dev = scale.sqrt();
67
68 Self {
69 selection_weights: LSTMCell::random_matrix(output_dim, input_dim, std_dev),
70 context_vectors: LSTMCell::random_matrix(output_dim, input_dim, std_dev),
71 }
72 }
73
74 pub fn forward(&self, input: &Array2<F>) -> Result<Array2<F>> {
76 Ok(input.clone())
78 }
79}
80
81#[derive(Debug)]
83pub struct GatedResidualNetwork<F: Float + Debug> {
84 #[allow(dead_code)]
86 linear_weights: Array2<F>,
87 #[allow(dead_code)]
89 gate_weights: Array2<F>,
90}
91
92impl<F: Float + Debug + Clone + FromPrimitive> GatedResidualNetwork<F> {
93 pub fn new(dim: usize) -> Self {
95 let scale = F::from(2.0).unwrap() / F::from(dim).unwrap();
96 let std_dev = scale.sqrt();
97
98 Self {
99 linear_weights: LSTMCell::random_matrix(dim, dim, std_dev),
100 gate_weights: LSTMCell::random_matrix(dim, dim, std_dev),
101 }
102 }
103
104 pub fn forward(&self, input: &Array2<F>) -> Result<Array2<F>> {
106 Ok(input.clone())
108 }
109}