tract_linalg/generic/
gelu.rs1#![allow(clippy::excessive_precision)]
2use crate::frame::element_wise::ElementWiseKer;
3use tract_data::internal::*;
4
5const SQRT_2_OVER_PI: f32 = 0.7978845608028654;
14const COEF: f32 = 0.044715;
15
16#[derive(Clone, Debug)]
17pub struct SGelu4;
18
19impl ElementWiseKer<f32> for SGelu4 {
20 fn name() -> &'static str {
21 "generic"
22 }
23
24 fn alignment_bytes() -> usize {
25 16
26 }
27
28 fn alignment_items() -> usize {
29 4
30 }
31
32 fn nr() -> usize {
33 4
34 }
35
36 fn run(x: &mut [f32], _: ()) {
37 debug_assert!(x.len() % Self::nr() == 0);
38 debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
39 x.iter_mut().for_each(|px| {
40 let v = *px;
41 let inner = SQRT_2_OVER_PI * (v + COEF * v * v * v);
42 *px = 0.5 * v * (1.0 + inner.tanh());
43 });
44 }
45}
46
47#[derive(Clone, Debug)]
48pub struct HGelu8;
49
50impl ElementWiseKer<f16> for HGelu8 {
51 fn name() -> &'static str {
52 "generic"
53 }
54
55 fn alignment_bytes() -> usize {
56 16
57 }
58
59 fn alignment_items() -> usize {
60 4
61 }
62
63 fn nr() -> usize {
64 8
65 }
66
67 fn run(x: &mut [f16], _: ()) {
68 debug_assert!(x.len() % Self::nr() == 0);
69 debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
70 x.iter_mut().for_each(|px| {
71 let v = px.to_f32();
72 let inner = SQRT_2_OVER_PI * (v + COEF * v * v * v);
73 *px = f16::from_f32(0.5 * v * (1.0 + inner.tanh()));
74 });
75 }
76}
77
78#[cfg(test)]
79#[macro_use]
80pub mod s {
81 gelu_frame_tests!(true, f32, crate::generic::gelu::SGelu4);
82}
83
84#[cfg(test)]
85#[macro_use]
86pub mod h {
87 gelu_frame_tests!(true, tract_data::internal::f16, crate::generic::gelu::HGelu8);
88}