tract_linalg/frame/mmm/
kernel.rs

1use pack::PackedFormat;
2
3use super::*;
4use std::borrow::Cow;
5use std::fmt::Debug;
6
7use crate::LADatum;
8
9pub trait MatMatMulKer: Clone + Debug + Send + Sync + 'static {
10    type Acc: LADatum;
11    fn name(&self) -> &str;
12    fn kernel(&self, op: &[FusedKerSpec<Self::Acc>]) -> isize;
13    fn mr(&self) -> usize;
14    fn nr(&self) -> usize;
15
16    #[allow(clippy::type_complexity)]
17    fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)];
18    fn stores(&self) -> Cow<[DatumType]>;
19
20    #[allow(unused_variables)]
21    fn can_fuse(&self, spec: &FusedSpec) -> bool {
22        true
23    }
24
25    #[allow(unused_variables)]
26    fn is_supported_here(&self) -> bool {
27        true
28    }
29}
30
31type Kernel<Acc> = unsafe fn(&[FusedKerSpec<Acc>]) -> isize;
32
33#[derive(Clone)]
34pub struct DynKernel<const MR: usize, const NR: usize, Acc: LADatum> {
35    pub name: String,
36    pub kernel: Kernel<Acc>,
37    pub packings: Vec<(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)>,
38    pub stores: Vec<DatumType>,
39    pub supported_predicate: fn() -> bool,
40    pub can_fuse: fn(&FusedSpec) -> bool,
41}
42
43impl<const MR: usize, const NR: usize, Acc: LADatum> DynKernel<MR, NR, Acc> {
44    pub fn new(
45        name: &str,
46        kernel: Kernel<Acc>,
47        packing_a: PackedFormat,
48        packing_b: PackedFormat,
49    ) -> Self {
50        let kernel = DynKernel {
51            name: name.to_string(),
52            kernel,
53            packings: vec![],
54            stores: vec![Acc::datum_type()],
55            supported_predicate: || true,
56            can_fuse: |_| true,
57        };
58        kernel.with_packing(packing_a, packing_b)
59    }
60
61    pub fn with_platform_condition(mut self, f: fn() -> bool) -> Self {
62        self.supported_predicate = f;
63        self
64    }
65
66    pub fn with_packing(mut self, a: impl MMMInputFormat, b: impl MMMInputFormat) -> Self {
67        self.packings.push((Box::new(a), Box::new(b)));
68        self
69    }
70
71    pub fn with_packing_a(self, a: impl MMMInputFormat) -> Self {
72        let b = self.regular_pack_b();
73        self.with_packing(a, b)
74    }
75
76    pub fn regular_pack_a(&self) -> PackedFormat {
77        *self.packings[0].0.clone().downcast::<PackedFormat>().unwrap()
78    }
79
80    pub fn regular_pack_b(&self) -> PackedFormat {
81        *self.packings[0].1.clone().downcast::<PackedFormat>().unwrap()
82    }
83
84    pub fn with_can_fuse(self, can_fuse: fn(&FusedSpec) -> bool) -> Self {
85        Self { can_fuse, ..self }
86    }
87
88    pub fn with_store<D: LADatum>(mut self) -> Self {
89        self.stores.push(D::datum_type());
90        self
91    }
92
93    pub fn mmm(&self) -> Box<dyn MatMatMul> {
94        Box::new(self.clone())
95    }
96}
97
98impl<const MR: usize, const NR: usize, Acc: LADatum> Debug for DynKernel<MR, NR, Acc> {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        write!(f, "{}", self.name)
101    }
102}
103
104impl<const MR: usize, const NR: usize, Acc: LADatum> MatMatMulKer for DynKernel<MR, NR, Acc> {
105    type Acc = Acc;
106    fn name(&self) -> &str {
107        &self.name
108    }
109
110    fn mr(&self) -> usize {
111        MR
112    }
113
114    fn nr(&self) -> usize {
115        NR
116    }
117
118    fn is_supported_here(&self) -> bool {
119        (self.supported_predicate)()
120    }
121
122    fn can_fuse(&self, spec: &FusedSpec) -> bool {
123        (self.can_fuse)(spec)
124    }
125
126    fn kernel(&self, op: &[FusedKerSpec<Self::Acc>]) -> isize {
127        unsafe { (self.kernel)(op) }
128    }
129
130    #[allow(clippy::type_complexity)]
131    fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)] {
132        &self.packings
133    }
134
135    fn stores(&self) -> Cow<[DatumType]> {
136        Cow::Borrowed(&self.stores)
137    }
138}