tract_linalg/generic/
sigmoid.rs

1#![allow(clippy::excessive_precision)]
2use crate::frame::element_wise::ElementWiseKer;
3use tract_data::internal::*;
4
5pub fn ssigmoid(x: f32) -> f32 {
6    const LOW: f32 = -18.6;
7    const HIGH: f32 = -LOW;
8
9    const ALPHA_13: f32 = -4.433153405e-18;
10    const ALPHA_11: f32 = 1.169974371e-14;
11    const ALPHA_9: f32 = -1.875289645e-11;
12    const ALPHA_7: f32 = 4.257889523e-8;
13    const ALPHA_5: f32 = 0.00004811817576;
14    const ALPHA_3: f32 = 0.008163842030;
15    const ALPHA_1: f32 = 0.2499999971;
16    const BETA_6: f32 = 3.922935744e-6;
17    const BETA_4: f32 = 0.001524872358;
18    const BETA_2: f32 = 0.1159886749;
19    const BETA_0: f32 = 1.0;
20
21    let x = x.clamp(LOW, HIGH);
22
23    let x2 = x * x;
24
25    let p = ALPHA_13;
26    let p = x2 * p + ALPHA_11;
27    let p = x2 * p + ALPHA_9;
28    let p = x2 * p + ALPHA_7;
29    let p = x2 * p + ALPHA_5;
30    let p = x2 * p + ALPHA_3;
31    let p = x2 * p + ALPHA_1;
32    let p = p * x;
33
34    let q = BETA_6;
35    let q = x2 * q + BETA_4;
36    let q = x2 * q + BETA_2;
37    let q = x2 * q + BETA_0;
38
39    p / q + 0.5
40}
41
42pub fn hsigmoid(x: f16) -> f16 {
43    /*
44     * (x (0.249895 + x^2 (0.00400222 - 0.0000124702 x^2)))
45     * /
46     * (1. + 0.098734 x^2)
47     */
48
49    const LOW: f16 = f16::from_f32_const(-6.92);
50    const HIGH: f16 = f16::from_f32_const(6.92);
51
52    const ALPHA_5: f16 = f16::from_f32_const(-0.0000124702);
53    const ALPHA_3: f16 = f16::from_f32_const(0.00400222);
54    const ALPHA_1: f16 = f16::from_f32_const(0.249895);
55
56    const BETA_2: f16 = f16::from_f32_const(0.098734);
57    const BETA_0: f16 = f16::from_f32_const(1.0);
58
59    let x = x.clamp(LOW, HIGH);
60
61    let x2 = x * x;
62
63    let p = ALPHA_5;
64    let p = x2 * p + ALPHA_3;
65    let p = x2 * p + ALPHA_1;
66    let p = p * x;
67
68    let q = BETA_2;
69    let q = x2 * q + BETA_0;
70
71    p / q + f16::from_f32_const(0.5)
72}
73
74#[derive(Clone, Debug)]
75pub struct SSigmoid4;
76
77impl ElementWiseKer<f32> for SSigmoid4 {
78    fn name() -> &'static str {
79        "generic"
80    }
81
82    fn alignment_bytes() -> usize {
83        16
84    }
85
86    fn alignment_items() -> usize {
87        4
88    }
89
90    fn nr() -> usize {
91        4
92    }
93
94    fn run(x: &mut [f32], _: ()) {
95        debug_assert!(x.len() % Self::nr() == 0);
96        debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
97        x.iter_mut().for_each(|px| *px = ssigmoid(*px))
98    }
99}
100
101#[derive(Clone, Debug)]
102pub struct HSigmoid8;
103
104impl ElementWiseKer<f16> for HSigmoid8 {
105    fn name() -> &'static str {
106        "generic"
107    }
108
109    fn alignment_bytes() -> usize {
110        16
111    }
112
113    fn alignment_items() -> usize {
114        4
115    }
116
117    fn nr() -> usize {
118        8
119    }
120
121    fn run(x: &mut [f16], _: ()) {
122        debug_assert!(x.len() % Self::nr() == 0);
123        debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
124        x.iter_mut().for_each(|px| *px = hsigmoid(*px))
125    }
126}
127
128#[cfg(test)]
129#[macro_use]
130pub mod s {
131    sigmoid_frame_tests!(true, f32, crate::generic::sigmoid::SSigmoid4);
132}
133
134#[cfg(test)]
135#[macro_use]
136pub mod h {
137    sigmoid_frame_tests!(true, tract_data::internal::f16, crate::generic::sigmoid::HSigmoid8);
138}