Skip to main content

tract_linalg/
x86_64_fma.rs

1use crate::frame::element_wise::ElementWiseKer;
2use crate::frame::reduce::{MapReduceKer, ReduceKer};
3use crate::x86_64_fma::softmax::x86_64_fma_softmax2_fastcompact_f32_32n;
4use crate::Ops;
5
6pub mod mmm;
7
8pub mod by_scalar;
9mod intel;
10pub mod max;
11pub mod panel_extract;
12pub mod softmax;
13
14const AVX2: fn() -> bool = || is_x86_feature_detected!("avx2");
15const FMA: fn() -> bool = || is_x86_feature_detected!("fma");
16const AVX512F: fn() -> bool = || is_x86_feature_detected!("avx512f");
17
18tanh_impl!(f32, fma_tanh_f32, 8, 8, is_x86_feature_detected!("fma"));
19sigmoid_impl!(f32, fma_sigmoid_f32, 8, 8, is_x86_feature_detected!("fma"));
20
21fn plug_avx2(_ops: &mut Ops) {}
22
23fn plug_fma(ops: &mut Ops) {
24    panel_extract::plug(ops);
25
26    ops.sigmoid_f32 = Box::new(|| fma_sigmoid_f32::ew());
27    ops.tanh_f32 = Box::new(|| fma_tanh_f32::ew());
28
29    ops.mul_by_scalar_f32 = Box::new(|| by_scalar::x86_64_avx_f32_mul_by_scalar_32n::ew());
30    ops.max_f32 = Box::new(|| max::x86_64_fma_max_f32_32n::red());
31    ops.softmax2_fastcompact_f32 = Box::new(|| x86_64_fma_softmax2_fastcompact_f32_32n::red());
32
33    log::info!("sigmoid_f32, tanh_f32: x86_64/fma activated");
34}
35
36fn plug_avx512f(_ops: &mut Ops) {}
37
38pub fn plug(ops: &mut Ops) {
39    mmm::plug(ops);
40    if is_x86_feature_detected!("avx2") {
41        plug_avx2(ops);
42        if is_x86_feature_detected!("fma") {
43            plug_fma(ops);
44            if is_x86_feature_detected!("avx512f") {
45                plug_avx512f(ops);
46            }
47        }
48    }
49}