tract_linalg/generic/
tanh.rs

1#![allow(clippy::excessive_precision)]
2use crate::frame::element_wise::ElementWiseKer;
3use tract_data::internal::*;
4
5pub fn stanh(x: f32) -> f32 {
6    const LOW: f32 = -8.9;
7    const HIGH: f32 = 8.9;
8
9    const ALPHA_13: f32 = -8.488492677e-14;
10    const ALPHA_11: f32 = 5.277853000e-11;
11    const ALPHA_9: f32 = -2.022500419e-8;
12    const ALPHA_7: f32 = 0.00001115424833;
13    const ALPHA_5: f32 = 0.003103950131;
14    const ALPHA_3: f32 = 0.1308400453;
15    const ALPHA_1: f32 = 0.9999999934;
16
17    const BETA_6: f32 = 0.0002546136580;
18    const BETA_4: f32 = 0.02449515379;
19    const BETA_2: f32 = 0.4641733162;
20    const BETA_0: f32 = 1.0;
21
22    let x = x.clamp(LOW, HIGH);
23
24    let x2 = x * x;
25
26    let p = ALPHA_13;
27    let p = x2 * p + ALPHA_11;
28    let p = x2 * p + ALPHA_9;
29    let p = x2 * p + ALPHA_7;
30    let p = x2 * p + ALPHA_5;
31    let p = x2 * p + ALPHA_3;
32    let p = x2 * p + ALPHA_1;
33    let p = p * x;
34
35    let q = BETA_6;
36    let q = x2 * q + BETA_4;
37    let q = x2 * q + BETA_2;
38    let q = x2 * q + BETA_0;
39
40    p / q
41}
42
43pub fn htanh(x: f16) -> f16 {
44    const LOW: f16 = f16::from_f32_const(-3.84);
45    const HIGH: f16 = f16::from_f32_const(3.84);
46
47    const ALPHA_3: f16 = f16::from_f32_const(0.082654955);
48    const ALPHA_1: f16 = f16::from_f32_const(0.99963124);
49
50    const BETA_4: f16 = f16::from_f32_const(0.0065383179);
51    const BETA_2: f16 = f16::from_f32_const(0.41401828);
52    const BETA_0: f16 = f16::from_f32_const(1.0);
53
54    let x = x.clamp(LOW, HIGH);
55
56    let x2 = x * x;
57
58    let p = ALPHA_3;
59    let p = x2 * p + ALPHA_1;
60    let p = p * x;
61
62    let q = BETA_4;
63    let q = x2 * q + BETA_2;
64    let q = x2 * q + BETA_0;
65
66    p / q
67}
68
69#[derive(Clone, Debug)]
70pub struct STanh4;
71
72impl ElementWiseKer<f32> for STanh4 {
73    fn name() -> &'static str {
74        "generic"
75    }
76
77    fn alignment_items() -> usize {
78        16
79    }
80
81    fn alignment_bytes() -> usize {
82        16
83    }
84
85    fn nr() -> usize {
86        4
87    }
88
89    fn run(x: &mut [f32], _: ()) {
90        debug_assert!(x.len() % Self::nr() == 0);
91        debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
92        x.iter_mut().for_each(|px| *px = stanh(*px))
93    }
94}
95
96#[cfg(test)]
97#[macro_use]
98pub mod s {
99    tanh_frame_tests!(true, f32, crate::generic::tanh::STanh4);
100}
101
102#[derive(Clone, Debug)]
103pub struct HTanh8;
104
105impl ElementWiseKer<f16> for HTanh8 {
106    fn name() -> &'static str {
107        "generic"
108    }
109
110    fn alignment_items() -> usize {
111        16
112    }
113
114    fn alignment_bytes() -> usize {
115        16
116    }
117
118    fn nr() -> usize {
119        8
120    }
121
122    fn run(x: &mut [f16], _: ()) {
123        debug_assert!(x.len() % Self::nr() == 0);
124        debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
125        x.iter_mut().for_each(|px| *px = htanh(*px))
126    }
127}
128
129#[cfg(test)]
130#[macro_use]
131pub mod h {
132    tanh_frame_tests!(true, tract_data::internal::f16, crate::generic::tanh::HTanh8);
133}