1use std::fmt::Debug;
2use tract_data::internal::*;
3
4#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
5pub enum OutputStoreSpec {
6 View { m_axis: Option<usize>, n_axis: Option<usize>, mr: usize, nr: usize },
7 Strides { row_byte_stride: isize, col_byte_stride: isize, mr: usize, nr: usize },
8}
9
10#[derive(Clone, Copy, Debug)]
11pub struct OutputStore {
12 pub(crate) ptr: *mut u8,
13 pub(crate) row_byte_stride: isize,
14 pub(crate) col_byte_stride: isize,
15 pub(crate) panel_row_byte_stride: isize,
16 pub(crate) panel_col_byte_stride: isize,
17 pub(crate) item_size: usize,
18 pub(crate) item_count: usize,
19 pub(crate) mr: usize,
20}
21
22unsafe impl Send for OutputStore {}
23unsafe impl Sync for OutputStore {}
24
25impl OutputStoreSpec {
26 #[inline]
27 pub unsafe fn wrap(&self, tensor: &TensorView) -> OutputStore {
28 let (mr, nr, row_byte_stride, col_byte_stride) = unsafe { self.compute_strides(tensor) };
29 OutputStore {
30 ptr: unsafe { tensor.as_ptr_unchecked::<u8>() } as _,
31 row_byte_stride,
32 col_byte_stride,
33 panel_row_byte_stride: row_byte_stride * mr as isize,
34 panel_col_byte_stride: col_byte_stride * nr as isize,
35 item_size: tensor.datum_type().size_of(),
36 mr,
37 item_count: tensor.len(),
38 }
39 }
40
41 #[inline]
42 unsafe fn compute_strides(&self, tensor: &TensorView) -> (usize, usize, isize, isize) {
43 let size_of = tensor.datum_type().size_of() as isize;
44 match self {
45 OutputStoreSpec::View { m_axis, n_axis, mr, nr, .. } => {
46 let tensor_strides = tensor.strides();
47 let row_item_stride =
48 m_axis.map(|ax| *unsafe { tensor_strides.get_unchecked(ax) }).unwrap_or(0);
49 let col_item_stride =
50 n_axis.map(|ax| *unsafe { tensor_strides.get_unchecked(ax) }).unwrap_or(0);
51 let row_byte_stride = row_item_stride * size_of;
52 let col_byte_stride = col_item_stride * size_of;
53 (*mr, *nr, row_byte_stride, col_byte_stride)
54 }
55 OutputStoreSpec::Strides { row_byte_stride, col_byte_stride, mr, nr, .. } => {
56 (*mr, *nr, *row_byte_stride, *col_byte_stride)
57 }
58 }
59 }
60}
61
62impl OutputStore {
63 #[inline]
64 pub(super) unsafe fn tile_c(&self, down: usize, right: usize) -> OutputStoreKer {
65 unsafe {
66 let (down, right) = (down as isize, right as isize);
67 OutputStoreKer {
68 ptr: self
69 .ptr
70 .offset(self.panel_row_byte_stride * down + self.panel_col_byte_stride * right)
71 as *mut _,
72 row_byte_stride: self.row_byte_stride,
73 col_byte_stride: self.col_byte_stride,
74 item_size: self.item_size,
75 }
76 }
77 }
78
79 #[inline]
80 pub fn item_size(&self) -> usize {
81 self.item_size
82 }
83
84 #[inline]
85 pub(super) unsafe fn set_from_tile(
86 &self,
87 down: usize,
88 right: usize,
89 height: usize,
90 width: usize,
91 tile: &OutputStoreKer,
92 ) {
93 unsafe {
94 if self.item_size() == 1 {
95 self.set_from_tile_t::<i8>(down, right, height, width, tile)
96 } else if self.item_size() == 2 {
97 self.set_from_tile_t::<i16>(down, right, height, width, tile)
98 } else if self.item_size() == 4 {
99 self.set_from_tile_t::<i32>(down, right, height, width, tile)
100 } else {
101 self.set_from_tile_t::<i64>(down, right, height, width, tile)
102 }
103 }
104 }
105
106 #[inline]
107 unsafe fn set_from_tile_t<T: Datum + Copy>(
108 &self,
109 down: usize,
110 right: usize,
111 height: usize,
112 width: usize,
113 tile: &OutputStoreKer,
114 ) {
115 unsafe {
116 let tile = tile.ptr as *mut T;
117 let dst = self.ptr.add(
118 self.panel_row_byte_stride as usize * down
119 + self.panel_col_byte_stride as usize * right,
120 );
121 for y in 0..height as isize {
122 for x in 0..width as isize {
123 let value = tile.offset(y + x * self.mr as isize);
124 let dst = dst.offset(y * self.row_byte_stride + x * self.col_byte_stride);
125 *(dst as *mut T) = *value;
126 }
127 }
128 }
129 }
130}
131
132#[repr(C)]
133#[derive(PartialEq, Eq, Copy, Clone, Debug)]
134pub struct OutputStoreKer {
135 pub ptr: *mut u8,
136 pub row_byte_stride: isize,
137 pub col_byte_stride: isize,
138 pub item_size: usize,
139}