Skip to main content

tract_linalg/frame/mmm/
mod.rs

1#[macro_use]
2mod macros;
3
4pub mod cost_model;
5#[macro_use]
6pub(crate) mod fuse;
7pub(crate) mod input_store;
8pub(crate) mod kernel;
9#[macro_use]
10pub(crate) mod panel_extract;
11mod scratch;
12mod storage;
13
14#[cfg(test)]
15#[macro_use]
16pub mod tests;
17
18use crate::multithread::Executor;
19#[cfg(feature = "multithread-mm")]
20use rayon::prelude::*;
21use std::borrow::Cow;
22use std::cmp::Ordering;
23use std::fmt::Debug;
24use tract_data::internal::*;
25
26pub use cost_model::*;
27pub use fuse::*;
28pub use input_store::*;
29pub use kernel::*;
30pub use panel_extract::*;
31pub use scratch::*;
32pub use storage::*;
33
34pub fn no_prefetch(_ptr: *const u8, _len: usize) {}
35
36#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
37pub enum ImplementationQuality {
38    /// Individual operations are emulated by individual conversion (f16->f32->f16)
39    Dreadful,
40    /// Rust scalar operation (with whatever optimisation the compiler manages)
41    Generic,
42    /// Implicit vectorization (e.g. Rust code, some unrolled loops, explicit template instantiations for small constant)
43    RustOptimized,
44    /// Explicit vectorization (e.g. intrinsics vector code)
45    TargetOptimized,
46    /// Hand optimized (assembly)
47    ManuallyOptimized,
48}
49
50impl ImplementationQuality {
51    pub fn best_to_worst() -> &'static [ImplementationQuality] {
52        use ImplementationQuality::*;
53        &[ManuallyOptimized, TargetOptimized, RustOptimized, Generic, Dreadful]
54    }
55
56    pub fn cost(&self) -> usize {
57        ImplementationQuality::best_to_worst().iter().position(|x| x == self).unwrap()
58    }
59}
60
61impl PartialOrd for ImplementationQuality {
62    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
63        Some(usize::from(*self).cmp(&usize::from(*other)))
64    }
65}
66
67impl From<ImplementationQuality> for usize {
68    fn from(value: ImplementationQuality) -> Self {
69        value.cost()
70    }
71}
72
73pub trait MatMatMul: Debug + dyn_clone::DynClone + Send + Sync + std::any::Any {
74    fn name(&self) -> &str;
75    fn mr(&self) -> usize;
76    fn nr(&self) -> usize;
77
78    fn quality(&self) -> ImplementationQuality;
79    fn dynamic_boost(&self) -> isize;
80
81    #[allow(clippy::type_complexity)]
82    fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)];
83
84    fn internal_type(&self) -> DatumType;
85
86    unsafe fn c_view(&self, m_axis: Option<usize>, n_axis: Option<usize>) -> OutputStoreSpec;
87    unsafe fn c_from_data_and_strides(
88        &self,
89        item_size: usize,
90        row_stride: isize,
91        col_stride: isize,
92    ) -> OutputStoreSpec;
93
94    fn can_fuse(&self, spec: &FusedSpec) -> bool;
95
96    fn stores(&self) -> Cow<'_, [DatumType]>;
97
98    unsafe fn run(&self, m: usize, n: usize, non_linear: &[FusedSpec]) -> TractResult<()> {
99        unsafe {
100            let mut scratch = self.allocate_scratch_space();
101            self.run_with_scratch_space(m, n, &mut *scratch, non_linear)
102        }
103    }
104
105    unsafe fn allocate_scratch_space(&self) -> Box<dyn ScratchSpace>;
106    unsafe fn can_use_scratch_space(&self, scratch: &dyn ScratchSpace) -> bool;
107    unsafe fn run_with_scratch_space(
108        &self,
109        m: usize,
110        n: usize,
111        scratch: &mut dyn ScratchSpace,
112        non_linear: &[FusedSpec],
113    ) -> TractResult<()>;
114}
115
116dyn_clone::clone_trait_object!(MatMatMul);
117
118impl PartialEq for Box<dyn MatMatMul> {
119    fn eq(&self, other: &Box<dyn MatMatMul>) -> bool {
120        self.name() == other.name()
121    }
122}
123impl Eq for Box<dyn MatMatMul> {}
124
125impl std::hash::Hash for Box<dyn MatMatMul> {
126    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
127        self.name().hash(state)
128    }
129}
130
131impl<K: MatMatMulKer> MatMatMul for K {
132    fn name(&self) -> &str {
133        self.name()
134    }
135    fn mr(&self) -> usize {
136        self.mr()
137    }
138    fn nr(&self) -> usize {
139        self.nr()
140    }
141
142    fn quality(&self) -> ImplementationQuality {
143        MatMatMulKer::quality(self)
144    }
145
146    fn dynamic_boost(&self) -> isize {
147        MatMatMulKer::dynamic_boost(self)
148    }
149
150    fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)] {
151        self.packings()
152    }
153
154    fn internal_type(&self) -> DatumType {
155        K::Acc::datum_type()
156    }
157
158    fn can_fuse(&self, spec: &FusedSpec) -> bool {
159        self.can_fuse(spec)
160    }
161
162    unsafe fn c_view(&self, m_axis: Option<usize>, n_axis: Option<usize>) -> OutputStoreSpec {
163        OutputStoreSpec::View { m_axis, n_axis, mr: self.mr(), nr: self.nr() }
164    }
165
166    unsafe fn c_from_data_and_strides(
167        &self,
168        item_size: usize,
169        row_stride: isize,
170        col_stride: isize,
171    ) -> OutputStoreSpec {
172        OutputStoreSpec::Strides {
173            row_byte_stride: row_stride * item_size as isize,
174            col_byte_stride: col_stride * item_size as isize,
175            mr: self.mr(),
176            nr: self.nr(),
177        }
178    }
179
180    fn stores(&self) -> Cow<'_, [DatumType]> {
181        self.stores()
182    }
183
184    unsafe fn allocate_scratch_space(&self) -> Box<dyn ScratchSpace> {
185        Box::<ScratchSpaceImpl<K::Acc>>::default()
186    }
187
188    unsafe fn can_use_scratch_space(&self, scratch: &dyn ScratchSpace) -> bool {
189        scratch.downcast_ref::<ScratchSpaceImpl<K::Acc>>().is_some()
190    }
191
192    unsafe fn run_with_scratch_space(
193        &self,
194        m: usize,
195        n: usize,
196        scratch: &mut dyn ScratchSpace,
197        non_linear: &[FusedSpec],
198    ) -> TractResult<()> {
199        unsafe {
200            let scratch = scratch
201                .downcast_mut::<ScratchSpaceImpl<K::Acc>>()
202                .context("Wrong scratch space type")?;
203            scratch.prepare(self, m, n, non_linear)?;
204            if n == 1 && self.nr() == 1 {
205                run_with_scratch_space_vec(self, m, scratch, non_linear)
206            } else {
207                let (mut prefer_col, mut prefer_row) = (0, 0);
208                for uop in non_linear.iter() {
209                    if let Some(col) = uop.prefer_col_outer() {
210                        prefer_col = col as usize;
211                        prefer_row = (!col) as usize;
212                    }
213                }
214                if prefer_col > prefer_row {
215                    run_with_scratch_space_col_outer(self, m, n, scratch, non_linear)
216                } else {
217                    run_with_scratch_space_row_outer(self, m, n, scratch, non_linear)
218                }
219            }
220        }
221    }
222}
223
224unsafe fn run_with_scratch_space_vec<K: MatMatMulKer>(
225    ker: &K,
226    m: usize,
227    scratch: &mut ScratchSpaceImpl<K::Acc>,
228    non_linear: &[FusedSpec],
229) -> TractResult<()> {
230    unsafe {
231        match crate::multithread::current_tract_executor() {
232            Executor::SingleThread => {
233                for ia in 0..m.divceil(ker.mr()) {
234                    scratch.run(ker, non_linear, ia, 0)?;
235                }
236                Ok(())
237            }
238            #[cfg(feature = "multithread-mm")]
239            Executor::MultiThread(pool) => pool.install(|| {
240                (0..m.div_ceil(ker.mr()))
241                    .into_par_iter()
242                    .try_for_each(|ia| scratch.run(ker, non_linear, ia, 0))
243            }),
244        }
245    }
246}
247
248unsafe fn run_with_scratch_space_col_outer<K: MatMatMulKer>(
249    ker: &K,
250    m: usize,
251    n: usize,
252    scratch: &mut ScratchSpaceImpl<K::Acc>,
253    non_linear: &[FusedSpec],
254) -> TractResult<()> {
255    unsafe {
256        match crate::multithread::current_tract_executor() {
257            Executor::SingleThread => {
258                for ib in 0..n.divceil(ker.nr()) {
259                    for ia in 0..m.divceil(ker.mr()) {
260                        scratch.run(ker, non_linear, ia, ib)?;
261                    }
262                }
263                Ok(())
264            }
265            #[cfg(feature = "multithread-mm")]
266            Executor::MultiThread(pool) => pool.install(|| {
267                (0..n.div_ceil(ker.nr())).into_par_iter().try_for_each(|ib| {
268                    for ia in 0..m.divceil(ker.mr()) {
269                        scratch.run(ker, non_linear, ia, ib)?;
270                    }
271                    Ok(())
272                })
273            }),
274        }
275    }
276}
277
278unsafe fn run_with_scratch_space_row_outer<K: MatMatMulKer>(
279    ker: &K,
280    m: usize,
281    n: usize,
282    scratch: &mut ScratchSpaceImpl<K::Acc>,
283    non_linear: &[FusedSpec],
284) -> TractResult<()> {
285    unsafe {
286        match crate::multithread::current_tract_executor() {
287            Executor::SingleThread => {
288                for ia in 0..m.divceil(ker.mr()) {
289                    for ib in 0..n.divceil(ker.nr()) {
290                        scratch.run(ker, non_linear, ia, ib)?;
291                    }
292                }
293                Ok(())
294            }
295            #[cfg(feature = "multithread-mm")]
296            Executor::MultiThread(pool) => pool.install(|| {
297                pool.install(|| {
298                    (0..m.div_ceil(ker.mr())).into_par_iter().try_for_each(|ia| {
299                        for ib in 0..n.divceil(ker.nr()) {
300                            scratch.run(ker, non_linear, ia, ib)?;
301                        }
302                        Ok(())
303                    })
304                })
305            }),
306        }
307    }
308}