tract_linalg/frame/mmm/
kernel.rs

1use crate::frame::pack::PackedFormat;
2
3use super::*;
4use std::borrow::Cow;
5use std::fmt::Debug;
6
7use crate::LADatum;
8
9pub trait MatMatMulKer: Clone + Debug + Send + Sync + 'static {
10    type Acc: LADatum;
11    fn name(&self) -> &str;
12    fn kernel(&self, op: &[FusedKerSpec<Self::Acc>]) -> isize;
13    fn mr(&self) -> usize;
14    fn nr(&self) -> usize;
15
16    fn quality(&self) -> ImplementationQuality;
17    fn dynamic_boost(&self) -> isize;
18
19    #[allow(clippy::type_complexity)]
20    fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)];
21    fn stores(&self) -> Cow<'_, [DatumType]>;
22
23    #[allow(unused_variables)]
24    fn can_fuse(&self, spec: &FusedSpec) -> bool {
25        true
26    }
27
28    #[allow(unused_variables)]
29    fn is_supported_here(&self) -> bool {
30        true
31    }
32}
33
34type Kernel<Acc> = unsafe fn(&[FusedKerSpec<Acc>]) -> isize;
35
36#[derive(Clone)]
37pub struct DynKernel<const MR: usize, const NR: usize, Acc: LADatum> {
38    pub name: String,
39    pub kernel: Kernel<Acc>,
40    pub quality: ImplementationQuality,
41    pub packings: Vec<(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)>,
42    pub stores: Vec<DatumType>,
43    pub supported_predicate: fn() -> bool,
44    pub boost: fn() -> isize,
45    pub can_fuse: fn(&FusedSpec) -> bool,
46}
47
48impl<const MR: usize, const NR: usize, Acc: LADatum> DynKernel<MR, NR, Acc> {
49    pub fn new(
50        name: &str,
51        kernel: Kernel<Acc>,
52        packing_a: PackedFormat,
53        packing_b: PackedFormat,
54        quality: ImplementationQuality,
55    ) -> Self {
56        let kernel = DynKernel {
57            name: name.to_string(),
58            kernel,
59            quality,
60            packings: vec![],
61            stores: vec![Acc::datum_type()],
62            supported_predicate: || true,
63            boost: || 0,
64            can_fuse: |_| true,
65        };
66        kernel.with_packing(packing_a, packing_b)
67    }
68
69    pub fn with_platform_condition(mut self, f: fn() -> bool) -> Self {
70        self.supported_predicate = f;
71        self
72    }
73
74    pub fn with_boost(mut self, f: fn() -> isize) -> Self {
75        self.boost = f;
76        self
77    }
78
79    pub fn with_packing(mut self, a: impl MMMInputFormat, b: impl MMMInputFormat) -> Self {
80        self.packings.push((Box::new(a), Box::new(b)));
81        self
82    }
83
84    pub fn with_packing_a(self, a: impl MMMInputFormat) -> Self {
85        let b = self.regular_pack_b();
86        self.with_packing(a, b)
87    }
88
89    pub fn regular_pack_a(&self) -> PackedFormat {
90        *self.packings[0].0.clone().downcast::<PackedFormat>().unwrap()
91    }
92
93    pub fn regular_pack_b(&self) -> PackedFormat {
94        *self.packings[0].1.clone().downcast::<PackedFormat>().unwrap()
95    }
96
97    pub fn with_can_fuse(self, can_fuse: fn(&FusedSpec) -> bool) -> Self {
98        Self { can_fuse, ..self }
99    }
100
101    pub fn with_store<D: LADatum>(mut self) -> Self {
102        self.stores.push(D::datum_type());
103        self
104    }
105
106    pub fn mmm(&self) -> Box<dyn MatMatMul> {
107        Box::new(self.clone())
108    }
109}
110
111impl<const MR: usize, const NR: usize, Acc: LADatum> Debug for DynKernel<MR, NR, Acc> {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        write!(f, "{}", self.name)
114    }
115}
116
117impl<const MR: usize, const NR: usize, Acc: LADatum> MatMatMulKer for DynKernel<MR, NR, Acc> {
118    type Acc = Acc;
119    fn name(&self) -> &str {
120        &self.name
121    }
122
123    fn mr(&self) -> usize {
124        MR
125    }
126
127    fn nr(&self) -> usize {
128        NR
129    }
130
131    fn quality(&self) -> ImplementationQuality {
132        self.quality
133    }
134
135    fn is_supported_here(&self) -> bool {
136        (self.supported_predicate)()
137    }
138
139    fn can_fuse(&self, spec: &FusedSpec) -> bool {
140        (self.can_fuse)(spec)
141    }
142
143    fn kernel(&self, op: &[FusedKerSpec<Self::Acc>]) -> isize {
144        unsafe { (self.kernel)(op) }
145    }
146
147    #[allow(clippy::type_complexity)]
148    fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)] {
149        &self.packings
150    }
151
152    fn stores(&self) -> Cow<'_, [DatumType]> {
153        Cow::Borrowed(&self.stores)
154    }
155
156    fn dynamic_boost(&self) -> isize {
157        (self.boost)()
158    }
159}