Skip to main content

tract_linalg/frame/mmm/
scratch.rs

1use super::{FusedKerSpec, FusedSpec, MatMatMulKer, OutputStoreKer};
2use crate::{BinOp, LADatum};
3use downcast_rs::{Downcast, impl_downcast};
4use std::cell::RefCell;
5use std::fmt::Debug;
6use std::sync::atomic::AtomicUsize;
7use tract_data::internal::num_integer::Integer;
8use tract_data::internal::*;
9
10static GENERATION: AtomicUsize = AtomicUsize::new(1);
11
12thread_local! {
13    static TLS: RefCell<TLSScratch> = Default::default();
14}
15
16#[derive(Default, Debug)]
17pub(crate) struct TLSScratch {
18    generation: usize,
19    blob: Blob,
20    ker_specs_16: Vec<FusedKerSpec<f16>>,
21    ker_specs_32: Vec<FusedKerSpec<f32>>,
22    ker_specs_64: Vec<FusedKerSpec<f64>>,
23}
24
25impl TLSScratch {
26    #[allow(unknown_lints, clippy::missing_transmute_annotations)]
27    fn ker_specs<TI: LADatum>(&mut self) -> &mut Vec<FusedKerSpec<TI>> {
28        unsafe {
29            if TI::datum_type() == f32::datum_type() || TI::datum_type() == i32::datum_type() {
30                std::mem::transmute(&mut self.ker_specs_32)
31            } else if TI::datum_type() == f16::datum_type() {
32                std::mem::transmute(&mut self.ker_specs_16)
33            } else if TI::datum_type() == f64::datum_type() {
34                std::mem::transmute(&mut self.ker_specs_64)
35            } else {
36                todo!();
37            }
38        }
39    }
40
41    fn sync<TI: LADatum>(&mut self, scratch: &ScratchSpaceImpl<TI>) {
42        if self.generation == scratch.generation {
43            return;
44        }
45        let ker_specs = self.ker_specs::<TI>();
46        ker_specs.clear();
47        ker_specs.extend_from_slice(&scratch.ker_specs);
48
49        unsafe {
50            self.blob.ensure_size_and_align(scratch.blob_size, scratch.blob_align);
51
52            for LocDependant { loc, ker_spec, .. } in &scratch.loc_dependant {
53                #[allow(clippy::single_match)]
54                if matches!(scratch.ker_specs[*ker_spec], FusedKerSpec::AddMatMul { .. }) {
55                    let scratch = &mut *(self.blob.as_ptr().add(*loc) as *mut AddMatMulTemp);
56                    scratch.panel_a_id = usize::MAX;
57                    scratch.panel_b_id = usize::MAX;
58                };
59            }
60        }
61        self.generation = scratch.generation;
62    }
63}
64
65pub trait ScratchSpace: Downcast + Send {}
66impl_downcast!(ScratchSpace);
67
68#[derive(Debug, Default)]
69pub struct ScratchSpaceImpl<TI: LADatum> {
70    generation: usize,
71    blob_size: usize,
72    blob_align: usize,
73    ker_specs: Vec<FusedKerSpec<TI>>,
74    loc_dependant: TVec<LocDependant>,
75    valid_down_tiles: usize,
76    remnant_down: usize,
77    valid_right_tiles: usize,
78    remnant_right: usize,
79}
80
81#[derive(Debug, new)]
82struct LocDependant {
83    spec: usize,
84    ker_spec: usize,
85    // offset for the location dependant structure
86    loc: usize,
87    // offset of its associated dynamic-size buffers
88    buffer_a: Option<usize>,
89    buffer_b: Option<usize>,
90}
91
92impl<TI: LADatum> ScratchSpace for ScratchSpaceImpl<TI> {}
93unsafe impl<TI: LADatum> Send for ScratchSpaceImpl<TI> {}
94
95#[derive(Debug)]
96struct AddMatMulTemp {
97    ptr_a: *const u8,
98    panel_a_id: usize,
99    ptr_b: *const u8,
100    panel_b_id: usize,
101}
102
103impl<TI: LADatum> ScratchSpaceImpl<TI> {
104    pub unsafe fn prepare(
105        &mut self,
106        ker: &impl MatMatMulKer<Acc = TI>,
107        m: usize,
108        n: usize,
109        specs: &[FusedSpec],
110    ) -> TractResult<()> {
111        use FusedKerSpec as FKS;
112        use FusedSpec as FS;
113        self.ker_specs.clear();
114        self.loc_dependant.clear();
115        self.ker_specs.reserve(specs.len() + 2);
116        self.ker_specs.push(FusedKerSpec::Clear);
117        self.valid_down_tiles = m / ker.mr();
118        self.remnant_down = m % ker.mr();
119        self.valid_right_tiles = n / ker.nr();
120        self.remnant_right = n % ker.nr();
121        let mut offset = 0;
122        let mut align = std::mem::size_of::<*const ()>();
123        fn ld(spec: usize, uspec: usize, loc: usize) -> LocDependant {
124            LocDependant { spec, ker_spec: uspec, loc, buffer_a: None, buffer_b: None }
125        }
126        for (ix, spec) in specs.iter().enumerate() {
127            offset = offset.next_multiple_of(&align);
128            let ker_spec = match spec {
129                FS::BinScalar(t, op) => match op {
130                    BinOp::Min => FKS::ScalarMin(*t.try_as_plain()?.to_scalar()?),
131                    BinOp::Max => FKS::ScalarMax(*t.try_as_plain()?.to_scalar()?),
132                    BinOp::Mul => FKS::ScalarMul(*t.try_as_plain()?.to_scalar()?),
133                    BinOp::Add => FKS::ScalarAdd(*t.try_as_plain()?.to_scalar()?),
134                    BinOp::Sub => FKS::ScalarSub(*t.try_as_plain()?.to_scalar()?),
135                    BinOp::SubF => FKS::ScalarSubF(*t.try_as_plain()?.to_scalar()?),
136                },
137                FS::ShiftLeft(s) => FKS::ShiftLeft(*s),
138                FS::RoundingShiftRight(s, rp) => FKS::RoundingShiftRight(*s, *rp),
139                FS::QScale(s, rp, m) => FKS::QScale(*s, *rp, *m),
140                FS::BinPerRow(_, _) => {
141                    self.loc_dependant.push(ld(ix, self.ker_specs.len(), offset));
142                    offset += TI::datum_type().size_of() * ker.mr();
143                    FusedKerSpec::Done
144                }
145                FS::BinPerCol(_, _) => {
146                    self.loc_dependant.push(ld(ix, self.ker_specs.len(), offset));
147                    offset += TI::datum_type().size_of() * ker.nr();
148                    FusedKerSpec::Done
149                }
150                FS::AddRowColProducts(_, _) => {
151                    self.loc_dependant.push(ld(ix, self.ker_specs.len(), offset));
152                    offset += TI::datum_type().size_of() * (ker.mr() + ker.nr());
153                    FusedKerSpec::Done
154                }
155                FS::AddUnicast(_) => {
156                    self.loc_dependant.push(ld(ix, self.ker_specs.len(), offset));
157                    offset += TI::datum_type().size_of() * ker.mr() * ker.nr();
158                    FusedKerSpec::Done
159                }
160                FS::Store(store) => {
161                    self.loc_dependant.push(ld(ix, self.ker_specs.len(), offset));
162                    offset += store.item_size * ker.mr() * ker.nr();
163                    FusedKerSpec::Done
164                }
165                FS::LeakyRelu(t) => FKS::LeakyRelu(*t.try_as_plain()?.to_scalar()?),
166                FS::AddMatMul { a, b, packing } => {
167                    let mut ld = ld(ix, self.ker_specs.len(), offset);
168                    offset += std::mem::size_of::<AddMatMulTemp>();
169                    if let Some(tmp) = a.scratch_panel_buffer_layout() {
170                        align = tmp.align().lcm(&align);
171                        offset = Integer::next_multiple_of(&offset, &tmp.align());
172                        ld.buffer_a = Some(offset);
173                        offset += tmp.size();
174                    }
175                    if let Some(tmp) = b.scratch_panel_buffer_layout() {
176                        align = tmp.align().lcm(&align);
177                        offset = Integer::next_multiple_of(&offset, &tmp.align());
178                        ld.buffer_b = Some(offset);
179                        offset += tmp.size();
180                    }
181                    self.loc_dependant.push(ld);
182                    FusedKerSpec::AddMatMul {
183                        k: 0,
184                        pa: std::ptr::null(),
185                        pb: std::ptr::null(),
186                        packing: *packing,
187                    }
188                }
189            };
190            self.ker_specs.push(ker_spec);
191        }
192        self.ker_specs.push(FKS::Done);
193        self.blob_size = offset;
194        self.blob_align = align;
195
196        self.generation = GENERATION.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
197        Ok(())
198    }
199
200    pub unsafe fn run(
201        &self,
202        ker: &impl MatMatMulKer<Acc = TI>,
203        specs: &[FusedSpec],
204        down: usize,
205        right: usize,
206    ) -> TractResult<()> {
207        // Per-tile entry: enter the TLS scope (does sync once) then run a single
208        // tile. Single-threaded callers should prefer `run_in_tls_scope`+
209        // `run_one_tile` to amortise the TLS borrow + sync over many tiles.
210        unsafe {
211            self.run_in_tls_scope(|this, tls| this.run_one_tile(ker, specs, tls, down, right))
212        }
213    }
214
215    /// Borrow the per-thread scratch blob for a single MMM call and `sync` it
216    /// once. The closure is invoked once with a mutable reference to the TLS
217    /// scratch and to `self`. Used by single-threaded matmul drivers to avoid
218    /// re-entering TLS / re-running `sync` per tile.
219    pub(crate) unsafe fn run_in_tls_scope<F, R>(&self, f: F) -> R
220    where
221        F: FnOnce(&Self, &mut TLSScratch) -> R,
222    {
223        TLS.with_borrow_mut(|tls| {
224            tls.sync(self);
225            f(self, tls)
226        })
227    }
228
229    /// Run a single tile against an already-borrowed TLS scratch. Caller is
230    /// responsible for entering `run_in_tls_scope` first (so `sync` has run).
231    #[inline(always)]
232    pub(crate) unsafe fn run_one_tile(
233        &self,
234        ker: &impl MatMatMulKer<Acc = TI>,
235        specs: &[FusedSpec],
236        tls: &mut TLSScratch,
237        down: usize,
238        right: usize,
239    ) -> TractResult<()> {
240        unsafe {
241            if down < self.valid_down_tiles && right < self.valid_right_tiles {
242                self.for_valid_tile(ker, specs, tls, down, right)?;
243                let err = ker.kernel(tls.ker_specs());
244                debug_assert_eq!(err, 0, "Kernel return error {err}");
245            } else {
246                let remnant_down =
247                    if down < self.valid_down_tiles { ker.mr() } else { self.remnant_down };
248                let remnant_right =
249                    if right < self.valid_right_tiles { ker.nr() } else { self.remnant_right };
250                self.for_border_tile(ker, specs, tls, down, right, remnant_down, remnant_right)?;
251                let err = ker.kernel(tls.ker_specs());
252                debug_assert_eq!(err, 0, "Kernel return error {err}");
253                self.postprocess_tile(specs, tls, down, right, remnant_down, remnant_right)?;
254            }
255            Ok(())
256        }
257    }
258
259    #[inline(always)]
260    unsafe fn for_valid_tile(
261        &self,
262        ker: &impl MatMatMulKer<Acc = TI>,
263        specs: &[FusedSpec],
264        tls: &mut TLSScratch,
265        down: usize,
266        right: usize,
267    ) -> TractResult<()> {
268        unsafe {
269            use FusedKerSpec as FKS;
270            use FusedSpec as FS;
271            let ScratchSpaceImpl { ker_specs, loc_dependant, .. } = self;
272            debug_assert!(specs.len() + 2 == ker_specs.len());
273            for LocDependant { spec, ker_spec, loc, buffer_a, buffer_b } in loc_dependant {
274                let spec = specs.get_unchecked(*spec);
275                let it = match spec {
276                    FS::BinPerRow(v, op) => {
277                        let v = v.as_ptr_unchecked::<TI>().add(down * ker.mr());
278                        match op {
279                            BinOp::Min => FKS::PerRowMin(v),
280                            BinOp::Max => FKS::PerRowMax(v),
281                            BinOp::Add => FKS::PerRowAdd(v),
282                            BinOp::Mul => FKS::PerRowMul(v),
283                            BinOp::Sub => FKS::PerRowSub(v),
284                            BinOp::SubF => FKS::PerRowSubF(v),
285                        }
286                    }
287                    FS::BinPerCol(v, op) => {
288                        let v = v.as_ptr_unchecked::<TI>().add(right * ker.nr());
289                        match op {
290                            BinOp::Min => FKS::PerColMin(v),
291                            BinOp::Max => FKS::PerColMax(v),
292                            BinOp::Add => FKS::PerColAdd(v),
293                            BinOp::Mul => FKS::PerColMul(v),
294                            BinOp::Sub => FKS::PerColSub(v),
295                            BinOp::SubF => FKS::PerColSubF(v),
296                        }
297                    }
298                    FS::AddRowColProducts(rows, cols) => {
299                        let row_ptr = rows.as_ptr_unchecked::<TI>().add(down * ker.mr());
300                        let col_ptr = cols.as_ptr_unchecked::<TI>().add(right * ker.nr());
301                        FKS::AddRowColProducts(row_ptr, col_ptr)
302                    }
303                    FS::AddUnicast(store) => FKS::AddUnicast(store.tile_c(down, right)),
304                    FS::Store(c_store) => FKS::Store(c_store.tile_c(down, right)),
305                    FS::AddMatMul { a, b, packing } => {
306                        let scratch = (tls.blob.as_mut_ptr().add(*loc) as *mut AddMatMulTemp)
307                            .as_mut()
308                            .unwrap();
309                        if scratch.panel_a_id != down {
310                            scratch.ptr_a = a.panel_bytes(
311                                down,
312                                buffer_a.map(|o| tls.blob.as_mut_ptr().add(o)),
313                            )?;
314                            scratch.panel_a_id = down;
315                        }
316                        if scratch.panel_b_id != right {
317                            scratch.ptr_b = b.panel_bytes(
318                                right,
319                                buffer_b.map(|o| tls.blob.as_mut_ptr().add(o)),
320                            )?;
321                            scratch.panel_b_id = right;
322                        }
323                        FKS::AddMatMul {
324                            k: b.k(),
325                            pa: scratch.ptr_a,
326                            pb: scratch.ptr_b,
327                            packing: *packing,
328                        }
329                    }
330                    _ => std::hint::unreachable_unchecked(),
331                };
332                *tls.ker_specs().get_unchecked_mut(*ker_spec) = it;
333            }
334            Ok(())
335        }
336    }
337
338    #[inline(never)]
339    #[allow(clippy::too_many_arguments)]
340    unsafe fn for_border_tile(
341        &self,
342        ker: &impl MatMatMulKer<Acc = TI>,
343        specs: &[FusedSpec],
344        tls: &mut TLSScratch,
345        down: usize,
346        right: usize,
347        m_remnant: usize,
348        n_remnant: usize,
349    ) -> TractResult<()> {
350        unsafe {
351            use FusedKerSpec as FKS;
352            use FusedSpec as FS;
353            for LocDependant { spec, ker_spec: uspec, loc, buffer_a, buffer_b } in
354                &self.loc_dependant
355            {
356                let loc = tls.blob.as_mut_ptr().add(*loc);
357                let spec = specs.get_unchecked(*spec);
358                let it = match spec {
359                    FS::BinPerRow(v, op) => {
360                        let buf = std::slice::from_raw_parts_mut(loc as *mut TI, ker.mr());
361                        let ptr = if m_remnant < ker.mr() {
362                            if m_remnant > 0 {
363                                buf.get_unchecked_mut(..m_remnant).copy_from_slice(
364                                    v.as_slice_unchecked()
365                                        .get_unchecked(down * ker.mr()..)
366                                        .get_unchecked(..m_remnant),
367                                );
368                            }
369                            if cfg!(debug_assertions) {
370                                buf.get_unchecked_mut(m_remnant..)
371                                    .iter_mut()
372                                    .for_each(|x| *x = TI::zero());
373                            }
374                            buf.as_ptr()
375                        } else {
376                            v.as_ptr_unchecked::<TI>().add(down * ker.mr())
377                        };
378                        match op {
379                            BinOp::Min => FKS::PerRowMin(ptr),
380                            BinOp::Max => FKS::PerRowMax(ptr),
381                            BinOp::Add => FKS::PerRowAdd(ptr),
382                            BinOp::Mul => FKS::PerRowMul(ptr),
383                            BinOp::Sub => FKS::PerRowSub(ptr),
384                            BinOp::SubF => FKS::PerRowSubF(ptr),
385                        }
386                    }
387                    FS::BinPerCol(v, op) => {
388                        let buf = std::slice::from_raw_parts_mut(loc as *mut TI, ker.nr());
389                        let ptr = if n_remnant < ker.nr() {
390                            if n_remnant > 0 {
391                                buf.get_unchecked_mut(..n_remnant).copy_from_slice(
392                                    v.as_slice_unchecked()
393                                        .get_unchecked(right * ker.nr()..)
394                                        .get_unchecked(..n_remnant),
395                                );
396                            }
397                            if cfg!(debug_assertions) {
398                                buf.get_unchecked_mut(n_remnant..)
399                                    .iter_mut()
400                                    .for_each(|x| *x = TI::zero());
401                            }
402                            buf.as_ptr()
403                        } else {
404                            v.as_ptr_unchecked::<TI>().add(right * ker.nr())
405                        };
406                        match op {
407                            BinOp::Min => FKS::PerColMin(ptr),
408                            BinOp::Max => FKS::PerColMax(ptr),
409                            BinOp::Add => FKS::PerColAdd(ptr),
410                            BinOp::Mul => FKS::PerColMul(ptr),
411                            BinOp::Sub => FKS::PerColSub(ptr),
412                            BinOp::SubF => FKS::PerColSubF(ptr),
413                        }
414                    }
415                    FS::AddRowColProducts(rows, cols) => {
416                        let r = std::slice::from_raw_parts_mut(loc as *mut TI, ker.mr());
417                        let row_ptr = if m_remnant < ker.mr() {
418                            r.get_unchecked_mut(..m_remnant).copy_from_slice(
419                                rows.as_slice_unchecked()
420                                    .get_unchecked(down * ker.mr()..)
421                                    .get_unchecked(..m_remnant),
422                            );
423                            if cfg!(debug_assertions) {
424                                r.get_unchecked_mut(m_remnant..)
425                                    .iter_mut()
426                                    .for_each(|x| *x = TI::zero());
427                            }
428                            r.as_ptr()
429                        } else {
430                            rows.as_ptr_unchecked::<TI>().add(down * ker.mr())
431                        };
432                        let c = std::slice::from_raw_parts_mut(
433                            (loc as *mut TI).add(ker.mr()),
434                            ker.nr(),
435                        );
436                        let col_ptr = if n_remnant < ker.nr() {
437                            c.get_unchecked_mut(..n_remnant).copy_from_slice(
438                                cols.as_slice_unchecked()
439                                    .get_unchecked(right * ker.nr()..)
440                                    .get_unchecked(..n_remnant),
441                            );
442                            if cfg!(debug_assertions) {
443                                r.get_unchecked_mut(n_remnant..)
444                                    .iter_mut()
445                                    .for_each(|x| *x = TI::zero());
446                            }
447                            c.as_ptr()
448                        } else {
449                            cols.as_ptr_unchecked::<TI>().add(right * ker.nr())
450                        };
451                        FKS::AddRowColProducts(row_ptr, col_ptr)
452                    }
453                    FS::AddUnicast(store) => {
454                        let row_byte_stride = store.row_byte_stride;
455                        let col_byte_stride = store.col_byte_stride;
456                        let tile_offset = row_byte_stride * down as isize * ker.mr() as isize
457                            + col_byte_stride * right as isize * ker.nr() as isize;
458                        let tile_ptr = store.ptr.offset(tile_offset);
459                        let tmp_d_tile =
460                            std::slice::from_raw_parts_mut(loc as *mut TI, ker.mr() * ker.nr());
461                        if cfg!(debug_assertions) {
462                            tmp_d_tile.iter_mut().for_each(|t| *t = TI::zero());
463                        }
464                        for r in 0..m_remnant as isize {
465                            for c in 0..n_remnant as isize {
466                                let inner_offset = c * col_byte_stride + r * row_byte_stride;
467                                if inner_offset + tile_offset
468                                    < (store.item_size * store.item_count) as isize
469                                {
470                                    *tmp_d_tile
471                                        .get_unchecked_mut(r as usize + c as usize * ker.mr()) =
472                                        *(tile_ptr.offset(inner_offset) as *const TI);
473                                }
474                            }
475                        }
476                        FKS::AddUnicast(OutputStoreKer {
477                            ptr: tmp_d_tile.as_ptr() as _,
478                            row_byte_stride: std::mem::size_of::<TI>() as isize,
479                            col_byte_stride: (std::mem::size_of::<TI>() * ker.mr()) as isize,
480                            item_size: std::mem::size_of::<TI>(),
481                        })
482                    }
483                    FS::Store(c_store) => {
484                        let tmpc = OutputStoreKer {
485                            ptr: loc as _,
486                            item_size: c_store.item_size,
487                            row_byte_stride: c_store.item_size as isize,
488                            col_byte_stride: (c_store.item_size * ker.mr()) as isize,
489                        };
490                        FKS::Store(tmpc)
491                    }
492                    FS::AddMatMul { a, b, packing } => {
493                        let scratch = (loc as *mut AddMatMulTemp).as_mut().unwrap();
494                        if scratch.panel_a_id != down {
495                            scratch.ptr_a = a.panel_bytes(
496                                down,
497                                buffer_a.map(|o| tls.blob.as_mut_ptr().add(o)),
498                            )?;
499                            scratch.panel_a_id = down;
500                        }
501                        if scratch.panel_b_id != right {
502                            scratch.ptr_b = b.panel_bytes(
503                                right,
504                                buffer_b.map(|o| tls.blob.as_mut_ptr().add(o)),
505                            )?;
506                            scratch.panel_b_id = right;
507                        }
508                        FKS::AddMatMul {
509                            k: b.k(),
510                            pa: scratch.ptr_a,
511                            pb: scratch.ptr_b,
512                            packing: *packing,
513                        }
514                    }
515                    _ => std::hint::unreachable_unchecked(),
516                };
517                *tls.ker_specs().get_unchecked_mut(*uspec) = it;
518            }
519            Ok(())
520        }
521    }
522
523    #[inline]
524    pub fn uspecs(&self) -> &[FusedKerSpec<TI>] {
525        &self.ker_specs
526    }
527
528    unsafe fn postprocess_tile(
529        &self,
530        specs: &[FusedSpec],
531        tls: &mut TLSScratch,
532        down: usize,
533        right: usize,
534        m_remnant: usize,
535        n_remnant: usize,
536    ) -> TractResult<()>
537    where
538        TI: LADatum,
539    {
540        unsafe {
541            for LocDependant { spec, ker_spec: uspec, .. } in self.loc_dependant.iter() {
542                let spec = specs.get_unchecked(*spec);
543                let ker_spec = tls.ker_specs::<TI>().get_unchecked(*uspec);
544                if let (FusedSpec::Store(c_store), FusedKerSpec::Store(tmp)) = (spec, ker_spec) {
545                    c_store.set_from_tile(down, right, m_remnant, n_remnant, tmp)
546                }
547            }
548            Ok(())
549        }
550    }
551}