Skip to main content

tract_linalg/frame/mmm/
mod.rs

1#[macro_use]
2mod macros;
3
4pub mod cost_model;
5#[macro_use]
6pub(crate) mod fuse;
7pub(crate) mod input_store;
8pub(crate) mod kernel;
9#[macro_use]
10pub(crate) mod panel_extract;
11mod scratch;
12mod storage;
13
14#[cfg(test)]
15#[macro_use]
16pub mod tests;
17
18use crate::multithread::Executor;
19use std::borrow::Cow;
20use std::cmp::Ordering;
21use std::fmt::Debug;
22use tract_data::internal::*;
23
24pub use cost_model::*;
25pub use fuse::*;
26pub use input_store::*;
27pub use kernel::*;
28pub use panel_extract::*;
29pub use scratch::*;
30pub use storage::*;
31
32pub fn no_prefetch(_ptr: *const u8, _len: usize) {}
33
34#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
35pub enum ImplementationQuality {
36    /// Individual operations are emulated by individual conversion (f16->f32->f16)
37    Dreadful,
38    /// Rust scalar operation (with whatever optimisation the compiler manages)
39    Generic,
40    /// Implicit vectorization (e.g. Rust code, some unrolled loops, explicit template instantiations for small constant)
41    RustOptimized,
42    /// Explicit vectorization (e.g. intrinsics vector code)
43    TargetOptimized,
44    /// Hand optimized (assembly)
45    ManuallyOptimized,
46}
47
48impl ImplementationQuality {
49    pub fn best_to_worst() -> &'static [ImplementationQuality] {
50        use ImplementationQuality::*;
51        &[ManuallyOptimized, TargetOptimized, RustOptimized, Generic, Dreadful]
52    }
53
54    pub fn cost(&self) -> usize {
55        ImplementationQuality::best_to_worst().iter().position(|x| x == self).unwrap()
56    }
57}
58
59impl PartialOrd for ImplementationQuality {
60    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
61        Some(usize::from(*self).cmp(&usize::from(*other)))
62    }
63}
64
65impl From<ImplementationQuality> for usize {
66    fn from(value: ImplementationQuality) -> Self {
67        value.cost()
68    }
69}
70
71pub trait MatMatMul: Debug + dyn_clone::DynClone + Send + Sync + std::any::Any {
72    fn name(&self) -> &str;
73    fn mr(&self) -> usize;
74    fn nr(&self) -> usize;
75
76    fn quality(&self) -> ImplementationQuality;
77    fn dynamic_boost(&self) -> isize;
78
79    /// Whether this kernel is runnable on the current CPU (platform feature
80    /// gate, e.g. FEAT_DotProd for the SDOT i8 kernel).
81    fn is_supported_here(&self) -> bool;
82
83    #[allow(clippy::type_complexity)]
84    fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)];
85
86    fn internal_type(&self) -> DatumType;
87
88    unsafe fn c_view(&self, m_axis: Option<usize>, n_axis: Option<usize>) -> OutputStoreSpec;
89    unsafe fn c_from_data_and_strides(
90        &self,
91        item_size: usize,
92        row_stride: isize,
93        col_stride: isize,
94    ) -> OutputStoreSpec;
95
96    fn can_fuse(&self, spec: &FusedSpec) -> bool;
97
98    fn stores(&self) -> Cow<'_, [DatumType]>;
99
100    unsafe fn run(&self, m: usize, n: usize, non_linear: &[FusedSpec]) -> TractResult<()> {
101        unsafe {
102            let mut scratch = self.allocate_scratch_space();
103            self.run_with_scratch_space(m, n, &mut *scratch, non_linear)
104        }
105    }
106
107    unsafe fn allocate_scratch_space(&self) -> Box<dyn ScratchSpace>;
108    unsafe fn can_use_scratch_space(&self, scratch: &dyn ScratchSpace) -> bool;
109    unsafe fn run_with_scratch_space(
110        &self,
111        m: usize,
112        n: usize,
113        scratch: &mut dyn ScratchSpace,
114        non_linear: &[FusedSpec],
115    ) -> TractResult<()>;
116}
117
118dyn_clone::clone_trait_object!(MatMatMul);
119
120impl PartialEq for Box<dyn MatMatMul> {
121    fn eq(&self, other: &Box<dyn MatMatMul>) -> bool {
122        self.name() == other.name()
123    }
124}
125impl Eq for Box<dyn MatMatMul> {}
126
127impl std::hash::Hash for Box<dyn MatMatMul> {
128    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
129        self.name().hash(state)
130    }
131}
132
133impl<K: MatMatMulKer> MatMatMul for K {
134    fn name(&self) -> &str {
135        self.name()
136    }
137    fn mr(&self) -> usize {
138        self.mr()
139    }
140    fn nr(&self) -> usize {
141        self.nr()
142    }
143
144    fn quality(&self) -> ImplementationQuality {
145        MatMatMulKer::quality(self)
146    }
147
148    fn dynamic_boost(&self) -> isize {
149        MatMatMulKer::dynamic_boost(self)
150    }
151
152    fn is_supported_here(&self) -> bool {
153        MatMatMulKer::is_supported_here(self)
154    }
155
156    fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)] {
157        self.packings()
158    }
159
160    fn internal_type(&self) -> DatumType {
161        K::Acc::datum_type()
162    }
163
164    fn can_fuse(&self, spec: &FusedSpec) -> bool {
165        self.can_fuse(spec)
166    }
167
168    unsafe fn c_view(&self, m_axis: Option<usize>, n_axis: Option<usize>) -> OutputStoreSpec {
169        OutputStoreSpec::View { m_axis, n_axis, mr: self.mr(), nr: self.nr() }
170    }
171
172    unsafe fn c_from_data_and_strides(
173        &self,
174        item_size: usize,
175        row_stride: isize,
176        col_stride: isize,
177    ) -> OutputStoreSpec {
178        OutputStoreSpec::Strides {
179            row_byte_stride: row_stride * item_size as isize,
180            col_byte_stride: col_stride * item_size as isize,
181            mr: self.mr(),
182            nr: self.nr(),
183        }
184    }
185
186    fn stores(&self) -> Cow<'_, [DatumType]> {
187        self.stores()
188    }
189
190    unsafe fn allocate_scratch_space(&self) -> Box<dyn ScratchSpace> {
191        Box::<ScratchSpaceImpl<K::Acc>>::default()
192    }
193
194    unsafe fn can_use_scratch_space(&self, scratch: &dyn ScratchSpace) -> bool {
195        scratch.downcast_ref::<ScratchSpaceImpl<K::Acc>>().is_some()
196    }
197
198    unsafe fn run_with_scratch_space(
199        &self,
200        m: usize,
201        n: usize,
202        scratch: &mut dyn ScratchSpace,
203        non_linear: &[FusedSpec],
204    ) -> TractResult<()> {
205        unsafe {
206            let scratch = scratch
207                .downcast_mut::<ScratchSpaceImpl<K::Acc>>()
208                .context("Wrong scratch space type")?;
209            scratch.prepare(self, m, n, non_linear)?;
210            if n == 1 && self.nr() == 1 {
211                run_with_scratch_space_vec(self, m, scratch, non_linear)
212            } else {
213                let (mut prefer_col, mut prefer_row) = (0, 0);
214                for uop in non_linear.iter() {
215                    if let Some(col) = uop.prefer_col_outer() {
216                        prefer_col = col as usize;
217                        prefer_row = (!col) as usize;
218                    }
219                }
220                // k drives the single-thread cache-block size; read it from the
221                // first AddMatMul's packed input (0 if none → max block).
222                let k = non_linear
223                    .iter()
224                    .find_map(|f| match f {
225                        FusedSpec::AddMatMul { a, .. } => Some(a.k()),
226                        _ => None,
227                    })
228                    .unwrap_or(0);
229                if prefer_col > prefer_row {
230                    run_with_scratch_space_col_outer(self, m, n, k, scratch, non_linear)
231                } else {
232                    run_with_scratch_space_row_outer(self, m, n, k, scratch, non_linear)
233                }
234            }
235        }
236    }
237}
238
239unsafe fn run_with_scratch_space_vec<K: MatMatMulKer>(
240    ker: &K,
241    m: usize,
242    scratch: &mut ScratchSpaceImpl<K::Acc>,
243    non_linear: &[FusedSpec],
244) -> TractResult<()> {
245    unsafe {
246        match crate::multithread::current_tract_executor() {
247            Executor::SingleThread => scratch.run_in_tls_scope(|scratch, tls| {
248                for ia in 0..m.divceil(ker.mr()) {
249                    scratch.run_one_tile(ker, non_linear, tls, ia, 0)?;
250                }
251                TractResult::Ok(())
252            }),
253            #[cfg(feature = "multithread-mm")]
254            Executor::MultiThread(pool) => chunked_dispatch_rayon(
255                Some(&pool),
256                m.divceil(ker.mr()),
257                1,
258                |ia_start, ia_end, _, _| {
259                    scratch.run_in_tls_scope(|scratch, tls| {
260                        for ia in ia_start..ia_end {
261                            scratch.run_one_tile(ker, non_linear, tls, ia, 0)?;
262                        }
263                        TractResult::Ok(())
264                    })
265                },
266            ),
267            #[cfg(feature = "multithread-mm")]
268            Executor::RayonGlobal => {
269                chunked_dispatch_rayon(None, m.divceil(ker.mr()), 1, |ia_start, ia_end, _, _| {
270                    scratch.run_in_tls_scope(|scratch, tls| {
271                        for ia in ia_start..ia_end {
272                            scratch.run_one_tile(ker, non_linear, tls, ia, 0)?;
273                        }
274                        TractResult::Ok(())
275                    })
276                })
277            }
278        }
279    }
280}
281
282/// Upper bound on the single-thread panel-block edge (matches the multithread
283/// `chunk_grid` default).
284const ST_BLK_MAX: usize = 16;
285
286#[cfg(target_os = "linux")]
287fn parse_cache_size(s: &str) -> usize {
288    let s = s.trim();
289    let (num, mult) = if let Some(n) = s.strip_suffix(['K', 'k']) {
290        (n, 1024)
291    } else if let Some(n) = s.strip_suffix(['M', 'm']) {
292        (n, 1024 * 1024)
293    } else {
294        (s, 1)
295    };
296    num.trim().parse::<usize>().unwrap_or(0) * mult
297}
298
299/// Best-effort L2 data-cache size in bytes (per perf-core / cluster); 0 if
300/// unknown. Cached. Used to size the single-thread cache-block budget so it is
301/// correct across hardware instead of a hard-coded constant.
302fn detect_l2_bytes() -> usize {
303    static L2: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
304    *L2.get_or_init(|| {
305        #[cfg(target_os = "macos")]
306        {
307            let sysctl = |k: &str| -> Option<usize> {
308                let o = std::process::Command::new("sysctl").arg("-n").arg(k).output().ok()?;
309                if !o.status.success() {
310                    return None;
311                }
312                String::from_utf8_lossy(&o.stdout).trim().parse().ok()
313            };
314            // Prefer the performance-core L2 on hybrid Apple Silicon.
315            sysctl("hw.perflevel0.l2cachesize").or_else(|| sysctl("hw.l2cachesize")).unwrap_or(0)
316        }
317        #[cfg(target_os = "linux")]
318        {
319            // index2/index3 is typically the unified L2 (index0/1 are L1 d/i).
320            for idx in [2usize, 3] {
321                if let Ok(s) = std::fs::read_to_string(format!(
322                    "/sys/devices/system/cpu/cpu0/cache/index{idx}/size"
323                )) {
324                    let b = parse_cache_size(s.trim());
325                    if b > 0 {
326                        return b;
327                    }
328                }
329            }
330            0
331        }
332        #[cfg(not(any(target_os = "macos", target_os = "linux")))]
333        {
334            0
335        }
336    })
337}
338
339/// Working-set budget (bytes) for the single-thread cache-block: ~a third of L2
340/// (leaving room for the C accumulator tile + packing metadata). Conservative
341/// 256 KiB fallback when L2 is unknown (WASM/Windows/BSD) ⇒ small blocks ≈ the
342/// naive loop, so it can never over-block a cache it can't see.
343fn block_budget_bytes() -> usize {
344    let l2 = detect_l2_bytes();
345    if l2 == 0 { 256 * 1024 } else { (l2 / 3).clamp(64 * 1024, 8 * 1024 * 1024) }
346}
347
348/// Cache-adaptive panel-block edge: large enough to amortise streaming, small
349/// enough that the block's A+B sub-panels (`~blk·(mr+nr)·k·elem_bytes`) stay
350/// L2-resident at the given `k`. Capped at [`ST_BLK_MAX`]; the floor of 1
351/// degrades exactly to the naive loop, so an unknown/small cache can never
352/// over-block (regression-safe). The budget is **cache-size derived** (not a
353/// hard-coded constant), so it is correct across hardware.
354#[inline]
355fn st_block_edge(mr: usize, nr: usize, k: usize, elem_bytes: usize) -> usize {
356    if k == 0 {
357        return ST_BLK_MAX;
358    }
359    let per_blk = ((mr + nr) * k * elem_bytes.max(1)).max(1);
360    (block_budget_bytes() / per_blk).clamp(1, ST_BLK_MAX)
361}
362
363/// Single-thread tile walk over the `m_panels × n_panels` grid, blocked into
364/// cache-sized panel blocks for locality (the naive nested loop re-streams the
365/// whole inner operand per outer panel at large k; the multithread path already
366/// blocks this way via `chunk_grid`). `col_outer` selects the within-block inner
367/// order (B-reuse vs A-reuse). Reordering independent tiles changes no result —
368/// bit-exact with the naive loop.
369#[inline]
370unsafe fn run_single_thread_blocked<K: MatMatMulKer>(
371    ker: &K,
372    m_panels: usize,
373    n_panels: usize,
374    k: usize,
375    col_outer: bool,
376    scratch: &mut ScratchSpaceImpl<K::Acc>,
377    non_linear: &[FusedSpec],
378) -> TractResult<()> {
379    unsafe {
380        let blk = st_block_edge(ker.mr(), ker.nr(), k, K::Acc::datum_type().size_of());
381        scratch.run_in_tls_scope(|scratch, tls| {
382            let mut jb = 0;
383            while jb < n_panels {
384                let jb_end = (jb + blk).min(n_panels);
385                let mut ja = 0;
386                while ja < m_panels {
387                    let ja_end = (ja + blk).min(m_panels);
388                    if col_outer {
389                        for ib in jb..jb_end {
390                            for ia in ja..ja_end {
391                                scratch.run_one_tile(ker, non_linear, tls, ia, ib)?;
392                            }
393                        }
394                    } else {
395                        for ia in ja..ja_end {
396                            for ib in jb..jb_end {
397                                scratch.run_one_tile(ker, non_linear, tls, ia, ib)?;
398                            }
399                        }
400                    }
401                    ja = ja_end;
402                }
403                jb = jb_end;
404            }
405            TractResult::Ok(())
406        })
407    }
408}
409
410unsafe fn run_with_scratch_space_col_outer<K: MatMatMulKer>(
411    ker: &K,
412    m: usize,
413    n: usize,
414    k: usize,
415    scratch: &mut ScratchSpaceImpl<K::Acc>,
416    non_linear: &[FusedSpec],
417) -> TractResult<()> {
418    unsafe {
419        match crate::multithread::current_tract_executor() {
420            Executor::SingleThread => run_single_thread_blocked(
421                ker,
422                m.divceil(ker.mr()),
423                n.divceil(ker.nr()),
424                k,
425                true,
426                scratch,
427                non_linear,
428            ),
429            #[cfg(feature = "multithread-mm")]
430            Executor::MultiThread(pool) => chunked_dispatch_rayon(
431                Some(&pool),
432                m.divceil(ker.mr()),
433                n.divceil(ker.nr()),
434                |ia_start, ia_end, ib_start, ib_end| {
435                    scratch.run_in_tls_scope(|scratch, tls| {
436                        for ib in ib_start..ib_end {
437                            for ia in ia_start..ia_end {
438                                scratch.run_one_tile(ker, non_linear, tls, ia, ib)?;
439                            }
440                        }
441                        TractResult::Ok(())
442                    })
443                },
444            ),
445            #[cfg(feature = "multithread-mm")]
446            Executor::RayonGlobal => chunked_dispatch_rayon(
447                None,
448                m.divceil(ker.mr()),
449                n.divceil(ker.nr()),
450                |ia_start, ia_end, ib_start, ib_end| {
451                    scratch.run_in_tls_scope(|scratch, tls| {
452                        for ib in ib_start..ib_end {
453                            for ia in ia_start..ia_end {
454                                scratch.run_one_tile(ker, non_linear, tls, ia, ib)?;
455                            }
456                        }
457                        TractResult::Ok(())
458                    })
459                },
460            ),
461        }
462    }
463}
464
465unsafe fn run_with_scratch_space_row_outer<K: MatMatMulKer>(
466    ker: &K,
467    m: usize,
468    n: usize,
469    k: usize,
470    scratch: &mut ScratchSpaceImpl<K::Acc>,
471    non_linear: &[FusedSpec],
472) -> TractResult<()> {
473    unsafe {
474        match crate::multithread::current_tract_executor() {
475            Executor::SingleThread => run_single_thread_blocked(
476                ker,
477                m.divceil(ker.mr()),
478                n.divceil(ker.nr()),
479                k,
480                false,
481                scratch,
482                non_linear,
483            ),
484            #[cfg(feature = "multithread-mm")]
485            Executor::MultiThread(pool) => chunked_dispatch_rayon(
486                Some(&pool),
487                m.divceil(ker.mr()),
488                n.divceil(ker.nr()),
489                |ia_start, ia_end, ib_start, ib_end| {
490                    scratch.run_in_tls_scope(|scratch, tls| {
491                        for ia in ia_start..ia_end {
492                            for ib in ib_start..ib_end {
493                                scratch.run_one_tile(ker, non_linear, tls, ia, ib)?;
494                            }
495                        }
496                        TractResult::Ok(())
497                    })
498                },
499            ),
500            #[cfg(feature = "multithread-mm")]
501            Executor::RayonGlobal => chunked_dispatch_rayon(
502                None,
503                m.divceil(ker.mr()),
504                n.divceil(ker.nr()),
505                |ia_start, ia_end, ib_start, ib_end| {
506                    scratch.run_in_tls_scope(|scratch, tls| {
507                        for ia in ia_start..ia_end {
508                            for ib in ib_start..ib_end {
509                                scratch.run_one_tile(ker, non_linear, tls, ia, ib)?;
510                            }
511                        }
512                        TractResult::Ok(())
513                    })
514                },
515            ),
516        }
517    }
518}
519
520/// Chunk grid for the 2D dispatch.
521///
522/// Mirrors ggml's `mul_mat` heuristic (`ggml/src/ggml-cpu/ggml-cpu.c:1378-1398`):
523///  * 16-tile panel chunks by default;
524///  * 64-tile chunks when one dimension is 1 (vec / vec-mat);
525///  * fallback to "block-per-thread along the longer axis" when the natural
526///    grid would have fewer than `4·nth` chunks.
527///
528/// Returns `(nchunks_m, nchunks_n, dr_m, dr_n)`.
529#[cfg(feature = "multithread-mm")]
530fn chunk_grid(n_panels_m: usize, n_panels_n: usize, nth: usize) -> (usize, usize, usize, usize) {
531    let chunk_size = if n_panels_m == 1 || n_panels_n == 1 { 64 } else { 16 };
532    let mut nchunks_m = n_panels_m.div_ceil(chunk_size);
533    let mut nchunks_n = n_panels_n.div_ceil(chunk_size);
534    if nchunks_m * nchunks_n < 4 * nth {
535        if n_panels_m > n_panels_n {
536            nchunks_m = nth;
537            nchunks_n = 1;
538        } else {
539            nchunks_m = 1;
540            nchunks_n = nth;
541        }
542    }
543    let dr_m = n_panels_m.div_ceil(nchunks_m).max(1);
544    let dr_n = n_panels_n.div_ceil(nchunks_n).max(1);
545    (nchunks_m, nchunks_n, dr_m, dr_n)
546}
547
548/// 2D chunked dispatcher across the (m_panels × n_panels) grid for the
549/// rayon path. Replaces a 1D `into_par_iter` over a single panel axis.
550/// Better-utilises threads on small/skewed shapes where one dimension has
551/// fewer panels than there are workers.
552///
553/// The closure receives **chunk bounds** (`ia_start, ia_end, ib_start, ib_end`),
554/// not per-tile indices. This lets the caller amortise per-worker setup
555/// (e.g. `ScratchSpaceImpl::run_in_tls_scope`) across all tiles in the
556/// chunk, mirroring #2206 for the multi-threaded path. The closure is
557/// invoked exactly once per rayon work item (and once total when the
558/// small-graph fallback path is taken).
559///
560/// `pool`:
561///   * `Some(p)` with `p.current_num_threads() > 1` → scoped via `p.install`
562///     (native, custom pool path).
563///   * `Some(p)` with single-thread pool, or `None` → dispatched via
564///     `into_par_iter` directly, which uses rayon's GLOBAL pool. This is
565///     the only working path on `wasm32-unknown-unknown` via
566///     `wasm_bindgen_rayon::init_thread_pool`.
567#[cfg(feature = "multithread-mm")]
568unsafe fn chunked_dispatch_rayon<F>(
569    pool: Option<&rayon::ThreadPool>,
570    n_panels_m: usize,
571    n_panels_n: usize,
572    run_chunk: F,
573) -> TractResult<()>
574where
575    F: Fn(usize, usize, usize, usize) -> TractResult<()> + Sync,
576{
577    use rayon::prelude::*;
578    if n_panels_m == 0 || n_panels_n == 0 {
579        return Ok(());
580    }
581    if n_panels_m * n_panels_n < crate::multithread::current_threading_panel_threshold() {
582        // Below the threading threshold: run the whole grid as a single chunk
583        // on the calling thread. Closure handles its own TLS scope.
584        return run_chunk(0, n_panels_m, 0, n_panels_n);
585    }
586    let use_global = pool.is_none_or(|p| p.current_num_threads() <= 1);
587    let body = || {
588        let nth = rayon::current_num_threads();
589        let (nchunks_m, nchunks_n, dr_m, dr_n) = chunk_grid(n_panels_m, n_panels_n, nth);
590        let total = nchunks_m * nchunks_n;
591        (0..total).into_par_iter().try_for_each(|idx| {
592            let im = idx % nchunks_m;
593            let in_ = idx / nchunks_m;
594            let ia_start = im * dr_m;
595            let ia_end = (ia_start + dr_m).min(n_panels_m);
596            let ib_start = in_ * dr_n;
597            let ib_end = (ib_start + dr_n).min(n_panels_n);
598            run_chunk(ia_start, ia_end, ib_start, ib_end)
599        })
600    };
601    if use_global { body() } else { pool.unwrap().install(body) }
602}