scirs2_neural/
activations_minimal.rs

1//! Minimal activation functions without Layer trait dependencies
2
3use crate::error::Result;
4use scirs2_core::ndarray::{Array, Zip};
5use scirs2_core::numeric::Float;
6use std::fmt::Debug;
7
8/// Trait for activation functions
9pub trait Activation<F> {
10    /// Forward pass of the activation function
11    fn forward(
12        &self,
13        input: &Array<F, scirs2_core::ndarray::IxDyn>,
14    ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>>;
15
16    /// Backward pass of the activation function
17    fn backward(
18        &self,
19        grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
20        input: &Array<F, scirs2_core::ndarray::IxDyn>,
21    ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>>;
22}
23
24/// GELU activation function
25#[derive(Debug, Clone, Copy)]
26pub struct GELU {
27    fast: bool,
28}
29
30impl GELU {
31    pub fn new() -> Self {
32        Self { fast: false }
33    }
34
35    pub fn fast() -> Self {
36        Self { fast: true }
37    }
38}
39
40impl Default for GELU {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl<F: Float + Debug> Activation<F> for GELU {
47    fn forward(
48        &self,
49        input: &Array<F, scirs2_core::ndarray::IxDyn>,
50    ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
51        let mut output = input.clone();
52
53        if self.fast {
54            let sqrt_2_over_pi = F::from(0.7978845608028654).unwrap();
55            let coeff = F::from(0.044715).unwrap();
56            let half = F::from(0.5).unwrap();
57            let one = F::one();
58
59            Zip::from(&mut output).for_each(|x| {
60                let x3 = *x * *x * *x;
61                let inner = sqrt_2_over_pi * (*x + coeff * x3);
62                *x = half * *x * (one + inner.tanh());
63            });
64        } else {
65            let sqrt_pi_over_2 = F::from(1.2533141373155).unwrap();
66            let coeff = F::from(0.044715).unwrap();
67            let half = F::from(0.5).unwrap();
68            let one = F::one();
69
70            Zip::from(&mut output).for_each(|x| {
71                let x2 = *x * *x;
72                let inner = sqrt_pi_over_2 * *x * (one + coeff * x2);
73                *x = half * *x * (one + inner.tanh());
74            });
75        }
76
77        Ok(output)
78    }
79
80    fn backward(
81        &self,
82        grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
83        input: &Array<F, scirs2_core::ndarray::IxDyn>,
84    ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
85        let mut grad_input = Array::zeros(grad_output.raw_dim());
86
87        if self.fast {
88            let sqrt_2_over_pi = F::from(0.7978845608028654).unwrap();
89            let coeff = F::from(0.044715).unwrap();
90            let half = F::from(0.5).unwrap();
91            let one = F::one();
92            let three = F::from(3.0).unwrap();
93
94            Zip::from(&mut grad_input)
95                .and(grad_output)
96                .and(input)
97                .for_each(|grad_in, &grad_out, &x| {
98                    let x2 = x * x;
99                    let x3 = x2 * x;
100                    let inner = sqrt_2_over_pi * (x + coeff * x3);
101                    let tanh_inner = inner.tanh();
102                    let sech_sq = one - tanh_inner * tanh_inner;
103                    let d_inner_dx = sqrt_2_over_pi * (one + three * coeff * x2);
104                    let dgelu_dx = half * (one + tanh_inner) + half * x * sech_sq * d_inner_dx;
105                    *grad_in = grad_out * dgelu_dx;
106                });
107        } else {
108            let sqrt_pi_over_2 = F::from(1.2533141373155).unwrap();
109            let coeff = F::from(0.044715).unwrap();
110            let half = F::from(0.5).unwrap();
111            let one = F::one();
112            let three = F::from(3.0).unwrap();
113
114            Zip::from(&mut grad_input)
115                .and(grad_output)
116                .and(input)
117                .for_each(|grad_in, &grad_out, &x| {
118                    let x2 = x * x;
119                    let inner = sqrt_pi_over_2 * x * (one + coeff * x2);
120                    let tanh_inner = inner.tanh();
121                    let sech_sq = one - tanh_inner * tanh_inner;
122                    let d_inner_dx = sqrt_pi_over_2 * (one + three * coeff * x2);
123                    let dgelu_dx = half * (one + tanh_inner) + half * x * sech_sq * d_inner_dx;
124                    *grad_in = grad_out * dgelu_dx;
125                });
126        }
127
128        Ok(grad_input)
129    }
130}
131
132/// Tanh activation function
133#[derive(Debug, Clone, Copy)]
134pub struct Tanh;
135
136impl Tanh {
137    pub fn new() -> Self {
138        Self
139    }
140}
141
142impl Default for Tanh {
143    fn default() -> Self {
144        Self::new()
145    }
146}
147
148impl<F: Float + Debug> Activation<F> for Tanh {
149    fn forward(
150        &self,
151        input: &Array<F, scirs2_core::ndarray::IxDyn>,
152    ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
153        let mut output = input.clone();
154        Zip::from(&mut output).for_each(|x| {
155            *x = x.tanh();
156        });
157        Ok(output)
158    }
159
160    fn backward(
161        &self,
162        grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
163        input: &Array<F, scirs2_core::ndarray::IxDyn>,
164    ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
165        let mut grad_input = Array::zeros(grad_output.raw_dim());
166
167        Zip::from(&mut grad_input)
168            .and(grad_output)
169            .and(input)
170            .for_each(|grad_in, &grad_out, &x| {
171                let tanh_x = x.tanh();
172                let derivative = F::one() - tanh_x * tanh_x;
173                *grad_in = grad_out * derivative;
174            });
175
176        Ok(grad_input)
177    }
178}
179
180/// Sigmoid activation function
181#[derive(Debug, Clone, Copy)]
182pub struct Sigmoid;
183
184impl Sigmoid {
185    pub fn new() -> Self {
186        Self
187    }
188}
189
190impl Default for Sigmoid {
191    fn default() -> Self {
192        Self::new()
193    }
194}
195
196impl<F: Float + Debug> Activation<F> for Sigmoid {
197    fn forward(
198        &self,
199        input: &Array<F, scirs2_core::ndarray::IxDyn>,
200    ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
201        let mut output = input.clone();
202        let one = F::one();
203        Zip::from(&mut output).for_each(|x| {
204            *x = one / (one + (-*x).exp());
205        });
206        Ok(output)
207    }
208
209    fn backward(
210        &self,
211        grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
212        input: &Array<F, scirs2_core::ndarray::IxDyn>,
213    ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
214        let mut grad_input = Array::zeros(grad_output.raw_dim());
215        let one = F::one();
216
217        Zip::from(&mut grad_input)
218            .and(grad_output)
219            .and(input)
220            .for_each(|grad_in, &grad_out, &x| {
221                let sigmoid_x = one / (one + (-x).exp());
222                let derivative = sigmoid_x * (one - sigmoid_x);
223                *grad_in = grad_out * derivative;
224            });
225
226        Ok(grad_input)
227    }
228}
229
230/// ReLU activation function
231#[derive(Debug, Clone, Copy)]
232pub struct ReLU {
233    alpha: f64,
234}
235
236impl ReLU {
237    pub fn new() -> Self {
238        Self { alpha: 0.0 }
239    }
240
241    pub fn leaky(alpha: f64) -> Self {
242        Self { alpha }
243    }
244}
245
246impl Default for ReLU {
247    fn default() -> Self {
248        Self::new()
249    }
250}
251
252impl<F: Float + Debug> Activation<F> for ReLU {
253    fn forward(
254        &self,
255        input: &Array<F, scirs2_core::ndarray::IxDyn>,
256    ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
257        let mut output = input.clone();
258        let zero = F::zero();
259        let alpha = F::from(self.alpha).unwrap_or(zero);
260
261        Zip::from(&mut output).for_each(|x| {
262            if *x < zero {
263                *x = alpha * *x;
264            }
265        });
266        Ok(output)
267    }
268
269    fn backward(
270        &self,
271        grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
272        input: &Array<F, scirs2_core::ndarray::IxDyn>,
273    ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
274        let mut grad_input = Array::zeros(grad_output.raw_dim());
275        let zero = F::zero();
276        let one = F::one();
277        let alpha = F::from(self.alpha).unwrap_or(zero);
278
279        Zip::from(&mut grad_input)
280            .and(grad_output)
281            .and(input)
282            .for_each(|grad_in, &grad_out, &x| {
283                let derivative = if x > zero { one } else { alpha };
284                *grad_in = grad_out * derivative;
285            });
286
287        Ok(grad_input)
288    }
289}
290
291/// Softmax activation function
292#[derive(Debug, Clone, Copy)]
293pub struct Softmax {
294    axis: isize,
295}
296
297impl Softmax {
298    pub fn new(axis: isize) -> Self {
299        Self { axis }
300    }
301}
302
303impl Default for Softmax {
304    fn default() -> Self {
305        Self::new(-1)
306    }
307}
308
309impl<F: Float + Debug> Activation<F> for Softmax {
310    fn forward(
311        &self,
312        input: &Array<F, scirs2_core::ndarray::IxDyn>,
313    ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
314        let mut output = input.clone();
315
316        // Simple softmax implementation for the last axis
317        if self.axis == -1 || self.axis as usize == input.ndim() - 1 {
318            // For 1D case or applying to last axis
319            let max_val = input.fold(F::neg_infinity(), |acc, &x| if x > acc { x } else { acc });
320
321            // Subtract max for numerical stability
322            Zip::from(&mut output).for_each(|x| {
323                *x = (*x - max_val).exp();
324            });
325
326            // Sum all exponentials
327            let sum = output.sum();
328
329            // Normalize
330            Zip::from(&mut output).for_each(|x| {
331                *x = *x / sum;
332            });
333        }
334
335        Ok(output)
336    }
337
338    fn backward(
339        &self,
340        grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
341        input: &Array<F, scirs2_core::ndarray::IxDyn>,
342    ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
343        // Forward pass to get softmax _output
344        let softmax_output = self.forward(input)?;
345        let mut grad_input = Array::zeros(grad_output.raw_dim());
346
347        // For softmax: grad = softmax * (grad_out - (softmax * grad_out).sum())
348        let sum_grad = Zip::from(&softmax_output)
349            .and(grad_output)
350            .fold(F::zero(), |acc, &s, &g| acc + s * g);
351
352        Zip::from(&mut grad_input)
353            .and(&softmax_output)
354            .and(grad_output)
355            .for_each(|grad_in, &s, &grad_out| {
356                *grad_in = s * (grad_out - sum_grad);
357            });
358
359        Ok(grad_input)
360    }
361}