tract_linalg/frame/mmm/
kernel.rs

1use crate::frame::pack::PackedFormat;
2
3use super::*;
4use std::borrow::Cow;
5use std::fmt::Debug;
6
7use crate::LADatum;
8
9pub trait MatMatMulKer: Clone + Debug + Send + Sync + 'static {
10    type Acc: LADatum;
11    fn name(&self) -> &str;
12    fn kernel(&self, op: &[FusedKerSpec<Self::Acc>]) -> isize;
13    fn mr(&self) -> usize;
14    fn nr(&self) -> usize;
15
16    fn quality(&self) -> ImplementationQuality;
17
18    #[allow(clippy::type_complexity)]
19    fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)];
20    fn stores(&self) -> Cow<[DatumType]>;
21
22    #[allow(unused_variables)]
23    fn can_fuse(&self, spec: &FusedSpec) -> bool {
24        true
25    }
26
27    #[allow(unused_variables)]
28    fn is_supported_here(&self) -> bool {
29        true
30    }
31}
32
33type Kernel<Acc> = unsafe fn(&[FusedKerSpec<Acc>]) -> isize;
34
35#[derive(Clone)]
36pub struct DynKernel<const MR: usize, const NR: usize, Acc: LADatum> {
37    pub name: String,
38    pub kernel: Kernel<Acc>,
39    pub quality: ImplementationQuality,
40    pub packings: Vec<(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)>,
41    pub stores: Vec<DatumType>,
42    pub supported_predicate: fn() -> bool,
43    pub can_fuse: fn(&FusedSpec) -> bool,
44}
45
46impl<const MR: usize, const NR: usize, Acc: LADatum> DynKernel<MR, NR, Acc> {
47    pub fn new(
48        name: &str,
49        kernel: Kernel<Acc>,
50        packing_a: PackedFormat,
51        packing_b: PackedFormat,
52        quality: ImplementationQuality,
53    ) -> Self {
54        let kernel = DynKernel {
55            name: name.to_string(),
56            kernel,
57            quality,
58            packings: vec![],
59            stores: vec![Acc::datum_type()],
60            supported_predicate: || true,
61            can_fuse: |_| true,
62        };
63        kernel.with_packing(packing_a, packing_b)
64    }
65
66    pub fn with_platform_condition(mut self, f: fn() -> bool) -> Self {
67        self.supported_predicate = f;
68        self
69    }
70
71    pub fn with_packing(mut self, a: impl MMMInputFormat, b: impl MMMInputFormat) -> Self {
72        self.packings.push((Box::new(a), Box::new(b)));
73        self
74    }
75
76    pub fn with_packing_a(self, a: impl MMMInputFormat) -> Self {
77        let b = self.regular_pack_b();
78        self.with_packing(a, b)
79    }
80
81    pub fn regular_pack_a(&self) -> PackedFormat {
82        *self.packings[0].0.clone().downcast::<PackedFormat>().unwrap()
83    }
84
85    pub fn regular_pack_b(&self) -> PackedFormat {
86        *self.packings[0].1.clone().downcast::<PackedFormat>().unwrap()
87    }
88
89    pub fn with_can_fuse(self, can_fuse: fn(&FusedSpec) -> bool) -> Self {
90        Self { can_fuse, ..self }
91    }
92
93    pub fn with_store<D: LADatum>(mut self) -> Self {
94        self.stores.push(D::datum_type());
95        self
96    }
97
98    pub fn mmm(&self) -> Box<dyn MatMatMul> {
99        Box::new(self.clone())
100    }
101}
102
103impl<const MR: usize, const NR: usize, Acc: LADatum> Debug for DynKernel<MR, NR, Acc> {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        write!(f, "{}", self.name)
106    }
107}
108
109impl<const MR: usize, const NR: usize, Acc: LADatum> MatMatMulKer for DynKernel<MR, NR, Acc> {
110    type Acc = Acc;
111    fn name(&self) -> &str {
112        &self.name
113    }
114
115    fn mr(&self) -> usize {
116        MR
117    }
118
119    fn nr(&self) -> usize {
120        NR
121    }
122
123    fn quality(&self) -> ImplementationQuality {
124        self.quality
125    }
126
127    fn is_supported_here(&self) -> bool {
128        (self.supported_predicate)()
129    }
130
131    fn can_fuse(&self, spec: &FusedSpec) -> bool {
132        (self.can_fuse)(spec)
133    }
134
135    fn kernel(&self, op: &[FusedKerSpec<Self::Acc>]) -> isize {
136        unsafe { (self.kernel)(op) }
137    }
138
139    #[allow(clippy::type_complexity)]
140    fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)] {
141        &self.packings
142    }
143
144    fn stores(&self) -> Cow<[DatumType]> {
145        Cow::Borrowed(&self.stores)
146    }
147}