tract_linalg/frame/mmm/
kernel.rs1use pack::PackedFormat;
2
3use super::*;
4use std::borrow::Cow;
5use std::fmt::Debug;
6
7use crate::LADatum;
8
9pub trait MatMatMulKer: Clone + Debug + Send + Sync + 'static {
10 type Acc: LADatum;
11 fn name(&self) -> &str;
12 fn kernel(&self, op: &[FusedKerSpec<Self::Acc>]) -> isize;
13 fn mr(&self) -> usize;
14 fn nr(&self) -> usize;
15
16 #[allow(clippy::type_complexity)]
17 fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)];
18 fn stores(&self) -> Cow<[DatumType]>;
19
20 #[allow(unused_variables)]
21 fn can_fuse(&self, spec: &FusedSpec) -> bool {
22 true
23 }
24
25 #[allow(unused_variables)]
26 fn is_supported_here(&self) -> bool {
27 true
28 }
29}
30
31type Kernel<Acc> = unsafe fn(&[FusedKerSpec<Acc>]) -> isize;
32
33#[derive(Clone)]
34pub struct DynKernel<const MR: usize, const NR: usize, Acc: LADatum> {
35 pub name: String,
36 pub kernel: Kernel<Acc>,
37 pub packings: Vec<(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)>,
38 pub stores: Vec<DatumType>,
39 pub supported_predicate: fn() -> bool,
40 pub can_fuse: fn(&FusedSpec) -> bool,
41}
42
43impl<const MR: usize, const NR: usize, Acc: LADatum> DynKernel<MR, NR, Acc> {
44 pub fn new(
45 name: &str,
46 kernel: Kernel<Acc>,
47 packing_a: PackedFormat,
48 packing_b: PackedFormat,
49 ) -> Self {
50 let kernel = DynKernel {
51 name: name.to_string(),
52 kernel,
53 packings: vec![],
54 stores: vec![Acc::datum_type()],
55 supported_predicate: || true,
56 can_fuse: |_| true,
57 };
58 kernel.with_packing(packing_a, packing_b)
59 }
60
61 pub fn with_platform_condition(mut self, f: fn() -> bool) -> Self {
62 self.supported_predicate = f;
63 self
64 }
65
66 pub fn with_packing(mut self, a: impl MMMInputFormat, b: impl MMMInputFormat) -> Self {
67 self.packings.push((Box::new(a), Box::new(b)));
68 self
69 }
70
71 pub fn with_packing_a(self, a: impl MMMInputFormat) -> Self {
72 let b = self.regular_pack_b();
73 self.with_packing(a, b)
74 }
75
76 pub fn regular_pack_a(&self) -> PackedFormat {
77 *self.packings[0].0.clone().downcast::<PackedFormat>().unwrap()
78 }
79
80 pub fn regular_pack_b(&self) -> PackedFormat {
81 *self.packings[0].1.clone().downcast::<PackedFormat>().unwrap()
82 }
83
84 pub fn with_can_fuse(self, can_fuse: fn(&FusedSpec) -> bool) -> Self {
85 Self { can_fuse, ..self }
86 }
87
88 pub fn with_store<D: LADatum>(mut self) -> Self {
89 self.stores.push(D::datum_type());
90 self
91 }
92
93 pub fn mmm(&self) -> Box<dyn MatMatMul> {
94 Box::new(self.clone())
95 }
96}
97
98impl<const MR: usize, const NR: usize, Acc: LADatum> Debug for DynKernel<MR, NR, Acc> {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 write!(f, "{}", self.name)
101 }
102}
103
104impl<const MR: usize, const NR: usize, Acc: LADatum> MatMatMulKer for DynKernel<MR, NR, Acc> {
105 type Acc = Acc;
106 fn name(&self) -> &str {
107 &self.name
108 }
109
110 fn mr(&self) -> usize {
111 MR
112 }
113
114 fn nr(&self) -> usize {
115 NR
116 }
117
118 fn is_supported_here(&self) -> bool {
119 (self.supported_predicate)()
120 }
121
122 fn can_fuse(&self, spec: &FusedSpec) -> bool {
123 (self.can_fuse)(spec)
124 }
125
126 fn kernel(&self, op: &[FusedKerSpec<Self::Acc>]) -> isize {
127 unsafe { (self.kernel)(op) }
128 }
129
130 #[allow(clippy::type_complexity)]
131 fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)] {
132 &self.packings
133 }
134
135 fn stores(&self) -> Cow<[DatumType]> {
136 Cow::Borrowed(&self.stores)
137 }
138}