Skip to main content

tract_linalg/x86_64_fma/
max.rs

1reduce_impl_wrap!(
2    f32,
3    x86_64_fma_max_f32_32n,
4    32,
5    8,
6    (),
7    f32::MIN,
8    #[inline(never)]
9    fn run(buf: &[f32], _: ()) -> f32 {
10        assert!(buf.len() % 32 == 0);
11        assert!(buf.len() > 0);
12        unsafe { x86_64_fma_max_f32_32n_run(buf) }
13    },
14    #[inline(never)]
15    fn reduce_two(a: f32, b: f32) -> f32 {
16        a.max(b)
17    }
18);
19
20#[target_feature(enable = "avx")]
21unsafe fn x86_64_fma_max_f32_32n_run(buf: &[f32]) -> f32 {
22    unsafe {
23        let len = buf.len();
24        let ptr = buf.as_ptr();
25        let mut acc = f32::MIN;
26        std::arch::asm!("
27            vbroadcastss ymm0, xmm0
28            vmovaps ymm1, ymm0
29            vmovaps ymm2, ymm0
30            vmovaps ymm3, ymm0
31            2:
32                vmovaps ymm4, [{ptr}]
33                vmovaps ymm5, [{ptr} + 32]
34                vmovaps ymm6, [{ptr} + 64]
35                vmovaps ymm7, [{ptr} + 96]
36                vmaxps ymm0, ymm0, ymm4
37                vmaxps ymm1, ymm1, ymm5
38                vmaxps ymm2, ymm2, ymm6
39                vmaxps ymm3, ymm3, ymm7
40                add {ptr}, 128
41                sub {len}, 32
42                jnz 2b
43            vmaxps ymm0, ymm0, ymm1
44            vmaxps ymm2, ymm2, ymm3
45            vmaxps ymm0, ymm0, ymm2
46            vperm2f128 ymm1, ymm0, ymm0, 1      // copy second half (4xf32) of ymm0 to ymm1
47            vmaxps xmm0, xmm0, xmm1             // xmm0 contains 4 values to max
48            vpermilps xmm1, xmm0, 2 + (3 << 2)  // second 2x32 bit half moved to top
49            vmaxps xmm0, xmm0, xmm1             // xmm0 containes 2 values
50            vpermilps xmm1, xmm0, 1             // second f32 to top
51            vmaxps xmm0, xmm0, xmm1
52            ",
53        len = inout(reg) len => _,
54        ptr = inout(reg) ptr => _,
55        inout("ymm0") acc,
56        out("ymm1") _, out("ymm2") _, out("ymm3") _,
57        out("ymm4") _, out("ymm5") _, out("ymm6") _, out("ymm7") _
58        );
59        acc
60    }
61}
62
63#[cfg(test)]
64mod test_x86_64_fma_max_f32_32n {
65    use super::*;
66    crate::max_frame_tests!(is_x86_feature_detected!("avx2"), f32, x86_64_fma_max_f32_32n);
67}