1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
pub mod cost_model;
#[macro_use]
pub(crate) mod fuse;
#[macro_use]
pub(crate) mod kernel;
pub(crate) mod input_store;
#[macro_use]
#[allow(clippy::module_inception)]
pub(crate) mod mmm;
mod scratch;
mod storage;
#[cfg(test)]
#[macro_use]
pub mod tests;

pub use cost_model::*;
pub use fuse::*;
pub use input_store::*;
pub use kernel::MatMatMulKer;
pub use mmm::*;
pub use scratch::*;
pub use storage::*;

pub fn no_prefetch(_ptr: *const u8, _len: usize) {}

macro_rules! MMMKernel {
    ($ti:ident, $func:ident; $mr: expr, $nr: expr; $alignment_bytes_packed_a: expr, $alignment_bytes_packed_b: expr; $end_padding_packed_a: expr, $end_padding_packed_b: expr ; $prefetch: ident, $cond: expr $(, can_fuse: $can_fuse:expr)?) => {
        paste! {
            mod [<sys_ $func>] {
                use crate::frame::mmm::*;
                #[allow(unused_imports)]
                use tract_data::prelude::f16;
                extern_kernel!(fn $func(op: *const FusedKerSpec<$ti>) -> isize);
            }

            #[allow(non_camel_case_types)]
            #[derive(Copy, Clone, Debug, new)]
            pub struct $func;

            impl $crate::frame::mmm::MatMatMulKer<$ti> for $func {
                #[inline(always)]
                fn name() -> &'static str {
                    stringify!($func)
                }
                #[inline(always)]
                fn mr() -> usize {
                    $mr
                }
                #[inline(always)]
                fn nr() -> usize {
                    $nr
                }
                #[inline(always)]
                fn alignment_bytes_packed_a() -> usize {
                    $alignment_bytes_packed_a
                }
                #[inline(always)]
                fn alignment_bytes_packed_b() -> usize {
                    $alignment_bytes_packed_b
                }
                #[inline(always)]
                fn end_padding_packed_a() -> usize {
                    $end_padding_packed_a
                }
                #[inline(always)]
                fn end_padding_packed_b() -> usize {
                    $end_padding_packed_b
                }
                #[inline(always)]
                fn kernel(spec: &[$crate::frame::mmm::FusedKerSpec<$ti>]) -> isize {
                    debug_assert!(spec.len() > 0);
                    debug_assert!(matches!(spec[spec.len() - 1], $crate::frame::mmm::FusedKerSpec::Done));
                    unsafe { [<sys_ $func>]::$func(spec.as_ptr()) }
                }
                #[inline(always)]
                fn prefetch(ptr: *const u8, len: usize) {
                    ($prefetch)(ptr, len)
                }
                $(
                    fn can_fuse(spec: &FusedSpec) -> bool {
                        ($can_fuse)(spec)
                    }
                )?
            }
        }
        test_mmm_kernel!($ti, $func, $cond);
    };
}

macro_rules! test_mmm_kernel {
    (f16, $func:ident, $cond: expr) => {
        test_mmm_kernel_f16!($func, $cond);
    };
    (f32, $func:ident, $cond: expr) => {
        test_mmm_kernel_f32!($func, $cond);
    };
    (i32, $func:ident, $cond: expr) => {
        test_mmm_kernel_i32!($func, $cond);
    };
}