scirs2_series/neural_forecasting/
config.rs1use scirs2_core::ndarray::{Array1, Array2, Array3};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use crate::error::{Result, TimeSeriesError};
11
12#[derive(Debug, Clone)]
14pub enum ActivationFunction {
15 Sigmoid,
17 Tanh,
19 ReLU,
21 GELU,
23 Swish,
25 Linear,
27}
28
29impl ActivationFunction {
30 pub fn apply<F: Float>(&self, x: F) -> F {
32 match self {
33 ActivationFunction::Sigmoid => {
34 let one = F::one();
35 one / (one + (-x).exp())
36 }
37 ActivationFunction::Tanh => x.tanh(),
38 ActivationFunction::ReLU => x.max(F::zero()),
39 ActivationFunction::GELU => {
40 let half = F::from(0.5).unwrap();
42 let one = F::one();
43 let sqrt_2_pi = F::from(0.7978845608).unwrap(); let coeff = F::from(0.044715).unwrap();
45
46 half * x * (one + (sqrt_2_pi * (x + coeff * x * x * x)).tanh())
47 }
48 ActivationFunction::Swish => {
49 let sigmoid = F::one() / (F::one() + (-x).exp());
50 x * sigmoid
51 }
52 ActivationFunction::Linear => x,
53 }
54 }
55
56 pub fn derivative<F: Float>(&self, x: F) -> F {
58 match self {
59 ActivationFunction::Sigmoid => {
60 let sigmoid = self.apply(x);
61 sigmoid * (F::one() - sigmoid)
62 }
63 ActivationFunction::Tanh => {
64 let tanh_x = x.tanh();
65 F::one() - tanh_x * tanh_x
66 }
67 ActivationFunction::ReLU => {
68 if x > F::zero() {
69 F::one()
70 } else {
71 F::zero()
72 }
73 }
74 ActivationFunction::GELU => {
75 F::one() / (F::one() + (-x).exp())
77 }
78 ActivationFunction::Swish => {
79 let sigmoid = F::one() / (F::one() + (-x).exp());
80 sigmoid * (F::one() + x * (F::one() - sigmoid))
81 }
82 ActivationFunction::Linear => F::one(),
83 }
84 }
85}