Skip to main content

rust_mlp/
activation.rs

1//! Activation functions.
2//!
3//! A dense layer computes a pre-activation value `z = W x + b` and then applies an
4//! activation function element-wise: `y = activation(z)`.
5//!
6//! In this crate we cache the *post-activation* outputs `y` in `Scratch`. During
7//! backprop we compute `dL/dz` from `dL/dy` using `y` (when possible). This keeps
8//! the per-sample hot path allocation-free without needing a separate `z` buffer.
9
10use crate::{Error, Result};
11
12#[derive(Debug, Clone, Copy, PartialEq)]
13/// Element-wise activation function.
14pub enum Activation {
15    Tanh,
16    ReLU,
17    LeakyReLU { alpha: f32 },
18    Sigmoid,
19    Identity,
20}
21
22impl Activation {
23    /// Validate activation parameters.
24    pub fn validate(self) -> Result<()> {
25        match self {
26            Activation::LeakyReLU { alpha } => {
27                if !(alpha.is_finite() && alpha >= 0.0) {
28                    return Err(Error::InvalidConfig(format!(
29                        "leaky ReLU alpha must be finite and >= 0, got {alpha}"
30                    )));
31                }
32            }
33            Activation::Tanh | Activation::ReLU | Activation::Sigmoid | Activation::Identity => {}
34        }
35
36        Ok(())
37    }
38
39    #[inline]
40    pub(crate) fn forward(self, x: f32) -> f32 {
41        match self {
42            Activation::Tanh => x.tanh(),
43            Activation::ReLU => x.max(0.0),
44            Activation::LeakyReLU { alpha } => {
45                if x > 0.0 {
46                    x
47                } else {
48                    alpha * x
49                }
50            }
51            Activation::Sigmoid => sigmoid(x),
52            Activation::Identity => x,
53        }
54    }
55
56    /// Derivative of the activation with respect to its input, expressed in terms
57    /// of the cached post-activation output `y`.
58    #[inline]
59    pub(crate) fn grad_from_output(self, y: f32) -> f32 {
60        match self {
61            Activation::Tanh => 1.0 - y * y,
62            Activation::ReLU => {
63                if y > 0.0 {
64                    1.0
65                } else {
66                    0.0
67                }
68            }
69            Activation::LeakyReLU { alpha } => {
70                if y > 0.0 {
71                    1.0
72                } else {
73                    alpha
74                }
75            }
76            Activation::Sigmoid => y * (1.0 - y),
77            Activation::Identity => 1.0,
78        }
79    }
80}
81
82#[inline]
83fn sigmoid(x: f32) -> f32 {
84    // Numerically stable sigmoid.
85    if x >= 0.0 {
86        let z = (-x).exp();
87        1.0 / (1.0 + z)
88    } else {
89        let z = x.exp();
90        z / (1.0 + z)
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    #[test]
99    fn leaky_relu_alpha_must_be_finite_and_non_negative() {
100        assert!(
101            Activation::LeakyReLU { alpha: f32::NAN }
102                .validate()
103                .is_err()
104        );
105        assert!(Activation::LeakyReLU { alpha: -0.1 }.validate().is_err());
106        assert!(Activation::LeakyReLU { alpha: 0.1 }.validate().is_ok());
107    }
108
109    #[test]
110    fn sigmoid_basic_values() {
111        let y0 = Activation::Sigmoid.forward(0.0);
112        assert!((y0 - 0.5).abs() < 1e-6);
113
114        let y_pos = Activation::Sigmoid.forward(10.0);
115        let y_neg = Activation::Sigmoid.forward(-10.0);
116        assert!(y_pos > 0.999);
117        assert!(y_neg < 0.001);
118    }
119
120    #[test]
121    fn relu_and_leaky_relu_shapes() {
122        assert_eq!(Activation::ReLU.forward(-2.0), 0.0);
123        assert_eq!(Activation::ReLU.forward(3.0), 3.0);
124
125        let act = Activation::LeakyReLU { alpha: 0.1 };
126        assert_eq!(act.forward(-2.0), -0.2);
127        assert_eq!(act.forward(3.0), 3.0);
128
129        // Gradients expressed via cached outputs.
130        assert_eq!(Activation::ReLU.grad_from_output(0.0), 0.0);
131        assert_eq!(Activation::ReLU.grad_from_output(1.0), 1.0);
132        assert_eq!(act.grad_from_output(-0.2), 0.1);
133        assert_eq!(act.grad_from_output(3.0), 1.0);
134    }
135
136    #[test]
137    fn tanh_and_sigmoid_gradients_from_output() {
138        let y_tanh = Activation::Tanh.forward(0.3);
139        let g_tanh = Activation::Tanh.grad_from_output(y_tanh);
140        assert!((g_tanh - (1.0 - y_tanh * y_tanh)).abs() < 1e-6);
141
142        let y_sig = Activation::Sigmoid.forward(0.0);
143        let g_sig = Activation::Sigmoid.grad_from_output(y_sig);
144        assert!((g_sig - 0.25).abs() < 1e-6);
145    }
146}