1use std::fmt;
2use std::fmt::Debug;
3use tract_data::internal::*;
4
5use super::MMMInputValue;
6
7#[derive(Clone, PartialEq, Eq)]
12pub struct PackedMatrixStorage {
13 values: Vec<Box<dyn MMMInputValue>>,
14 batch_shape: TVec<usize>,
15 batch_strides: TVec<isize>,
16}
17
18impl PackedMatrixStorage {
19 pub fn new(value: Box<dyn MMMInputValue>) -> Self {
21 PackedMatrixStorage { values: vec![value], batch_shape: tvec![], batch_strides: tvec![] }
22 }
23
24 pub fn new_batched(shape: &[usize], values: Vec<Box<dyn MMMInputValue>>) -> Self {
26 let expected: usize = shape.iter().product();
27 assert_eq!(values.len(), expected, "values length must match shape product");
28 let strides = Self::compute_strides(shape);
29 PackedMatrixStorage { values, batch_shape: shape.into(), batch_strides: strides }
30 }
31
32 fn compute_strides(shape: &[usize]) -> TVec<isize> {
33 let mut strides: TVec<isize> = tvec![0; shape.len()];
34 if !shape.is_empty() {
35 strides[shape.len() - 1] = 1;
36 for i in (0..shape.len() - 1).rev() {
37 strides[i] = strides[i + 1] * shape[i + 1] as isize;
38 }
39 }
40 strides
41 }
42
43 #[inline]
45 pub fn value(&self) -> &dyn MMMInputValue {
46 debug_assert_eq!(self.values.len(), 1);
47 &*self.values[0]
48 }
49
50 pub fn value_at(&self, coords: &[usize]) -> &dyn MMMInputValue {
52 let idx = self.flat_index(coords);
53 &*self.values[idx]
54 }
55
56 #[inline]
58 pub fn value_at_flat(&self, idx: usize) -> &dyn MMMInputValue {
59 &*self.values[idx]
60 }
61
62 pub fn values(&self) -> &[Box<dyn MMMInputValue>] {
63 &self.values
64 }
65
66 pub fn batch_shape(&self) -> &[usize] {
67 &self.batch_shape
68 }
69
70 pub fn batch_strides(&self) -> &[isize] {
71 &self.batch_strides
72 }
73
74 pub fn into_tensor(self, dt: DatumType) -> Tensor {
76 let shape: TVec<usize> = self.batch_shape.clone();
77 Tensor::from_storage(dt, &shape, self)
78 }
79
80 fn flat_index(&self, coords: &[usize]) -> usize {
81 coords.iter().zip(self.batch_strides.iter()).map(|(c, s)| *c as isize * s).sum::<isize>()
82 as usize
83 }
84}
85
86impl fmt::Debug for PackedMatrixStorage {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 write!(f, "PackedMatrixStorage({} values, shape={:?})", self.values.len(), self.batch_shape)
89 }
90}
91
92impl fmt::Display for PackedMatrixStorage {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 write!(f, "PackedMatrixStorage({} values, shape={:?})", self.values.len(), self.batch_shape)
95 }
96}
97
98impl TensorStorage for PackedMatrixStorage {
99 fn byte_len(&self) -> usize {
100 self.values.len() * std::mem::size_of::<Box<dyn MMMInputValue>>()
102 }
103
104 fn is_empty(&self) -> bool {
105 self.values.is_empty()
106 }
107
108 fn deep_clone(&self) -> Box<dyn TensorStorage> {
109 Box::new(self.clone())
110 }
111
112 fn as_plain(&self) -> Option<&PlainStorage> {
113 None
114 }
115
116 fn as_plain_mut(&mut self) -> Option<&mut PlainStorage> {
117 None
118 }
119
120 fn into_plain(self: Box<Self>) -> Option<PlainStorage> {
121 None
122 }
123
124 fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
125 for v in &self.values {
126 v.dyn_hash(state);
127 }
128 }
129
130 fn exotic_fact(&self, _shape: &[usize]) -> TractResult<Option<Box<dyn ExoticFact>>> {
131 if self.values.len() == 1 {
132 Ok(Some(dyn_clone::clone_box(self.values[0].exotic_fact())))
133 } else {
134 let facts: TVec<Box<dyn ExoticFact>> =
135 self.values.iter().map(|v| dyn_clone::clone_box(v.exotic_fact())).collect();
136 Ok(Some(Box::new(facts)))
137 }
138 }
139}
140
141#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
142pub enum OutputStoreSpec {
143 View { m_axis: Option<usize>, n_axis: Option<usize>, mr: usize, nr: usize },
144 Strides { row_byte_stride: isize, col_byte_stride: isize, mr: usize, nr: usize },
145}
146
147#[derive(Clone, Copy, Debug)]
148pub struct OutputStore {
149 pub(crate) ptr: *mut u8,
150 pub(crate) row_byte_stride: isize,
151 pub(crate) col_byte_stride: isize,
152 pub(crate) panel_row_byte_stride: isize,
153 pub(crate) panel_col_byte_stride: isize,
154 pub(crate) item_size: usize,
155 pub(crate) item_count: usize,
156 pub(crate) mr: usize,
157}
158
159unsafe impl Send for OutputStore {}
160unsafe impl Sync for OutputStore {}
161
162impl OutputStoreSpec {
163 #[inline]
164 pub unsafe fn wrap(&self, tensor: &TensorView) -> OutputStore {
165 let (mr, nr, row_byte_stride, col_byte_stride) = unsafe { self.compute_strides(tensor) };
166 OutputStore {
167 ptr: unsafe { tensor.as_ptr_unchecked::<u8>() } as _,
168 row_byte_stride,
169 col_byte_stride,
170 panel_row_byte_stride: row_byte_stride * mr as isize,
171 panel_col_byte_stride: col_byte_stride * nr as isize,
172 item_size: tensor.datum_type().size_of(),
173 mr,
174 item_count: tensor.len(),
175 }
176 }
177
178 #[inline]
179 unsafe fn compute_strides(&self, tensor: &TensorView) -> (usize, usize, isize, isize) {
180 let size_of = tensor.datum_type().size_of() as isize;
181 match self {
182 OutputStoreSpec::View { m_axis, n_axis, mr, nr, .. } => {
183 let tensor_strides = tensor.strides();
184 let row_item_stride =
185 m_axis.map(|ax| *unsafe { tensor_strides.get_unchecked(ax) }).unwrap_or(0);
186 let col_item_stride =
187 n_axis.map(|ax| *unsafe { tensor_strides.get_unchecked(ax) }).unwrap_or(0);
188 let row_byte_stride = row_item_stride * size_of;
189 let col_byte_stride = col_item_stride * size_of;
190 (*mr, *nr, row_byte_stride, col_byte_stride)
191 }
192 OutputStoreSpec::Strides { row_byte_stride, col_byte_stride, mr, nr, .. } => {
193 (*mr, *nr, *row_byte_stride, *col_byte_stride)
194 }
195 }
196 }
197}
198
199impl OutputStore {
200 #[inline]
201 pub(super) unsafe fn tile_c(&self, down: usize, right: usize) -> OutputStoreKer {
202 unsafe {
203 let (down, right) = (down as isize, right as isize);
204 OutputStoreKer {
205 ptr: self
206 .ptr
207 .offset(self.panel_row_byte_stride * down + self.panel_col_byte_stride * right)
208 as *mut _,
209 row_byte_stride: self.row_byte_stride,
210 col_byte_stride: self.col_byte_stride,
211 item_size: self.item_size,
212 }
213 }
214 }
215
216 #[inline]
217 pub fn item_size(&self) -> usize {
218 self.item_size
219 }
220
221 #[inline]
222 pub(super) unsafe fn set_from_tile(
223 &self,
224 down: usize,
225 right: usize,
226 height: usize,
227 width: usize,
228 tile: &OutputStoreKer,
229 ) {
230 unsafe {
231 if self.item_size() == 1 {
232 self.set_from_tile_t::<i8>(down, right, height, width, tile)
233 } else if self.item_size() == 2 {
234 self.set_from_tile_t::<i16>(down, right, height, width, tile)
235 } else if self.item_size() == 4 {
236 self.set_from_tile_t::<i32>(down, right, height, width, tile)
237 } else {
238 self.set_from_tile_t::<i64>(down, right, height, width, tile)
239 }
240 }
241 }
242
243 #[inline]
244 unsafe fn set_from_tile_t<T: Datum + Copy>(
245 &self,
246 down: usize,
247 right: usize,
248 height: usize,
249 width: usize,
250 tile: &OutputStoreKer,
251 ) {
252 unsafe {
253 let tile = tile.ptr as *mut T;
254 let dst = self.ptr.add(
255 self.panel_row_byte_stride as usize * down
256 + self.panel_col_byte_stride as usize * right,
257 );
258 for y in 0..height as isize {
259 for x in 0..width as isize {
260 let value = tile.offset(y + x * self.mr as isize);
261 let dst = dst.offset(y * self.row_byte_stride + x * self.col_byte_stride);
262 *(dst as *mut T) = *value;
263 }
264 }
265 }
266 }
267}
268
269#[repr(C)]
270#[derive(PartialEq, Eq, Copy, Clone, Debug)]
271pub struct OutputStoreKer {
272 pub ptr: *mut u8,
273 pub row_byte_stride: isize,
274 pub col_byte_stride: isize,
275 pub item_size: usize,
276}