tract_linalg/generic/
leaky_relu.rs1#![allow(clippy::excessive_precision)]
2use crate::frame::element_wise::ElementWiseKer;
3use tract_data::internal::*;
4use tract_num_traits::Zero;
5
6#[derive(Clone, Debug)]
7pub struct SLeakyRelu4;
8
9impl ElementWiseKer<f32, f32> for SLeakyRelu4 {
10 fn name() -> &'static str {
11 "generic"
12 }
13
14 fn alignment_bytes() -> usize {
15 16
16 }
17
18 fn alignment_items() -> usize {
19 4
20 }
21
22 fn nr() -> usize {
23 4
24 }
25
26 fn run(x: &mut [f32], alpha: f32) {
27 debug_assert!(x.len() % Self::nr() == 0);
28 debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
29 x.iter_mut().for_each(|px| *px = if *px < 0. { *px * alpha } else { *px });
30 }
31}
32
33#[derive(Clone, Debug)]
34pub struct HLeakyRelu8;
35
36impl ElementWiseKer<f16, f16> for HLeakyRelu8 {
37 fn name() -> &'static str {
38 "generic"
39 }
40
41 fn alignment_bytes() -> usize {
42 16
43 }
44
45 fn alignment_items() -> usize {
46 4
47 }
48
49 fn nr() -> usize {
50 8
51 }
52
53 fn run(x: &mut [f16], alpha: f16) {
54 debug_assert!(x.len() % Self::nr() == 0);
55 debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
56 x.iter_mut().for_each(|px| *px = if *px < f16::zero() { *px * alpha } else { *px })
57 }
58}
59
60#[cfg(test)]
61#[macro_use]
62pub mod s {
63 leaky_relu_frame_tests!(true, f32, crate::generic::leaky_relu::SLeakyRelu4);
64}
65
66#[cfg(test)]
67#[macro_use]
68pub mod h {
69 leaky_relu_frame_tests!(
70 true,
71 tract_data::internal::f16,
72 crate::generic::leaky_relu::HLeakyRelu8
73 );
74}