1#[macro_use]
2mod macros;
3
4pub mod cost_model;
5#[macro_use]
6pub(crate) mod fuse;
7pub(crate) mod input_store;
8pub(crate) mod kernel;
9#[macro_use]
10pub(crate) mod panel_extract;
11mod scratch;
12mod storage;
13
14#[cfg(test)]
15#[macro_use]
16pub mod tests;
17
18use crate::multithread::Executor;
19#[cfg(feature = "multithread-mm")]
20use rayon::prelude::*;
21use std::borrow::Cow;
22use std::cmp::Ordering;
23use std::fmt::Debug;
24use tract_data::internal::*;
25
26pub use cost_model::*;
27pub use fuse::*;
28pub use input_store::*;
29pub use kernel::*;
30pub use panel_extract::*;
31pub use scratch::*;
32pub use storage::*;
33
34pub fn no_prefetch(_ptr: *const u8, _len: usize) {}
35
36#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
37pub enum ImplementationQuality {
38 Dreadful,
40 Generic,
42 RustOptimized,
44 TargetOptimized,
46 ManuallyOptimized,
48}
49
50impl ImplementationQuality {
51 pub fn best_to_worst() -> &'static [ImplementationQuality] {
52 use ImplementationQuality::*;
53 &[ManuallyOptimized, TargetOptimized, RustOptimized, Generic, Dreadful]
54 }
55
56 pub fn cost(&self) -> usize {
57 ImplementationQuality::best_to_worst().iter().position(|x| x == self).unwrap()
58 }
59}
60
61impl PartialOrd for ImplementationQuality {
62 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
63 Some(usize::from(*self).cmp(&usize::from(*other)))
64 }
65}
66
67impl From<ImplementationQuality> for usize {
68 fn from(value: ImplementationQuality) -> Self {
69 value.cost()
70 }
71}
72
73pub trait MatMatMul: Debug + dyn_clone::DynClone + Send + Sync + std::any::Any {
74 fn name(&self) -> &str;
75 fn mr(&self) -> usize;
76 fn nr(&self) -> usize;
77
78 fn quality(&self) -> ImplementationQuality;
79 fn dynamic_boost(&self) -> isize;
80
81 #[allow(clippy::type_complexity)]
82 fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)];
83
84 fn internal_type(&self) -> DatumType;
85
86 unsafe fn c_view(&self, m_axis: Option<usize>, n_axis: Option<usize>) -> OutputStoreSpec;
87 unsafe fn c_from_data_and_strides(
88 &self,
89 item_size: usize,
90 row_stride: isize,
91 col_stride: isize,
92 ) -> OutputStoreSpec;
93
94 fn can_fuse(&self, spec: &FusedSpec) -> bool;
95
96 fn stores(&self) -> Cow<'_, [DatumType]>;
97
98 unsafe fn run(&self, m: usize, n: usize, non_linear: &[FusedSpec]) -> TractResult<()> {
99 unsafe {
100 let mut scratch = self.allocate_scratch_space();
101 self.run_with_scratch_space(m, n, &mut *scratch, non_linear)
102 }
103 }
104
105 unsafe fn allocate_scratch_space(&self) -> Box<dyn ScratchSpace>;
106 unsafe fn can_use_scratch_space(&self, scratch: &dyn ScratchSpace) -> bool;
107 unsafe fn run_with_scratch_space(
108 &self,
109 m: usize,
110 n: usize,
111 scratch: &mut dyn ScratchSpace,
112 non_linear: &[FusedSpec],
113 ) -> TractResult<()>;
114}
115
116dyn_clone::clone_trait_object!(MatMatMul);
117
118impl PartialEq for Box<dyn MatMatMul> {
119 fn eq(&self, other: &Box<dyn MatMatMul>) -> bool {
120 self.name() == other.name()
121 }
122}
123impl Eq for Box<dyn MatMatMul> {}
124
125impl std::hash::Hash for Box<dyn MatMatMul> {
126 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
127 self.name().hash(state)
128 }
129}
130
131impl<K: MatMatMulKer> MatMatMul for K {
132 fn name(&self) -> &str {
133 self.name()
134 }
135 fn mr(&self) -> usize {
136 self.mr()
137 }
138 fn nr(&self) -> usize {
139 self.nr()
140 }
141
142 fn quality(&self) -> ImplementationQuality {
143 MatMatMulKer::quality(self)
144 }
145
146 fn dynamic_boost(&self) -> isize {
147 MatMatMulKer::dynamic_boost(self)
148 }
149
150 fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)] {
151 self.packings()
152 }
153
154 fn internal_type(&self) -> DatumType {
155 K::Acc::datum_type()
156 }
157
158 fn can_fuse(&self, spec: &FusedSpec) -> bool {
159 self.can_fuse(spec)
160 }
161
162 unsafe fn c_view(&self, m_axis: Option<usize>, n_axis: Option<usize>) -> OutputStoreSpec {
163 OutputStoreSpec::View { m_axis, n_axis, mr: self.mr(), nr: self.nr() }
164 }
165
166 unsafe fn c_from_data_and_strides(
167 &self,
168 item_size: usize,
169 row_stride: isize,
170 col_stride: isize,
171 ) -> OutputStoreSpec {
172 OutputStoreSpec::Strides {
173 row_byte_stride: row_stride * item_size as isize,
174 col_byte_stride: col_stride * item_size as isize,
175 mr: self.mr(),
176 nr: self.nr(),
177 }
178 }
179
180 fn stores(&self) -> Cow<'_, [DatumType]> {
181 self.stores()
182 }
183
184 unsafe fn allocate_scratch_space(&self) -> Box<dyn ScratchSpace> {
185 Box::<ScratchSpaceImpl<K::Acc>>::default()
186 }
187
188 unsafe fn can_use_scratch_space(&self, scratch: &dyn ScratchSpace) -> bool {
189 scratch.downcast_ref::<ScratchSpaceImpl<K::Acc>>().is_some()
190 }
191
192 unsafe fn run_with_scratch_space(
193 &self,
194 m: usize,
195 n: usize,
196 scratch: &mut dyn ScratchSpace,
197 non_linear: &[FusedSpec],
198 ) -> TractResult<()> {
199 unsafe {
200 let scratch = scratch
201 .downcast_mut::<ScratchSpaceImpl<K::Acc>>()
202 .context("Wrong scratch space type")?;
203 scratch.prepare(self, m, n, non_linear)?;
204 if n == 1 && self.nr() == 1 {
205 run_with_scratch_space_vec(self, m, scratch, non_linear)
206 } else {
207 let (mut prefer_col, mut prefer_row) = (0, 0);
208 for uop in non_linear.iter() {
209 if let Some(col) = uop.prefer_col_outer() {
210 prefer_col = col as usize;
211 prefer_row = (!col) as usize;
212 }
213 }
214 if prefer_col > prefer_row {
215 run_with_scratch_space_col_outer(self, m, n, scratch, non_linear)
216 } else {
217 run_with_scratch_space_row_outer(self, m, n, scratch, non_linear)
218 }
219 }
220 }
221 }
222}
223
224unsafe fn run_with_scratch_space_vec<K: MatMatMulKer>(
225 ker: &K,
226 m: usize,
227 scratch: &mut ScratchSpaceImpl<K::Acc>,
228 non_linear: &[FusedSpec],
229) -> TractResult<()> {
230 unsafe {
231 match crate::multithread::current_tract_executor() {
232 Executor::SingleThread => {
233 for ia in 0..m.divceil(ker.mr()) {
234 scratch.run(ker, non_linear, ia, 0)?;
235 }
236 Ok(())
237 }
238 #[cfg(feature = "multithread-mm")]
239 Executor::MultiThread(pool) => pool.install(|| {
240 (0..m.div_ceil(ker.mr()))
241 .into_par_iter()
242 .try_for_each(|ia| scratch.run(ker, non_linear, ia, 0))
243 }),
244 }
245 }
246}
247
248unsafe fn run_with_scratch_space_col_outer<K: MatMatMulKer>(
249 ker: &K,
250 m: usize,
251 n: usize,
252 scratch: &mut ScratchSpaceImpl<K::Acc>,
253 non_linear: &[FusedSpec],
254) -> TractResult<()> {
255 unsafe {
256 match crate::multithread::current_tract_executor() {
257 Executor::SingleThread => {
258 for ib in 0..n.divceil(ker.nr()) {
259 for ia in 0..m.divceil(ker.mr()) {
260 scratch.run(ker, non_linear, ia, ib)?;
261 }
262 }
263 Ok(())
264 }
265 #[cfg(feature = "multithread-mm")]
266 Executor::MultiThread(pool) => pool.install(|| {
267 (0..n.div_ceil(ker.nr())).into_par_iter().try_for_each(|ib| {
268 for ia in 0..m.divceil(ker.mr()) {
269 scratch.run(ker, non_linear, ia, ib)?;
270 }
271 Ok(())
272 })
273 }),
274 }
275 }
276}
277
278unsafe fn run_with_scratch_space_row_outer<K: MatMatMulKer>(
279 ker: &K,
280 m: usize,
281 n: usize,
282 scratch: &mut ScratchSpaceImpl<K::Acc>,
283 non_linear: &[FusedSpec],
284) -> TractResult<()> {
285 unsafe {
286 match crate::multithread::current_tract_executor() {
287 Executor::SingleThread => {
288 for ia in 0..m.divceil(ker.mr()) {
289 for ib in 0..n.divceil(ker.nr()) {
290 scratch.run(ker, non_linear, ia, ib)?;
291 }
292 }
293 Ok(())
294 }
295 #[cfg(feature = "multithread-mm")]
296 Executor::MultiThread(pool) => pool.install(|| {
297 pool.install(|| {
298 (0..m.div_ceil(ker.mr())).into_par_iter().try_for_each(|ia| {
299 for ib in 0..n.divceil(ker.nr()) {
300 scratch.run(ker, non_linear, ia, ib)?;
301 }
302 Ok(())
303 })
304 })
305 }),
306 }
307 }
308}