tract_linalg/frame/mmm/
storage.rs

1use std::fmt::Debug;
2use tract_data::internal::*;
3
4#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
5pub enum OutputStoreSpec {
6    View { m_axis: Option<usize>, n_axis: Option<usize>, mr: usize, nr: usize },
7    Strides { row_byte_stride: isize, col_byte_stride: isize, mr: usize, nr: usize },
8}
9
10#[derive(Clone, Copy, Debug)]
11pub struct OutputStore {
12    pub(crate) ptr: *mut u8,
13    pub(crate) row_byte_stride: isize,
14    pub(crate) col_byte_stride: isize,
15    pub(crate) panel_row_byte_stride: isize,
16    pub(crate) panel_col_byte_stride: isize,
17    pub(crate) item_size: usize,
18    pub(crate) item_count: usize,
19    pub(crate) mr: usize,
20}
21
22unsafe impl Send for OutputStore {}
23unsafe impl Sync for OutputStore {}
24
25impl OutputStoreSpec {
26    #[inline]
27    pub unsafe fn wrap(&self, tensor: &TensorView) -> OutputStore {
28        let (mr, nr, row_byte_stride, col_byte_stride) = unsafe { self.compute_strides(tensor) };
29        OutputStore {
30            ptr: unsafe { tensor.as_ptr_unchecked::<u8>() } as _,
31            row_byte_stride,
32            col_byte_stride,
33            panel_row_byte_stride: row_byte_stride * mr as isize,
34            panel_col_byte_stride: col_byte_stride * nr as isize,
35            item_size: tensor.datum_type().size_of(),
36            mr,
37            item_count: tensor.len(),
38        }
39    }
40
41    #[inline]
42    unsafe fn compute_strides(&self, tensor: &TensorView) -> (usize, usize, isize, isize) {
43        let size_of = tensor.datum_type().size_of() as isize;
44        match self {
45            OutputStoreSpec::View { m_axis, n_axis, mr, nr, .. } => {
46                let tensor_strides = tensor.strides();
47                let row_item_stride =
48                    m_axis.map(|ax| *unsafe { tensor_strides.get_unchecked(ax) }).unwrap_or(0);
49                let col_item_stride =
50                    n_axis.map(|ax| *unsafe { tensor_strides.get_unchecked(ax) }).unwrap_or(0);
51                let row_byte_stride = row_item_stride * size_of;
52                let col_byte_stride = col_item_stride * size_of;
53                (*mr, *nr, row_byte_stride, col_byte_stride)
54            }
55            OutputStoreSpec::Strides { row_byte_stride, col_byte_stride, mr, nr, .. } => {
56                (*mr, *nr, *row_byte_stride, *col_byte_stride)
57            }
58        }
59    }
60}
61
62impl OutputStore {
63    #[inline]
64    pub(super) unsafe fn tile_c(&self, down: usize, right: usize) -> OutputStoreKer {
65        unsafe {
66            let (down, right) = (down as isize, right as isize);
67            OutputStoreKer {
68                ptr: self
69                    .ptr
70                    .offset(self.panel_row_byte_stride * down + self.panel_col_byte_stride * right)
71                    as *mut _,
72                row_byte_stride: self.row_byte_stride,
73                col_byte_stride: self.col_byte_stride,
74                item_size: self.item_size,
75            }
76        }
77    }
78
79    #[inline]
80    pub fn item_size(&self) -> usize {
81        self.item_size
82    }
83
84    #[inline]
85    pub(super) unsafe fn set_from_tile(
86        &self,
87        down: usize,
88        right: usize,
89        height: usize,
90        width: usize,
91        tile: &OutputStoreKer,
92    ) {
93        unsafe {
94            if self.item_size() == 1 {
95                self.set_from_tile_t::<i8>(down, right, height, width, tile)
96            } else if self.item_size() == 2 {
97                self.set_from_tile_t::<i16>(down, right, height, width, tile)
98            } else if self.item_size() == 4 {
99                self.set_from_tile_t::<i32>(down, right, height, width, tile)
100            } else {
101                self.set_from_tile_t::<i64>(down, right, height, width, tile)
102            }
103        }
104    }
105
106    #[inline]
107    unsafe fn set_from_tile_t<T: Datum + Copy>(
108        &self,
109        down: usize,
110        right: usize,
111        height: usize,
112        width: usize,
113        tile: &OutputStoreKer,
114    ) {
115        unsafe {
116            let tile = tile.ptr as *mut T;
117            let dst = self.ptr.add(
118                self.panel_row_byte_stride as usize * down
119                    + self.panel_col_byte_stride as usize * right,
120            );
121            for y in 0..height as isize {
122                for x in 0..width as isize {
123                    let value = tile.offset(y + x * self.mr as isize);
124                    let dst = dst.offset(y * self.row_byte_stride + x * self.col_byte_stride);
125                    *(dst as *mut T) = *value;
126                }
127            }
128        }
129    }
130}
131
132#[repr(C)]
133#[derive(PartialEq, Eq, Copy, Clone, Debug)]
134pub struct OutputStoreKer {
135    pub ptr: *mut u8,
136    pub row_byte_stride: isize,
137    pub col_byte_stride: isize,
138    pub item_size: usize,
139}