1#[macro_use]
2mod macros;
3
4pub mod cost_model;
5#[macro_use]
6pub(crate) mod fuse;
7pub(crate) mod input_store;
8pub(crate) mod kernel;
9#[macro_use]
10pub(crate) mod panel_extract;
11mod scratch;
12mod storage;
13
14#[cfg(test)]
15#[macro_use]
16pub mod tests;
17
18use crate::multithread::Executor;
19use std::borrow::Cow;
20use std::cmp::Ordering;
21use std::fmt::Debug;
22use tract_data::internal::*;
23
24pub use cost_model::*;
25pub use fuse::*;
26pub use input_store::*;
27pub use kernel::*;
28pub use panel_extract::*;
29pub use scratch::*;
30pub use storage::*;
31
32pub fn no_prefetch(_ptr: *const u8, _len: usize) {}
33
34#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
35pub enum ImplementationQuality {
36 Dreadful,
38 Generic,
40 RustOptimized,
42 TargetOptimized,
44 ManuallyOptimized,
46}
47
48impl ImplementationQuality {
49 pub fn best_to_worst() -> &'static [ImplementationQuality] {
50 use ImplementationQuality::*;
51 &[ManuallyOptimized, TargetOptimized, RustOptimized, Generic, Dreadful]
52 }
53
54 pub fn cost(&self) -> usize {
55 ImplementationQuality::best_to_worst().iter().position(|x| x == self).unwrap()
56 }
57}
58
59impl PartialOrd for ImplementationQuality {
60 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
61 Some(usize::from(*self).cmp(&usize::from(*other)))
62 }
63}
64
65impl From<ImplementationQuality> for usize {
66 fn from(value: ImplementationQuality) -> Self {
67 value.cost()
68 }
69}
70
71pub trait MatMatMul: Debug + dyn_clone::DynClone + Send + Sync + std::any::Any {
72 fn name(&self) -> &str;
73 fn mr(&self) -> usize;
74 fn nr(&self) -> usize;
75
76 fn quality(&self) -> ImplementationQuality;
77 fn dynamic_boost(&self) -> isize;
78
79 fn is_supported_here(&self) -> bool;
82
83 #[allow(clippy::type_complexity)]
84 fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)];
85
86 fn internal_type(&self) -> DatumType;
87
88 unsafe fn c_view(&self, m_axis: Option<usize>, n_axis: Option<usize>) -> OutputStoreSpec;
89 unsafe fn c_from_data_and_strides(
90 &self,
91 item_size: usize,
92 row_stride: isize,
93 col_stride: isize,
94 ) -> OutputStoreSpec;
95
96 fn can_fuse(&self, spec: &FusedSpec) -> bool;
97
98 fn stores(&self) -> Cow<'_, [DatumType]>;
99
100 unsafe fn run(&self, m: usize, n: usize, non_linear: &[FusedSpec]) -> TractResult<()> {
101 unsafe {
102 let mut scratch = self.allocate_scratch_space();
103 self.run_with_scratch_space(m, n, &mut *scratch, non_linear)
104 }
105 }
106
107 unsafe fn allocate_scratch_space(&self) -> Box<dyn ScratchSpace>;
108 unsafe fn can_use_scratch_space(&self, scratch: &dyn ScratchSpace) -> bool;
109 unsafe fn run_with_scratch_space(
110 &self,
111 m: usize,
112 n: usize,
113 scratch: &mut dyn ScratchSpace,
114 non_linear: &[FusedSpec],
115 ) -> TractResult<()>;
116}
117
118dyn_clone::clone_trait_object!(MatMatMul);
119
120impl PartialEq for Box<dyn MatMatMul> {
121 fn eq(&self, other: &Box<dyn MatMatMul>) -> bool {
122 self.name() == other.name()
123 }
124}
125impl Eq for Box<dyn MatMatMul> {}
126
127impl std::hash::Hash for Box<dyn MatMatMul> {
128 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
129 self.name().hash(state)
130 }
131}
132
133impl<K: MatMatMulKer> MatMatMul for K {
134 fn name(&self) -> &str {
135 self.name()
136 }
137 fn mr(&self) -> usize {
138 self.mr()
139 }
140 fn nr(&self) -> usize {
141 self.nr()
142 }
143
144 fn quality(&self) -> ImplementationQuality {
145 MatMatMulKer::quality(self)
146 }
147
148 fn dynamic_boost(&self) -> isize {
149 MatMatMulKer::dynamic_boost(self)
150 }
151
152 fn is_supported_here(&self) -> bool {
153 MatMatMulKer::is_supported_here(self)
154 }
155
156 fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)] {
157 self.packings()
158 }
159
160 fn internal_type(&self) -> DatumType {
161 K::Acc::datum_type()
162 }
163
164 fn can_fuse(&self, spec: &FusedSpec) -> bool {
165 self.can_fuse(spec)
166 }
167
168 unsafe fn c_view(&self, m_axis: Option<usize>, n_axis: Option<usize>) -> OutputStoreSpec {
169 OutputStoreSpec::View { m_axis, n_axis, mr: self.mr(), nr: self.nr() }
170 }
171
172 unsafe fn c_from_data_and_strides(
173 &self,
174 item_size: usize,
175 row_stride: isize,
176 col_stride: isize,
177 ) -> OutputStoreSpec {
178 OutputStoreSpec::Strides {
179 row_byte_stride: row_stride * item_size as isize,
180 col_byte_stride: col_stride * item_size as isize,
181 mr: self.mr(),
182 nr: self.nr(),
183 }
184 }
185
186 fn stores(&self) -> Cow<'_, [DatumType]> {
187 self.stores()
188 }
189
190 unsafe fn allocate_scratch_space(&self) -> Box<dyn ScratchSpace> {
191 Box::<ScratchSpaceImpl<K::Acc>>::default()
192 }
193
194 unsafe fn can_use_scratch_space(&self, scratch: &dyn ScratchSpace) -> bool {
195 scratch.downcast_ref::<ScratchSpaceImpl<K::Acc>>().is_some()
196 }
197
198 unsafe fn run_with_scratch_space(
199 &self,
200 m: usize,
201 n: usize,
202 scratch: &mut dyn ScratchSpace,
203 non_linear: &[FusedSpec],
204 ) -> TractResult<()> {
205 unsafe {
206 let scratch = scratch
207 .downcast_mut::<ScratchSpaceImpl<K::Acc>>()
208 .context("Wrong scratch space type")?;
209 scratch.prepare(self, m, n, non_linear)?;
210 if n == 1 && self.nr() == 1 {
211 run_with_scratch_space_vec(self, m, scratch, non_linear)
212 } else {
213 let (mut prefer_col, mut prefer_row) = (0, 0);
214 for uop in non_linear.iter() {
215 if let Some(col) = uop.prefer_col_outer() {
216 prefer_col = col as usize;
217 prefer_row = (!col) as usize;
218 }
219 }
220 let k = non_linear
223 .iter()
224 .find_map(|f| match f {
225 FusedSpec::AddMatMul { a, .. } => Some(a.k()),
226 _ => None,
227 })
228 .unwrap_or(0);
229 if prefer_col > prefer_row {
230 run_with_scratch_space_col_outer(self, m, n, k, scratch, non_linear)
231 } else {
232 run_with_scratch_space_row_outer(self, m, n, k, scratch, non_linear)
233 }
234 }
235 }
236 }
237}
238
239unsafe fn run_with_scratch_space_vec<K: MatMatMulKer>(
240 ker: &K,
241 m: usize,
242 scratch: &mut ScratchSpaceImpl<K::Acc>,
243 non_linear: &[FusedSpec],
244) -> TractResult<()> {
245 unsafe {
246 match crate::multithread::current_tract_executor() {
247 Executor::SingleThread => scratch.run_in_tls_scope(|scratch, tls| {
248 for ia in 0..m.divceil(ker.mr()) {
249 scratch.run_one_tile(ker, non_linear, tls, ia, 0)?;
250 }
251 TractResult::Ok(())
252 }),
253 #[cfg(feature = "multithread-mm")]
254 Executor::MultiThread(pool) => chunked_dispatch_rayon(
255 Some(&pool),
256 m.divceil(ker.mr()),
257 1,
258 |ia_start, ia_end, _, _| {
259 scratch.run_in_tls_scope(|scratch, tls| {
260 for ia in ia_start..ia_end {
261 scratch.run_one_tile(ker, non_linear, tls, ia, 0)?;
262 }
263 TractResult::Ok(())
264 })
265 },
266 ),
267 #[cfg(feature = "multithread-mm")]
268 Executor::RayonGlobal => {
269 chunked_dispatch_rayon(None, m.divceil(ker.mr()), 1, |ia_start, ia_end, _, _| {
270 scratch.run_in_tls_scope(|scratch, tls| {
271 for ia in ia_start..ia_end {
272 scratch.run_one_tile(ker, non_linear, tls, ia, 0)?;
273 }
274 TractResult::Ok(())
275 })
276 })
277 }
278 }
279 }
280}
281
282const ST_BLK_MAX: usize = 16;
285
286#[cfg(target_os = "linux")]
287fn parse_cache_size(s: &str) -> usize {
288 let s = s.trim();
289 let (num, mult) = if let Some(n) = s.strip_suffix(['K', 'k']) {
290 (n, 1024)
291 } else if let Some(n) = s.strip_suffix(['M', 'm']) {
292 (n, 1024 * 1024)
293 } else {
294 (s, 1)
295 };
296 num.trim().parse::<usize>().unwrap_or(0) * mult
297}
298
299fn detect_l2_bytes() -> usize {
303 static L2: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
304 *L2.get_or_init(|| {
305 #[cfg(target_os = "macos")]
306 {
307 let sysctl = |k: &str| -> Option<usize> {
308 let o = std::process::Command::new("sysctl").arg("-n").arg(k).output().ok()?;
309 if !o.status.success() {
310 return None;
311 }
312 String::from_utf8_lossy(&o.stdout).trim().parse().ok()
313 };
314 sysctl("hw.perflevel0.l2cachesize").or_else(|| sysctl("hw.l2cachesize")).unwrap_or(0)
316 }
317 #[cfg(target_os = "linux")]
318 {
319 for idx in [2usize, 3] {
321 if let Ok(s) = std::fs::read_to_string(format!(
322 "/sys/devices/system/cpu/cpu0/cache/index{idx}/size"
323 )) {
324 let b = parse_cache_size(s.trim());
325 if b > 0 {
326 return b;
327 }
328 }
329 }
330 0
331 }
332 #[cfg(not(any(target_os = "macos", target_os = "linux")))]
333 {
334 0
335 }
336 })
337}
338
339fn block_budget_bytes() -> usize {
344 let l2 = detect_l2_bytes();
345 if l2 == 0 { 256 * 1024 } else { (l2 / 3).clamp(64 * 1024, 8 * 1024 * 1024) }
346}
347
348#[inline]
355fn st_block_edge(mr: usize, nr: usize, k: usize, elem_bytes: usize) -> usize {
356 if k == 0 {
357 return ST_BLK_MAX;
358 }
359 let per_blk = ((mr + nr) * k * elem_bytes.max(1)).max(1);
360 (block_budget_bytes() / per_blk).clamp(1, ST_BLK_MAX)
361}
362
363#[inline]
370unsafe fn run_single_thread_blocked<K: MatMatMulKer>(
371 ker: &K,
372 m_panels: usize,
373 n_panels: usize,
374 k: usize,
375 col_outer: bool,
376 scratch: &mut ScratchSpaceImpl<K::Acc>,
377 non_linear: &[FusedSpec],
378) -> TractResult<()> {
379 unsafe {
380 let blk = st_block_edge(ker.mr(), ker.nr(), k, K::Acc::datum_type().size_of());
381 scratch.run_in_tls_scope(|scratch, tls| {
382 let mut jb = 0;
383 while jb < n_panels {
384 let jb_end = (jb + blk).min(n_panels);
385 let mut ja = 0;
386 while ja < m_panels {
387 let ja_end = (ja + blk).min(m_panels);
388 if col_outer {
389 for ib in jb..jb_end {
390 for ia in ja..ja_end {
391 scratch.run_one_tile(ker, non_linear, tls, ia, ib)?;
392 }
393 }
394 } else {
395 for ia in ja..ja_end {
396 for ib in jb..jb_end {
397 scratch.run_one_tile(ker, non_linear, tls, ia, ib)?;
398 }
399 }
400 }
401 ja = ja_end;
402 }
403 jb = jb_end;
404 }
405 TractResult::Ok(())
406 })
407 }
408}
409
410unsafe fn run_with_scratch_space_col_outer<K: MatMatMulKer>(
411 ker: &K,
412 m: usize,
413 n: usize,
414 k: usize,
415 scratch: &mut ScratchSpaceImpl<K::Acc>,
416 non_linear: &[FusedSpec],
417) -> TractResult<()> {
418 unsafe {
419 match crate::multithread::current_tract_executor() {
420 Executor::SingleThread => run_single_thread_blocked(
421 ker,
422 m.divceil(ker.mr()),
423 n.divceil(ker.nr()),
424 k,
425 true,
426 scratch,
427 non_linear,
428 ),
429 #[cfg(feature = "multithread-mm")]
430 Executor::MultiThread(pool) => chunked_dispatch_rayon(
431 Some(&pool),
432 m.divceil(ker.mr()),
433 n.divceil(ker.nr()),
434 |ia_start, ia_end, ib_start, ib_end| {
435 scratch.run_in_tls_scope(|scratch, tls| {
436 for ib in ib_start..ib_end {
437 for ia in ia_start..ia_end {
438 scratch.run_one_tile(ker, non_linear, tls, ia, ib)?;
439 }
440 }
441 TractResult::Ok(())
442 })
443 },
444 ),
445 #[cfg(feature = "multithread-mm")]
446 Executor::RayonGlobal => chunked_dispatch_rayon(
447 None,
448 m.divceil(ker.mr()),
449 n.divceil(ker.nr()),
450 |ia_start, ia_end, ib_start, ib_end| {
451 scratch.run_in_tls_scope(|scratch, tls| {
452 for ib in ib_start..ib_end {
453 for ia in ia_start..ia_end {
454 scratch.run_one_tile(ker, non_linear, tls, ia, ib)?;
455 }
456 }
457 TractResult::Ok(())
458 })
459 },
460 ),
461 }
462 }
463}
464
465unsafe fn run_with_scratch_space_row_outer<K: MatMatMulKer>(
466 ker: &K,
467 m: usize,
468 n: usize,
469 k: usize,
470 scratch: &mut ScratchSpaceImpl<K::Acc>,
471 non_linear: &[FusedSpec],
472) -> TractResult<()> {
473 unsafe {
474 match crate::multithread::current_tract_executor() {
475 Executor::SingleThread => run_single_thread_blocked(
476 ker,
477 m.divceil(ker.mr()),
478 n.divceil(ker.nr()),
479 k,
480 false,
481 scratch,
482 non_linear,
483 ),
484 #[cfg(feature = "multithread-mm")]
485 Executor::MultiThread(pool) => chunked_dispatch_rayon(
486 Some(&pool),
487 m.divceil(ker.mr()),
488 n.divceil(ker.nr()),
489 |ia_start, ia_end, ib_start, ib_end| {
490 scratch.run_in_tls_scope(|scratch, tls| {
491 for ia in ia_start..ia_end {
492 for ib in ib_start..ib_end {
493 scratch.run_one_tile(ker, non_linear, tls, ia, ib)?;
494 }
495 }
496 TractResult::Ok(())
497 })
498 },
499 ),
500 #[cfg(feature = "multithread-mm")]
501 Executor::RayonGlobal => chunked_dispatch_rayon(
502 None,
503 m.divceil(ker.mr()),
504 n.divceil(ker.nr()),
505 |ia_start, ia_end, ib_start, ib_end| {
506 scratch.run_in_tls_scope(|scratch, tls| {
507 for ia in ia_start..ia_end {
508 for ib in ib_start..ib_end {
509 scratch.run_one_tile(ker, non_linear, tls, ia, ib)?;
510 }
511 }
512 TractResult::Ok(())
513 })
514 },
515 ),
516 }
517 }
518}
519
520#[cfg(feature = "multithread-mm")]
530fn chunk_grid(n_panels_m: usize, n_panels_n: usize, nth: usize) -> (usize, usize, usize, usize) {
531 let chunk_size = if n_panels_m == 1 || n_panels_n == 1 { 64 } else { 16 };
532 let mut nchunks_m = n_panels_m.div_ceil(chunk_size);
533 let mut nchunks_n = n_panels_n.div_ceil(chunk_size);
534 if nchunks_m * nchunks_n < 4 * nth {
535 if n_panels_m > n_panels_n {
536 nchunks_m = nth;
537 nchunks_n = 1;
538 } else {
539 nchunks_m = 1;
540 nchunks_n = nth;
541 }
542 }
543 let dr_m = n_panels_m.div_ceil(nchunks_m).max(1);
544 let dr_n = n_panels_n.div_ceil(nchunks_n).max(1);
545 (nchunks_m, nchunks_n, dr_m, dr_n)
546}
547
548#[cfg(feature = "multithread-mm")]
568unsafe fn chunked_dispatch_rayon<F>(
569 pool: Option<&rayon::ThreadPool>,
570 n_panels_m: usize,
571 n_panels_n: usize,
572 run_chunk: F,
573) -> TractResult<()>
574where
575 F: Fn(usize, usize, usize, usize) -> TractResult<()> + Sync,
576{
577 use rayon::prelude::*;
578 if n_panels_m == 0 || n_panels_n == 0 {
579 return Ok(());
580 }
581 if n_panels_m * n_panels_n < crate::multithread::current_threading_panel_threshold() {
582 return run_chunk(0, n_panels_m, 0, n_panels_n);
585 }
586 let use_global = pool.is_none_or(|p| p.current_num_threads() <= 1);
587 let body = || {
588 let nth = rayon::current_num_threads();
589 let (nchunks_m, nchunks_n, dr_m, dr_n) = chunk_grid(n_panels_m, n_panels_n, nth);
590 let total = nchunks_m * nchunks_n;
591 (0..total).into_par_iter().try_for_each(|idx| {
592 let im = idx % nchunks_m;
593 let in_ = idx / nchunks_m;
594 let ia_start = im * dr_m;
595 let ia_end = (ia_start + dr_m).min(n_panels_m);
596 let ib_start = in_ * dr_n;
597 let ib_end = (ib_start + dr_n).min(n_panels_n);
598 run_chunk(ia_start, ia_end, ib_start, ib_end)
599 })
600 };
601 if use_global { body() } else { pool.unwrap().install(body) }
602}