Skip to main content

tract_linalg/frame/
pack.rs

1use std::alloc::Layout;
2use std::fmt::{Debug, Display};
3use std::marker::PhantomData;
4use std::ops::Range;
5use std::sync::Arc;
6use tract_data::internal::*;
7
8use crate::mmm::{EagerPackedInput, MMMInputFormat, MMMInputValue, PackedOpaqueFact};
9
10use crate::WeightType;
11
12#[derive(Clone, Eq, PartialEq, Hash)]
13pub struct PackedFormat {
14    pub dt: DatumType,
15    pub r: usize,
16    pub alignment_bytes: usize,
17    pub end_padding_record: usize,
18}
19
20impl MMMInputFormat for PackedFormat {
21    fn prepare_tensor(&self, t: &Tensor, k_axis: usize, mn_axis: usize) -> TractResult<Tensor> {
22        let packed = PackedFormat::pack_tensor(self, t, k_axis, mn_axis)?;
23        Ok(tensor0(Opaque(Arc::new(packed))))
24    }
25
26    fn prepare_one(
27        &self,
28        t: &Tensor,
29        k_axis: usize,
30        mn_axis: usize,
31    ) -> TractResult<Box<dyn MMMInputValue>> {
32        PackedFormat::pack_tensor(self, t, k_axis, mn_axis)
33    }
34
35    fn precursor(&self) -> WeightType {
36        WeightType::Plain(self.dt)
37    }
38
39    fn r(&self) -> usize {
40        self.r
41    }
42
43    fn k_alignment(&self) -> usize {
44        1
45    }
46
47    fn same_as(&self, other: &dyn MMMInputFormat) -> bool {
48        other.downcast_ref::<Self>().is_some_and(|other| self == other)
49    }
50
51    #[allow(clippy::collapsible_if)]
52    fn merge_with<'o, 'a: 'o, 'b: 'o>(
53        &'a self,
54        other: &'b dyn MMMInputFormat,
55    ) -> Option<&'o dyn MMMInputFormat> {
56        if let Some(other) = other.downcast_ref::<PackedFormat>() {
57            if self.r == other.r && self.dt == other.dt {
58                if self.alignment_bytes % other.alignment_bytes == 0
59                    && self.end_padding_record >= other.end_padding_record
60                {
61                    return Some(self);
62                }
63                if other.alignment_bytes % self.alignment_bytes == 0
64                    && other.end_padding_record >= self.end_padding_record
65                {
66                    return Some(other);
67                }
68            }
69        }
70        None
71    }
72
73    fn mem_size(&self, k: TDim, mn: TDim) -> TDim {
74        self.len(k, mn) * self.dt.size_of()
75    }
76
77    fn extract_at_mn_f16(
78        &self,
79        data: &EagerPackedInput,
80        mn: usize,
81        slice: &mut [f16],
82    ) -> TractResult<()> {
83        ensure!(data.format().same_as(self));
84        ensure!(self.len(data.k(), data.mn()) * self.dt.size_of() == data.packed.len());
85        unsafe {
86            let ptr = data.packed.as_ptr().add(
87                (self.single_panel_len(data.k()) * (mn / self.r) + mn % self.r) * self.dt.size_of(),
88            );
89            for (i, slot) in slice.iter_mut().enumerate() {
90                let ptr = ptr.add(i * self.dt.size_of() * self.r);
91                *slot = if self.dt == f16::datum_type() {
92                    *(ptr as *const f16)
93                } else if self.dt == f32::datum_type() {
94                    f16::from_f32(*(ptr as *const f32))
95                } else {
96                    bail!("Unexpected DT {:?}", self.dt)
97                }
98            }
99        }
100        Ok(())
101    }
102
103    fn extract_at_mn_f32(
104        &self,
105        data: &EagerPackedInput,
106        mn: usize,
107        slice: &mut [f32],
108    ) -> TractResult<()> {
109        ensure!(data.format().same_as(self));
110        ensure!(self.len(data.k(), data.mn()) * self.dt.size_of() == data.packed.len());
111        unsafe {
112            let ptr = data.packed.as_ptr().add(
113                (self.single_panel_len(data.k()) * (mn / self.r) + mn % self.r) * self.dt.size_of(),
114            );
115            for (i, slot) in slice.iter_mut().enumerate() {
116                let ptr = ptr.add(i * self.dt.size_of() * self.r);
117                *slot = if self.dt == f16::datum_type() {
118                    (*(ptr as *const f16)).to_f32()
119                } else if self.dt == f32::datum_type() {
120                    *(ptr as *const f32)
121                } else {
122                    bail!("Unexpected DT {:?}", self.dt)
123                }
124            }
125        }
126        Ok(())
127    }
128}
129
130impl Display for PackedFormat {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        write!(f, "Packed{:?}[{}]", self.dt, self.r)
133    }
134}
135
136impl Debug for PackedFormat {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        write!(
139            f,
140            "Packed{:?}[{}]@{}+{}",
141            self.dt, self.r, self.alignment_bytes, self.end_padding_record
142        )
143    }
144}
145
146impl PackedFormat {
147    pub const fn new(dt: DatumType, nr: usize, alignment_bytes: usize) -> PackedFormat {
148        PackedFormat { dt, r: nr, alignment_bytes, end_padding_record: 1 }
149    }
150
151    pub const fn with_end_padding_record(self, end_padding_record: usize) -> Self {
152        PackedFormat { end_padding_record, ..self }
153    }
154
155    #[inline]
156    pub fn align(self, alignment: usize) -> Self {
157        Self { alignment_bytes: alignment, ..self }
158    }
159
160    #[inline]
161    pub fn alignment(&self) -> usize {
162        self.alignment_bytes
163    }
164
165    #[inline]
166    pub fn panel_width(&self) -> usize {
167        self.r
168    }
169
170    #[inline]
171    pub fn len<D: DimLike>(&self, k: D, n: D) -> D {
172        n.divceil(self.r) * self.single_panel_len(k)
173    }
174
175    #[inline]
176    pub fn single_panel_len<D: DimLike>(&self, k: D) -> D {
177        ((k + self.end_padding_record) * self.r).divceil(self.alignment()) * self.alignment()
178    }
179
180    #[inline]
181    pub fn single_panel_layout(&self, k: usize, item_size: usize) -> Layout {
182        Layout::from_size_align(self.single_panel_len(k) * item_size, self.alignment()).unwrap()
183    }
184
185    pub fn pack_tensor(
186        &self,
187        t: &Tensor,
188        k_axis: usize,
189        mn_axis: usize,
190    ) -> TractResult<Box<dyn MMMInputValue>> {
191        ensure!(t.datum_type().is_copy());
192        ensure!(
193            t.datum_type().unquantized() == self.dt.unquantized(),
194            "Attempting to pack for {self} tensor {t:?}"
195        );
196        let k = t.shape()[k_axis];
197        let mn = t.shape()[mn_axis];
198        let packed_len = self.len(k, mn);
199        let panel_len = self.single_panel_len(k);
200        let panel_bytes = panel_len * t.datum_type().size_of();
201        let strides = t.strides();
202        unsafe {
203            let mut packed = Blob::new_for_size_and_align(
204                t.datum_type().size_of() * packed_len,
205                self.alignment_bytes,
206            );
207            if cfg!(debug_assertions) {
208                packed.as_bytes_mut().fill(0u8);
209            }
210            dispatch_copy!(Self::pack_t(t.datum_type())(
211                self,
212                packed.as_mut_ptr() as _,
213                t.as_ptr_unchecked(),
214                mn,
215                strides[k_axis],
216                strides[mn_axis],
217                0..k,
218                0..mn
219            ));
220            Ok(Box::new(EagerPackedInput {
221                fact: PackedOpaqueFact { format: Box::new(self.clone()), mn: mn.to_dim(), k },
222                packed: packed.into(),
223                panel_bytes,
224                mn,
225            }))
226        }
227    }
228
229    pub fn pack_tensor_view(
230        &self,
231        t: &TensorView,
232        k_axis: usize,
233        mn_axis: usize,
234    ) -> TractResult<Box<dyn MMMInputValue>> {
235        ensure!(
236            t.datum_type().unquantized() == self.dt.unquantized(),
237            "Attempting to pack for {self} tensor view {t:?}"
238        );
239        let k = t.shape()[k_axis];
240        let mn = t.shape()[mn_axis];
241        let packed_len = self.len(k, mn);
242        let panel_len = self.single_panel_len(k);
243        let panel_bytes = panel_len * t.datum_type().size_of();
244        let strides = t.strides();
245        unsafe {
246            let mut packed = Blob::new_for_size_and_align(
247                t.datum_type().size_of() * packed_len,
248                self.alignment_bytes,
249            );
250            if cfg!(debug_assertions) {
251                packed.as_bytes_mut().fill(0u8);
252            }
253            dispatch_copy!(Self::pack_t(t.datum_type())(
254                self,
255                packed.as_mut_ptr() as _,
256                t.as_ptr_unchecked(),
257                mn,
258                strides[k_axis],
259                strides[mn_axis],
260                0..k,
261                0..mn
262            ));
263            Ok(Box::new(EagerPackedInput {
264                fact: PackedOpaqueFact { format: Box::new(self.clone()), mn: mn.to_dim(), k },
265                packed: packed.into(),
266                panel_bytes,
267                mn,
268            }))
269        }
270    }
271
272    pub unsafe fn pack<'a, 'b>(
273        &self,
274        pb: impl std::borrow::BorrowMut<TensorView<'a>>,
275        b: impl std::borrow::Borrow<TensorView<'b>>,
276        k_axis: usize,
277        mn_axis: usize,
278    ) {
279        let k = b.borrow().shape()[k_axis];
280        let mn = b.borrow().shape()[mn_axis];
281        unsafe { self.pack_segment(pb, b, k_axis, mn_axis, 0..k, 0..mn) };
282    }
283
284
285    #[allow(clippy::too_many_arguments)]
286    #[rustfmt::skip]
287    pub unsafe fn pack_t<T: Datum + Copy>(
288        &self,
289        pb: *mut T,
290        b: *const T,
291        mn: usize,
292        k_stride: isize,
293        mn_stride: isize,
294        k_range: Range<usize>,
295        mn_range: Range<usize>,
296        ) { unsafe {
297        if k_range.len() == 0 || mn_range.len() == 0 {
298            return
299        }
300        if self.r == 1 && k_stride == 1 && mn == 1 {
301            pb.copy_from_nonoverlapping(b.add(k_range.start), k_range.len())
302        } else if mn_stride == 1 {
303            let size_of = T::datum_type().size_of();
304            let rbytes = self.r * size_of;
305            let mn_valid_end = mn_range.end.min(mn);
306            let mn_range_bytes = mn_range.start * size_of..mn_valid_end * size_of;
307            let k_stride_bytes = k_stride * size_of as isize;
308            let bb = b as *const u8;
309            let pbb = pb as *mut u8;
310            let panel_len = self.single_panel_len(k_range.len()) * size_of;
311            match rbytes {
312                16 => pack_mn_major::<[u8; 16]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
313                24 => pack_mn_major::<[u8; 24]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
314                32 => pack_mn_major::<[u8; 32]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
315                48 => pack_mn_major::<[u8; 48]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
316                64 => pack_mn_major::<[u8; 64]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
317                _ => {
318                    let mut packer = self.write_with_k_outer(pb, k_range.len(), mn_range.len());
319                    for k in k_range {
320                        for x in mn_range.start..mn_valid_end {
321                            packer.write(*b.offset(x as isize + k_stride * k as isize))
322                        }
323                        for _x in mn_valid_end..mn_range.end {
324                            packer.write(T::default())
325                        }
326                    }
327                }
328            }
329        } else if k_stride == 1 {
330            let mut packer = self.write_with_k_inner(pb, k_range.len(), mn);
331            let mn_valid_end = mn_range.end.min(mn);
332            for x in mn_range.start..mn_valid_end {
333                for k in k_range.clone() {
334                    packer.write(*b.offset(x as isize * mn_stride + k as isize))
335                }
336            }
337            // just ignore invalid mn_range
338        } else {
339            let mut packer = self.write_with_k_outer(pb, k_range.len(), mn);
340            let mn_valid_end = mn_range.end.min(mn);
341            for k in k_range {
342                for x in mn_range.start..mn_valid_end {
343                    packer.write(*b.offset(x as isize * mn_stride + k_stride * k as isize))
344                }
345                for _x in mn_valid_end..mn_range.end {
346                    packer.write(T::default())
347                }
348            }
349        }
350    }}
351
352    #[inline]
353    pub unsafe fn pack_segment<'a, 'b>(
354        &self,
355        mut pb: impl std::borrow::BorrowMut<TensorView<'a>>,
356        b: impl std::borrow::Borrow<TensorView<'b>>,
357        k_axis: usize,
358        mn_axis: usize,
359        k_range: Range<usize>,
360        mn_range: Range<usize>,
361    ) {
362        debug_assert!(pb.borrow().len() >= self.len(k_range.len(), mn_range.len()));
363        let pb = pb.borrow_mut();
364        let b = b.borrow();
365        let dt = pb.datum_type();
366        unsafe {
367            dispatch_copy!(Self::pack_t(dt)(
368                self,
369                pb.as_ptr_mut_unchecked(),
370                b.as_ptr_unchecked(),
371                b.shape()[mn_axis],
372                b.strides()[k_axis],
373                b.strides()[mn_axis],
374                k_range,
375                mn_range
376            ));
377        }
378    }
379
380    pub fn write_with_k_outer<'p, T: Copy + Debug>(
381        &self,
382        pb: *mut T,
383        k: usize,
384        mn: usize,
385    ) -> KOutWriter<'p, T> {
386        KOutWriter::new(pb, self.r, self.single_panel_len(k), mn, k)
387    }
388
389    pub fn write_single_panel_with_k_outer<'p, T: Copy + Debug>(
390        &self,
391        pb: *mut T,
392    ) -> KOutSinglePanelWriter<'p, T> {
393        KOutSinglePanelWriter::new(pb)
394    }
395
396    pub fn write_with_k_inner<'p, T: Copy + Debug>(
397        &self,
398        pb: *mut T,
399        k: usize,
400        mn: usize,
401    ) -> KInWriter<'p, T> {
402        let panel_len = self.single_panel_len(k);
403        KInWriter::new(pb, panel_len, self.r, mn, k)
404    }
405}
406
407pub trait PackingWriter<T: Copy> {
408    fn write(&mut self, t: T);
409}
410
411#[derive(Debug)]
412pub struct KOutSinglePanelWriter<'p, T>
413where
414    T: Copy + std::fmt::Debug,
415{
416    ptr: *mut T,
417    _phantom: PhantomData<&'p T>,
418}
419
420impl<'p, T> KOutSinglePanelWriter<'p, T>
421where
422    T: Copy + std::fmt::Debug,
423{
424    pub fn new(ptr: *mut T) -> KOutSinglePanelWriter<'p, T> {
425        KOutSinglePanelWriter { ptr, _phantom: PhantomData }
426    }
427}
428
429impl<T> PackingWriter<T> for KOutSinglePanelWriter<'_, T>
430where
431    T: Copy + std::fmt::Debug,
432{
433    #[inline(always)]
434    fn write(&mut self, t: T) {
435        unsafe {
436            *self.ptr = t;
437            self.ptr = self.ptr.offset(1);
438        }
439    }
440}
441
442#[derive(Debug)]
443pub struct KOutWriter<'p, T>
444where
445    T: Copy + std::fmt::Debug,
446{
447    ptr: *mut T,
448    panels: usize,
449    panel_width: usize,
450    last_panel_width: usize,
451    remain: usize,
452    current_panel: usize,
453    next_panel: isize,
454    next_lane: isize,
455    _phantom: PhantomData<&'p T>,
456}
457
458impl<'p, T> KOutWriter<'p, T>
459where
460    T: Copy + std::fmt::Debug,
461{
462    pub fn new(
463        ptr: *mut T,
464        panel_width: usize,
465        panel_len: usize,
466        mn: usize,
467        _k: usize,
468    ) -> KOutWriter<'p, T> {
469        let panels = mn.divceil(panel_width);
470        let last_panel_width = mn - (panels - 1) * panel_width;
471        KOutWriter {
472            ptr,
473            panels,
474            panel_width,
475            last_panel_width,
476            remain: if panels > 1 { panel_width } else { last_panel_width },
477            current_panel: 0,
478            next_panel: (panel_len - panel_width) as isize,
479            next_lane: (panel_width - last_panel_width) as isize
480                - (panel_len * (panels - 1)) as isize,
481            _phantom: PhantomData,
482        }
483    }
484}
485
486impl<T> PackingWriter<T> for KOutWriter<'_, T>
487where
488    T: Copy + std::fmt::Debug,
489{
490    #[inline(always)]
491    fn write(&mut self, t: T) {
492        unsafe {
493            *self.ptr = t;
494            self.remain -= 1;
495            self.ptr = self.ptr.offset(1);
496            if self.remain == 0 {
497                self.current_panel += 1;
498                if self.current_panel == self.panels {
499                    self.ptr = self.ptr.offset(self.next_lane);
500                    self.current_panel = 0;
501                } else {
502                    self.ptr = self.ptr.offset(self.next_panel);
503                }
504                if self.current_panel == self.panels - 1 {
505                    self.remain = self.last_panel_width;
506                } else {
507                    self.remain = self.panel_width;
508                }
509            }
510        }
511    }
512}
513
514#[derive(Debug)]
515pub struct KInWriter<'p, T>
516where
517    T: Copy + Debug,
518{
519    ptr: *mut T,
520    k: usize,
521    panels: usize,
522    panel_width: usize,
523    last_panel_width: usize,
524    remain_on_k: usize,
525    remain_on_mn: usize,
526    current_panel: usize,
527    next_mn_offset: isize,
528    next_panel_offset: isize,
529    _phantom: PhantomData<&'p T>,
530}
531
532impl<'p, T> KInWriter<'p, T>
533where
534    T: Copy + Debug,
535{
536    pub fn new(
537        ptr: *mut T,
538        panel_len: usize,
539        panel_width: usize,
540        mn: usize,
541        k: usize,
542    ) -> KInWriter<'p, T> {
543        let panels = mn.divceil(panel_width);
544        let last_panel_width = mn - (panels - 1) * panel_width;
545        KInWriter {
546            ptr,
547            k,
548            panels,
549            panel_width,
550            last_panel_width,
551            remain_on_k: k,
552            remain_on_mn: if panels == 1 { last_panel_width } else { panel_width },
553            current_panel: 0,
554            next_mn_offset: 1 - (k * panel_width) as isize,
555            next_panel_offset: panel_len as isize - (k * panel_width + panel_width - 1) as isize,
556            //                 ^ next panel     ^    ^ rewind left ^   ^ rewind up   ^
557            _phantom: PhantomData,
558        }
559    }
560}
561
562impl<T> PackingWriter<T> for KInWriter<'_, T>
563where
564    T: Copy + std::fmt::Debug,
565{
566    #[inline(always)]
567    fn write(&mut self, t: T) {
568        unsafe {
569            *self.ptr = t;
570            self.remain_on_k -= 1;
571            self.ptr = self.ptr.add(self.panel_width);
572            if self.remain_on_k == 0 {
573                self.remain_on_k = self.k;
574                self.remain_on_mn -= 1;
575                if self.remain_on_mn > 0 {
576                    self.ptr = self.ptr.offset(self.next_mn_offset);
577                } else {
578                    self.ptr = self.ptr.offset(self.next_panel_offset);
579                    self.current_panel += 1;
580                    if self.current_panel == self.panels - 1 {
581                        self.remain_on_mn = self.last_panel_width;
582                    } else {
583                        self.remain_on_mn = self.panel_width;
584                    }
585                }
586            }
587        }
588    }
589}
590
591#[inline(never)]
592unsafe fn pack_mn_major<Chunk: Copy>(
593    b: *const u8,
594    packed: *mut u8,
595    panel_len: usize,
596    k_stride_bytes: isize,
597    mn_range_bytes: Range<usize>,
598    k_range: Range<usize>,
599) {
600    unsafe {
601        let mnr = std::mem::size_of::<Chunk>();
602        let full_panes = mn_range_bytes.len() / mnr;
603        let partial_pane = mn_range_bytes.len() % mnr;
604        for k in 0..k_range.len() {
605            let mut p_row = packed.add(k * mnr);
606            let mut b_row = b.offset(
607                (k_range.start + k) as isize * k_stride_bytes + mn_range_bytes.start as isize,
608            );
609            for _ in 0..full_panes {
610                p_row.copy_from_nonoverlapping(b_row, mnr);
611                p_row = p_row.add(panel_len);
612                b_row = b_row.add(mnr);
613            }
614            if partial_pane > 0 {
615                p_row.copy_from_nonoverlapping(b_row, partial_pane);
616            }
617        }
618    }
619}
620
621pub trait Packing {
622    fn packing(r: usize) -> PackedFormat;
623}
624
625impl<D: Datum> Packing for D {
626    fn packing(r: usize) -> PackedFormat {
627        PackedFormat::new(Self::datum_type(), r, vector_size())
628    }
629}
630
631#[cfg(test)]
632mod test {
633    use std::ops::Range;
634
635    use proptest::prelude::*;
636    use tract_data::internal::num_integer::Integer;
637    use tract_data::internal::tract_ndarray::Zip;
638    use tract_data::internal::*;
639    use tract_ndarray::prelude::*;
640
641    #[derive(Debug)]
642    struct PackProblem {
643        k: usize,
644        mn: usize,
645        is_a: bool,
646        r: usize,
647        k_range: Range<usize>,
648        mn_range: Range<usize>,
649        align_panel: usize,
650    }
651
652    impl PackProblem {
653        fn input(&self) -> Array2<u32> {
654            let shape = if self.is_a { (self.mn, self.k) } else { (self.k, self.mn) };
655            let data = (0..(self.k * self.mn) as u32).collect();
656            Array2::from_shape_vec(shape, data).unwrap()
657        }
658
659        fn packer(&self) -> Array2<u32> {
660            let panels = self.mn_range.len().divceil(self.r);
661            let packer = super::PackedFormat::new(u32::datum_type(), self.r, self.align_panel)
662                .with_end_padding_record(0);
663            let input = self.input().into_tensor();
664            let panel_len = packer.single_panel_len(self.k_range.len());
665            let mut output =
666                Tensor::zero::<u32>(&[packer.len(self.k_range.len(), self.mn_range.len())])
667                    .unwrap();
668            unsafe {
669                packer.pack_segment(
670                    output.view_mut(),
671                    input.view(),
672                    self.is_a as usize,
673                    !self.is_a as usize,
674                    self.k_range.clone(),
675                    self.mn_range.clone(),
676                )
677            };
678            output
679                .into_dense_array::<u32>()
680                .unwrap()
681                .into_shape_with_order((panels, panel_len))
682                .unwrap()
683        }
684
685        fn reference(&self) -> Array2<u32> {
686            let input = self.input();
687            let panels = self.mn_range.len().divceil(self.r);
688            let len = Integer::next_multiple_of(&(self.k_range.len() * self.r), &self.align_panel);
689            Array2::from_shape_fn([panels, len], |(panel, z)| {
690                let k = z / self.r;
691                let x = z % self.r;
692                let mn = panel * self.r + x + self.mn_range.start;
693                let k = k + self.k_range.start;
694                let coords = if self.is_a { (mn, k) } else { (k, mn) };
695                *input.get(coords).unwrap_or(&0)
696            })
697        }
698
699        fn valid(&self) -> Array2<bool> {
700            let panels = self.mn_range.len().divceil(self.r);
701            let len = Integer::next_multiple_of(&(self.k_range.len() * self.r), &self.align_panel);
702            Array2::from_shape_fn([panels, len], |(panel, z)| {
703                let k = z / self.r;
704                let x = z % self.r;
705                let k = k + self.k_range.start;
706                let mn = panel * self.r + x + self.mn_range.start;
707                k < self.k_range.end.min(self.k) && mn < self.mn_range.end.min(self.mn)
708            })
709        }
710
711        fn check(&self) {
712            let mut packer = self.packer();
713            let mut reference = self.reference();
714            let valid = self.valid();
715            Zip::from(&mut packer).and(&valid).for_each(|p, v| *p = if *v { *p } else { -1 as _ });
716            Zip::from(&mut reference)
717                .and(&valid)
718                .for_each(|p, v| *p = if *v { *p } else { -1 as _ });
719            assert_eq!(packer, reference);
720        }
721    }
722
723    impl Arbitrary for PackProblem {
724        type Parameters = ();
725        type Strategy = BoxedStrategy<PackProblem>;
726        fn arbitrary_with(_args: ()) -> Self::Strategy {
727            (any::<bool>(), 1usize..9, 1usize..20, 1usize..20)
728                .prop_flat_map(|(is_a, r, k, mn)| {
729                    (
730                        Just((is_a, r, k, mn)),
731                        sub_range_strat(0..k),
732                        sub_range_strat(0..mn),
733                        1usize..5,
734                    )
735                })
736                .prop_map(|((is_a, r, k, mn), k_range, mn_range, align_panel)| PackProblem {
737                    k,
738                    mn,
739                    is_a,
740                    r,
741                    k_range,
742                    mn_range,
743                    align_panel,
744                })
745                .boxed()
746        }
747    }
748
749    fn sub_range_strat(range: Range<usize>) -> BoxedStrategy<Range<usize>> {
750        (0..range.len())
751            .prop_flat_map(|cropped| (Just(cropped), 0..=cropped))
752            .prop_map(move |(cropped, left)| range.start + left..range.end - (cropped - left))
753            .boxed()
754    }
755
756    proptest::proptest! {
757        #[test]
758        fn prop(pb in any::<PackProblem>()) {
759            pb.check();
760        }
761
762        #[test]
763        fn subrange_prop(_range in sub_range_strat(0..20)) {
764        }
765
766    }
767
768    #[test]
769    fn simple_b_1() {
770        PackProblem {
771            k: 2,
772            mn: 1,
773            is_a: false,
774            r: 1,
775            k_range: 0..2,
776            mn_range: 0..1,
777            align_panel: 1,
778        }
779        .check();
780    }
781
782    #[test]
783    fn simple_b_2() {
784        PackProblem {
785            k: 2,
786            mn: 2,
787            is_a: false,
788            r: 1,
789            k_range: 0..2,
790            mn_range: 0..2,
791            align_panel: 1,
792        }
793        .check()
794    }
795
796    #[test]
797    fn simple_b_3() {
798        PackProblem {
799            k: 2,
800            mn: 1,
801            is_a: false,
802            r: 4,
803            k_range: 0..2,
804            mn_range: 0..1,
805            align_panel: 1,
806        }
807        .check();
808    }
809
810    #[test]
811    fn simple_b_4() {
812        PackProblem {
813            k: 1,
814            mn: 3,
815            is_a: false,
816            r: 2,
817            k_range: 0..1,
818            mn_range: 0..3,
819            align_panel: 1,
820        }
821        .check();
822    }
823
824    #[test]
825    fn simple_a_1() {
826        PackProblem {
827            k: 2,
828            mn: 2,
829            is_a: true,
830            r: 1,
831            k_range: 0..2,
832            mn_range: 0..2,
833            align_panel: 1,
834        }
835        .check();
836    }
837
838    #[test]
839    fn simple_a_2() {
840        PackProblem {
841            k: 2,
842            mn: 3,
843            is_a: true,
844            r: 2,
845            k_range: 0..2,
846            mn_range: 0..3,
847            align_panel: 1,
848        }
849        .check();
850    }
851
852    #[test]
853    fn range_k_0() {
854        PackProblem {
855            k: 2,
856            mn: 1,
857            is_a: false,
858            r: 1,
859            k_range: 1..2,
860            mn_range: 0..1,
861            align_panel: 1,
862        }
863        .check();
864    }
865
866    #[test]
867    fn range_k_1() {
868        PackProblem {
869            k: 2,
870            mn: 2,
871            is_a: false,
872            r: 1,
873            k_range: 0..2,
874            mn_range: 0..1,
875            align_panel: 1,
876        }
877        .check();
878    }
879
880    #[test]
881    fn range_k_2() {
882        PackProblem {
883            k: 2,
884            mn: 1,
885            is_a: false,
886            r: 6,
887            k_range: 1..2,
888            mn_range: 0..1,
889            align_panel: 1,
890        }
891        .check();
892    }
893
894    #[test]
895    fn range_mn_0() {
896        PackProblem {
897            k: 1,
898            mn: 2,
899            is_a: false,
900            r: 2,
901            k_range: 0..1,
902            mn_range: 0..1,
903            align_panel: 1,
904        }
905        .check();
906    }
907
908    #[test]
909    fn range_b_4() {
910        PackProblem {
911            k: 1,
912            mn: 2,
913            is_a: false,
914            r: 6,
915            k_range: 0..1,
916            mn_range: 1..2,
917            align_panel: 1,
918        }
919        .check();
920    }
921
922    #[test]
923    fn range_b_5() {
924        PackProblem {
925            k: 1,
926            mn: 7,
927            is_a: false,
928            r: 6,
929            k_range: 0..1,
930            mn_range: 1..7,
931            align_panel: 1,
932        }
933        .check();
934    }
935
936    #[test]
937    fn align_a_1() {
938        PackProblem {
939            k: 2,
940            mn: 2,
941            is_a: true,
942            r: 1,
943            k_range: 0..1,
944            mn_range: 0..2,
945            align_panel: 2,
946        }
947        .check();
948    }
949
950    #[test]
951    fn align_b_1() {
952        PackProblem {
953            k: 1,
954            mn: 1,
955            is_a: false,
956            r: 1,
957            k_range: 0..1,
958            mn_range: 0..1,
959            align_panel: 2,
960        }
961        .check();
962    }
963
964    #[test]
965    fn align_b_2() {
966        PackProblem {
967            k: 3,
968            mn: 1,
969            is_a: false,
970            r: 1,
971            k_range: 0..3,
972            mn_range: 0..1,
973            align_panel: 2,
974        }
975        .check();
976    }
977
978    #[test]
979    fn align_b_3() {
980        PackProblem {
981            k: 1,
982            mn: 1,
983            is_a: false,
984            r: 3,
985            k_range: 0..1,
986            mn_range: 0..1,
987            align_panel: 2,
988        }
989        .check();
990    }
991
992    #[test]
993    fn align_b_4() {
994        PackProblem {
995            k: 2,
996            mn: 1,
997            is_a: false,
998            r: 1,
999            k_range: 0..1,
1000            mn_range: 0..1,
1001            align_panel: 2,
1002        }
1003        .check();
1004    }
1005
1006    #[test]
1007    fn align_b_5() {
1008        PackProblem {
1009            k: 1,
1010            mn: 5,
1011            is_a: false,
1012            r: 4,
1013            k_range: 0..1,
1014            mn_range: 0..5,
1015            align_panel: 3,
1016        }
1017        .check();
1018    }
1019}