tract_linalg/frame/mmm/
storage.rs

1use std::fmt::Debug;
2use tract_data::internal::*;
3
4#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
5pub enum OutputStoreSpec {
6    View { m_axis: usize, n_axis: usize, mr: usize, nr: usize },
7    Strides { row_byte_stride: isize, col_byte_stride: isize, mr: usize, nr: usize },
8}
9
10#[derive(Clone, Copy, Debug)]
11pub struct OutputStore {
12    pub(crate) ptr: *mut u8,
13    pub(crate) row_byte_stride: isize,
14    pub(crate) col_byte_stride: isize,
15    pub(crate) panel_row_byte_stride: isize,
16    pub(crate) panel_col_byte_stride: isize,
17    pub(crate) item_size: usize,
18    pub(crate) item_count: usize,
19    pub(crate) mr: usize,
20}
21
22unsafe impl Send for OutputStore {}
23unsafe impl Sync for OutputStore {}
24
25impl OutputStoreSpec {
26    #[inline]
27    pub unsafe fn wrap(&self, tensor: &TensorView) -> OutputStore {
28        let (mr, nr, row_byte_stride, col_byte_stride) = self.compute_strides(tensor);
29        OutputStore {
30            ptr: tensor.as_ptr_unchecked::<u8>() as _,
31            row_byte_stride,
32            col_byte_stride,
33            panel_row_byte_stride: row_byte_stride * mr as isize,
34            panel_col_byte_stride: col_byte_stride * nr as isize,
35            item_size: tensor.datum_type().size_of(),
36            mr,
37            item_count: tensor.len(),
38        }
39    }
40
41    #[inline]
42    unsafe fn compute_strides(&self, tensor: &TensorView) -> (usize, usize, isize, isize) {
43        let size_of = tensor.datum_type().size_of() as isize;
44        match self {
45            OutputStoreSpec::View { m_axis, n_axis, mr, nr, .. } => {
46                let tensor_strides = tensor.strides();
47                let row_item_stride = *tensor_strides.get_unchecked(*m_axis);
48                let col_item_stride = *tensor_strides.get_unchecked(*n_axis);
49                let row_byte_stride = row_item_stride * size_of;
50                let col_byte_stride = col_item_stride * size_of;
51                (*mr, *nr, row_byte_stride, col_byte_stride)
52            }
53            OutputStoreSpec::Strides { row_byte_stride, col_byte_stride, mr, nr, .. } => {
54                (*mr, *nr, *row_byte_stride, *col_byte_stride)
55            }
56        }
57    }
58}
59
60impl OutputStore {
61    #[inline]
62    pub(super) unsafe fn tile_c(&self, down: usize, right: usize) -> OutputStoreKer {
63        let (down, right) = (down as isize, right as isize);
64        OutputStoreKer {
65            ptr: self
66                .ptr
67                .offset(self.panel_row_byte_stride * down + self.panel_col_byte_stride * right)
68                as *mut _,
69            row_byte_stride: self.row_byte_stride,
70            col_byte_stride: self.col_byte_stride,
71            item_size: self.item_size,
72        }
73    }
74
75    #[inline]
76    pub fn item_size(&self) -> usize {
77        self.item_size
78    }
79
80    #[inline]
81    pub(super) unsafe fn set_from_tile(
82        &self,
83        down: usize,
84        right: usize,
85        height: usize,
86        width: usize,
87        tile: &OutputStoreKer,
88    ) {
89        if self.item_size() == 1 {
90            self.set_from_tile_t::<i8>(down, right, height, width, tile)
91        } else if self.item_size() == 2 {
92            self.set_from_tile_t::<i16>(down, right, height, width, tile)
93        } else if self.item_size() == 4 {
94            self.set_from_tile_t::<i32>(down, right, height, width, tile)
95        } else {
96            self.set_from_tile_t::<i64>(down, right, height, width, tile)
97        }
98    }
99
100    #[inline]
101    unsafe fn set_from_tile_t<T: Datum + Copy>(
102        &self,
103        down: usize,
104        right: usize,
105        height: usize,
106        width: usize,
107        tile: &OutputStoreKer,
108    ) {
109        let tile = tile.ptr as *mut T;
110        let dst = self.ptr.add(
111            self.panel_row_byte_stride as usize * down
112                + self.panel_col_byte_stride as usize * right,
113        );
114        for y in 0..height as isize {
115            for x in 0..width as isize {
116                let value = tile.offset(y + x * self.mr as isize);
117                let dst = dst.offset(y * self.row_byte_stride + x * self.col_byte_stride);
118                *(dst as *mut T) = *value;
119            }
120        }
121    }
122}
123
124#[repr(C)]
125#[derive(PartialEq, Eq, Copy, Clone, Debug)]
126pub struct OutputStoreKer {
127    pub ptr: *mut u8,
128    pub row_byte_stride: isize,
129    pub col_byte_stride: isize,
130    pub item_size: usize,
131}