scirs2_series/neural_forecasting/
attention.rs1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use super::lstm::LSTMCell;
11use crate::error::Result; #[derive(Debug)]
15pub struct FlashAttention<F: Float + Debug> {
16 #[allow(dead_code)]
18 model_dim: usize,
19 #[allow(dead_code)]
21 num_heads: usize,
22 #[allow(dead_code)]
24 w_query: Array2<F>,
25 #[allow(dead_code)]
27 w_key: Array2<F>,
28 #[allow(dead_code)]
30 w_value: Array2<F>,
31}
32
33impl<F: Float + Debug + Clone + FromPrimitive> FlashAttention<F> {
34 pub fn new(model_dim: usize, num_heads: usize) -> Self {
36 let scale = F::from(2.0).unwrap() / F::from(model_dim).unwrap();
37 let std_dev = scale.sqrt();
38
39 Self {
40 model_dim,
41 num_heads,
42 w_query: LSTMCell::random_matrix(model_dim, model_dim, std_dev),
43 w_key: LSTMCell::random_matrix(model_dim, model_dim, std_dev),
44 w_value: LSTMCell::random_matrix(model_dim, model_dim, std_dev),
45 }
46 }
47
48 pub fn forward(&self, input: &Array2<F>) -> Result<Array2<F>> {
50 Ok(input.clone())
52 }
53}