tract_linalg/frame/mmm/
kernel.rs1use crate::frame::pack::PackedFormat;
2
3use super::*;
4use std::borrow::Cow;
5use std::fmt::Debug;
6
7use crate::LADatum;
8
9pub trait MatMatMulKer: Clone + Debug + Send + Sync + 'static {
10 type Acc: LADatum;
11 fn name(&self) -> &str;
12 fn kernel(&self, op: &[FusedKerSpec<Self::Acc>]) -> isize;
13 fn mr(&self) -> usize;
14 fn nr(&self) -> usize;
15
16 fn quality(&self) -> ImplementationQuality;
17 fn dynamic_boost(&self) -> isize;
18
19 #[allow(clippy::type_complexity)]
20 fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)];
21 fn stores(&self) -> Cow<'_, [DatumType]>;
22
23 #[allow(unused_variables)]
24 fn can_fuse(&self, spec: &FusedSpec) -> bool {
25 true
26 }
27
28 #[allow(unused_variables)]
29 fn is_supported_here(&self) -> bool {
30 true
31 }
32}
33
34type Kernel<Acc> = unsafe fn(&[FusedKerSpec<Acc>]) -> isize;
35
36#[derive(Clone)]
37pub struct DynKernel<const MR: usize, const NR: usize, Acc: LADatum> {
38 pub name: String,
39 pub kernel: Kernel<Acc>,
40 pub quality: ImplementationQuality,
41 pub packings: Vec<(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)>,
42 pub stores: Vec<DatumType>,
43 pub supported_predicate: fn() -> bool,
44 pub boost: fn() -> isize,
45 pub can_fuse: fn(&FusedSpec) -> bool,
46}
47
48impl<const MR: usize, const NR: usize, Acc: LADatum> DynKernel<MR, NR, Acc> {
49 pub fn new(
50 name: &str,
51 kernel: Kernel<Acc>,
52 packing_a: PackedFormat,
53 packing_b: PackedFormat,
54 quality: ImplementationQuality,
55 ) -> Self {
56 let kernel = DynKernel {
57 name: name.to_string(),
58 kernel,
59 quality,
60 packings: vec![],
61 stores: vec![Acc::datum_type()],
62 supported_predicate: || true,
63 boost: || 0,
64 can_fuse: |_| true,
65 };
66 kernel.with_packing(packing_a, packing_b)
67 }
68
69 pub fn with_platform_condition(mut self, f: fn() -> bool) -> Self {
70 self.supported_predicate = f;
71 self
72 }
73
74 pub fn with_boost(mut self, f: fn() -> isize) -> Self {
75 self.boost = f;
76 self
77 }
78
79 pub fn with_packing(mut self, a: impl MMMInputFormat, b: impl MMMInputFormat) -> Self {
80 self.packings.push((Box::new(a), Box::new(b)));
81 self
82 }
83
84 pub fn with_packing_a(self, a: impl MMMInputFormat) -> Self {
85 let b = self.regular_pack_b();
86 self.with_packing(a, b)
87 }
88
89 pub fn regular_pack_a(&self) -> PackedFormat {
90 *self.packings[0].0.clone().downcast::<PackedFormat>().unwrap()
91 }
92
93 pub fn regular_pack_b(&self) -> PackedFormat {
94 *self.packings[0].1.clone().downcast::<PackedFormat>().unwrap()
95 }
96
97 pub fn with_can_fuse(self, can_fuse: fn(&FusedSpec) -> bool) -> Self {
98 Self { can_fuse, ..self }
99 }
100
101 pub fn with_store<D: LADatum>(mut self) -> Self {
102 self.stores.push(D::datum_type());
103 self
104 }
105
106 pub fn mmm(&self) -> Box<dyn MatMatMul> {
107 Box::new(self.clone())
108 }
109}
110
111impl<const MR: usize, const NR: usize, Acc: LADatum> Debug for DynKernel<MR, NR, Acc> {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 write!(f, "{}", self.name)
114 }
115}
116
117impl<const MR: usize, const NR: usize, Acc: LADatum> MatMatMulKer for DynKernel<MR, NR, Acc> {
118 type Acc = Acc;
119 fn name(&self) -> &str {
120 &self.name
121 }
122
123 fn mr(&self) -> usize {
124 MR
125 }
126
127 fn nr(&self) -> usize {
128 NR
129 }
130
131 fn quality(&self) -> ImplementationQuality {
132 self.quality
133 }
134
135 fn is_supported_here(&self) -> bool {
136 (self.supported_predicate)()
137 }
138
139 fn can_fuse(&self, spec: &FusedSpec) -> bool {
140 (self.can_fuse)(spec)
141 }
142
143 fn kernel(&self, op: &[FusedKerSpec<Self::Acc>]) -> isize {
144 unsafe { (self.kernel)(op) }
145 }
146
147 #[allow(clippy::type_complexity)]
148 fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)] {
149 &self.packings
150 }
151
152 fn stores(&self) -> Cow<'_, [DatumType]> {
153 Cow::Borrowed(&self.stores)
154 }
155
156 fn dynamic_boost(&self) -> isize {
157 (self.boost)()
158 }
159}