tract_linalg/x86_64_fma/
max.rs1reduce_impl_wrap!(
2 f32,
3 x86_64_fma_max_f32_32n,
4 32,
5 8,
6 (),
7 f32::MIN,
8 #[inline(never)]
9 fn run(buf: &[f32], _: ()) -> f32 {
10 assert!(buf.len() % 32 == 0);
11 assert!(buf.len() > 0);
12 unsafe { x86_64_fma_max_f32_32n_run(buf) }
13 },
14 #[inline(never)]
15 fn reduce_two(a: f32, b: f32) -> f32 {
16 a.max(b)
17 }
18);
19
20#[target_feature(enable = "avx")]
21unsafe fn x86_64_fma_max_f32_32n_run(buf: &[f32]) -> f32 {
22 unsafe {
23 let len = buf.len();
24 let ptr = buf.as_ptr();
25 let mut acc = f32::MIN;
26 std::arch::asm!("
27 vbroadcastss ymm0, xmm0
28 vmovaps ymm1, ymm0
29 vmovaps ymm2, ymm0
30 vmovaps ymm3, ymm0
31 2:
32 vmovaps ymm4, [{ptr}]
33 vmovaps ymm5, [{ptr} + 32]
34 vmovaps ymm6, [{ptr} + 64]
35 vmovaps ymm7, [{ptr} + 96]
36 vmaxps ymm0, ymm0, ymm4
37 vmaxps ymm1, ymm1, ymm5
38 vmaxps ymm2, ymm2, ymm6
39 vmaxps ymm3, ymm3, ymm7
40 add {ptr}, 128
41 sub {len}, 32
42 jnz 2b
43 vmaxps ymm0, ymm0, ymm1
44 vmaxps ymm2, ymm2, ymm3
45 vmaxps ymm0, ymm0, ymm2
46 vperm2f128 ymm1, ymm0, ymm0, 1 // copy second half (4xf32) of ymm0 to ymm1
47 vmaxps xmm0, xmm0, xmm1 // xmm0 contains 4 values to max
48 vpermilps xmm1, xmm0, 2 + (3 << 2) // second 2x32 bit half moved to top
49 vmaxps xmm0, xmm0, xmm1 // xmm0 containes 2 values
50 vpermilps xmm1, xmm0, 1 // second f32 to top
51 vmaxps xmm0, xmm0, xmm1
52 ",
53 len = inout(reg) len => _,
54 ptr = inout(reg) ptr => _,
55 inout("ymm0") acc,
56 out("ymm1") _, out("ymm2") _, out("ymm3") _,
57 out("ymm4") _, out("ymm5") _, out("ymm6") _, out("ymm7") _
58 );
59 acc
60 }
61}
62
63#[cfg(test)]
64mod test_x86_64_fma_max_f32_32n {
65 use super::*;
66 crate::max_frame_tests!(is_x86_feature_detected!("avx2"), f32, x86_64_fma_max_f32_32n);
67}