1use std::alloc::Layout;
2use std::fmt::{Debug, Display};
3use std::marker::PhantomData;
4use std::ops::Range;
5use tract_data::internal::*;
6
7use crate::mmm::{
8 EagerPackedInput, MMMInputFormat, MMMInputValue, PackedExoticFact, PackedMatrixStorage,
9};
10
11use crate::WeightType;
12
13#[derive(Clone, Eq, PartialEq, Hash)]
14pub struct PackedFormat {
15 pub dt: DatumType,
16 pub r: usize,
17 pub alignment_bytes: usize,
18 pub end_padding_record: usize,
19}
20
21impl MMMInputFormat for PackedFormat {
22 fn prepare_tensor(&self, t: &Tensor, k_axis: usize, mn_axis: usize) -> TractResult<Tensor> {
23 let packed = PackedFormat::pack_tensor(self, t, k_axis, mn_axis)?;
24 Ok(PackedMatrixStorage::new(packed).into_tensor(t.datum_type()))
25 }
26
27 fn prepare_one(
28 &self,
29 t: &Tensor,
30 k_axis: usize,
31 mn_axis: usize,
32 ) -> TractResult<Box<dyn MMMInputValue>> {
33 PackedFormat::pack_tensor(self, t, k_axis, mn_axis)
34 }
35
36 fn precursor(&self) -> WeightType {
37 WeightType::Plain(self.dt)
38 }
39
40 fn r(&self) -> usize {
41 self.r
42 }
43
44 fn k_alignment(&self) -> usize {
45 1
46 }
47
48 #[allow(clippy::collapsible_if)]
49 fn merge_with<'o, 'a: 'o, 'b: 'o>(
50 &'a self,
51 other: &'b dyn MMMInputFormat,
52 ) -> Option<&'o dyn MMMInputFormat> {
53 if let Some(other) = other.downcast_ref::<PackedFormat>() {
54 if self.r == other.r && self.dt == other.dt {
55 if self.alignment_bytes % other.alignment_bytes == 0
56 && self.end_padding_record >= other.end_padding_record
57 {
58 return Some(self);
59 }
60 if other.alignment_bytes % self.alignment_bytes == 0
61 && other.end_padding_record >= self.end_padding_record
62 {
63 return Some(other);
64 }
65 }
66 }
67 None
68 }
69
70 fn mem_size(&self, k: TDim, mn: TDim) -> TDim {
71 self.len(k, mn) * self.dt.size_of()
72 }
73
74 fn extract_at_mn_f16(
75 &self,
76 data: &EagerPackedInput,
77 mn: usize,
78 slice: &mut [f16],
79 ) -> TractResult<()> {
80 ensure!(data.format().dyn_eq(self));
81 ensure!(self.len(data.k(), data.mn()) * self.dt.size_of() == data.packed.len());
82 unsafe {
83 let ptr = data.packed.as_ptr().add(
84 (self.single_panel_len(data.k()) * (mn / self.r) + mn % self.r) * self.dt.size_of(),
85 );
86 for (i, slot) in slice.iter_mut().enumerate() {
87 let ptr = ptr.add(i * self.dt.size_of() * self.r);
88 *slot = if self.dt == f16::datum_type() {
89 *(ptr as *const f16)
90 } else if self.dt == f32::datum_type() {
91 f16::from_f32(*(ptr as *const f32))
92 } else {
93 bail!("Unexpected DT {:?}", self.dt)
94 }
95 }
96 }
97 Ok(())
98 }
99
100 fn extract_at_mn_f32(
101 &self,
102 data: &EagerPackedInput,
103 mn: usize,
104 slice: &mut [f32],
105 ) -> TractResult<()> {
106 ensure!(data.format().dyn_eq(self));
107 ensure!(self.len(data.k(), data.mn()) * self.dt.size_of() == data.packed.len());
108 unsafe {
109 let ptr = data.packed.as_ptr().add(
110 (self.single_panel_len(data.k()) * (mn / self.r) + mn % self.r) * self.dt.size_of(),
111 );
112 for (i, slot) in slice.iter_mut().enumerate() {
113 let ptr = ptr.add(i * self.dt.size_of() * self.r);
114 *slot = if self.dt == f16::datum_type() {
115 (*(ptr as *const f16)).to_f32()
116 } else if self.dt == f32::datum_type() {
117 *(ptr as *const f32)
118 } else {
119 bail!("Unexpected DT {:?}", self.dt)
120 }
121 }
122 }
123 Ok(())
124 }
125}
126
127impl Display for PackedFormat {
128 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129 write!(f, "Packed{:?}[{}]", self.dt, self.r)
130 }
131}
132
133impl Debug for PackedFormat {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 write!(
136 f,
137 "Packed{:?}[{}]@{}+{}",
138 self.dt, self.r, self.alignment_bytes, self.end_padding_record
139 )
140 }
141}
142
143impl PackedFormat {
144 pub const fn new(dt: DatumType, nr: usize, alignment_bytes: usize) -> PackedFormat {
145 PackedFormat { dt, r: nr, alignment_bytes, end_padding_record: 1 }
146 }
147
148 pub const fn with_end_padding_record(self, end_padding_record: usize) -> Self {
149 PackedFormat { end_padding_record, ..self }
150 }
151
152 #[inline]
153 pub fn align(self, alignment: usize) -> Self {
154 Self { alignment_bytes: alignment, ..self }
155 }
156
157 #[inline]
158 pub fn alignment(&self) -> usize {
159 self.alignment_bytes
160 }
161
162 #[inline]
163 pub fn panel_width(&self) -> usize {
164 self.r
165 }
166
167 #[inline]
168 pub fn len<D: DimLike>(&self, k: D, n: D) -> D {
169 n.divceil(self.r) * self.single_panel_len(k)
170 }
171
172 #[inline]
173 pub fn single_panel_len<D: DimLike>(&self, k: D) -> D {
174 ((k + self.end_padding_record) * self.r).divceil(self.alignment()) * self.alignment()
175 }
176
177 #[inline]
178 pub fn single_panel_layout(&self, k: usize, item_size: usize) -> Layout {
179 Layout::from_size_align(self.single_panel_len(k) * item_size, self.alignment()).unwrap()
180 }
181
182 pub fn pack_tensor(
183 &self,
184 t: &Tensor,
185 k_axis: usize,
186 mn_axis: usize,
187 ) -> TractResult<Box<dyn MMMInputValue>> {
188 ensure!(t.datum_type().is_copy());
189 ensure!(
190 t.datum_type().unquantized() == self.dt.unquantized(),
191 "Attempting to pack for {self} tensor {t:?}"
192 );
193 let k = t.shape()[k_axis];
194 let mn = t.shape()[mn_axis];
195 let packed_len = self.len(k, mn);
196 let panel_len = self.single_panel_len(k);
197 let panel_bytes = panel_len * t.datum_type().size_of();
198 let strides = t.strides();
199 unsafe {
200 let mut packed = Blob::new_for_size_and_align(
201 t.datum_type().size_of() * packed_len,
202 self.alignment_bytes,
203 );
204 if cfg!(debug_assertions) {
205 packed.as_bytes_mut().fill(0u8);
206 }
207 dispatch_copy!(Self::pack_t(t.datum_type())(
208 self,
209 packed.as_mut_ptr() as _,
210 t.as_ptr_unchecked(),
211 mn,
212 strides[k_axis],
213 strides[mn_axis],
214 0..k,
215 0..mn
216 ));
217 Ok(Box::new(EagerPackedInput {
218 fact: PackedExoticFact { format: Box::new(self.clone()), mn: mn.to_dim(), k },
219 packed: packed.into(),
220 panel_bytes,
221 mn,
222 }))
223 }
224 }
225
226 pub fn pack_tensor_view(
227 &self,
228 t: &TensorView,
229 k_axis: usize,
230 mn_axis: usize,
231 ) -> TractResult<Box<dyn MMMInputValue>> {
232 ensure!(
233 t.datum_type().unquantized() == self.dt.unquantized(),
234 "Attempting to pack for {self} tensor view {t:?}"
235 );
236 let k = t.shape()[k_axis];
237 let mn = t.shape()[mn_axis];
238 let packed_len = self.len(k, mn);
239 let panel_len = self.single_panel_len(k);
240 let panel_bytes = panel_len * t.datum_type().size_of();
241 let strides = t.strides();
242 unsafe {
243 let mut packed = Blob::new_for_size_and_align(
244 t.datum_type().size_of() * packed_len,
245 self.alignment_bytes,
246 );
247 if cfg!(debug_assertions) {
248 packed.as_bytes_mut().fill(0u8);
249 }
250 dispatch_copy!(Self::pack_t(t.datum_type())(
251 self,
252 packed.as_mut_ptr() as _,
253 t.as_ptr_unchecked(),
254 mn,
255 strides[k_axis],
256 strides[mn_axis],
257 0..k,
258 0..mn
259 ));
260 Ok(Box::new(EagerPackedInput {
261 fact: PackedExoticFact { format: Box::new(self.clone()), mn: mn.to_dim(), k },
262 packed: packed.into(),
263 panel_bytes,
264 mn,
265 }))
266 }
267 }
268
269 pub unsafe fn pack<'a, 'b>(
270 &self,
271 pb: impl std::borrow::BorrowMut<TensorView<'a>>,
272 b: impl std::borrow::Borrow<TensorView<'b>>,
273 k_axis: usize,
274 mn_axis: usize,
275 ) {
276 let k = b.borrow().shape()[k_axis];
277 let mn = b.borrow().shape()[mn_axis];
278 unsafe { self.pack_segment(pb, b, k_axis, mn_axis, 0..k, 0..mn) };
279 }
280
281
282 #[allow(clippy::too_many_arguments)]
283 #[rustfmt::skip]
284 pub unsafe fn pack_t<T: Datum + Copy>(
285 &self,
286 pb: *mut T,
287 b: *const T,
288 mn: usize,
289 k_stride: isize,
290 mn_stride: isize,
291 k_range: Range<usize>,
292 mn_range: Range<usize>,
293 ) { unsafe {
294 if k_range.len() == 0 || mn_range.len() == 0 {
295 return
296 }
297 if self.r == 1 && k_stride == 1 && mn == 1 {
298 pb.copy_from_nonoverlapping(b.add(k_range.start), k_range.len())
299 } else if mn_stride == 1 {
300 let size_of = T::datum_type().size_of();
301 let rbytes = self.r * size_of;
302 let mn_valid_end = mn_range.end.min(mn);
303 let mn_range_bytes = mn_range.start * size_of..mn_valid_end * size_of;
304 let k_stride_bytes = k_stride * size_of as isize;
305 let bb = b as *const u8;
306 let pbb = pb as *mut u8;
307 let panel_len = self.single_panel_len(k_range.len()) * size_of;
308 match rbytes {
309 16 => pack_mn_major::<[u8; 16]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
310 24 => pack_mn_major::<[u8; 24]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
311 32 => pack_mn_major::<[u8; 32]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
312 48 => pack_mn_major::<[u8; 48]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
313 64 => pack_mn_major::<[u8; 64]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
314 96 => pack_mn_major::<[u8; 96]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
315 128 => pack_mn_major::<[u8; 128]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
316 _ => {
317 let mut packer = self.write_with_k_outer(pb, k_range.len(), mn_range.len());
318 for k in k_range {
319 for x in mn_range.start..mn_valid_end {
320 packer.write(*b.offset(x as isize + k_stride * k as isize))
321 }
322 for _x in mn_valid_end..mn_range.end {
323 packer.write(T::default())
324 }
325 }
326 }
327 }
328 } else if k_stride == 1 {
329 let mut packer = self.write_with_k_inner(pb, k_range.len(), mn);
330 let mn_valid_end = mn_range.end.min(mn);
331 for x in mn_range.start..mn_valid_end {
332 for k in k_range.clone() {
333 packer.write(*b.offset(x as isize * mn_stride + k as isize))
334 }
335 }
336 } else {
338 let mut packer = self.write_with_k_outer(pb, k_range.len(), mn);
339 let mn_valid_end = mn_range.end.min(mn);
340 for k in k_range {
341 for x in mn_range.start..mn_valid_end {
342 packer.write(*b.offset(x as isize * mn_stride + k_stride * k as isize))
343 }
344 for _x in mn_valid_end..mn_range.end {
345 packer.write(T::default())
346 }
347 }
348 }
349 }}
350
351 #[inline]
352 pub unsafe fn pack_segment<'a, 'b>(
353 &self,
354 mut pb: impl std::borrow::BorrowMut<TensorView<'a>>,
355 b: impl std::borrow::Borrow<TensorView<'b>>,
356 k_axis: usize,
357 mn_axis: usize,
358 k_range: Range<usize>,
359 mn_range: Range<usize>,
360 ) {
361 debug_assert!(pb.borrow().len() >= self.len(k_range.len(), mn_range.len()));
362 let pb = pb.borrow_mut();
363 let b = b.borrow();
364 let dt = pb.datum_type();
365 unsafe {
366 dispatch_copy!(Self::pack_t(dt)(
367 self,
368 pb.as_ptr_mut_unchecked(),
369 b.as_ptr_unchecked(),
370 b.shape()[mn_axis],
371 b.strides()[k_axis],
372 b.strides()[mn_axis],
373 k_range,
374 mn_range
375 ));
376 }
377 }
378
379 pub fn write_with_k_outer<'p, T: Copy + Debug>(
380 &self,
381 pb: *mut T,
382 k: usize,
383 mn: usize,
384 ) -> KOutWriter<'p, T> {
385 KOutWriter::new(pb, self.r, self.single_panel_len(k), mn, k)
386 }
387
388 pub fn write_single_panel_with_k_outer<'p, T: Copy + Debug>(
389 &self,
390 pb: *mut T,
391 ) -> KOutSinglePanelWriter<'p, T> {
392 KOutSinglePanelWriter::new(pb)
393 }
394
395 pub fn write_with_k_inner<'p, T: Copy + Debug>(
396 &self,
397 pb: *mut T,
398 k: usize,
399 mn: usize,
400 ) -> KInWriter<'p, T> {
401 let panel_len = self.single_panel_len(k);
402 KInWriter::new(pb, panel_len, self.r, mn, k)
403 }
404}
405
406pub trait PackingWriter<T: Copy> {
407 fn write(&mut self, t: T);
408
409 #[inline]
416 fn write_slice(&mut self, ts: &[T]) {
417 for t in ts {
418 self.write(*t);
419 }
420 }
421}
422
423#[derive(Debug)]
424pub struct KOutSinglePanelWriter<'p, T>
425where
426 T: Copy + std::fmt::Debug,
427{
428 ptr: *mut T,
429 _phantom: PhantomData<&'p T>,
430}
431
432impl<'p, T> KOutSinglePanelWriter<'p, T>
433where
434 T: Copy + std::fmt::Debug,
435{
436 pub fn new(ptr: *mut T) -> KOutSinglePanelWriter<'p, T> {
437 KOutSinglePanelWriter { ptr, _phantom: PhantomData }
438 }
439}
440
441impl<T> PackingWriter<T> for KOutSinglePanelWriter<'_, T>
442where
443 T: Copy + std::fmt::Debug,
444{
445 #[inline(always)]
446 fn write(&mut self, t: T) {
447 unsafe {
448 *self.ptr = t;
449 self.ptr = self.ptr.offset(1);
450 }
451 }
452
453 #[inline]
454 fn write_slice(&mut self, ts: &[T]) {
455 unsafe {
459 std::ptr::copy_nonoverlapping(ts.as_ptr(), self.ptr, ts.len());
460 self.ptr = self.ptr.add(ts.len());
461 }
462 }
463}
464
465#[derive(Debug)]
466pub struct KOutWriter<'p, T>
467where
468 T: Copy + std::fmt::Debug,
469{
470 ptr: *mut T,
471 panels: usize,
472 panel_width: usize,
473 last_panel_width: usize,
474 remain: usize,
475 current_panel: usize,
476 next_panel: isize,
477 next_lane: isize,
478 _phantom: PhantomData<&'p T>,
479}
480
481impl<'p, T> KOutWriter<'p, T>
482where
483 T: Copy + std::fmt::Debug,
484{
485 pub fn new(
486 ptr: *mut T,
487 panel_width: usize,
488 panel_len: usize,
489 mn: usize,
490 _k: usize,
491 ) -> KOutWriter<'p, T> {
492 let panels = mn.divceil(panel_width);
493 let last_panel_width = mn - (panels - 1) * panel_width;
494 KOutWriter {
495 ptr,
496 panels,
497 panel_width,
498 last_panel_width,
499 remain: if panels > 1 { panel_width } else { last_panel_width },
500 current_panel: 0,
501 next_panel: (panel_len - panel_width) as isize,
502 next_lane: (panel_width - last_panel_width) as isize
503 - (panel_len * (panels - 1)) as isize,
504 _phantom: PhantomData,
505 }
506 }
507}
508
509impl<T> PackingWriter<T> for KOutWriter<'_, T>
510where
511 T: Copy + std::fmt::Debug,
512{
513 #[inline(always)]
514 fn write(&mut self, t: T) {
515 unsafe {
516 *self.ptr = t;
517 self.remain -= 1;
518 self.ptr = self.ptr.offset(1);
519 if self.remain == 0 {
520 self.current_panel += 1;
521 if self.current_panel == self.panels {
522 self.ptr = self.ptr.offset(self.next_lane);
523 self.current_panel = 0;
524 } else {
525 self.ptr = self.ptr.offset(self.next_panel);
526 }
527 if self.current_panel == self.panels - 1 {
528 self.remain = self.last_panel_width;
529 } else {
530 self.remain = self.panel_width;
531 }
532 }
533 }
534 }
535
536 #[inline]
537 fn write_slice(&mut self, ts: &[T]) {
538 let n = ts.len();
546 if n == 0 {
547 return;
548 }
549 if n < self.remain {
550 unsafe {
552 std::ptr::copy_nonoverlapping(ts.as_ptr(), self.ptr, n);
553 self.ptr = self.ptr.add(n);
554 }
555 self.remain -= n;
556 } else if n == self.remain {
557 unsafe {
563 std::ptr::copy_nonoverlapping(ts.as_ptr(), self.ptr, n);
564 self.ptr = self.ptr.add(n);
565 self.current_panel += 1;
566 if self.current_panel == self.panels {
567 self.ptr = self.ptr.offset(self.next_lane);
568 self.current_panel = 0;
569 } else {
570 self.ptr = self.ptr.offset(self.next_panel);
571 }
572 if self.current_panel == self.panels - 1 {
573 self.remain = self.last_panel_width;
574 } else {
575 self.remain = self.panel_width;
576 }
577 }
578 } else {
579 for t in ts {
582 self.write(*t);
583 }
584 }
585 }
586}
587
588#[derive(Debug)]
589pub struct KInWriter<'p, T>
590where
591 T: Copy + Debug,
592{
593 ptr: *mut T,
594 k: usize,
595 panels: usize,
596 panel_width: usize,
597 last_panel_width: usize,
598 remain_on_k: usize,
599 remain_on_mn: usize,
600 current_panel: usize,
601 next_mn_offset: isize,
602 next_panel_offset: isize,
603 _phantom: PhantomData<&'p T>,
604}
605
606impl<'p, T> KInWriter<'p, T>
607where
608 T: Copy + Debug,
609{
610 pub fn new(
611 ptr: *mut T,
612 panel_len: usize,
613 panel_width: usize,
614 mn: usize,
615 k: usize,
616 ) -> KInWriter<'p, T> {
617 let panels = mn.divceil(panel_width);
618 let last_panel_width = mn - (panels - 1) * panel_width;
619 KInWriter {
620 ptr,
621 k,
622 panels,
623 panel_width,
624 last_panel_width,
625 remain_on_k: k,
626 remain_on_mn: if panels == 1 { last_panel_width } else { panel_width },
627 current_panel: 0,
628 next_mn_offset: 1 - (k * panel_width) as isize,
629 next_panel_offset: panel_len as isize - (k * panel_width + panel_width - 1) as isize,
630 _phantom: PhantomData,
632 }
633 }
634}
635
636impl<T> PackingWriter<T> for KInWriter<'_, T>
637where
638 T: Copy + std::fmt::Debug,
639{
640 #[inline(always)]
641 fn write(&mut self, t: T) {
642 unsafe {
643 *self.ptr = t;
644 self.remain_on_k -= 1;
645 self.ptr = self.ptr.add(self.panel_width);
646 if self.remain_on_k == 0 {
647 self.remain_on_k = self.k;
648 self.remain_on_mn -= 1;
649 if self.remain_on_mn > 0 {
650 self.ptr = self.ptr.offset(self.next_mn_offset);
651 } else {
652 self.ptr = self.ptr.offset(self.next_panel_offset);
653 self.current_panel += 1;
654 if self.current_panel == self.panels - 1 {
655 self.remain_on_mn = self.last_panel_width;
656 } else {
657 self.remain_on_mn = self.panel_width;
658 }
659 }
660 }
661 }
662 }
663}
664
665#[inline(never)]
666unsafe fn pack_mn_major<Chunk: Copy>(
667 b: *const u8,
668 packed: *mut u8,
669 panel_len: usize,
670 k_stride_bytes: isize,
671 mn_range_bytes: Range<usize>,
672 k_range: Range<usize>,
673) {
674 unsafe {
675 let mnr = std::mem::size_of::<Chunk>();
676 let full_panes = mn_range_bytes.len() / mnr;
677 let partial_pane = mn_range_bytes.len() % mnr;
678 for k in 0..k_range.len() {
679 let mut p_row = packed.add(k * mnr);
680 let mut b_row = b.offset(
681 (k_range.start + k) as isize * k_stride_bytes + mn_range_bytes.start as isize,
682 );
683 for _ in 0..full_panes {
684 p_row.copy_from_nonoverlapping(b_row, mnr);
685 p_row = p_row.add(panel_len);
686 b_row = b_row.add(mnr);
687 }
688 if partial_pane > 0 {
689 p_row.copy_from_nonoverlapping(b_row, partial_pane);
690 }
691 }
692 }
693}
694
695#[derive(Debug)]
700pub struct KOut4Writer<'p, T>
701where
702 T: Copy + std::fmt::Debug,
703{
704 base: *mut T,
705 r4: usize, panel_len: usize, panels: usize,
708 panel_width: usize,
709 last_panel_width: usize,
710 kb: usize, kr: usize, panel: usize,
713 local_mn: usize,
714 _phantom: PhantomData<&'p T>,
715}
716
717impl<'p, T> KOut4Writer<'p, T>
718where
719 T: Copy + std::fmt::Debug,
720{
721 pub fn new(base: *mut T, r: usize, panel_len: usize, mn: usize) -> KOut4Writer<'p, T> {
722 let panels = mn.divceil(r).max(1);
723 let last_panel_width = mn - (panels - 1) * r;
724 KOut4Writer {
725 base,
726 r4: r * 4,
727 panel_len,
728 panels,
729 panel_width: r,
730 last_panel_width,
731 kb: 0,
732 kr: 0,
733 panel: 0,
734 local_mn: 0,
735 _phantom: PhantomData,
736 }
737 }
738 #[inline(always)]
739 fn panel_width(&self) -> usize {
740 if self.panel == self.panels - 1 { self.last_panel_width } else { self.panel_width }
741 }
742 #[inline(always)]
743 fn advance(&mut self, by: usize) {
744 self.local_mn += by;
745 if self.local_mn >= self.panel_width() {
746 self.local_mn = 0;
747 self.panel += 1;
748 if self.panel == self.panels {
749 self.panel = 0;
750 self.kr += 1;
751 if self.kr == 4 {
752 self.kr = 0;
753 self.kb += 1;
754 }
755 }
756 }
757 }
758}
759
760impl<T> PackingWriter<T> for KOut4Writer<'_, T>
761where
762 T: Copy + std::fmt::Debug,
763{
764 #[inline(always)]
765 fn write(&mut self, t: T) {
766 unsafe {
767 let off = self.panel * self.panel_len + self.kb * self.r4 + self.local_mn * 4 + self.kr;
768 *self.base.add(off) = t;
769 }
770 self.advance(1);
771 }
772
773 #[inline]
774 fn write_slice(&mut self, ts: &[T]) {
775 let n = ts.len();
776 if n == 0 {
777 return;
778 }
779 let pw = self.panel_width();
780 if self.local_mn + n <= pw {
781 unsafe {
783 let mut d = self.base.add(
784 self.panel * self.panel_len + self.kb * self.r4 + self.local_mn * 4 + self.kr,
785 );
786 for &t in ts {
787 *d = t;
788 d = d.add(4);
789 }
790 }
791 self.advance(n);
792 } else {
793 for &t in ts {
794 self.write(t);
795 }
796 }
797 }
798}
799
800#[derive(Clone, Debug, Hash, PartialEq, Eq)]
804pub struct PackedI8K4 {
805 pub r: usize,
806 pub align: usize,
807}
808impl PackedI8K4 {
809 pub fn new(r: usize) -> Self {
810 PackedI8K4 { r, align: 16 }
811 }
812 fn panel(&self, k: usize) -> usize {
813 (k.div_ceil(4) * 4) * self.r
814 }
815 pub fn single_panel_len(&self, k: usize) -> usize {
816 self.panel(k)
817 }
818 pub fn len(&self, k: usize, mn: usize) -> usize {
819 mn.divceil(self.r) * self.panel(k)
820 }
821 pub fn alignment(&self) -> usize {
822 self.align
823 }
824 pub fn write_with_k_outer<'p, T: Copy + std::fmt::Debug>(
826 &self,
827 pb: *mut T,
828 k: usize,
829 mn: usize,
830 ) -> KOut4Writer<'p, T> {
831 KOut4Writer::new(pb, self.r, self.panel(k), mn)
832 }
833 pub fn pack_view(
835 &self,
836 t: &TensorView,
837 k_axis: usize,
838 mn_axis: usize,
839 ) -> TractResult<Box<dyn MMMInputValue>> {
840 let k = t.shape()[k_axis];
841 let mn = t.shape()[mn_axis];
842 let kp = k.div_ceil(4) * 4;
843 let pl = kp * self.r;
844 let panels = mn.div_ceil(self.r);
845 let st = t.strides();
846 let mut blob = unsafe { Blob::new_for_size_and_align(panels * pl, self.align) };
847 blob.as_bytes_mut().fill(0);
848 let (ks, ms) = (st[k_axis], st[mn_axis]);
849 let kblocks = kp / 4;
850 unsafe {
851 let src = t.as_ptr_unchecked::<i8>();
852 let dst = blob.as_mut_ptr() as *mut i8;
853 for p in 0..panels {
854 let pw = self.r.min(mn - p * self.r);
855 let panel = dst.add(p * pl);
856 let mn0 = (p * self.r) as isize;
857 for kb in 0..kblocks {
858 for kr in 0..4 {
859 let kk = kb * 4 + kr;
860 if kk >= k {
861 break;
862 }
863 let srow = src.offset(kk as isize * ks + mn0 * ms);
864 let dcol = panel.add(kb * self.r * 4 + kr);
865 for lm in 0..pw {
866 *dcol.add(lm * 4) = *srow.offset(lm as isize * ms);
867 }
868 }
869 }
870 }
871 }
872 Ok(Box::new(EagerPackedInput {
873 fact: PackedExoticFact { format: Box::new(self.clone()), mn: mn.to_dim(), k },
874 packed: blob.into(),
875 panel_bytes: pl,
876 mn,
877 }))
878 }
879}
880impl std::fmt::Display for PackedI8K4 {
881 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
882 write!(f, "I8K4[{}]", self.r)
883 }
884}
885impl MMMInputFormat for PackedI8K4 {
886 fn prepare_tensor(&self, t: &Tensor, k_axis: usize, mn_axis: usize) -> TractResult<Tensor> {
887 Ok(PackedMatrixStorage::new(self.prepare_one(t, k_axis, mn_axis)?)
888 .into_tensor(t.datum_type()))
889 }
890 fn prepare_one(
891 &self,
892 t: &Tensor,
893 k_axis: usize,
894 mn_axis: usize,
895 ) -> TractResult<Box<dyn MMMInputValue>> {
896 self.pack_view(&t.view(), k_axis, mn_axis)
897 }
898 fn precursor(&self) -> WeightType {
899 WeightType::Plain(i8::datum_type())
900 }
901 fn r(&self) -> usize {
902 self.r
903 }
904 fn k_alignment(&self) -> usize {
905 4
906 }
907 fn merge_with<'o, 'a: 'o, 'b: 'o>(
908 &'a self,
909 o: &'b dyn MMMInputFormat,
910 ) -> Option<&'o dyn MMMInputFormat> {
911 o.downcast_ref::<PackedI8K4>().filter(|x| x.r == self.r).map(|_| self as _)
912 }
913 fn mem_size(&self, k: TDim, mn: TDim) -> TDim {
914 mn.divceil(self.r) * self.panel(k.to_usize().unwrap_or(0))
915 }
916 fn extract_at_mn_f16(&self, _: &EagerPackedInput, _: usize, _: &mut [f16]) -> TractResult<()> {
917 bail!("no f16 extract")
918 }
919 fn extract_at_mn_f32(&self, _: &EagerPackedInput, _: usize, _: &mut [f32]) -> TractResult<()> {
920 bail!("no f32 extract")
921 }
922}
923
924pub trait Packing {
925 fn packing(r: usize) -> PackedFormat;
926}
927
928impl<D: Datum> Packing for D {
929 fn packing(r: usize) -> PackedFormat {
930 PackedFormat::new(Self::datum_type(), r, vector_size())
931 }
932}
933
934#[cfg(test)]
935mod test {
936 use std::ops::Range;
937
938 use proptest::prelude::*;
939 use tract_data::internal::num_integer::Integer;
940 use tract_data::internal::tract_ndarray::Zip;
941 use tract_data::internal::*;
942 use tract_ndarray::prelude::*;
943
944 #[derive(Debug)]
945 struct PackProblem {
946 k: usize,
947 mn: usize,
948 is_a: bool,
949 r: usize,
950 k_range: Range<usize>,
951 mn_range: Range<usize>,
952 align_panel: usize,
953 }
954
955 impl PackProblem {
956 fn input(&self) -> Array2<u32> {
957 let shape = if self.is_a { (self.mn, self.k) } else { (self.k, self.mn) };
958 let data = (0..(self.k * self.mn) as u32).collect();
959 Array2::from_shape_vec(shape, data).unwrap()
960 }
961
962 fn packer(&self) -> Array2<u32> {
963 let panels = self.mn_range.len().divceil(self.r);
964 let packer = super::PackedFormat::new(u32::datum_type(), self.r, self.align_panel)
965 .with_end_padding_record(0);
966 let input = self.input().into_tensor();
967 let panel_len = packer.single_panel_len(self.k_range.len());
968 let mut output =
969 Tensor::zero::<u32>(&[packer.len(self.k_range.len(), self.mn_range.len())])
970 .unwrap();
971 unsafe {
972 packer.pack_segment(
973 output.view_mut(),
974 input.view(),
975 self.is_a as usize,
976 !self.is_a as usize,
977 self.k_range.clone(),
978 self.mn_range.clone(),
979 )
980 };
981 output
982 .into_plain_array::<u32>()
983 .unwrap()
984 .into_shape_with_order((panels, panel_len))
985 .unwrap()
986 }
987
988 fn reference(&self) -> Array2<u32> {
989 let input = self.input();
990 let panels = self.mn_range.len().divceil(self.r);
991 let len = Integer::next_multiple_of(&(self.k_range.len() * self.r), &self.align_panel);
992 Array2::from_shape_fn([panels, len], |(panel, z)| {
993 let k = z / self.r;
994 let x = z % self.r;
995 let mn = panel * self.r + x + self.mn_range.start;
996 let k = k + self.k_range.start;
997 let coords = if self.is_a { (mn, k) } else { (k, mn) };
998 *input.get(coords).unwrap_or(&0)
999 })
1000 }
1001
1002 fn valid(&self) -> Array2<bool> {
1003 let panels = self.mn_range.len().divceil(self.r);
1004 let len = Integer::next_multiple_of(&(self.k_range.len() * self.r), &self.align_panel);
1005 Array2::from_shape_fn([panels, len], |(panel, z)| {
1006 let k = z / self.r;
1007 let x = z % self.r;
1008 let k = k + self.k_range.start;
1009 let mn = panel * self.r + x + self.mn_range.start;
1010 k < self.k_range.end.min(self.k) && mn < self.mn_range.end.min(self.mn)
1011 })
1012 }
1013
1014 fn check(&self) {
1015 let mut packer = self.packer();
1016 let mut reference = self.reference();
1017 let valid = self.valid();
1018 Zip::from(&mut packer).and(&valid).for_each(|p, v| *p = if *v { *p } else { -1 as _ });
1019 Zip::from(&mut reference)
1020 .and(&valid)
1021 .for_each(|p, v| *p = if *v { *p } else { -1 as _ });
1022 assert_eq!(packer, reference);
1023 }
1024 }
1025
1026 impl Arbitrary for PackProblem {
1027 type Parameters = ();
1028 type Strategy = BoxedStrategy<PackProblem>;
1029 fn arbitrary_with(_args: ()) -> Self::Strategy {
1030 (any::<bool>(), 1usize..9, 1usize..20, 1usize..20)
1031 .prop_flat_map(|(is_a, r, k, mn)| {
1032 (
1033 Just((is_a, r, k, mn)),
1034 sub_range_strat(0..k),
1035 sub_range_strat(0..mn),
1036 1usize..5,
1037 )
1038 })
1039 .prop_map(|((is_a, r, k, mn), k_range, mn_range, align_panel)| PackProblem {
1040 k,
1041 mn,
1042 is_a,
1043 r,
1044 k_range,
1045 mn_range,
1046 align_panel,
1047 })
1048 .boxed()
1049 }
1050 }
1051
1052 fn sub_range_strat(range: Range<usize>) -> BoxedStrategy<Range<usize>> {
1053 (0..range.len())
1054 .prop_flat_map(|cropped| (Just(cropped), 0..=cropped))
1055 .prop_map(move |(cropped, left)| range.start + left..range.end - (cropped - left))
1056 .boxed()
1057 }
1058
1059 proptest::proptest! {
1060 #[test]
1061 fn prop(pb in any::<PackProblem>()) {
1062 pb.check();
1063 }
1064
1065 #[test]
1066 fn subrange_prop(_range in sub_range_strat(0..20)) {
1067 }
1068
1069 }
1070 #[derive(Debug, Clone)]
1082 struct PackI8K4Problem {
1083 k: usize,
1084 mn: usize,
1085 r: usize,
1086 is_a: bool,
1090 }
1091
1092 impl PackI8K4Problem {
1093 fn logical(&self) -> Array2<i8> {
1095 Array2::from_shape_fn((self.k, self.mn), |(kk, m)| {
1096 (kk.wrapping_mul(31).wrapping_add(m.wrapping_mul(17)).wrapping_add(1)) as i8
1097 })
1098 }
1099
1100 fn panel_len(&self) -> usize {
1101 (self.k.div_ceil(4) * 4) * self.r
1102 }
1103
1104 fn reference(&self) -> Vec<i8> {
1106 let logical = self.logical();
1107 let r = self.r;
1108 let pl = self.panel_len();
1109 let panels = self.mn.div_ceil(r);
1110 let mut out = vec![0i8; panels * pl];
1111 for p in 0..panels {
1112 let pw = r.min(self.mn - p * r);
1113 for kk in 0..self.k {
1114 for lm in 0..pw {
1115 let m = p * r + lm;
1116 let off = p * pl + (kk / 4) * r * 4 + lm * 4 + (kk % 4);
1117 out[off] = logical[[kk, m]];
1118 }
1119 }
1120 }
1121 out
1122 }
1123
1124 fn pack_view_bytes(&self) -> Vec<i8> {
1126 let logical = self.logical();
1127 let packer = super::PackedI8K4::new(self.r);
1128 let (tensor, k_axis, mn_axis) = if self.is_a {
1129 let a = Array2::from_shape_fn((self.mn, self.k), |(m, kk)| logical[[kk, m]]);
1131 (a.into_tensor(), 1usize, 0usize)
1132 } else {
1133 (logical.clone().into_tensor(), 0usize, 1usize)
1134 };
1135 let packed = packer.pack_view(&tensor.view(), k_axis, mn_axis).unwrap();
1136 let pl = self.panel_len();
1137 let panels = self.mn.div_ceil(self.r);
1138 assert_eq!(packed.panels_count(), panels);
1139 assert_eq!(packed.k(), self.k);
1140 assert_eq!(packed.mn(), self.mn);
1141 let mut out = vec![0i8; panels * pl];
1142 unsafe {
1143 for p in 0..panels {
1144 let ptr = packed.panel_bytes(p, None).unwrap() as *const i8;
1145 std::ptr::copy_nonoverlapping(ptr, out.as_mut_ptr().add(p * pl), pl);
1146 }
1147 }
1148 out
1149 }
1150
1151 fn writer_bytes(&self) -> Vec<i8> {
1153 let logical = self.logical();
1154 let packer = super::PackedI8K4::new(self.r);
1155 let total = packer.len(self.k, self.mn);
1156 assert_eq!(total, self.mn.div_ceil(self.r) * self.panel_len());
1157 let mut buf = vec![0i8; total];
1158 {
1159 let mut w = packer.write_with_k_outer(buf.as_mut_ptr(), self.k, self.mn);
1160 for kk in 0..self.k {
1161 for m in 0..self.mn {
1162 super::PackingWriter::write(&mut w, logical[[kk, m]]);
1163 }
1164 }
1165 }
1166 buf
1167 }
1168
1169 fn check(&self) {
1170 let reference = self.reference();
1171 assert_eq!(
1172 self.pack_view_bytes(),
1173 reference,
1174 "pack_view disagrees with reference for {self:?}"
1175 );
1176 assert_eq!(
1177 self.writer_bytes(),
1178 reference,
1179 "write_with_k_outer disagrees with reference for {self:?}"
1180 );
1181 }
1182 }
1183
1184 impl Arbitrary for PackI8K4Problem {
1185 type Parameters = ();
1186 type Strategy = BoxedStrategy<PackI8K4Problem>;
1187 fn arbitrary_with(_: ()) -> Self::Strategy {
1188 (any::<bool>(), prop::sample::select(vec![4usize, 8, 16, 32]), 1usize..40, 1usize..40)
1190 .prop_map(|(is_a, r, k, mn)| PackI8K4Problem { k, mn, r, is_a })
1191 .boxed()
1192 }
1193 }
1194
1195 proptest::proptest! {
1196 #[test]
1197 fn pack_i8k4_prop(pb in any::<PackI8K4Problem>()) {
1198 pb.check();
1199 }
1200 }
1201
1202 fn k4(k: usize, mn: usize, r: usize, is_a: bool) -> PackI8K4Problem {
1203 PackI8K4Problem { k, mn, r, is_a }
1204 }
1205
1206 #[test]
1207 fn i8k4_smallest() {
1208 k4(1, 1, 4, false).check();
1209 k4(1, 1, 4, true).check();
1210 }
1211
1212 #[test]
1213 fn i8k4_exact_tile() {
1214 k4(4, 4, 4, false).check();
1216 k4(8, 32, 32, false).check();
1217 k4(8, 32, 32, true).check();
1218 }
1219
1220 #[test]
1221 fn i8k4_k_not_multiple_of_4() {
1222 for k in [1, 2, 3, 5, 6, 7, 9] {
1224 k4(k, 4, 4, false).check();
1225 k4(k, 7, 8, true).check();
1226 }
1227 }
1228
1229 #[test]
1230 fn i8k4_partial_last_panel() {
1231 k4(5, 7, 4, false).check();
1233 k4(5, 7, 4, true).check();
1234 k4(4, 33, 32, false).check();
1235 k4(4, 33, 32, true).check();
1236 k4(3, 1, 32, false).check();
1237 }
1238
1239 #[test]
1240 fn i8k4_single_wide_tile() {
1241 k4(7, 1, 32, false).check();
1243 k4(7, 5, 16, true).check();
1244 }
1245
1246 #[test]
1247 fn i8k4_many_panels() {
1248 k4(13, 100, 8, false).check();
1249 k4(13, 100, 8, true).check();
1250 k4(17, 65, 16, false).check();
1251 }
1252
1253 #[test]
1254 fn simple_b_1() {
1255 PackProblem {
1256 k: 2,
1257 mn: 1,
1258 is_a: false,
1259 r: 1,
1260 k_range: 0..2,
1261 mn_range: 0..1,
1262 align_panel: 1,
1263 }
1264 .check();
1265 }
1266
1267 #[test]
1268 fn simple_b_2() {
1269 PackProblem {
1270 k: 2,
1271 mn: 2,
1272 is_a: false,
1273 r: 1,
1274 k_range: 0..2,
1275 mn_range: 0..2,
1276 align_panel: 1,
1277 }
1278 .check()
1279 }
1280
1281 #[test]
1282 fn simple_b_3() {
1283 PackProblem {
1284 k: 2,
1285 mn: 1,
1286 is_a: false,
1287 r: 4,
1288 k_range: 0..2,
1289 mn_range: 0..1,
1290 align_panel: 1,
1291 }
1292 .check();
1293 }
1294
1295 #[test]
1296 fn simple_b_4() {
1297 PackProblem {
1298 k: 1,
1299 mn: 3,
1300 is_a: false,
1301 r: 2,
1302 k_range: 0..1,
1303 mn_range: 0..3,
1304 align_panel: 1,
1305 }
1306 .check();
1307 }
1308
1309 #[test]
1310 fn simple_a_1() {
1311 PackProblem {
1312 k: 2,
1313 mn: 2,
1314 is_a: true,
1315 r: 1,
1316 k_range: 0..2,
1317 mn_range: 0..2,
1318 align_panel: 1,
1319 }
1320 .check();
1321 }
1322
1323 #[test]
1324 fn simple_a_2() {
1325 PackProblem {
1326 k: 2,
1327 mn: 3,
1328 is_a: true,
1329 r: 2,
1330 k_range: 0..2,
1331 mn_range: 0..3,
1332 align_panel: 1,
1333 }
1334 .check();
1335 }
1336
1337 #[test]
1338 fn range_k_0() {
1339 PackProblem {
1340 k: 2,
1341 mn: 1,
1342 is_a: false,
1343 r: 1,
1344 k_range: 1..2,
1345 mn_range: 0..1,
1346 align_panel: 1,
1347 }
1348 .check();
1349 }
1350
1351 #[test]
1352 fn range_k_1() {
1353 PackProblem {
1354 k: 2,
1355 mn: 2,
1356 is_a: false,
1357 r: 1,
1358 k_range: 0..2,
1359 mn_range: 0..1,
1360 align_panel: 1,
1361 }
1362 .check();
1363 }
1364
1365 #[test]
1366 fn range_k_2() {
1367 PackProblem {
1368 k: 2,
1369 mn: 1,
1370 is_a: false,
1371 r: 6,
1372 k_range: 1..2,
1373 mn_range: 0..1,
1374 align_panel: 1,
1375 }
1376 .check();
1377 }
1378
1379 #[test]
1380 fn range_mn_0() {
1381 PackProblem {
1382 k: 1,
1383 mn: 2,
1384 is_a: false,
1385 r: 2,
1386 k_range: 0..1,
1387 mn_range: 0..1,
1388 align_panel: 1,
1389 }
1390 .check();
1391 }
1392
1393 #[test]
1394 fn range_b_4() {
1395 PackProblem {
1396 k: 1,
1397 mn: 2,
1398 is_a: false,
1399 r: 6,
1400 k_range: 0..1,
1401 mn_range: 1..2,
1402 align_panel: 1,
1403 }
1404 .check();
1405 }
1406
1407 #[test]
1408 fn range_b_5() {
1409 PackProblem {
1410 k: 1,
1411 mn: 7,
1412 is_a: false,
1413 r: 6,
1414 k_range: 0..1,
1415 mn_range: 1..7,
1416 align_panel: 1,
1417 }
1418 .check();
1419 }
1420
1421 #[test]
1422 fn align_a_1() {
1423 PackProblem {
1424 k: 2,
1425 mn: 2,
1426 is_a: true,
1427 r: 1,
1428 k_range: 0..1,
1429 mn_range: 0..2,
1430 align_panel: 2,
1431 }
1432 .check();
1433 }
1434
1435 #[test]
1436 fn align_b_1() {
1437 PackProblem {
1438 k: 1,
1439 mn: 1,
1440 is_a: false,
1441 r: 1,
1442 k_range: 0..1,
1443 mn_range: 0..1,
1444 align_panel: 2,
1445 }
1446 .check();
1447 }
1448
1449 #[test]
1450 fn align_b_2() {
1451 PackProblem {
1452 k: 3,
1453 mn: 1,
1454 is_a: false,
1455 r: 1,
1456 k_range: 0..3,
1457 mn_range: 0..1,
1458 align_panel: 2,
1459 }
1460 .check();
1461 }
1462
1463 #[test]
1464 fn align_b_3() {
1465 PackProblem {
1466 k: 1,
1467 mn: 1,
1468 is_a: false,
1469 r: 3,
1470 k_range: 0..1,
1471 mn_range: 0..1,
1472 align_panel: 2,
1473 }
1474 .check();
1475 }
1476
1477 #[test]
1478 fn align_b_4() {
1479 PackProblem {
1480 k: 2,
1481 mn: 1,
1482 is_a: false,
1483 r: 1,
1484 k_range: 0..1,
1485 mn_range: 0..1,
1486 align_panel: 2,
1487 }
1488 .check();
1489 }
1490
1491 #[test]
1492 fn align_b_5() {
1493 PackProblem {
1494 k: 1,
1495 mn: 5,
1496 is_a: false,
1497 r: 4,
1498 k_range: 0..1,
1499 mn_range: 0..5,
1500 align_panel: 3,
1501 }
1502 .check();
1503 }
1504}