scirs2_series/neural_forecasting/
config.rs

1//! Configuration and Common Types for Neural Forecasting
2//!
3//! This module contains common configuration structures, enums, and utility types
4//! used across all neural forecasting architectures.
5
6use scirs2_core::ndarray::{Array1, Array2, Array3};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use crate::error::{Result, TimeSeriesError};
11
12/// Activation functions for neural networks
13#[derive(Debug, Clone)]
14pub enum ActivationFunction {
15    /// Sigmoid activation
16    Sigmoid,
17    /// Hyperbolic tangent
18    Tanh,
19    /// Rectified Linear Unit
20    ReLU,
21    /// Gaussian Error Linear Unit
22    GELU,
23    /// Swish activation
24    Swish,
25    /// Linear activation (identity)
26    Linear,
27}
28
29impl ActivationFunction {
30    /// Apply activation function
31    pub fn apply<F: Float>(&self, x: F) -> F {
32        match self {
33            ActivationFunction::Sigmoid => {
34                let one = F::one();
35                one / (one + (-x).exp())
36            }
37            ActivationFunction::Tanh => x.tanh(),
38            ActivationFunction::ReLU => x.max(F::zero()),
39            ActivationFunction::GELU => {
40                // Approximation of GELU
41                let half = F::from(0.5).unwrap();
42                let one = F::one();
43                let sqrt_2_pi = F::from(0.7978845608).unwrap(); // sqrt(2/π)
44                let coeff = F::from(0.044715).unwrap();
45
46                half * x * (one + (sqrt_2_pi * (x + coeff * x * x * x)).tanh())
47            }
48            ActivationFunction::Swish => {
49                let sigmoid = F::one() / (F::one() + (-x).exp());
50                x * sigmoid
51            }
52            ActivationFunction::Linear => x,
53        }
54    }
55
56    /// Apply derivative of activation function
57    pub fn derivative<F: Float>(&self, x: F) -> F {
58        match self {
59            ActivationFunction::Sigmoid => {
60                let sigmoid = self.apply(x);
61                sigmoid * (F::one() - sigmoid)
62            }
63            ActivationFunction::Tanh => {
64                let tanh_x = x.tanh();
65                F::one() - tanh_x * tanh_x
66            }
67            ActivationFunction::ReLU => {
68                if x > F::zero() {
69                    F::one()
70                } else {
71                    F::zero()
72                }
73            }
74            ActivationFunction::GELU => {
75                // Simplified derivative approximation
76                F::one() / (F::one() + (-x).exp())
77            }
78            ActivationFunction::Swish => {
79                let sigmoid = F::one() / (F::one() + (-x).exp());
80                sigmoid * (F::one() + x * (F::one() - sigmoid))
81            }
82            ActivationFunction::Linear => F::one(),
83        }
84    }
85}