1use crate::error::Result;
4use scirs2_core::ndarray::{Array, Zip};
5use scirs2_core::numeric::Float;
6use std::fmt::Debug;
7
8pub trait Activation<F> {
10 fn forward(
12 &self,
13 input: &Array<F, scirs2_core::ndarray::IxDyn>,
14 ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>>;
15
16 fn backward(
18 &self,
19 grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
20 input: &Array<F, scirs2_core::ndarray::IxDyn>,
21 ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>>;
22}
23
24#[derive(Debug, Clone, Copy)]
26pub struct GELU {
27 fast: bool,
28}
29
30impl GELU {
31 pub fn new() -> Self {
32 Self { fast: false }
33 }
34
35 pub fn fast() -> Self {
36 Self { fast: true }
37 }
38}
39
40impl Default for GELU {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl<F: Float + Debug> Activation<F> for GELU {
47 fn forward(
48 &self,
49 input: &Array<F, scirs2_core::ndarray::IxDyn>,
50 ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
51 let mut output = input.clone();
52
53 if self.fast {
54 let sqrt_2_over_pi = F::from(0.7978845608028654).unwrap();
55 let coeff = F::from(0.044715).unwrap();
56 let half = F::from(0.5).unwrap();
57 let one = F::one();
58
59 Zip::from(&mut output).for_each(|x| {
60 let x3 = *x * *x * *x;
61 let inner = sqrt_2_over_pi * (*x + coeff * x3);
62 *x = half * *x * (one + inner.tanh());
63 });
64 } else {
65 let sqrt_pi_over_2 = F::from(1.2533141373155).unwrap();
66 let coeff = F::from(0.044715).unwrap();
67 let half = F::from(0.5).unwrap();
68 let one = F::one();
69
70 Zip::from(&mut output).for_each(|x| {
71 let x2 = *x * *x;
72 let inner = sqrt_pi_over_2 * *x * (one + coeff * x2);
73 *x = half * *x * (one + inner.tanh());
74 });
75 }
76
77 Ok(output)
78 }
79
80 fn backward(
81 &self,
82 grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
83 input: &Array<F, scirs2_core::ndarray::IxDyn>,
84 ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
85 let mut grad_input = Array::zeros(grad_output.raw_dim());
86
87 if self.fast {
88 let sqrt_2_over_pi = F::from(0.7978845608028654).unwrap();
89 let coeff = F::from(0.044715).unwrap();
90 let half = F::from(0.5).unwrap();
91 let one = F::one();
92 let three = F::from(3.0).unwrap();
93
94 Zip::from(&mut grad_input)
95 .and(grad_output)
96 .and(input)
97 .for_each(|grad_in, &grad_out, &x| {
98 let x2 = x * x;
99 let x3 = x2 * x;
100 let inner = sqrt_2_over_pi * (x + coeff * x3);
101 let tanh_inner = inner.tanh();
102 let sech_sq = one - tanh_inner * tanh_inner;
103 let d_inner_dx = sqrt_2_over_pi * (one + three * coeff * x2);
104 let dgelu_dx = half * (one + tanh_inner) + half * x * sech_sq * d_inner_dx;
105 *grad_in = grad_out * dgelu_dx;
106 });
107 } else {
108 let sqrt_pi_over_2 = F::from(1.2533141373155).unwrap();
109 let coeff = F::from(0.044715).unwrap();
110 let half = F::from(0.5).unwrap();
111 let one = F::one();
112 let three = F::from(3.0).unwrap();
113
114 Zip::from(&mut grad_input)
115 .and(grad_output)
116 .and(input)
117 .for_each(|grad_in, &grad_out, &x| {
118 let x2 = x * x;
119 let inner = sqrt_pi_over_2 * x * (one + coeff * x2);
120 let tanh_inner = inner.tanh();
121 let sech_sq = one - tanh_inner * tanh_inner;
122 let d_inner_dx = sqrt_pi_over_2 * (one + three * coeff * x2);
123 let dgelu_dx = half * (one + tanh_inner) + half * x * sech_sq * d_inner_dx;
124 *grad_in = grad_out * dgelu_dx;
125 });
126 }
127
128 Ok(grad_input)
129 }
130}
131
132#[derive(Debug, Clone, Copy)]
134pub struct Tanh;
135
136impl Tanh {
137 pub fn new() -> Self {
138 Self
139 }
140}
141
142impl Default for Tanh {
143 fn default() -> Self {
144 Self::new()
145 }
146}
147
148impl<F: Float + Debug> Activation<F> for Tanh {
149 fn forward(
150 &self,
151 input: &Array<F, scirs2_core::ndarray::IxDyn>,
152 ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
153 let mut output = input.clone();
154 Zip::from(&mut output).for_each(|x| {
155 *x = x.tanh();
156 });
157 Ok(output)
158 }
159
160 fn backward(
161 &self,
162 grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
163 input: &Array<F, scirs2_core::ndarray::IxDyn>,
164 ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
165 let mut grad_input = Array::zeros(grad_output.raw_dim());
166
167 Zip::from(&mut grad_input)
168 .and(grad_output)
169 .and(input)
170 .for_each(|grad_in, &grad_out, &x| {
171 let tanh_x = x.tanh();
172 let derivative = F::one() - tanh_x * tanh_x;
173 *grad_in = grad_out * derivative;
174 });
175
176 Ok(grad_input)
177 }
178}
179
180#[derive(Debug, Clone, Copy)]
182pub struct Sigmoid;
183
184impl Sigmoid {
185 pub fn new() -> Self {
186 Self
187 }
188}
189
190impl Default for Sigmoid {
191 fn default() -> Self {
192 Self::new()
193 }
194}
195
196impl<F: Float + Debug> Activation<F> for Sigmoid {
197 fn forward(
198 &self,
199 input: &Array<F, scirs2_core::ndarray::IxDyn>,
200 ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
201 let mut output = input.clone();
202 let one = F::one();
203 Zip::from(&mut output).for_each(|x| {
204 *x = one / (one + (-*x).exp());
205 });
206 Ok(output)
207 }
208
209 fn backward(
210 &self,
211 grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
212 input: &Array<F, scirs2_core::ndarray::IxDyn>,
213 ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
214 let mut grad_input = Array::zeros(grad_output.raw_dim());
215 let one = F::one();
216
217 Zip::from(&mut grad_input)
218 .and(grad_output)
219 .and(input)
220 .for_each(|grad_in, &grad_out, &x| {
221 let sigmoid_x = one / (one + (-x).exp());
222 let derivative = sigmoid_x * (one - sigmoid_x);
223 *grad_in = grad_out * derivative;
224 });
225
226 Ok(grad_input)
227 }
228}
229
230#[derive(Debug, Clone, Copy)]
232pub struct ReLU {
233 alpha: f64,
234}
235
236impl ReLU {
237 pub fn new() -> Self {
238 Self { alpha: 0.0 }
239 }
240
241 pub fn leaky(alpha: f64) -> Self {
242 Self { alpha }
243 }
244}
245
246impl Default for ReLU {
247 fn default() -> Self {
248 Self::new()
249 }
250}
251
252impl<F: Float + Debug> Activation<F> for ReLU {
253 fn forward(
254 &self,
255 input: &Array<F, scirs2_core::ndarray::IxDyn>,
256 ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
257 let mut output = input.clone();
258 let zero = F::zero();
259 let alpha = F::from(self.alpha).unwrap_or(zero);
260
261 Zip::from(&mut output).for_each(|x| {
262 if *x < zero {
263 *x = alpha * *x;
264 }
265 });
266 Ok(output)
267 }
268
269 fn backward(
270 &self,
271 grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
272 input: &Array<F, scirs2_core::ndarray::IxDyn>,
273 ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
274 let mut grad_input = Array::zeros(grad_output.raw_dim());
275 let zero = F::zero();
276 let one = F::one();
277 let alpha = F::from(self.alpha).unwrap_or(zero);
278
279 Zip::from(&mut grad_input)
280 .and(grad_output)
281 .and(input)
282 .for_each(|grad_in, &grad_out, &x| {
283 let derivative = if x > zero { one } else { alpha };
284 *grad_in = grad_out * derivative;
285 });
286
287 Ok(grad_input)
288 }
289}
290
291#[derive(Debug, Clone, Copy)]
293pub struct Softmax {
294 axis: isize,
295}
296
297impl Softmax {
298 pub fn new(axis: isize) -> Self {
299 Self { axis }
300 }
301}
302
303impl Default for Softmax {
304 fn default() -> Self {
305 Self::new(-1)
306 }
307}
308
309impl<F: Float + Debug> Activation<F> for Softmax {
310 fn forward(
311 &self,
312 input: &Array<F, scirs2_core::ndarray::IxDyn>,
313 ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
314 let mut output = input.clone();
315
316 if self.axis == -1 || self.axis as usize == input.ndim() - 1 {
318 let max_val = input.fold(F::neg_infinity(), |acc, &x| if x > acc { x } else { acc });
320
321 Zip::from(&mut output).for_each(|x| {
323 *x = (*x - max_val).exp();
324 });
325
326 let sum = output.sum();
328
329 Zip::from(&mut output).for_each(|x| {
331 *x = *x / sum;
332 });
333 }
334
335 Ok(output)
336 }
337
338 fn backward(
339 &self,
340 grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
341 input: &Array<F, scirs2_core::ndarray::IxDyn>,
342 ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
343 let softmax_output = self.forward(input)?;
345 let mut grad_input = Array::zeros(grad_output.raw_dim());
346
347 let sum_grad = Zip::from(&softmax_output)
349 .and(grad_output)
350 .fold(F::zero(), |acc, &s, &g| acc + s * g);
351
352 Zip::from(&mut grad_input)
353 .and(&softmax_output)
354 .and(grad_output)
355 .for_each(|grad_in, &s, &grad_out| {
356 *grad_in = s * (grad_out - sum_grad);
357 });
358
359 Ok(grad_input)
360 }
361}