tract_linalg/frame/mmm/
kernel.rs1use crate::frame::pack::PackedFormat;
2
3use super::*;
4use std::borrow::Cow;
5use std::fmt::Debug;
6
7use crate::LADatum;
8
9pub trait MatMatMulKer: Clone + Debug + Send + Sync + 'static {
10 type Acc: LADatum;
11 fn name(&self) -> &str;
12 fn kernel(&self, op: &[FusedKerSpec<Self::Acc>]) -> isize;
13 fn mr(&self) -> usize;
14 fn nr(&self) -> usize;
15
16 fn quality(&self) -> ImplementationQuality;
17
18 #[allow(clippy::type_complexity)]
19 fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)];
20 fn stores(&self) -> Cow<[DatumType]>;
21
22 #[allow(unused_variables)]
23 fn can_fuse(&self, spec: &FusedSpec) -> bool {
24 true
25 }
26
27 #[allow(unused_variables)]
28 fn is_supported_here(&self) -> bool {
29 true
30 }
31}
32
33type Kernel<Acc> = unsafe fn(&[FusedKerSpec<Acc>]) -> isize;
34
35#[derive(Clone)]
36pub struct DynKernel<const MR: usize, const NR: usize, Acc: LADatum> {
37 pub name: String,
38 pub kernel: Kernel<Acc>,
39 pub quality: ImplementationQuality,
40 pub packings: Vec<(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)>,
41 pub stores: Vec<DatumType>,
42 pub supported_predicate: fn() -> bool,
43 pub can_fuse: fn(&FusedSpec) -> bool,
44}
45
46impl<const MR: usize, const NR: usize, Acc: LADatum> DynKernel<MR, NR, Acc> {
47 pub fn new(
48 name: &str,
49 kernel: Kernel<Acc>,
50 packing_a: PackedFormat,
51 packing_b: PackedFormat,
52 quality: ImplementationQuality,
53 ) -> Self {
54 let kernel = DynKernel {
55 name: name.to_string(),
56 kernel,
57 quality,
58 packings: vec![],
59 stores: vec![Acc::datum_type()],
60 supported_predicate: || true,
61 can_fuse: |_| true,
62 };
63 kernel.with_packing(packing_a, packing_b)
64 }
65
66 pub fn with_platform_condition(mut self, f: fn() -> bool) -> Self {
67 self.supported_predicate = f;
68 self
69 }
70
71 pub fn with_packing(mut self, a: impl MMMInputFormat, b: impl MMMInputFormat) -> Self {
72 self.packings.push((Box::new(a), Box::new(b)));
73 self
74 }
75
76 pub fn with_packing_a(self, a: impl MMMInputFormat) -> Self {
77 let b = self.regular_pack_b();
78 self.with_packing(a, b)
79 }
80
81 pub fn regular_pack_a(&self) -> PackedFormat {
82 *self.packings[0].0.clone().downcast::<PackedFormat>().unwrap()
83 }
84
85 pub fn regular_pack_b(&self) -> PackedFormat {
86 *self.packings[0].1.clone().downcast::<PackedFormat>().unwrap()
87 }
88
89 pub fn with_can_fuse(self, can_fuse: fn(&FusedSpec) -> bool) -> Self {
90 Self { can_fuse, ..self }
91 }
92
93 pub fn with_store<D: LADatum>(mut self) -> Self {
94 self.stores.push(D::datum_type());
95 self
96 }
97
98 pub fn mmm(&self) -> Box<dyn MatMatMul> {
99 Box::new(self.clone())
100 }
101}
102
103impl<const MR: usize, const NR: usize, Acc: LADatum> Debug for DynKernel<MR, NR, Acc> {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 write!(f, "{}", self.name)
106 }
107}
108
109impl<const MR: usize, const NR: usize, Acc: LADatum> MatMatMulKer for DynKernel<MR, NR, Acc> {
110 type Acc = Acc;
111 fn name(&self) -> &str {
112 &self.name
113 }
114
115 fn mr(&self) -> usize {
116 MR
117 }
118
119 fn nr(&self) -> usize {
120 NR
121 }
122
123 fn quality(&self) -> ImplementationQuality {
124 self.quality
125 }
126
127 fn is_supported_here(&self) -> bool {
128 (self.supported_predicate)()
129 }
130
131 fn can_fuse(&self, spec: &FusedSpec) -> bool {
132 (self.can_fuse)(spec)
133 }
134
135 fn kernel(&self, op: &[FusedKerSpec<Self::Acc>]) -> isize {
136 unsafe { (self.kernel)(op) }
137 }
138
139 #[allow(clippy::type_complexity)]
140 fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)] {
141 &self.packings
142 }
143
144 fn stores(&self) -> Cow<[DatumType]> {
145 Cow::Borrowed(&self.stores)
146 }
147}