tract_linalg/generic/
leaky_relu.rs

1#![allow(clippy::excessive_precision)]
2use crate::frame::element_wise::ElementWiseKer;
3use tract_data::internal::*;
4use tract_num_traits::Zero;
5
6#[derive(Clone, Debug)]
7pub struct SLeakyRelu4;
8
9impl ElementWiseKer<f32, f32> for SLeakyRelu4 {
10    fn name() -> &'static str {
11        "generic"
12    }
13
14    fn alignment_bytes() -> usize {
15        16
16    }
17
18    fn alignment_items() -> usize {
19        4
20    }
21
22    fn nr() -> usize {
23        4
24    }
25
26    fn run(x: &mut [f32], alpha: f32) {
27        debug_assert!(x.len() % Self::nr() == 0);
28        debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
29        x.iter_mut().for_each(|px| *px = if *px < 0. { *px * alpha } else { *px });
30    }
31}
32
33#[derive(Clone, Debug)]
34pub struct HLeakyRelu8;
35
36impl ElementWiseKer<f16, f16> for HLeakyRelu8 {
37    fn name() -> &'static str {
38        "generic"
39    }
40
41    fn alignment_bytes() -> usize {
42        16
43    }
44
45    fn alignment_items() -> usize {
46        4
47    }
48
49    fn nr() -> usize {
50        8
51    }
52
53    fn run(x: &mut [f16], alpha: f16) {
54        debug_assert!(x.len() % Self::nr() == 0);
55        debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
56        x.iter_mut().for_each(|px| *px = if *px < f16::zero() { *px * alpha } else { *px })
57    }
58}
59
60#[cfg(test)]
61#[macro_use]
62pub mod s {
63    leaky_relu_frame_tests!(true, f32, crate::generic::leaky_relu::SLeakyRelu4);
64}
65
66#[cfg(test)]
67#[macro_use]
68pub mod h {
69    leaky_relu_frame_tests!(
70        true,
71        tract_data::internal::f16,
72        crate::generic::leaky_relu::HLeakyRelu8
73    );
74}