Skip to main content

tensorlogic_scirs_backend/
activations.rs

1//! Activation functions for neural network layers.
2//!
3//! Provides element-wise, output, gradient, and scalar activation functions
4//! backed by ndarray operations, as well as a unified [`ActivationType`] enum
5//! for dispatch and an [`ActivationBenchmark`] for statistical summaries.
6
7use scirs2_core::ndarray::{ArrayD, Zip};
8
9// ─────────────────────────────────────────────────────────────────────────────
10// Error type
11// ─────────────────────────────────────────────────────────────────────────────
12
13/// Errors that can arise during activation-function computation.
14#[derive(Debug, Clone)]
15pub enum ActivationError {
16    /// The input tensor has no elements.
17    EmptyInput,
18    /// A hyperparameter has an illegal value.
19    InvalidParameter {
20        name: String,
21        value: f64,
22        reason: String,
23    },
24    /// Tensor shapes are incompatible (e.g. PReLU weights vs. input).
25    ShapeMismatch {
26        expected: Vec<usize>,
27        got: Vec<usize>,
28    },
29}
30
31impl std::fmt::Display for ActivationError {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        match self {
34            Self::EmptyInput => write!(f, "activation: input tensor is empty"),
35            Self::InvalidParameter {
36                name,
37                value,
38                reason,
39            } => {
40                write!(
41                    f,
42                    "activation: invalid parameter '{name}' = {value}: {reason}"
43                )
44            }
45            Self::ShapeMismatch { expected, got } => {
46                write!(
47                    f,
48                    "activation: shape mismatch — expected {expected:?}, got {got:?}"
49                )
50            }
51        }
52    }
53}
54
55impl std::error::Error for ActivationError {}
56
57// ─────────────────────────────────────────────────────────────────────────────
58// Internal helpers
59// ─────────────────────────────────────────────────────────────────────────────
60
61/// Numerically-stable `erf` via Abramowitz & Stegun rational approximation
62/// (maximum error ≈ 1.5 × 10⁻⁷).
63#[inline]
64fn erf_approx(x: f64) -> f64 {
65    const A1: f64 = 0.278_393;
66    const A2: f64 = 0.230_389;
67    const A3: f64 = 0.000_972;
68    const A4: f64 = 0.078_108;
69    let sign = x.signum();
70    let x = x.abs();
71    let t = 1.0 / (1.0 + 0.47047 * x);
72    let poly = ((A4 * t + A3) * t + A2) * t + A1;
73    let result = 1.0 - poly * t * (-x * x).exp();
74    sign * result
75}
76
77#[inline]
78fn sigmoid_scalar_impl(x: f64) -> f64 {
79    1.0 / (1.0 + (-x).exp())
80}
81
82#[inline]
83fn softplus_scalar(x: f64, beta: f64) -> f64 {
84    // Use identity softplus(x) ≈ x for large x to avoid overflow.
85    let bx = beta * x;
86    if bx > 30.0 {
87        x
88    } else {
89        (1.0 + bx.exp()).ln() / beta
90    }
91}
92
93// ─────────────────────────────────────────────────────────────────────────────
94// Scalar helpers (public)
95// ─────────────────────────────────────────────────────────────────────────────
96
97/// Scalar ReLU.
98#[inline]
99pub fn relu_scalar(x: f64) -> f64 {
100    x.max(0.0)
101}
102
103/// Scalar GELU: `x * 0.5 * (1 + erf(x / sqrt(2)))`.
104#[inline]
105pub fn gelu_scalar(x: f64) -> f64 {
106    x * 0.5 * (1.0 + erf_approx(x / std::f64::consts::SQRT_2))
107}
108
109/// Scalar Swish / SiLU: `x * sigmoid(x)`.
110#[inline]
111pub fn swish_scalar(x: f64) -> f64 {
112    x * sigmoid_scalar_impl(x)
113}
114
115/// Scalar sigmoid: `1 / (1 + exp(-x))`.
116#[inline]
117pub fn sigmoid_scalar(x: f64) -> f64 {
118    sigmoid_scalar_impl(x)
119}
120
121// ─────────────────────────────────────────────────────────────────────────────
122// Element-wise activation functions
123// ─────────────────────────────────────────────────────────────────────────────
124
125/// Rectified Linear Unit: `max(0, x)`.
126pub fn relu(input: &ArrayD<f64>) -> ArrayD<f64> {
127    input.mapv(relu_scalar)
128}
129
130/// ReLU6: `min(max(0, x), 6)`.
131pub fn relu6(input: &ArrayD<f64>) -> ArrayD<f64> {
132    input.mapv(|x| x.clamp(0.0, 6.0))
133}
134
135/// Leaky ReLU: `x` if `x >= 0`, else `negative_slope * x`.
136pub fn leaky_relu(input: &ArrayD<f64>, negative_slope: f64) -> ArrayD<f64> {
137    input.mapv(|x| if x >= 0.0 { x } else { negative_slope * x })
138}
139
140/// Exponential Linear Unit: `x` if `x >= 0`, else `alpha * (exp(x) - 1)`.
141///
142/// Returns `ActivationError::InvalidParameter` when `alpha < 0`.
143pub fn elu(input: &ArrayD<f64>, alpha: f64) -> Result<ArrayD<f64>, ActivationError> {
144    if alpha < 0.0 {
145        return Err(ActivationError::InvalidParameter {
146            name: "alpha".into(),
147            value: alpha,
148            reason: "alpha must be non-negative for ELU".into(),
149        });
150    }
151    Ok(input.mapv(|x| if x >= 0.0 { x } else { alpha * (x.exp() - 1.0) }))
152}
153
154/// Scaled ELU with fixed constants: `scale * max(x, alpha*(exp(x)-1))`.
155///
156/// alpha = 1.6732632423543772, scale = 1.0507009873554805.
157pub fn selu(input: &ArrayD<f64>) -> ArrayD<f64> {
158    const ALPHA: f64 = 1.673_263_242_354_377_2;
159    const SCALE: f64 = 1.050_700_987_355_480_5;
160    input.mapv(|x| SCALE * if x >= 0.0 { x } else { ALPHA * (x.exp() - 1.0) })
161}
162
163/// Gaussian Error Linear Unit (exact): `x * 0.5 * (1 + erf(x / sqrt(2)))`.
164pub fn gelu(input: &ArrayD<f64>) -> ArrayD<f64> {
165    input.mapv(gelu_scalar)
166}
167
168/// GELU fast approximation via tanh:
169/// `0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
170pub fn gelu_approx(input: &ArrayD<f64>) -> ArrayD<f64> {
171    const C: f64 = 0.797_884_560_802_865_4; // sqrt(2/pi)
172    input.mapv(|x| {
173        let inner = C * (x + 0.044_715 * x * x * x);
174        0.5 * x * (1.0 + inner.tanh())
175    })
176}
177
178/// Swish / SiLU: `x * sigmoid(x)`.
179pub fn swish(input: &ArrayD<f64>) -> ArrayD<f64> {
180    input.mapv(swish_scalar)
181}
182
183/// Alias for [`swish`].
184pub fn silu(input: &ArrayD<f64>) -> ArrayD<f64> {
185    swish(input)
186}
187
188/// Mish: `x * tanh(ln(1 + exp(x)))`.
189pub fn mish(input: &ArrayD<f64>) -> ArrayD<f64> {
190    input.mapv(|x| {
191        let sp = softplus_scalar(x, 1.0);
192        x * sp.tanh()
193    })
194}
195
196/// Softplus: `(1/beta) * ln(1 + exp(beta * x))`.
197///
198/// Returns `ActivationError::InvalidParameter` when `beta <= 0`.
199pub fn softplus(input: &ArrayD<f64>, beta: f64) -> Result<ArrayD<f64>, ActivationError> {
200    if beta <= 0.0 {
201        return Err(ActivationError::InvalidParameter {
202            name: "beta".into(),
203            value: beta,
204            reason: "beta must be positive for Softplus".into(),
205        });
206    }
207    Ok(input.mapv(|x| softplus_scalar(x, beta)))
208}
209
210/// Softsign: `x / (1 + |x|)`.
211pub fn softsign(input: &ArrayD<f64>) -> ArrayD<f64> {
212    input.mapv(|x| x / (1.0 + x.abs()))
213}
214
215/// Hard-Swish: `x * relu6(x + 3) / 6`.
216pub fn hardswish(input: &ArrayD<f64>) -> ArrayD<f64> {
217    input.mapv(|x| x * (x + 3.0).clamp(0.0, 6.0) / 6.0)
218}
219
220/// Hard-Sigmoid: `relu6(x + 3) / 6`.
221pub fn hardsigmoid(input: &ArrayD<f64>) -> ArrayD<f64> {
222    input.mapv(|x| (x + 3.0).clamp(0.0, 6.0) / 6.0)
223}
224
225/// Sigmoid: `1 / (1 + exp(-x))`.
226pub fn sigmoid(input: &ArrayD<f64>) -> ArrayD<f64> {
227    input.mapv(sigmoid_scalar_impl)
228}
229
230/// Hyperbolic tangent activation (renamed to avoid conflict with `f64::tanh`).
231pub fn tanh_activation(input: &ArrayD<f64>) -> ArrayD<f64> {
232    input.mapv(|x| x.tanh())
233}
234
235/// Parametric ReLU: `x` if `x >= 0`, else `weights[channel] * x`.
236///
237/// `weights` must broadcast along axis-0 of `input` (i.e. its total number of
238/// elements equals the number of channels = `input.shape()[0]`).  For a 1-D
239/// input the weights tensor must have a single element.
240pub fn prelu(input: &ArrayD<f64>, weights: &ArrayD<f64>) -> Result<ArrayD<f64>, ActivationError> {
241    if input.is_empty() {
242        return Err(ActivationError::EmptyInput);
243    }
244
245    // Determine channel count: axis 0 for ndim >= 1, else 1.
246    let channels = if input.ndim() == 0 {
247        1
248    } else {
249        input.shape()[0]
250    };
251    let w_len = weights.len();
252
253    if w_len != channels && w_len != 1 {
254        return Err(ActivationError::ShapeMismatch {
255            expected: vec![channels],
256            got: weights.shape().to_vec(),
257        });
258    }
259
260    let weights_flat: Vec<f64> = weights.iter().copied().collect();
261    let get_w = |ch: usize| -> f64 {
262        if w_len == 1 {
263            weights_flat[0]
264        } else {
265            weights_flat[ch]
266        }
267    };
268
269    if input.ndim() <= 1 {
270        // 0-D or 1-D: channel index = element index (or 0)
271        let out: Vec<f64> = input
272            .iter()
273            .enumerate()
274            .map(|(i, &x)| {
275                let ch = if w_len == 1 { 0 } else { i };
276                if x >= 0.0 {
277                    x
278                } else {
279                    get_w(ch) * x
280                }
281            })
282            .collect();
283        return Ok(ArrayD::from_shape_vec(input.raw_dim(), out)
284            .unwrap_or_else(|_| input.mapv(relu_scalar)));
285    }
286
287    // N-D: channel = first axis index
288    let shape = input.shape().to_vec();
289    let mut result = input.clone();
290    let stride: usize = shape[1..].iter().product();
291
292    for (idx, val) in result.iter_mut().enumerate() {
293        let ch = (idx / stride) % channels;
294        if *val < 0.0 {
295            *val *= get_w(ch);
296        }
297    }
298    Ok(result)
299}
300
301// ─────────────────────────────────────────────────────────────────────────────
302// Output activations
303// ─────────────────────────────────────────────────────────────────────────────
304
305/// Softmax along `axis`.  Subtracts the max for numerical stability.
306pub fn softmax(input: &ArrayD<f64>, axis: usize) -> Result<ArrayD<f64>, ActivationError> {
307    if input.is_empty() {
308        return Err(ActivationError::EmptyInput);
309    }
310    if axis >= input.ndim() {
311        return Err(ActivationError::InvalidParameter {
312            name: "axis".into(),
313            value: axis as f64,
314            reason: format!("axis {} out of range for ndim {}", axis, input.ndim()),
315        });
316    }
317
318    // max along axis for stability
319    let max_vals = input.map_axis(scirs2_core::ndarray::Axis(axis), |lane| {
320        lane.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
321    });
322
323    let mut shifted = input.clone();
324    // Broadcast-subtract max along the given axis
325    Zip::from(&mut shifted)
326        .and_broadcast(&max_vals.insert_axis(scirs2_core::ndarray::Axis(axis)))
327        .for_each(|s, &m| *s -= m);
328
329    let mut exped = shifted.mapv(f64::exp);
330
331    let sum_vals = exped.map_axis(scirs2_core::ndarray::Axis(axis), |lane| {
332        lane.iter().cloned().sum::<f64>()
333    });
334
335    Zip::from(&mut exped)
336        .and_broadcast(&sum_vals.insert_axis(scirs2_core::ndarray::Axis(axis)))
337        .for_each(|e, &s| *e /= s);
338
339    Ok(exped)
340}
341
342/// Numerically stable log-softmax along `axis`.
343pub fn log_softmax(input: &ArrayD<f64>, axis: usize) -> Result<ArrayD<f64>, ActivationError> {
344    if input.is_empty() {
345        return Err(ActivationError::EmptyInput);
346    }
347    if axis >= input.ndim() {
348        return Err(ActivationError::InvalidParameter {
349            name: "axis".into(),
350            value: axis as f64,
351            reason: format!("axis {} out of range for ndim {}", axis, input.ndim()),
352        });
353    }
354
355    let max_vals = input.map_axis(scirs2_core::ndarray::Axis(axis), |lane| {
356        lane.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
357    });
358
359    let mut shifted = input.clone();
360    Zip::from(&mut shifted)
361        .and_broadcast(&max_vals.insert_axis(scirs2_core::ndarray::Axis(axis)))
362        .for_each(|s, &m| *s -= m);
363
364    let log_sum_exp = shifted
365        .mapv(f64::exp)
366        .map_axis(scirs2_core::ndarray::Axis(axis), |lane| {
367            lane.iter().cloned().sum::<f64>().ln()
368        });
369
370    Zip::from(&mut shifted)
371        .and_broadcast(&log_sum_exp.insert_axis(scirs2_core::ndarray::Axis(axis)))
372        .for_each(|s, &lse| *s -= lse);
373
374    Ok(shifted)
375}
376
377// ─────────────────────────────────────────────────────────────────────────────
378// Gradient functions
379// ─────────────────────────────────────────────────────────────────────────────
380
381/// ReLU gradient: `grad_output` where `input > 0`, else `0`.
382pub fn relu_grad(input: &ArrayD<f64>, grad_output: &ArrayD<f64>) -> ArrayD<f64> {
383    let mut out = grad_output.clone();
384    Zip::from(&mut out).and(input).for_each(|g, &x| {
385        if x <= 0.0 {
386            *g = 0.0;
387        }
388    });
389    out
390}
391
392/// Sigmoid gradient: `output * (1 - output) * grad_output`.
393///
394/// `output` should be the **result** of `sigmoid(x)`, not the raw input.
395pub fn sigmoid_grad(output: &ArrayD<f64>, grad_output: &ArrayD<f64>) -> ArrayD<f64> {
396    let mut out = grad_output.clone();
397    Zip::from(&mut out)
398        .and(output)
399        .for_each(|g, &s| *g *= s * (1.0 - s));
400    out
401}
402
403/// Tanh gradient: `(1 - output^2) * grad_output`.
404///
405/// `output` should be the **result** of `tanh(x)`, not the raw input.
406pub fn tanh_grad(output: &ArrayD<f64>, grad_output: &ArrayD<f64>) -> ArrayD<f64> {
407    let mut out = grad_output.clone();
408    Zip::from(&mut out)
409        .and(output)
410        .for_each(|g, &t| *g *= 1.0 - t * t);
411    out
412}
413
414// ─────────────────────────────────────────────────────────────────────────────
415// Unified dispatch enum
416// ─────────────────────────────────────────────────────────────────────────────
417
418/// Enumeration of supported activation functions for unified dispatch.
419#[derive(Debug, Clone, PartialEq)]
420pub enum ActivationType {
421    Relu,
422    Relu6,
423    LeakyRelu(f64),
424    Elu(f64),
425    Selu,
426    Gelu,
427    GeluApprox,
428    Swish,
429    Mish,
430    Softplus(f64),
431    Softsign,
432    Hardswish,
433    Hardsigmoid,
434    Sigmoid,
435    Tanh,
436}
437
438impl ActivationType {
439    /// Apply this activation to `input`.
440    pub fn apply(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>, ActivationError> {
441        match self {
442            Self::Relu => Ok(relu(input)),
443            Self::Relu6 => Ok(relu6(input)),
444            Self::LeakyRelu(s) => Ok(leaky_relu(input, *s)),
445            Self::Elu(a) => elu(input, *a),
446            Self::Selu => Ok(selu(input)),
447            Self::Gelu => Ok(gelu(input)),
448            Self::GeluApprox => Ok(gelu_approx(input)),
449            Self::Swish => Ok(swish(input)),
450            Self::Mish => Ok(mish(input)),
451            Self::Softplus(b) => softplus(input, *b),
452            Self::Softsign => Ok(softsign(input)),
453            Self::Hardswish => Ok(hardswish(input)),
454            Self::Hardsigmoid => Ok(hardsigmoid(input)),
455            Self::Sigmoid => Ok(sigmoid(input)),
456            Self::Tanh => Ok(tanh_activation(input)),
457        }
458    }
459
460    /// Human-readable name of this activation.
461    pub fn name(&self) -> &'static str {
462        match self {
463            Self::Relu => "relu",
464            Self::Relu6 => "relu6",
465            Self::LeakyRelu(_) => "leaky_relu",
466            Self::Elu(_) => "elu",
467            Self::Selu => "selu",
468            Self::Gelu => "gelu",
469            Self::GeluApprox => "gelu_approx",
470            Self::Swish => "swish",
471            Self::Mish => "mish",
472            Self::Softplus(_) => "softplus",
473            Self::Softsign => "softsign",
474            Self::Hardswish => "hardswish",
475            Self::Hardsigmoid => "hardsigmoid",
476            Self::Sigmoid => "sigmoid",
477            Self::Tanh => "tanh",
478        }
479    }
480
481    /// Whether this activation is a monotonically non-decreasing function.
482    pub fn is_monotone(&self) -> bool {
483        matches!(
484            self,
485            Self::Relu
486                | Self::Relu6
487                | Self::LeakyRelu(_)
488                | Self::Elu(_)
489                | Self::Selu
490                | Self::Gelu
491                | Self::GeluApprox
492                | Self::Swish
493                | Self::Softplus(_)
494                | Self::Softsign
495                | Self::Sigmoid
496                | Self::Tanh
497        )
498    }
499
500    /// Approximate output range `(min, max)`.
501    pub fn output_range(&self) -> (f64, f64) {
502        match self {
503            Self::Relu => (0.0, f64::INFINITY),
504            Self::Relu6 => (0.0, 6.0),
505            Self::LeakyRelu(_) => (f64::NEG_INFINITY, f64::INFINITY),
506            Self::Elu(_) | Self::Selu => (f64::NEG_INFINITY, f64::INFINITY),
507            Self::Gelu | Self::GeluApprox => (f64::NEG_INFINITY, f64::INFINITY),
508            Self::Swish | Self::Mish => (f64::NEG_INFINITY, f64::INFINITY),
509            Self::Softplus(_) => (0.0, f64::INFINITY),
510            Self::Softsign => (-1.0, 1.0),
511            Self::Hardswish => (f64::NEG_INFINITY, f64::INFINITY),
512            Self::Hardsigmoid => (0.0, 1.0),
513            Self::Sigmoid => (0.0, 1.0),
514            Self::Tanh => (-1.0, 1.0),
515        }
516    }
517}
518
519// ─────────────────────────────────────────────────────────────────────────────
520// Benchmark helper
521// ─────────────────────────────────────────────────────────────────────────────
522
523/// Statistical summary of an activation applied to a sample input.
524#[derive(Debug, Clone)]
525pub struct ActivationBenchmark {
526    pub name: String,
527    pub input_size: usize,
528    pub mean_output: f64,
529    pub std_output: f64,
530    pub min_output: f64,
531    pub max_output: f64,
532}
533
534impl ActivationBenchmark {
535    /// Run `activation` on `input` and collect statistics.
536    pub fn compute(
537        activation: &ActivationType,
538        input: &ArrayD<f64>,
539    ) -> Result<Self, ActivationError> {
540        if input.is_empty() {
541            return Err(ActivationError::EmptyInput);
542        }
543        let output = activation.apply(input)?;
544        let n = output.len() as f64;
545        let values: Vec<f64> = output.iter().copied().collect();
546
547        let mean = values.iter().sum::<f64>() / n;
548        let variance = values.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / n;
549        let std_output = variance.sqrt();
550        let min_output = values.iter().cloned().fold(f64::INFINITY, f64::min);
551        let max_output = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
552
553        Ok(Self {
554            name: activation.name().to_owned(),
555            input_size: input.len(),
556            mean_output: mean,
557            std_output,
558            min_output,
559            max_output,
560        })
561    }
562
563    /// One-line human-readable summary.
564    pub fn summary(&self) -> String {
565        format!(
566            "{} [n={}] mean={:.4} std={:.4} min={:.4} max={:.4}",
567            self.name,
568            self.input_size,
569            self.mean_output,
570            self.std_output,
571            self.min_output,
572            self.max_output,
573        )
574    }
575}
576
577// ─────────────────────────────────────────────────────────────────────────────
578// Tests
579// ─────────────────────────────────────────────────────────────────────────────
580
581#[cfg(test)]
582mod tests {
583    use super::*;
584    use scirs2_core::ndarray::{arr1, Array2};
585
586    const EPS: f64 = 1e-6;
587
588    fn arr(v: &[f64]) -> ArrayD<f64> {
589        arr1(v).into_dyn()
590    }
591
592    fn check_close(a: f64, b: f64, eps: f64, msg: &str) {
593        assert!((a - b).abs() < eps, "{msg}: |{a} - {b}| >= {eps}");
594    }
595
596    #[test]
597    fn test_relu_zeros_negative() {
598        let input = arr(&[-3.0, -1.0, 0.0]);
599        let out = relu(&input);
600        for &v in out.iter() {
601            assert_eq!(v, 0.0, "ReLU of non-positive must be 0");
602        }
603    }
604
605    #[test]
606    fn test_relu_positive_unchanged() {
607        let input = arr(&[1.0, 2.5, 100.0]);
608        let out = relu(&input);
609        for (&i, &o) in input.iter().zip(out.iter()) {
610            assert_eq!(i, o, "ReLU must preserve positive values");
611        }
612    }
613
614    #[test]
615    fn test_relu6_clamp() {
616        let input = arr(&[7.0, 6.0, 5.0, -1.0]);
617        let out = relu6(&input);
618        assert_eq!(out[0], 6.0, "values > 6 must be clamped to 6");
619        assert_eq!(out[1], 6.0);
620        assert_eq!(
621            out[2], 5.0,
622            "values <= 6 must be unchanged (if non-negative)"
623        );
624        assert_eq!(out[3], 0.0, "negative values must be 0");
625    }
626
627    #[test]
628    fn test_leaky_relu_negative_slope() {
629        let slope = 0.1;
630        let input = arr(&[-4.0, -1.0, 0.0, 2.0]);
631        let out = leaky_relu(&input, slope);
632        check_close(out[0], -0.4, EPS, "leaky_relu(-4, 0.1)");
633        check_close(out[1], -0.1, EPS, "leaky_relu(-1, 0.1)");
634        check_close(out[2], 0.0, EPS, "leaky_relu(0, 0.1)");
635        check_close(out[3], 2.0, EPS, "leaky_relu(2, 0.1)");
636    }
637
638    #[test]
639    fn test_elu_positive_unchanged() {
640        let input = arr(&[0.5, 1.0, 3.0]);
641        let out = elu(&input, 1.0).expect("elu should succeed");
642        for (&i, &o) in input.iter().zip(out.iter()) {
643            check_close(i, o, EPS, "ELU positive must be identity");
644        }
645    }
646
647    #[test]
648    fn test_elu_negative_approaches_minus_alpha() {
649        let alpha = 1.0;
650        let input = arr(&[-50.0]);
651        let out = elu(&input, alpha).expect("elu should succeed");
652        // alpha*(exp(-50) - 1) ≈ -alpha
653        check_close(
654            out[0],
655            -alpha,
656            1e-10,
657            "ELU large-negative approaches -alpha",
658        );
659    }
660
661    #[test]
662    fn test_selu_scale() {
663        const SCALE: f64 = 1.050_700_987_355_480_5;
664        let input = arr(&[1.0, 2.0, 3.0]);
665        let out = selu(&input);
666        for (&i, &o) in input.iter().zip(out.iter()) {
667            check_close(o, SCALE * i, EPS, "SELU positive = scale * x");
668        }
669    }
670
671    #[test]
672    fn test_gelu_near_zero() {
673        let input = arr(&[0.0]);
674        let out = gelu(&input);
675        check_close(out[0], 0.0, EPS, "gelu(0) must be 0");
676    }
677
678    #[test]
679    fn test_gelu_positive() {
680        // For large positive x, gelu(x) ≈ x
681        let x = 10.0_f64;
682        let result = gelu_scalar(x);
683        check_close(result, x, 1e-4, "gelu(large positive) ≈ large positive");
684    }
685
686    #[test]
687    fn test_swish_zero() {
688        let input = arr(&[0.0]);
689        let out = swish(&input);
690        check_close(out[0], 0.0, EPS, "swish(0) must be 0");
691    }
692
693    #[test]
694    fn test_sigmoid_midpoint() {
695        let input = arr(&[0.0]);
696        let out = sigmoid(&input);
697        check_close(out[0], 0.5, EPS, "sigmoid(0) must be 0.5");
698    }
699
700    #[test]
701    fn test_softmax_sums_to_one() {
702        let data = Array2::from_shape_vec((2, 4), vec![1.0, 2.0, 3.0, 4.0, 0.5, 1.5, 2.5, 3.5])
703            .expect("shape ok")
704            .into_dyn();
705        let out = softmax(&data, 1).expect("softmax ok");
706        // Each row must sum to 1
707        for row_idx in 0..2_usize {
708            let row_sum: f64 = (0..4).map(|c| out[[row_idx, c]]).sum();
709            check_close(row_sum, 1.0, EPS, "softmax row sum");
710        }
711    }
712
713    #[test]
714    fn test_log_softmax_matches() {
715        let data = arr(&[1.0, 2.0, 3.0, 4.0]);
716        let sm = softmax(&data, 0).expect("softmax ok");
717        let lsm = log_softmax(&data, 0).expect("log_softmax ok");
718        for (&s, &ls) in sm.iter().zip(lsm.iter()) {
719            check_close(s.ln(), ls, 1e-9, "log(softmax) == log_softmax");
720        }
721    }
722
723    #[test]
724    fn test_relu_grad_mask() {
725        let input = arr(&[-2.0, 0.0, 3.0]);
726        let grad = arr(&[1.0, 1.0, 1.0]);
727        let out = relu_grad(&input, &grad);
728        assert_eq!(out[0], 0.0, "grad must be 0 for negative input");
729        assert_eq!(out[1], 0.0, "grad must be 0 for zero input");
730        assert_eq!(out[2], 1.0, "grad must pass through for positive input");
731    }
732
733    #[test]
734    fn test_sigmoid_grad_formula() {
735        // sigmoid_grad at x=0: s=0.5, s*(1-s)=0.25
736        let s_out = arr(&[0.5]);
737        let grad = arr(&[2.0]);
738        let out = sigmoid_grad(&s_out, &grad);
739        check_close(out[0], 0.5, EPS, "sigmoid_grad(0.5) * 2.0 == 0.5");
740    }
741
742    #[test]
743    fn test_activation_type_apply_relu() {
744        let input = arr(&[-1.0, 0.0, 1.0, 2.0]);
745        let expected = relu(&input);
746        let got = ActivationType::Relu.apply(&input).expect("apply ok");
747        for (&e, &g) in expected.iter().zip(got.iter()) {
748            check_close(e, g, EPS, "ActivationType::Relu.apply == relu");
749        }
750    }
751
752    #[test]
753    fn test_activation_type_name() {
754        let variants = [
755            ActivationType::Relu,
756            ActivationType::Relu6,
757            ActivationType::LeakyRelu(0.1),
758            ActivationType::Elu(1.0),
759            ActivationType::Selu,
760            ActivationType::Gelu,
761            ActivationType::GeluApprox,
762            ActivationType::Swish,
763            ActivationType::Mish,
764            ActivationType::Softplus(1.0),
765            ActivationType::Softsign,
766            ActivationType::Hardswish,
767            ActivationType::Hardsigmoid,
768            ActivationType::Sigmoid,
769            ActivationType::Tanh,
770        ];
771        for v in &variants {
772            assert!(!v.name().is_empty(), "name must not be empty: {:?}", v);
773        }
774    }
775
776    #[test]
777    fn test_activation_type_output_range() {
778        // Check that the range min <= max for all variants
779        let variants = [
780            ActivationType::Relu,
781            ActivationType::Relu6,
782            ActivationType::Softsign,
783            ActivationType::Hardsigmoid,
784            ActivationType::Sigmoid,
785            ActivationType::Tanh,
786            ActivationType::Softplus(1.0),
787        ];
788        for v in &variants {
789            let (lo, hi) = v.output_range();
790            assert!(lo <= hi, "output_range lo <= hi for {:?}", v);
791        }
792        // Bounded activations
793        let (lo, hi) = ActivationType::Relu6.output_range();
794        assert_eq!(lo, 0.0);
795        assert_eq!(hi, 6.0);
796        let (lo, hi) = ActivationType::Sigmoid.output_range();
797        assert_eq!(lo, 0.0);
798        assert_eq!(hi, 1.0);
799    }
800
801    #[test]
802    fn test_activation_benchmark_compute() {
803        let input = arr(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
804        let bench =
805            ActivationBenchmark::compute(&ActivationType::Relu, &input).expect("benchmark ok");
806        assert_eq!(bench.name, "relu");
807        assert_eq!(bench.input_size, 5);
808        assert!(bench.min_output >= 0.0, "ReLU output must be non-negative");
809        assert!(bench.max_output >= bench.min_output);
810        assert!(!bench.summary().is_empty());
811    }
812
813    #[test]
814    fn test_hardswish_bounds() {
815        // hardswish(x) = x * relu6(x+3) / 6
816        // For x <= -3: relu6(-3+3)=0, so output=0
817        // For x >= 3:  relu6(3+3)=6, so output=x
818        let input = arr(&[-10.0, -3.0, 0.0, 3.0, 10.0]);
819        let out = hardswish(&input);
820        check_close(out[0], 0.0, EPS, "hardswish(-10) = 0");
821        check_close(out[1], 0.0, EPS, "hardswish(-3) = 0");
822        // x=0: 0 * relu6(3)/6 = 0 * 1 = 0... actually = 0*3/6 = 0... wait
823        // hardswish(0) = 0 * relu6(3)/6 = 0
824        check_close(out[2], 0.0, EPS, "hardswish(0) = 0");
825        // x=3: 3 * relu6(6)/6 = 3 * 1 = 3
826        check_close(out[3], 3.0, EPS, "hardswish(3) = 3");
827        // x=10: 10 * relu6(13)/6 = 10 * 1 = 10
828        check_close(out[4], 10.0, EPS, "hardswish(10) = 10");
829    }
830}