1use super::{FusedKerSpec, FusedSpec, MatMatMulKer, OutputStoreKer};
2use crate::{BinOp, LADatum};
3use downcast_rs::{Downcast, impl_downcast};
4use std::cell::RefCell;
5use std::fmt::Debug;
6use std::sync::atomic::AtomicUsize;
7use tract_data::internal::num_integer::Integer;
8use tract_data::internal::*;
9
10static GENERATION: AtomicUsize = AtomicUsize::new(1);
11
12thread_local! {
13 static TLS: RefCell<TLSScratch> = Default::default();
14}
15
16#[derive(Default, Debug)]
17pub(crate) struct TLSScratch {
18 generation: usize,
19 blob: Blob,
20 ker_specs_16: Vec<FusedKerSpec<f16>>,
21 ker_specs_32: Vec<FusedKerSpec<f32>>,
22 ker_specs_64: Vec<FusedKerSpec<f64>>,
23}
24
25impl TLSScratch {
26 #[allow(unknown_lints, clippy::missing_transmute_annotations)]
27 fn ker_specs<TI: LADatum>(&mut self) -> &mut Vec<FusedKerSpec<TI>> {
28 unsafe {
29 if TI::datum_type() == f32::datum_type() || TI::datum_type() == i32::datum_type() {
30 std::mem::transmute(&mut self.ker_specs_32)
31 } else if TI::datum_type() == f16::datum_type() {
32 std::mem::transmute(&mut self.ker_specs_16)
33 } else if TI::datum_type() == f64::datum_type() {
34 std::mem::transmute(&mut self.ker_specs_64)
35 } else {
36 todo!();
37 }
38 }
39 }
40
41 fn sync<TI: LADatum>(&mut self, scratch: &ScratchSpaceImpl<TI>) {
42 if self.generation == scratch.generation {
43 return;
44 }
45 let ker_specs = self.ker_specs::<TI>();
46 ker_specs.clear();
47 ker_specs.extend_from_slice(&scratch.ker_specs);
48
49 unsafe {
50 self.blob.ensure_size_and_align(scratch.blob_size, scratch.blob_align);
51
52 for LocDependant { loc, ker_spec, .. } in &scratch.loc_dependant {
53 #[allow(clippy::single_match)]
54 if matches!(scratch.ker_specs[*ker_spec], FusedKerSpec::AddMatMul { .. }) {
55 let scratch = &mut *(self.blob.as_ptr().add(*loc) as *mut AddMatMulTemp);
56 scratch.panel_a_id = usize::MAX;
57 scratch.panel_b_id = usize::MAX;
58 };
59 }
60 }
61 self.generation = scratch.generation;
62 }
63}
64
65pub trait ScratchSpace: Downcast + Send {}
66impl_downcast!(ScratchSpace);
67
68#[derive(Debug, Default)]
69pub struct ScratchSpaceImpl<TI: LADatum> {
70 generation: usize,
71 blob_size: usize,
72 blob_align: usize,
73 ker_specs: Vec<FusedKerSpec<TI>>,
74 loc_dependant: TVec<LocDependant>,
75 valid_down_tiles: usize,
76 remnant_down: usize,
77 valid_right_tiles: usize,
78 remnant_right: usize,
79}
80
81#[derive(Debug, new)]
82struct LocDependant {
83 spec: usize,
84 ker_spec: usize,
85 loc: usize,
87 buffer_a: Option<usize>,
89 buffer_b: Option<usize>,
90}
91
92impl<TI: LADatum> ScratchSpace for ScratchSpaceImpl<TI> {}
93unsafe impl<TI: LADatum> Send for ScratchSpaceImpl<TI> {}
94
95#[derive(Debug)]
96struct AddMatMulTemp {
97 ptr_a: *const u8,
98 panel_a_id: usize,
99 ptr_b: *const u8,
100 panel_b_id: usize,
101}
102
103impl<TI: LADatum> ScratchSpaceImpl<TI> {
104 pub unsafe fn prepare(
105 &mut self,
106 ker: &impl MatMatMulKer<Acc = TI>,
107 m: usize,
108 n: usize,
109 specs: &[FusedSpec],
110 ) -> TractResult<()> {
111 use FusedKerSpec as FKS;
112 use FusedSpec as FS;
113 self.ker_specs.clear();
114 self.loc_dependant.clear();
115 self.ker_specs.reserve(specs.len() + 2);
116 self.ker_specs.push(FusedKerSpec::Clear);
117 self.valid_down_tiles = m / ker.mr();
118 self.remnant_down = m % ker.mr();
119 self.valid_right_tiles = n / ker.nr();
120 self.remnant_right = n % ker.nr();
121 let mut offset = 0;
122 let mut align = std::mem::size_of::<*const ()>();
123 fn ld(spec: usize, uspec: usize, loc: usize) -> LocDependant {
124 LocDependant { spec, ker_spec: uspec, loc, buffer_a: None, buffer_b: None }
125 }
126 for (ix, spec) in specs.iter().enumerate() {
127 offset = offset.next_multiple_of(&align);
128 let ker_spec = match spec {
129 FS::BinScalar(t, op) => match op {
130 BinOp::Min => FKS::ScalarMin(*t.try_as_plain()?.to_scalar()?),
131 BinOp::Max => FKS::ScalarMax(*t.try_as_plain()?.to_scalar()?),
132 BinOp::Mul => FKS::ScalarMul(*t.try_as_plain()?.to_scalar()?),
133 BinOp::Add => FKS::ScalarAdd(*t.try_as_plain()?.to_scalar()?),
134 BinOp::Sub => FKS::ScalarSub(*t.try_as_plain()?.to_scalar()?),
135 BinOp::SubF => FKS::ScalarSubF(*t.try_as_plain()?.to_scalar()?),
136 },
137 FS::ShiftLeft(s) => FKS::ShiftLeft(*s),
138 FS::RoundingShiftRight(s, rp) => FKS::RoundingShiftRight(*s, *rp),
139 FS::QScale(s, rp, m) => FKS::QScale(*s, *rp, *m),
140 FS::BinPerRow(_, _) => {
141 self.loc_dependant.push(ld(ix, self.ker_specs.len(), offset));
142 offset += TI::datum_type().size_of() * ker.mr();
143 FusedKerSpec::Done
144 }
145 FS::BinPerCol(_, _) => {
146 self.loc_dependant.push(ld(ix, self.ker_specs.len(), offset));
147 offset += TI::datum_type().size_of() * ker.nr();
148 FusedKerSpec::Done
149 }
150 FS::AddRowColProducts(_, _) => {
151 self.loc_dependant.push(ld(ix, self.ker_specs.len(), offset));
152 offset += TI::datum_type().size_of() * (ker.mr() + ker.nr());
153 FusedKerSpec::Done
154 }
155 FS::AddUnicast(_) => {
156 self.loc_dependant.push(ld(ix, self.ker_specs.len(), offset));
157 offset += TI::datum_type().size_of() * ker.mr() * ker.nr();
158 FusedKerSpec::Done
159 }
160 FS::Store(store) => {
161 self.loc_dependant.push(ld(ix, self.ker_specs.len(), offset));
162 offset += store.item_size * ker.mr() * ker.nr();
163 FusedKerSpec::Done
164 }
165 FS::LeakyRelu(t) => FKS::LeakyRelu(*t.try_as_plain()?.to_scalar()?),
166 FS::AddMatMul { a, b, packing } => {
167 let mut ld = ld(ix, self.ker_specs.len(), offset);
168 offset += std::mem::size_of::<AddMatMulTemp>();
169 if let Some(tmp) = a.scratch_panel_buffer_layout() {
170 align = tmp.align().lcm(&align);
171 offset = Integer::next_multiple_of(&offset, &tmp.align());
172 ld.buffer_a = Some(offset);
173 offset += tmp.size();
174 }
175 if let Some(tmp) = b.scratch_panel_buffer_layout() {
176 align = tmp.align().lcm(&align);
177 offset = Integer::next_multiple_of(&offset, &tmp.align());
178 ld.buffer_b = Some(offset);
179 offset += tmp.size();
180 }
181 self.loc_dependant.push(ld);
182 FusedKerSpec::AddMatMul {
183 k: 0,
184 pa: std::ptr::null(),
185 pb: std::ptr::null(),
186 packing: *packing,
187 }
188 }
189 };
190 self.ker_specs.push(ker_spec);
191 }
192 self.ker_specs.push(FKS::Done);
193 self.blob_size = offset;
194 self.blob_align = align;
195
196 self.generation = GENERATION.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
197 Ok(())
198 }
199
200 pub unsafe fn run(
201 &self,
202 ker: &impl MatMatMulKer<Acc = TI>,
203 specs: &[FusedSpec],
204 down: usize,
205 right: usize,
206 ) -> TractResult<()> {
207 unsafe {
211 self.run_in_tls_scope(|this, tls| this.run_one_tile(ker, specs, tls, down, right))
212 }
213 }
214
215 pub(crate) unsafe fn run_in_tls_scope<F, R>(&self, f: F) -> R
220 where
221 F: FnOnce(&Self, &mut TLSScratch) -> R,
222 {
223 TLS.with_borrow_mut(|tls| {
224 tls.sync(self);
225 f(self, tls)
226 })
227 }
228
229 #[inline(always)]
232 pub(crate) unsafe fn run_one_tile(
233 &self,
234 ker: &impl MatMatMulKer<Acc = TI>,
235 specs: &[FusedSpec],
236 tls: &mut TLSScratch,
237 down: usize,
238 right: usize,
239 ) -> TractResult<()> {
240 unsafe {
241 if down < self.valid_down_tiles && right < self.valid_right_tiles {
242 self.for_valid_tile(ker, specs, tls, down, right)?;
243 let err = ker.kernel(tls.ker_specs());
244 debug_assert_eq!(err, 0, "Kernel return error {err}");
245 } else {
246 let remnant_down =
247 if down < self.valid_down_tiles { ker.mr() } else { self.remnant_down };
248 let remnant_right =
249 if right < self.valid_right_tiles { ker.nr() } else { self.remnant_right };
250 self.for_border_tile(ker, specs, tls, down, right, remnant_down, remnant_right)?;
251 let err = ker.kernel(tls.ker_specs());
252 debug_assert_eq!(err, 0, "Kernel return error {err}");
253 self.postprocess_tile(specs, tls, down, right, remnant_down, remnant_right)?;
254 }
255 Ok(())
256 }
257 }
258
259 #[inline(always)]
260 unsafe fn for_valid_tile(
261 &self,
262 ker: &impl MatMatMulKer<Acc = TI>,
263 specs: &[FusedSpec],
264 tls: &mut TLSScratch,
265 down: usize,
266 right: usize,
267 ) -> TractResult<()> {
268 unsafe {
269 use FusedKerSpec as FKS;
270 use FusedSpec as FS;
271 let ScratchSpaceImpl { ker_specs, loc_dependant, .. } = self;
272 debug_assert!(specs.len() + 2 == ker_specs.len());
273 for LocDependant { spec, ker_spec, loc, buffer_a, buffer_b } in loc_dependant {
274 let spec = specs.get_unchecked(*spec);
275 let it = match spec {
276 FS::BinPerRow(v, op) => {
277 let v = v.as_ptr_unchecked::<TI>().add(down * ker.mr());
278 match op {
279 BinOp::Min => FKS::PerRowMin(v),
280 BinOp::Max => FKS::PerRowMax(v),
281 BinOp::Add => FKS::PerRowAdd(v),
282 BinOp::Mul => FKS::PerRowMul(v),
283 BinOp::Sub => FKS::PerRowSub(v),
284 BinOp::SubF => FKS::PerRowSubF(v),
285 }
286 }
287 FS::BinPerCol(v, op) => {
288 let v = v.as_ptr_unchecked::<TI>().add(right * ker.nr());
289 match op {
290 BinOp::Min => FKS::PerColMin(v),
291 BinOp::Max => FKS::PerColMax(v),
292 BinOp::Add => FKS::PerColAdd(v),
293 BinOp::Mul => FKS::PerColMul(v),
294 BinOp::Sub => FKS::PerColSub(v),
295 BinOp::SubF => FKS::PerColSubF(v),
296 }
297 }
298 FS::AddRowColProducts(rows, cols) => {
299 let row_ptr = rows.as_ptr_unchecked::<TI>().add(down * ker.mr());
300 let col_ptr = cols.as_ptr_unchecked::<TI>().add(right * ker.nr());
301 FKS::AddRowColProducts(row_ptr, col_ptr)
302 }
303 FS::AddUnicast(store) => FKS::AddUnicast(store.tile_c(down, right)),
304 FS::Store(c_store) => FKS::Store(c_store.tile_c(down, right)),
305 FS::AddMatMul { a, b, packing } => {
306 let scratch = (tls.blob.as_mut_ptr().add(*loc) as *mut AddMatMulTemp)
307 .as_mut()
308 .unwrap();
309 if scratch.panel_a_id != down {
310 scratch.ptr_a = a.panel_bytes(
311 down,
312 buffer_a.map(|o| tls.blob.as_mut_ptr().add(o)),
313 )?;
314 scratch.panel_a_id = down;
315 }
316 if scratch.panel_b_id != right {
317 scratch.ptr_b = b.panel_bytes(
318 right,
319 buffer_b.map(|o| tls.blob.as_mut_ptr().add(o)),
320 )?;
321 scratch.panel_b_id = right;
322 }
323 FKS::AddMatMul {
324 k: b.k(),
325 pa: scratch.ptr_a,
326 pb: scratch.ptr_b,
327 packing: *packing,
328 }
329 }
330 _ => std::hint::unreachable_unchecked(),
331 };
332 *tls.ker_specs().get_unchecked_mut(*ker_spec) = it;
333 }
334 Ok(())
335 }
336 }
337
338 #[inline(never)]
339 #[allow(clippy::too_many_arguments)]
340 unsafe fn for_border_tile(
341 &self,
342 ker: &impl MatMatMulKer<Acc = TI>,
343 specs: &[FusedSpec],
344 tls: &mut TLSScratch,
345 down: usize,
346 right: usize,
347 m_remnant: usize,
348 n_remnant: usize,
349 ) -> TractResult<()> {
350 unsafe {
351 use FusedKerSpec as FKS;
352 use FusedSpec as FS;
353 for LocDependant { spec, ker_spec: uspec, loc, buffer_a, buffer_b } in
354 &self.loc_dependant
355 {
356 let loc = tls.blob.as_mut_ptr().add(*loc);
357 let spec = specs.get_unchecked(*spec);
358 let it = match spec {
359 FS::BinPerRow(v, op) => {
360 let buf = std::slice::from_raw_parts_mut(loc as *mut TI, ker.mr());
361 let ptr = if m_remnant < ker.mr() {
362 if m_remnant > 0 {
363 buf.get_unchecked_mut(..m_remnant).copy_from_slice(
364 v.as_slice_unchecked()
365 .get_unchecked(down * ker.mr()..)
366 .get_unchecked(..m_remnant),
367 );
368 }
369 if cfg!(debug_assertions) {
370 buf.get_unchecked_mut(m_remnant..)
371 .iter_mut()
372 .for_each(|x| *x = TI::zero());
373 }
374 buf.as_ptr()
375 } else {
376 v.as_ptr_unchecked::<TI>().add(down * ker.mr())
377 };
378 match op {
379 BinOp::Min => FKS::PerRowMin(ptr),
380 BinOp::Max => FKS::PerRowMax(ptr),
381 BinOp::Add => FKS::PerRowAdd(ptr),
382 BinOp::Mul => FKS::PerRowMul(ptr),
383 BinOp::Sub => FKS::PerRowSub(ptr),
384 BinOp::SubF => FKS::PerRowSubF(ptr),
385 }
386 }
387 FS::BinPerCol(v, op) => {
388 let buf = std::slice::from_raw_parts_mut(loc as *mut TI, ker.nr());
389 let ptr = if n_remnant < ker.nr() {
390 if n_remnant > 0 {
391 buf.get_unchecked_mut(..n_remnant).copy_from_slice(
392 v.as_slice_unchecked()
393 .get_unchecked(right * ker.nr()..)
394 .get_unchecked(..n_remnant),
395 );
396 }
397 if cfg!(debug_assertions) {
398 buf.get_unchecked_mut(n_remnant..)
399 .iter_mut()
400 .for_each(|x| *x = TI::zero());
401 }
402 buf.as_ptr()
403 } else {
404 v.as_ptr_unchecked::<TI>().add(right * ker.nr())
405 };
406 match op {
407 BinOp::Min => FKS::PerColMin(ptr),
408 BinOp::Max => FKS::PerColMax(ptr),
409 BinOp::Add => FKS::PerColAdd(ptr),
410 BinOp::Mul => FKS::PerColMul(ptr),
411 BinOp::Sub => FKS::PerColSub(ptr),
412 BinOp::SubF => FKS::PerColSubF(ptr),
413 }
414 }
415 FS::AddRowColProducts(rows, cols) => {
416 let r = std::slice::from_raw_parts_mut(loc as *mut TI, ker.mr());
417 let row_ptr = if m_remnant < ker.mr() {
418 r.get_unchecked_mut(..m_remnant).copy_from_slice(
419 rows.as_slice_unchecked()
420 .get_unchecked(down * ker.mr()..)
421 .get_unchecked(..m_remnant),
422 );
423 if cfg!(debug_assertions) {
424 r.get_unchecked_mut(m_remnant..)
425 .iter_mut()
426 .for_each(|x| *x = TI::zero());
427 }
428 r.as_ptr()
429 } else {
430 rows.as_ptr_unchecked::<TI>().add(down * ker.mr())
431 };
432 let c = std::slice::from_raw_parts_mut(
433 (loc as *mut TI).add(ker.mr()),
434 ker.nr(),
435 );
436 let col_ptr = if n_remnant < ker.nr() {
437 c.get_unchecked_mut(..n_remnant).copy_from_slice(
438 cols.as_slice_unchecked()
439 .get_unchecked(right * ker.nr()..)
440 .get_unchecked(..n_remnant),
441 );
442 if cfg!(debug_assertions) {
443 r.get_unchecked_mut(n_remnant..)
444 .iter_mut()
445 .for_each(|x| *x = TI::zero());
446 }
447 c.as_ptr()
448 } else {
449 cols.as_ptr_unchecked::<TI>().add(right * ker.nr())
450 };
451 FKS::AddRowColProducts(row_ptr, col_ptr)
452 }
453 FS::AddUnicast(store) => {
454 let row_byte_stride = store.row_byte_stride;
455 let col_byte_stride = store.col_byte_stride;
456 let tile_offset = row_byte_stride * down as isize * ker.mr() as isize
457 + col_byte_stride * right as isize * ker.nr() as isize;
458 let tile_ptr = store.ptr.offset(tile_offset);
459 let tmp_d_tile =
460 std::slice::from_raw_parts_mut(loc as *mut TI, ker.mr() * ker.nr());
461 if cfg!(debug_assertions) {
462 tmp_d_tile.iter_mut().for_each(|t| *t = TI::zero());
463 }
464 for r in 0..m_remnant as isize {
465 for c in 0..n_remnant as isize {
466 let inner_offset = c * col_byte_stride + r * row_byte_stride;
467 if inner_offset + tile_offset
468 < (store.item_size * store.item_count) as isize
469 {
470 *tmp_d_tile
471 .get_unchecked_mut(r as usize + c as usize * ker.mr()) =
472 *(tile_ptr.offset(inner_offset) as *const TI);
473 }
474 }
475 }
476 FKS::AddUnicast(OutputStoreKer {
477 ptr: tmp_d_tile.as_ptr() as _,
478 row_byte_stride: std::mem::size_of::<TI>() as isize,
479 col_byte_stride: (std::mem::size_of::<TI>() * ker.mr()) as isize,
480 item_size: std::mem::size_of::<TI>(),
481 })
482 }
483 FS::Store(c_store) => {
484 let tmpc = OutputStoreKer {
485 ptr: loc as _,
486 item_size: c_store.item_size,
487 row_byte_stride: c_store.item_size as isize,
488 col_byte_stride: (c_store.item_size * ker.mr()) as isize,
489 };
490 FKS::Store(tmpc)
491 }
492 FS::AddMatMul { a, b, packing } => {
493 let scratch = (loc as *mut AddMatMulTemp).as_mut().unwrap();
494 if scratch.panel_a_id != down {
495 scratch.ptr_a = a.panel_bytes(
496 down,
497 buffer_a.map(|o| tls.blob.as_mut_ptr().add(o)),
498 )?;
499 scratch.panel_a_id = down;
500 }
501 if scratch.panel_b_id != right {
502 scratch.ptr_b = b.panel_bytes(
503 right,
504 buffer_b.map(|o| tls.blob.as_mut_ptr().add(o)),
505 )?;
506 scratch.panel_b_id = right;
507 }
508 FKS::AddMatMul {
509 k: b.k(),
510 pa: scratch.ptr_a,
511 pb: scratch.ptr_b,
512 packing: *packing,
513 }
514 }
515 _ => std::hint::unreachable_unchecked(),
516 };
517 *tls.ker_specs().get_unchecked_mut(*uspec) = it;
518 }
519 Ok(())
520 }
521 }
522
523 #[inline]
524 pub fn uspecs(&self) -> &[FusedKerSpec<TI>] {
525 &self.ker_specs
526 }
527
528 unsafe fn postprocess_tile(
529 &self,
530 specs: &[FusedSpec],
531 tls: &mut TLSScratch,
532 down: usize,
533 right: usize,
534 m_remnant: usize,
535 n_remnant: usize,
536 ) -> TractResult<()>
537 where
538 TI: LADatum,
539 {
540 unsafe {
541 for LocDependant { spec, ker_spec: uspec, .. } in self.loc_dependant.iter() {
542 let spec = specs.get_unchecked(*spec);
543 let ker_spec = tls.ker_specs::<TI>().get_unchecked(*uspec);
544 if let (FusedSpec::Store(c_store), FusedKerSpec::Store(tmp)) = (spec, ker_spec) {
545 c_store.set_from_tile(down, right, m_remnant, n_remnant, tmp)
546 }
547 }
548 Ok(())
549 }
550 }
551}