tract_linalg/frame/mmm/
mod.rs

1#[macro_use]
2mod macros;
3
4pub mod cost_model;
5#[macro_use]
6pub(crate) mod fuse;
7pub(crate) mod input_store;
8pub(crate) mod kernel;
9#[macro_use]
10pub(crate) mod panel_extract;
11mod scratch;
12mod storage;
13
14#[cfg(test)]
15#[macro_use]
16pub mod tests;
17
18use crate::multithread::Executor;
19#[cfg(feature = "multithread-mm")]
20use rayon::prelude::*;
21use std::borrow::Cow;
22use std::cmp::Ordering;
23use std::fmt::Debug;
24use tract_data::internal::*;
25
26pub use cost_model::*;
27pub use fuse::*;
28pub use input_store::*;
29pub use kernel::*;
30pub use panel_extract::*;
31pub use scratch::*;
32pub use storage::*;
33
34pub fn no_prefetch(_ptr: *const u8, _len: usize) {}
35
36#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
37pub enum ImplementationQuality {
38    /// Individual operations are emulated by individual conversion (f16->f32->f16)
39    Dreadful,
40    /// Rust scalar operation (with whatever optimisation the compiler manages)
41    Generic,
42    /// Implicit vectorization (e.g. Rust code, some unrolled loops, explicit template instantiations for small constant)
43    RustOptimized,
44    /// Explicit vectorization (e.g. intrinsics vector code)
45    TargetOptimized,
46    /// Hand optimized (assembly)
47    ManuallyOptimized,
48}
49
50impl ImplementationQuality {
51    pub fn best_to_worst() -> &'static [ImplementationQuality] {
52        use ImplementationQuality::*;
53        &[ManuallyOptimized, TargetOptimized, RustOptimized, Generic, Dreadful]
54    }
55
56    pub fn cost(&self) -> usize {
57        ImplementationQuality::best_to_worst().iter().position(|x| x == self).unwrap()
58    }
59}
60
61impl PartialOrd for ImplementationQuality {
62    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
63        Some(usize::from(*self).cmp(&usize::from(*other)))
64    }
65}
66
67impl From<ImplementationQuality> for usize {
68    fn from(value: ImplementationQuality) -> Self {
69        value.cost()
70    }
71}
72
73pub trait MatMatMul: Debug + dyn_clone::DynClone + Send + Sync + std::any::Any {
74    fn name(&self) -> &str;
75    fn mr(&self) -> usize;
76    fn nr(&self) -> usize;
77
78    fn quality(&self) -> ImplementationQuality;
79
80    #[allow(clippy::type_complexity)]
81    fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)];
82
83    fn internal_type(&self) -> DatumType;
84
85    unsafe fn c_view(&self, m_axis: usize, n_axis: usize) -> OutputStoreSpec;
86    unsafe fn c_from_data_and_strides(
87        &self,
88        item_size: usize,
89        row_stride: isize,
90        col_stride: isize,
91    ) -> OutputStoreSpec;
92
93    fn can_fuse(&self, spec: &FusedSpec) -> bool;
94
95    fn stores(&self) -> Cow<[DatumType]>;
96
97    unsafe fn run(&self, m: usize, n: usize, non_linear: &[FusedSpec]) -> TractResult<()> {
98        let mut scratch = self.allocate_scratch_space();
99        self.run_with_scratch_space(m, n, &mut *scratch, non_linear)
100    }
101
102    unsafe fn allocate_scratch_space(&self) -> Box<dyn ScratchSpace>;
103    unsafe fn can_use_scratch_space(&self, scratch: &dyn ScratchSpace) -> bool;
104    unsafe fn run_with_scratch_space(
105        &self,
106        m: usize,
107        n: usize,
108        scratch: &mut dyn ScratchSpace,
109        non_linear: &[FusedSpec],
110    ) -> TractResult<()>;
111}
112
113dyn_clone::clone_trait_object!(MatMatMul);
114
115impl PartialEq for Box<dyn MatMatMul> {
116    fn eq(&self, other: &Box<dyn MatMatMul>) -> bool {
117        self.as_ref().type_id() == other.as_ref().type_id()
118    }
119}
120
121impl std::hash::Hash for Box<dyn MatMatMul> {
122    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
123        self.as_ref().type_id().hash(state)
124    }
125}
126
127impl<K: MatMatMulKer> MatMatMul for K {
128    fn name(&self) -> &str {
129        self.name()
130    }
131    fn mr(&self) -> usize {
132        self.mr()
133    }
134    fn nr(&self) -> usize {
135        self.nr()
136    }
137
138    fn quality(&self) -> ImplementationQuality {
139        MatMatMulKer::quality(self)
140    }
141
142    fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)] {
143        self.packings()
144    }
145
146    fn internal_type(&self) -> DatumType {
147        K::Acc::datum_type()
148    }
149
150    fn can_fuse(&self, spec: &FusedSpec) -> bool {
151        self.can_fuse(spec)
152    }
153
154    unsafe fn c_view(&self, m_axis: usize, n_axis: usize) -> OutputStoreSpec {
155        OutputStoreSpec::View { m_axis, n_axis, mr: self.mr(), nr: self.nr() }
156    }
157
158    unsafe fn c_from_data_and_strides(
159        &self,
160        item_size: usize,
161        row_stride: isize,
162        col_stride: isize,
163    ) -> OutputStoreSpec {
164        OutputStoreSpec::Strides {
165            row_byte_stride: row_stride * item_size as isize,
166            col_byte_stride: col_stride * item_size as isize,
167            mr: self.mr(),
168            nr: self.nr(),
169        }
170    }
171
172    fn stores(&self) -> Cow<[DatumType]> {
173        self.stores()
174    }
175
176    unsafe fn allocate_scratch_space(&self) -> Box<dyn ScratchSpace> {
177        Box::<ScratchSpaceImpl<K::Acc>>::default()
178    }
179
180    unsafe fn can_use_scratch_space(&self, scratch: &dyn ScratchSpace) -> bool {
181        scratch.downcast_ref::<ScratchSpaceImpl<K::Acc>>().is_some()
182    }
183
184    unsafe fn run_with_scratch_space(
185        &self,
186        m: usize,
187        n: usize,
188        scratch: &mut dyn ScratchSpace,
189        non_linear: &[FusedSpec],
190    ) -> TractResult<()> {
191        let scratch = scratch
192            .downcast_mut::<ScratchSpaceImpl<K::Acc>>()
193            .context("Wrong scratch space type")?;
194        scratch.prepare(self, m, n, non_linear)?;
195        if n == 1 && self.nr() == 1 {
196            run_with_scratch_space_vec(self, m, scratch, non_linear)
197        } else {
198            let (mut prefer_col, mut prefer_row) = (0, 0);
199            for uop in non_linear.iter() {
200                if let Some(col) = uop.prefer_col_outer() {
201                    prefer_col = col as usize;
202                    prefer_row = (!col) as usize;
203                }
204            }
205            if prefer_col > prefer_row {
206                run_with_scratch_space_col_outer(self, m, n, scratch, non_linear)
207            } else {
208                run_with_scratch_space_row_outer(self, m, n, scratch, non_linear)
209            }
210        }
211    }
212}
213
214unsafe fn run_with_scratch_space_vec<K: MatMatMulKer>(
215    ker: &K,
216    m: usize,
217    scratch: &mut ScratchSpaceImpl<K::Acc>,
218    non_linear: &[FusedSpec],
219) -> TractResult<()> {
220    match crate::multithread::current_tract_executor() {
221        Executor::SingleThread => {
222            for ia in 0..m.divceil(ker.mr()) {
223                scratch.run(ker, non_linear, ia, 0)?;
224            }
225            Ok(())
226        }
227        #[cfg(feature = "multithread-mm")]
228        Executor::MultiThread(pool) => pool.install(|| {
229            (0..m.div_ceil(ker.mr()))
230                .into_par_iter()
231                .try_for_each(|ia| scratch.run(ker, non_linear, ia, 0))
232        }),
233    }
234}
235
236unsafe fn run_with_scratch_space_col_outer<K: MatMatMulKer>(
237    ker: &K,
238    m: usize,
239    n: usize,
240    scratch: &mut ScratchSpaceImpl<K::Acc>,
241    non_linear: &[FusedSpec],
242) -> TractResult<()> {
243    match crate::multithread::current_tract_executor() {
244        Executor::SingleThread => {
245            for ib in 0..n.divceil(ker.nr()) {
246                for ia in 0..m.divceil(ker.mr()) {
247                    scratch.run(ker, non_linear, ia, ib)?;
248                }
249            }
250            Ok(())
251        }
252        #[cfg(feature = "multithread-mm")]
253        Executor::MultiThread(pool) => pool.install(|| {
254            (0..n.div_ceil(ker.nr())).into_par_iter().try_for_each(|ib| {
255                for ia in 0..m.divceil(ker.mr()) {
256                    scratch.run(ker, non_linear, ia, ib)?;
257                }
258                Ok(())
259            })
260        }),
261    }
262}
263
264unsafe fn run_with_scratch_space_row_outer<K: MatMatMulKer>(
265    ker: &K,
266    m: usize,
267    n: usize,
268    scratch: &mut ScratchSpaceImpl<K::Acc>,
269    non_linear: &[FusedSpec],
270) -> TractResult<()> {
271    match crate::multithread::current_tract_executor() {
272        Executor::SingleThread => {
273            for ia in 0..m.divceil(ker.mr()) {
274                for ib in 0..n.divceil(ker.nr()) {
275                    scratch.run(ker, non_linear, ia, ib)?;
276                }
277            }
278            Ok(())
279        }
280        #[cfg(feature = "multithread-mm")]
281        Executor::MultiThread(pool) => pool.install(|| {
282            pool.install(|| {
283                (0..m.div_ceil(ker.mr())).into_par_iter().try_for_each(|ia| {
284                    for ib in 0..n.divceil(ker.nr()) {
285                        scratch.run(ker, non_linear, ia, ib)?;
286                    }
287                    Ok(())
288                })
289            })
290        }),
291    }
292}