1#[macro_use]
2mod macros;
3
4pub mod cost_model;
5#[macro_use]
6pub(crate) mod fuse;
7pub(crate) mod input_store;
8pub(crate) mod kernel;
9#[macro_use]
10pub(crate) mod panel_extract;
11mod scratch;
12mod storage;
13
14#[cfg(test)]
15#[macro_use]
16pub mod tests;
17
18use crate::multithread::Executor;
19#[cfg(feature = "multithread-mm")]
20use rayon::prelude::*;
21use std::borrow::Cow;
22use std::cmp::Ordering;
23use std::fmt::Debug;
24use tract_data::internal::*;
25
26pub use cost_model::*;
27pub use fuse::*;
28pub use input_store::*;
29pub use kernel::*;
30pub use panel_extract::*;
31pub use scratch::*;
32pub use storage::*;
33
34pub fn no_prefetch(_ptr: *const u8, _len: usize) {}
35
36#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
37pub enum ImplementationQuality {
38 Dreadful,
40 Generic,
42 RustOptimized,
44 TargetOptimized,
46 ManuallyOptimized,
48}
49
50impl ImplementationQuality {
51 pub fn best_to_worst() -> &'static [ImplementationQuality] {
52 use ImplementationQuality::*;
53 &[ManuallyOptimized, TargetOptimized, RustOptimized, Generic, Dreadful]
54 }
55
56 pub fn cost(&self) -> usize {
57 ImplementationQuality::best_to_worst().iter().position(|x| x == self).unwrap()
58 }
59}
60
61impl PartialOrd for ImplementationQuality {
62 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
63 Some(usize::from(*self).cmp(&usize::from(*other)))
64 }
65}
66
67impl From<ImplementationQuality> for usize {
68 fn from(value: ImplementationQuality) -> Self {
69 value.cost()
70 }
71}
72
73pub trait MatMatMul: Debug + dyn_clone::DynClone + Send + Sync + std::any::Any {
74 fn name(&self) -> &str;
75 fn mr(&self) -> usize;
76 fn nr(&self) -> usize;
77
78 fn quality(&self) -> ImplementationQuality;
79
80 #[allow(clippy::type_complexity)]
81 fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)];
82
83 fn internal_type(&self) -> DatumType;
84
85 unsafe fn c_view(&self, m_axis: usize, n_axis: usize) -> OutputStoreSpec;
86 unsafe fn c_from_data_and_strides(
87 &self,
88 item_size: usize,
89 row_stride: isize,
90 col_stride: isize,
91 ) -> OutputStoreSpec;
92
93 fn can_fuse(&self, spec: &FusedSpec) -> bool;
94
95 fn stores(&self) -> Cow<[DatumType]>;
96
97 unsafe fn run(&self, m: usize, n: usize, non_linear: &[FusedSpec]) -> TractResult<()> {
98 let mut scratch = self.allocate_scratch_space();
99 self.run_with_scratch_space(m, n, &mut *scratch, non_linear)
100 }
101
102 unsafe fn allocate_scratch_space(&self) -> Box<dyn ScratchSpace>;
103 unsafe fn can_use_scratch_space(&self, scratch: &dyn ScratchSpace) -> bool;
104 unsafe fn run_with_scratch_space(
105 &self,
106 m: usize,
107 n: usize,
108 scratch: &mut dyn ScratchSpace,
109 non_linear: &[FusedSpec],
110 ) -> TractResult<()>;
111}
112
113dyn_clone::clone_trait_object!(MatMatMul);
114
115impl PartialEq for Box<dyn MatMatMul> {
116 fn eq(&self, other: &Box<dyn MatMatMul>) -> bool {
117 self.as_ref().type_id() == other.as_ref().type_id()
118 }
119}
120
121impl std::hash::Hash for Box<dyn MatMatMul> {
122 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
123 self.as_ref().type_id().hash(state)
124 }
125}
126
127impl<K: MatMatMulKer> MatMatMul for K {
128 fn name(&self) -> &str {
129 self.name()
130 }
131 fn mr(&self) -> usize {
132 self.mr()
133 }
134 fn nr(&self) -> usize {
135 self.nr()
136 }
137
138 fn quality(&self) -> ImplementationQuality {
139 MatMatMulKer::quality(self)
140 }
141
142 fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)] {
143 self.packings()
144 }
145
146 fn internal_type(&self) -> DatumType {
147 K::Acc::datum_type()
148 }
149
150 fn can_fuse(&self, spec: &FusedSpec) -> bool {
151 self.can_fuse(spec)
152 }
153
154 unsafe fn c_view(&self, m_axis: usize, n_axis: usize) -> OutputStoreSpec {
155 OutputStoreSpec::View { m_axis, n_axis, mr: self.mr(), nr: self.nr() }
156 }
157
158 unsafe fn c_from_data_and_strides(
159 &self,
160 item_size: usize,
161 row_stride: isize,
162 col_stride: isize,
163 ) -> OutputStoreSpec {
164 OutputStoreSpec::Strides {
165 row_byte_stride: row_stride * item_size as isize,
166 col_byte_stride: col_stride * item_size as isize,
167 mr: self.mr(),
168 nr: self.nr(),
169 }
170 }
171
172 fn stores(&self) -> Cow<[DatumType]> {
173 self.stores()
174 }
175
176 unsafe fn allocate_scratch_space(&self) -> Box<dyn ScratchSpace> {
177 Box::<ScratchSpaceImpl<K::Acc>>::default()
178 }
179
180 unsafe fn can_use_scratch_space(&self, scratch: &dyn ScratchSpace) -> bool {
181 scratch.downcast_ref::<ScratchSpaceImpl<K::Acc>>().is_some()
182 }
183
184 unsafe fn run_with_scratch_space(
185 &self,
186 m: usize,
187 n: usize,
188 scratch: &mut dyn ScratchSpace,
189 non_linear: &[FusedSpec],
190 ) -> TractResult<()> {
191 let scratch = scratch
192 .downcast_mut::<ScratchSpaceImpl<K::Acc>>()
193 .context("Wrong scratch space type")?;
194 scratch.prepare(self, m, n, non_linear)?;
195 if n == 1 && self.nr() == 1 {
196 run_with_scratch_space_vec(self, m, scratch, non_linear)
197 } else {
198 let (mut prefer_col, mut prefer_row) = (0, 0);
199 for uop in non_linear.iter() {
200 if let Some(col) = uop.prefer_col_outer() {
201 prefer_col = col as usize;
202 prefer_row = (!col) as usize;
203 }
204 }
205 if prefer_col > prefer_row {
206 run_with_scratch_space_col_outer(self, m, n, scratch, non_linear)
207 } else {
208 run_with_scratch_space_row_outer(self, m, n, scratch, non_linear)
209 }
210 }
211 }
212}
213
214unsafe fn run_with_scratch_space_vec<K: MatMatMulKer>(
215 ker: &K,
216 m: usize,
217 scratch: &mut ScratchSpaceImpl<K::Acc>,
218 non_linear: &[FusedSpec],
219) -> TractResult<()> {
220 match crate::multithread::current_tract_executor() {
221 Executor::SingleThread => {
222 for ia in 0..m.divceil(ker.mr()) {
223 scratch.run(ker, non_linear, ia, 0)?;
224 }
225 Ok(())
226 }
227 #[cfg(feature = "multithread-mm")]
228 Executor::MultiThread(pool) => pool.install(|| {
229 (0..m.div_ceil(ker.mr()))
230 .into_par_iter()
231 .try_for_each(|ia| scratch.run(ker, non_linear, ia, 0))
232 }),
233 }
234}
235
236unsafe fn run_with_scratch_space_col_outer<K: MatMatMulKer>(
237 ker: &K,
238 m: usize,
239 n: usize,
240 scratch: &mut ScratchSpaceImpl<K::Acc>,
241 non_linear: &[FusedSpec],
242) -> TractResult<()> {
243 match crate::multithread::current_tract_executor() {
244 Executor::SingleThread => {
245 for ib in 0..n.divceil(ker.nr()) {
246 for ia in 0..m.divceil(ker.mr()) {
247 scratch.run(ker, non_linear, ia, ib)?;
248 }
249 }
250 Ok(())
251 }
252 #[cfg(feature = "multithread-mm")]
253 Executor::MultiThread(pool) => pool.install(|| {
254 (0..n.div_ceil(ker.nr())).into_par_iter().try_for_each(|ib| {
255 for ia in 0..m.divceil(ker.mr()) {
256 scratch.run(ker, non_linear, ia, ib)?;
257 }
258 Ok(())
259 })
260 }),
261 }
262}
263
264unsafe fn run_with_scratch_space_row_outer<K: MatMatMulKer>(
265 ker: &K,
266 m: usize,
267 n: usize,
268 scratch: &mut ScratchSpaceImpl<K::Acc>,
269 non_linear: &[FusedSpec],
270) -> TractResult<()> {
271 match crate::multithread::current_tract_executor() {
272 Executor::SingleThread => {
273 for ia in 0..m.divceil(ker.mr()) {
274 for ib in 0..n.divceil(ker.nr()) {
275 scratch.run(ker, non_linear, ia, ib)?;
276 }
277 }
278 Ok(())
279 }
280 #[cfg(feature = "multithread-mm")]
281 Executor::MultiThread(pool) => pool.install(|| {
282 pool.install(|| {
283 (0..m.div_ceil(ker.mr())).into_par_iter().try_for_each(|ia| {
284 for ib in 0..n.divceil(ker.nr()) {
285 scratch.run(ker, non_linear, ia, ib)?;
286 }
287 Ok(())
288 })
289 })
290 }),
291 }
292}