Skip to main content

tract_core/ops/nn/
mod.rs

1mod data_formats;
2pub mod gelu_approximate;
3pub mod grid_sample;
4mod reduce;
5pub mod resize;
6pub mod rms_norm;
7pub mod silu;
8mod softmax;
9
10pub use self::data_formats::{BaseDataShape, DataFormat, DataShape, SymDataShape};
11pub use self::gelu_approximate::GeluApproximate;
12pub use self::grid_sample::{GridSample, InterpolationMode, PaddingMode};
13pub use self::reduce::{Reduce, Reducer, expand_mean_of_squares};
14pub use self::resize::{CoordTransformer, Interpolator, Nearest, Resize};
15pub use self::rms_norm::RmsNorm;
16pub use self::silu::Silu;
17pub use self::softmax::{Softmax, SoftmaxExp, SoftmaxKind};
18
19pub use crate::internal::*;
20
21use tract_num_traits::AsPrimitive;
22
23element_wise!(sigmoid, Sigmoid,
24 [f16] => |_, xs| { (tract_linalg::ops().sigmoid_f16)().run(xs) },
25 [f32] => |_, xs| { (tract_linalg::ops().sigmoid_f32)().run(xs) };
26 q: [i8, u8, i32, i32] => |x: f32| 1.0 / (1.0+(-x).exp());
27 cost: |dt| {tvec!((Cost::FMA(dt), 11), (Cost::Div(dt), 1))};
28 declutter: silu::detect_silu
29);
30
31element_wise!(hard_swish, HardSwish,
32[f16] => |_, xs| { xs.iter_mut().for_each(|x| *x = *x * f16::from_f32(0.0).max(f16::from_f32(1.0).min(f16::from_f32(1. / 6.) * *x + f16::from_f32(0.5)))); Ok(()) },
33[f32] => |_, xs| { (tract_linalg::ops().hardswish_f32)().run(xs) }
34                                         );
35
36element_wise!(leaky_relu, LeakyRelu { alpha: f32 },
37 [f16] => |op, xs| { (tract_linalg::ops().leaky_relu_f16)().run_with_params(xs, f16::from_f32(op.alpha)) },
38 [f32] => |op, xs| { (tract_linalg::ops().leaky_relu_f32)().run_with_params(xs, op.alpha) }
39);