Skip to main content

ruvector_cnn/layers/
activation.rs

1//! Activation Functions
2//!
3//! SIMD-optimized activation functions:
4//! - ReLU: max(0, x)
5//! - ReLU6: min(6, max(0, x))
6//! - Swish: x * sigmoid(x)
7//! - HardSwish: x * relu6(x + 3) / 6
8//! - Sigmoid: 1 / (1 + exp(-x))
9
10use crate::{simd, CnnResult, Tensor};
11
12use super::Layer;
13
14/// Types of activation functions.
15#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
16pub enum ActivationType {
17    /// ReLU: max(0, x)
18    ReLU,
19    /// ReLU6: min(6, max(0, x))
20    ReLU6,
21    /// Swish: x * sigmoid(x)
22    Swish,
23    /// HardSwish: x * relu6(x + 3) / 6
24    HardSwish,
25    /// Sigmoid: 1 / (1 + exp(-x))
26    Sigmoid,
27    /// No activation (identity)
28    Identity,
29}
30
31/// Generic activation layer that wraps an activation type.
32#[derive(Clone, Debug)]
33pub struct Activation {
34    activation_type: ActivationType,
35}
36
37impl Activation {
38    /// Creates a new activation layer.
39    pub fn new(activation_type: ActivationType) -> Self {
40        Self { activation_type }
41    }
42
43    /// Returns the activation type.
44    pub fn activation_type(&self) -> ActivationType {
45        self.activation_type
46    }
47
48    /// Applies the activation in-place.
49    pub fn apply_inplace(&self, data: &mut [f32]) {
50        match self.activation_type {
51            ActivationType::ReLU => {
52                for x in data.iter_mut() {
53                    *x = x.max(0.0);
54                }
55            }
56            ActivationType::ReLU6 => {
57                for x in data.iter_mut() {
58                    *x = x.max(0.0).min(6.0);
59                }
60            }
61            ActivationType::Swish => {
62                for x in data.iter_mut() {
63                    let sigmoid = 1.0 / (1.0 + (-*x).exp());
64                    *x *= sigmoid;
65                }
66            }
67            ActivationType::HardSwish => {
68                for x in data.iter_mut() {
69                    *x *= (*x + 3.0).max(0.0).min(6.0) / 6.0;
70                }
71            }
72            ActivationType::Sigmoid => {
73                for x in data.iter_mut() {
74                    *x = 1.0 / (1.0 + (-*x).exp());
75                }
76            }
77            ActivationType::Identity => {}
78        }
79    }
80}
81
82impl Layer for Activation {
83    fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
84        let mut output = input.clone();
85        self.apply_inplace(output.data_mut());
86        Ok(output)
87    }
88
89    fn name(&self) -> &'static str {
90        match self.activation_type {
91            ActivationType::ReLU => "ReLU",
92            ActivationType::ReLU6 => "ReLU6",
93            ActivationType::Swish => "Swish",
94            ActivationType::HardSwish => "HardSwish",
95            ActivationType::Sigmoid => "Sigmoid",
96            ActivationType::Identity => "Identity",
97        }
98    }
99}
100
101/// ReLU activation: max(0, x)
102#[derive(Debug, Clone, Default)]
103pub struct ReLU;
104
105impl ReLU {
106    /// Create a new ReLU activation
107    pub fn new() -> Self {
108        Self
109    }
110}
111
112impl Layer for ReLU {
113    fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
114        let mut output = Tensor::zeros(input.shape());
115        simd::relu_simd(input.data(), output.data_mut());
116        Ok(output)
117    }
118
119    fn name(&self) -> &'static str {
120        "ReLU"
121    }
122}
123
124/// ReLU6 activation: min(6, max(0, x))
125/// Used in MobileNet architectures
126#[derive(Debug, Clone, Default)]
127pub struct ReLU6;
128
129impl ReLU6 {
130    /// Create a new ReLU6 activation
131    pub fn new() -> Self {
132        Self
133    }
134}
135
136impl Layer for ReLU6 {
137    fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
138        let mut output = Tensor::zeros(input.shape());
139        simd::relu6_simd(input.data(), output.data_mut());
140        Ok(output)
141    }
142
143    fn name(&self) -> &'static str {
144        "ReLU6"
145    }
146}
147
148/// Swish activation: x * sigmoid(x)
149/// Also known as SiLU (Sigmoid Linear Unit)
150#[derive(Debug, Clone, Default)]
151pub struct Swish;
152
153impl Swish {
154    /// Create a new Swish activation
155    pub fn new() -> Self {
156        Self
157    }
158}
159
160impl Layer for Swish {
161    fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
162        let mut output = Tensor::zeros(input.shape());
163        simd::scalar::swish_scalar(input.data(), output.data_mut());
164        Ok(output)
165    }
166
167    fn name(&self) -> &'static str {
168        "Swish"
169    }
170}
171
172/// HardSwish activation: x * relu6(x + 3) / 6
173/// Efficient approximation of Swish for mobile inference
174#[derive(Debug, Clone, Default)]
175pub struct HardSwish;
176
177impl HardSwish {
178    /// Create a new HardSwish activation
179    pub fn new() -> Self {
180        Self
181    }
182}
183
184impl Layer for HardSwish {
185    fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
186        let mut output = Tensor::zeros(input.shape());
187        simd::scalar::hard_swish_scalar(input.data(), output.data_mut());
188        Ok(output)
189    }
190
191    fn name(&self) -> &'static str {
192        "HardSwish"
193    }
194}
195
196/// Sigmoid activation: 1 / (1 + exp(-x))
197#[derive(Debug, Clone, Default)]
198pub struct Sigmoid;
199
200impl Sigmoid {
201    /// Create a new Sigmoid activation
202    pub fn new() -> Self {
203        Self
204    }
205}
206
207impl Layer for Sigmoid {
208    fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
209        let mut output = Tensor::zeros(input.shape());
210        simd::scalar::sigmoid_scalar(input.data(), output.data_mut());
211        Ok(output)
212    }
213
214    fn name(&self) -> &'static str {
215        "Sigmoid"
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn test_relu() {
225        let relu = ReLU::new();
226        let input = Tensor::from_data(vec![-2.0, -1.0, 0.0, 1.0, 2.0], &[5]).unwrap();
227        let output = relu.forward(&input).unwrap();
228
229        assert_eq!(output.data(), &[0.0, 0.0, 0.0, 1.0, 2.0]);
230    }
231
232    #[test]
233    fn test_relu6() {
234        let relu6 = ReLU6::new();
235        let input = Tensor::from_data(vec![-2.0, 0.0, 3.0, 6.0, 10.0], &[5]).unwrap();
236        let output = relu6.forward(&input).unwrap();
237
238        assert_eq!(output.data(), &[0.0, 0.0, 3.0, 6.0, 6.0]);
239    }
240
241    #[test]
242    fn test_sigmoid() {
243        let sigmoid = Sigmoid::new();
244        let input = Tensor::from_data(vec![0.0], &[1]).unwrap();
245        let output = sigmoid.forward(&input).unwrap();
246
247        assert!((output.data()[0] - 0.5).abs() < 0.001);
248    }
249
250    #[test]
251    fn test_swish() {
252        let swish = Swish::new();
253        let input = Tensor::from_data(vec![0.0, 1.0, -1.0], &[3]).unwrap();
254        let output = swish.forward(&input).unwrap();
255
256        // swish(0) = 0 * 0.5 = 0
257        assert!(output.data()[0].abs() < 0.001);
258        // swish(1) = 1 * sigmoid(1) ≈ 0.731
259        assert!((output.data()[1] - 0.731).abs() < 0.01);
260    }
261
262    #[test]
263    fn test_hard_swish() {
264        let hs = HardSwish::new();
265        let input = Tensor::from_data(vec![-4.0, -3.0, 0.0, 3.0, 4.0], &[5]).unwrap();
266        let output = hs.forward(&input).unwrap();
267
268        // hardswish(-4) = -4 * relu6(-1) / 6 = 0
269        assert!(output.data()[0].abs() < 0.001);
270        // hardswish(-3) = -3 * relu6(0) / 6 = 0
271        assert!(output.data()[1].abs() < 0.001);
272        // hardswish(0) = 0 * relu6(3) / 6 = 0
273        assert!(output.data()[2].abs() < 0.001);
274        // hardswish(3) = 3 * relu6(6) / 6 = 3
275        assert!((output.data()[3] - 3.0).abs() < 0.001);
276    }
277}