1use std::fmt::Debug;
2use tract_data::internal::*;
3
4#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
5pub enum OutputStoreSpec {
6 View { m_axis: usize, n_axis: usize, mr: usize, nr: usize },
7 Strides { row_byte_stride: isize, col_byte_stride: isize, mr: usize, nr: usize },
8}
9
10#[derive(Clone, Copy, Debug)]
11pub struct OutputStore {
12 pub(crate) ptr: *mut u8,
13 pub(crate) row_byte_stride: isize,
14 pub(crate) col_byte_stride: isize,
15 pub(crate) panel_row_byte_stride: isize,
16 pub(crate) panel_col_byte_stride: isize,
17 pub(crate) item_size: usize,
18 pub(crate) item_count: usize,
19 pub(crate) mr: usize,
20}
21
22unsafe impl Send for OutputStore {}
23unsafe impl Sync for OutputStore {}
24
25impl OutputStoreSpec {
26 #[inline]
27 pub unsafe fn wrap(&self, tensor: &TensorView) -> OutputStore {
28 let (mr, nr, row_byte_stride, col_byte_stride) = self.compute_strides(tensor);
29 OutputStore {
30 ptr: tensor.as_ptr_unchecked::<u8>() as _,
31 row_byte_stride,
32 col_byte_stride,
33 panel_row_byte_stride: row_byte_stride * mr as isize,
34 panel_col_byte_stride: col_byte_stride * nr as isize,
35 item_size: tensor.datum_type().size_of(),
36 mr,
37 item_count: tensor.len(),
38 }
39 }
40
41 #[inline]
42 unsafe fn compute_strides(&self, tensor: &TensorView) -> (usize, usize, isize, isize) {
43 let size_of = tensor.datum_type().size_of() as isize;
44 match self {
45 OutputStoreSpec::View { m_axis, n_axis, mr, nr, .. } => {
46 let tensor_strides = tensor.strides();
47 let row_item_stride = *tensor_strides.get_unchecked(*m_axis);
48 let col_item_stride = *tensor_strides.get_unchecked(*n_axis);
49 let row_byte_stride = row_item_stride * size_of;
50 let col_byte_stride = col_item_stride * size_of;
51 (*mr, *nr, row_byte_stride, col_byte_stride)
52 }
53 OutputStoreSpec::Strides { row_byte_stride, col_byte_stride, mr, nr, .. } => {
54 (*mr, *nr, *row_byte_stride, *col_byte_stride)
55 }
56 }
57 }
58}
59
60impl OutputStore {
61 #[inline]
62 pub(super) unsafe fn tile_c(&self, down: usize, right: usize) -> OutputStoreKer {
63 let (down, right) = (down as isize, right as isize);
64 OutputStoreKer {
65 ptr: self
66 .ptr
67 .offset(self.panel_row_byte_stride * down + self.panel_col_byte_stride * right)
68 as *mut _,
69 row_byte_stride: self.row_byte_stride,
70 col_byte_stride: self.col_byte_stride,
71 item_size: self.item_size,
72 }
73 }
74
75 #[inline]
76 pub fn item_size(&self) -> usize {
77 self.item_size
78 }
79
80 #[inline]
81 pub(super) unsafe fn set_from_tile(
82 &self,
83 down: usize,
84 right: usize,
85 height: usize,
86 width: usize,
87 tile: &OutputStoreKer,
88 ) {
89 if self.item_size() == 1 {
90 self.set_from_tile_t::<i8>(down, right, height, width, tile)
91 } else if self.item_size() == 2 {
92 self.set_from_tile_t::<i16>(down, right, height, width, tile)
93 } else if self.item_size() == 4 {
94 self.set_from_tile_t::<i32>(down, right, height, width, tile)
95 } else {
96 self.set_from_tile_t::<i64>(down, right, height, width, tile)
97 }
98 }
99
100 #[inline]
101 unsafe fn set_from_tile_t<T: Datum + Copy>(
102 &self,
103 down: usize,
104 right: usize,
105 height: usize,
106 width: usize,
107 tile: &OutputStoreKer,
108 ) {
109 let tile = tile.ptr as *mut T;
110 let dst = self.ptr.add(
111 self.panel_row_byte_stride as usize * down
112 + self.panel_col_byte_stride as usize * right,
113 );
114 for y in 0..height as isize {
115 for x in 0..width as isize {
116 let value = tile.offset(y + x * self.mr as isize);
117 let dst = dst.offset(y * self.row_byte_stride + x * self.col_byte_stride);
118 *(dst as *mut T) = *value;
119 }
120 }
121 }
122}
123
124#[repr(C)]
125#[derive(PartialEq, Eq, Copy, Clone, Debug)]
126pub struct OutputStoreKer {
127 pub ptr: *mut u8,
128 pub row_byte_stride: isize,
129 pub col_byte_stride: isize,
130 pub item_size: usize,
131}