ruvector_cnn/layers/
activation.rs1use crate::{simd, CnnResult, Tensor};
11
12use super::Layer;
13
14#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
16pub enum ActivationType {
17 ReLU,
19 ReLU6,
21 Swish,
23 HardSwish,
25 Sigmoid,
27 Identity,
29}
30
31#[derive(Clone, Debug)]
33pub struct Activation {
34 activation_type: ActivationType,
35}
36
37impl Activation {
38 pub fn new(activation_type: ActivationType) -> Self {
40 Self { activation_type }
41 }
42
43 pub fn activation_type(&self) -> ActivationType {
45 self.activation_type
46 }
47
48 pub fn apply_inplace(&self, data: &mut [f32]) {
50 match self.activation_type {
51 ActivationType::ReLU => {
52 for x in data.iter_mut() {
53 *x = x.max(0.0);
54 }
55 }
56 ActivationType::ReLU6 => {
57 for x in data.iter_mut() {
58 *x = x.max(0.0).min(6.0);
59 }
60 }
61 ActivationType::Swish => {
62 for x in data.iter_mut() {
63 let sigmoid = 1.0 / (1.0 + (-*x).exp());
64 *x *= sigmoid;
65 }
66 }
67 ActivationType::HardSwish => {
68 for x in data.iter_mut() {
69 *x *= (*x + 3.0).max(0.0).min(6.0) / 6.0;
70 }
71 }
72 ActivationType::Sigmoid => {
73 for x in data.iter_mut() {
74 *x = 1.0 / (1.0 + (-*x).exp());
75 }
76 }
77 ActivationType::Identity => {}
78 }
79 }
80}
81
82impl Layer for Activation {
83 fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
84 let mut output = input.clone();
85 self.apply_inplace(output.data_mut());
86 Ok(output)
87 }
88
89 fn name(&self) -> &'static str {
90 match self.activation_type {
91 ActivationType::ReLU => "ReLU",
92 ActivationType::ReLU6 => "ReLU6",
93 ActivationType::Swish => "Swish",
94 ActivationType::HardSwish => "HardSwish",
95 ActivationType::Sigmoid => "Sigmoid",
96 ActivationType::Identity => "Identity",
97 }
98 }
99}
100
101#[derive(Debug, Clone, Default)]
103pub struct ReLU;
104
105impl ReLU {
106 pub fn new() -> Self {
108 Self
109 }
110}
111
112impl Layer for ReLU {
113 fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
114 let mut output = Tensor::zeros(input.shape());
115 simd::relu_simd(input.data(), output.data_mut());
116 Ok(output)
117 }
118
119 fn name(&self) -> &'static str {
120 "ReLU"
121 }
122}
123
124#[derive(Debug, Clone, Default)]
127pub struct ReLU6;
128
129impl ReLU6 {
130 pub fn new() -> Self {
132 Self
133 }
134}
135
136impl Layer for ReLU6 {
137 fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
138 let mut output = Tensor::zeros(input.shape());
139 simd::relu6_simd(input.data(), output.data_mut());
140 Ok(output)
141 }
142
143 fn name(&self) -> &'static str {
144 "ReLU6"
145 }
146}
147
148#[derive(Debug, Clone, Default)]
151pub struct Swish;
152
153impl Swish {
154 pub fn new() -> Self {
156 Self
157 }
158}
159
160impl Layer for Swish {
161 fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
162 let mut output = Tensor::zeros(input.shape());
163 simd::scalar::swish_scalar(input.data(), output.data_mut());
164 Ok(output)
165 }
166
167 fn name(&self) -> &'static str {
168 "Swish"
169 }
170}
171
172#[derive(Debug, Clone, Default)]
175pub struct HardSwish;
176
177impl HardSwish {
178 pub fn new() -> Self {
180 Self
181 }
182}
183
184impl Layer for HardSwish {
185 fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
186 let mut output = Tensor::zeros(input.shape());
187 simd::scalar::hard_swish_scalar(input.data(), output.data_mut());
188 Ok(output)
189 }
190
191 fn name(&self) -> &'static str {
192 "HardSwish"
193 }
194}
195
196#[derive(Debug, Clone, Default)]
198pub struct Sigmoid;
199
200impl Sigmoid {
201 pub fn new() -> Self {
203 Self
204 }
205}
206
207impl Layer for Sigmoid {
208 fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
209 let mut output = Tensor::zeros(input.shape());
210 simd::scalar::sigmoid_scalar(input.data(), output.data_mut());
211 Ok(output)
212 }
213
214 fn name(&self) -> &'static str {
215 "Sigmoid"
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn test_relu() {
225 let relu = ReLU::new();
226 let input = Tensor::from_data(vec![-2.0, -1.0, 0.0, 1.0, 2.0], &[5]).unwrap();
227 let output = relu.forward(&input).unwrap();
228
229 assert_eq!(output.data(), &[0.0, 0.0, 0.0, 1.0, 2.0]);
230 }
231
232 #[test]
233 fn test_relu6() {
234 let relu6 = ReLU6::new();
235 let input = Tensor::from_data(vec![-2.0, 0.0, 3.0, 6.0, 10.0], &[5]).unwrap();
236 let output = relu6.forward(&input).unwrap();
237
238 assert_eq!(output.data(), &[0.0, 0.0, 3.0, 6.0, 6.0]);
239 }
240
241 #[test]
242 fn test_sigmoid() {
243 let sigmoid = Sigmoid::new();
244 let input = Tensor::from_data(vec![0.0], &[1]).unwrap();
245 let output = sigmoid.forward(&input).unwrap();
246
247 assert!((output.data()[0] - 0.5).abs() < 0.001);
248 }
249
250 #[test]
251 fn test_swish() {
252 let swish = Swish::new();
253 let input = Tensor::from_data(vec![0.0, 1.0, -1.0], &[3]).unwrap();
254 let output = swish.forward(&input).unwrap();
255
256 assert!(output.data()[0].abs() < 0.001);
258 assert!((output.data()[1] - 0.731).abs() < 0.01);
260 }
261
262 #[test]
263 fn test_hard_swish() {
264 let hs = HardSwish::new();
265 let input = Tensor::from_data(vec![-4.0, -3.0, 0.0, 3.0, 4.0], &[5]).unwrap();
266 let output = hs.forward(&input).unwrap();
267
268 assert!(output.data()[0].abs() < 0.001);
270 assert!(output.data()[1].abs() < 0.001);
272 assert!(output.data()[2].abs() < 0.001);
274 assert!((output.data()[3] - 3.0).abs() < 0.001);
276 }
277}