1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#![allow(clippy::excessive_precision)]
use crate::frame::element_wise::ElementWiseKer;
use tract_data::internal::*;

pub fn stanh(x: f32) -> f32 {
    const LOW: f32 = -8.9;
    const HIGH: f32 = 8.9;

    const ALPHA_13: f32 = -8.488492677e-14;
    const ALPHA_11: f32 = 5.277853000e-11;
    const ALPHA_9: f32 = -2.022500419e-8;
    const ALPHA_7: f32 = 0.00001115424833;
    const ALPHA_5: f32 = 0.003103950131;
    const ALPHA_3: f32 = 0.1308400453;
    const ALPHA_1: f32 = 0.9999999934;

    const BETA_6: f32 = 0.0002546136580;
    const BETA_4: f32 = 0.02449515379;
    const BETA_2: f32 = 0.4641733162;
    const BETA_0: f32 = 1.0;

    let x = x.clamp(LOW, HIGH);

    let x2 = x * x;

    let p = ALPHA_13;
    let p = x2 * p + ALPHA_11;
    let p = x2 * p + ALPHA_9;
    let p = x2 * p + ALPHA_7;
    let p = x2 * p + ALPHA_5;
    let p = x2 * p + ALPHA_3;
    let p = x2 * p + ALPHA_1;
    let p = p * x;

    let q = BETA_6;
    let q = x2 * q + BETA_4;
    let q = x2 * q + BETA_2;
    let q = x2 * q + BETA_0;

    p / q
}

pub fn htanh(x: f16) -> f16 {
    const LOW: f16 = f16::from_f32_const(-3.84);
    const HIGH: f16 = f16::from_f32_const(3.84);

    const ALPHA_3: f16 = f16::from_f32_const(0.082654955);
    const ALPHA_1: f16 = f16::from_f32_const(0.99963124);

    const BETA_4: f16 = f16::from_f32_const(0.0065383179);
    const BETA_2: f16 = f16::from_f32_const(0.41401828);
    const BETA_0: f16 = f16::from_f32_const(1.0);

    let x = x.clamp(LOW, HIGH);

    let x2 = x * x;

    let p = ALPHA_3;
    let p = x2 * p + ALPHA_1;
    let p = p * x;

    let q = BETA_4;
    let q = x2 * q + BETA_2;
    let q = x2 * q + BETA_0;

    p / q
}

#[derive(Clone, Debug)]
pub struct STanh4;

impl ElementWiseKer<f32> for STanh4 {
    fn name() -> &'static str {
        "generic"
    }

    fn alignment_items() -> usize {
        16
    }

    fn alignment_bytes() -> usize {
        16
    }

    fn nr() -> usize {
        4
    }

    fn run(x: &mut [f32], _: ()) {
        debug_assert!(x.len() % Self::nr() == 0);
        debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
        x.iter_mut().for_each(|px| *px = stanh(*px))
    }
}

#[cfg(test)]
#[macro_use]
pub mod s {
    tanh_frame_tests!(true, f32, crate::generic::tanh::STanh4);
}

#[derive(Clone, Debug)]
pub struct HTanh8;

impl ElementWiseKer<f16> for HTanh8 {
    fn name() -> &'static str {
        "generic"
    }

    fn alignment_items() -> usize {
        16
    }

    fn alignment_bytes() -> usize {
        16
    }

    fn nr() -> usize {
        8
    }

    fn run(x: &mut [f16], _: ()) {
        debug_assert!(x.len() % Self::nr() == 0);
        debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
        x.iter_mut().for_each(|px| *px = htanh(*px))
    }
}

#[cfg(test)]
#[macro_use]
pub mod h {
    tanh_frame_tests!(true, tract_data::internal::f16, crate::generic::tanh::HTanh8);
}