tract_linalg/generic/
tanh.rs1#![allow(clippy::excessive_precision)]
2use crate::frame::element_wise::ElementWiseKer;
3use tract_data::internal::*;
4
5pub fn stanh(x: f32) -> f32 {
6 const LOW: f32 = -8.9;
7 const HIGH: f32 = 8.9;
8
9 const ALPHA_13: f32 = -8.488492677e-14;
10 const ALPHA_11: f32 = 5.277853000e-11;
11 const ALPHA_9: f32 = -2.022500419e-8;
12 const ALPHA_7: f32 = 0.00001115424833;
13 const ALPHA_5: f32 = 0.003103950131;
14 const ALPHA_3: f32 = 0.1308400453;
15 const ALPHA_1: f32 = 0.9999999934;
16
17 const BETA_6: f32 = 0.0002546136580;
18 const BETA_4: f32 = 0.02449515379;
19 const BETA_2: f32 = 0.4641733162;
20 const BETA_0: f32 = 1.0;
21
22 let x = x.clamp(LOW, HIGH);
23
24 let x2 = x * x;
25
26 let p = ALPHA_13;
27 let p = x2 * p + ALPHA_11;
28 let p = x2 * p + ALPHA_9;
29 let p = x2 * p + ALPHA_7;
30 let p = x2 * p + ALPHA_5;
31 let p = x2 * p + ALPHA_3;
32 let p = x2 * p + ALPHA_1;
33 let p = p * x;
34
35 let q = BETA_6;
36 let q = x2 * q + BETA_4;
37 let q = x2 * q + BETA_2;
38 let q = x2 * q + BETA_0;
39
40 p / q
41}
42
43pub fn htanh(x: f16) -> f16 {
44 const LOW: f16 = f16::from_f32_const(-3.84);
45 const HIGH: f16 = f16::from_f32_const(3.84);
46
47 const ALPHA_3: f16 = f16::from_f32_const(0.082654955);
48 const ALPHA_1: f16 = f16::from_f32_const(0.99963124);
49
50 const BETA_4: f16 = f16::from_f32_const(0.0065383179);
51 const BETA_2: f16 = f16::from_f32_const(0.41401828);
52 const BETA_0: f16 = f16::from_f32_const(1.0);
53
54 let x = x.clamp(LOW, HIGH);
55
56 let x2 = x * x;
57
58 let p = ALPHA_3;
59 let p = x2 * p + ALPHA_1;
60 let p = p * x;
61
62 let q = BETA_4;
63 let q = x2 * q + BETA_2;
64 let q = x2 * q + BETA_0;
65
66 p / q
67}
68
69#[derive(Clone, Debug)]
70pub struct STanh4;
71
72impl ElementWiseKer<f32> for STanh4 {
73 fn name() -> &'static str {
74 "generic"
75 }
76
77 fn alignment_items() -> usize {
78 16
79 }
80
81 fn alignment_bytes() -> usize {
82 16
83 }
84
85 fn nr() -> usize {
86 4
87 }
88
89 fn run(x: &mut [f32], _: ()) {
90 debug_assert!(x.len() % Self::nr() == 0);
91 debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
92 x.iter_mut().for_each(|px| *px = stanh(*px))
93 }
94}
95
96#[cfg(test)]
97#[macro_use]
98pub mod s {
99 tanh_frame_tests!(true, f32, crate::generic::tanh::STanh4);
100}
101
102#[derive(Clone, Debug)]
103pub struct HTanh8;
104
105impl ElementWiseKer<f16> for HTanh8 {
106 fn name() -> &'static str {
107 "generic"
108 }
109
110 fn alignment_items() -> usize {
111 16
112 }
113
114 fn alignment_bytes() -> usize {
115 16
116 }
117
118 fn nr() -> usize {
119 8
120 }
121
122 fn run(x: &mut [f16], _: ()) {
123 debug_assert!(x.len() % Self::nr() == 0);
124 debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
125 x.iter_mut().for_each(|px| *px = htanh(*px))
126 }
127}
128
129#[cfg(test)]
130#[macro_use]
131pub mod h {
132 tanh_frame_tests!(true, tract_data::internal::f16, crate::generic::tanh::HTanh8);
133}