Skip to main content

tract_linalg/generic/
hardswish.rs

1#![allow(clippy::excessive_precision)]
2use crate::frame::element_wise::ElementWiseKer;
3use tract_data::internal::*;
4use tract_num_traits::Zero;
5
6#[derive(Clone, Debug)]
7pub struct SHardSwish4;
8
9impl ElementWiseKer<f32> for SHardSwish4 {
10    fn name() -> &'static str {
11        "generic"
12    }
13
14    fn alignment_bytes() -> usize {
15        16
16    }
17
18    fn alignment_items() -> usize {
19        4
20    }
21
22    fn nr() -> usize {
23        4
24    }
25
26    fn run(x: &mut [f32], _: ()) {
27        debug_assert!(x.len() % Self::nr() == 0);
28        debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
29        const INV6: f32 = 1.0 / 6.0;
30        x.iter_mut().for_each(|px| {
31            let relu6 = ((*px + 3.0).min(6.0)).max(0.0);
32            *px = *px * relu6 * INV6;
33        });
34    }
35}
36
37#[derive(Clone, Debug)]
38pub struct HHardSwish8;
39
40impl ElementWiseKer<f16> for HHardSwish8 {
41    fn name() -> &'static str {
42        "generic"
43    }
44
45    fn alignment_bytes() -> usize {
46        16
47    }
48
49    fn alignment_items() -> usize {
50        4
51    }
52
53    fn nr() -> usize {
54        8
55    }
56
57    fn run(x: &mut [f16], _: ()) {
58        debug_assert!(x.len() % Self::nr() == 0);
59        debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
60        let three = f16::from_f32(3.0);
61        let six = f16::from_f32(6.0);
62        let inv6 = f16::from_f32(1.0 / 6.0);
63        x.iter_mut().for_each(|px| {
64            let relu6 = ((*px + three).min(six)).max(f16::zero());
65            *px = *px * relu6 * inv6;
66        });
67    }
68}
69
70#[cfg(test)]
71#[macro_use]
72pub mod s {
73    hardswish_frame_tests!(true, f32, crate::generic::hardswish::SHardSwish4);
74}
75
76#[cfg(test)]
77#[macro_use]
78pub mod h {
79    hardswish_frame_tests!(true, tract_data::internal::f16, crate::generic::hardswish::HHardSwish8);
80}