tract_linalg/frame/
pack.rs

1use std::alloc::Layout;
2use std::fmt::{Debug, Display};
3use std::marker::PhantomData;
4use std::ops::Range;
5use std::sync::Arc;
6use tract_data::internal::*;
7
8use crate::mmm::{EagerPackedInput, MMMInputFormat, MMMInputValue, PackedOpaqueFact};
9
10use crate::WeightType;
11
12#[derive(Clone, Eq, PartialEq, Hash)]
13pub struct PackedFormat {
14    pub dt: DatumType,
15    pub r: usize,
16    pub alignment_bytes: usize,
17    pub end_padding_record: usize,
18}
19
20impl MMMInputFormat for PackedFormat {
21    fn prepare_tensor(&self, t: &Tensor, k_axis: usize, mn_axis: usize) -> TractResult<Tensor> {
22        let packed = PackedFormat::pack_tensor(self, t, k_axis, mn_axis)?;
23        Ok(tensor0(Opaque(Arc::new(packed))))
24    }
25
26    fn prepare_one(
27        &self,
28        t: &Tensor,
29        k_axis: usize,
30        mn_axis: usize,
31    ) -> TractResult<Box<dyn MMMInputValue>> {
32        PackedFormat::pack_tensor(self, t, k_axis, mn_axis)
33    }
34
35    fn precursor(&self) -> WeightType {
36        WeightType::Plain(self.dt)
37    }
38
39    fn r(&self) -> usize {
40        self.r
41    }
42
43    fn k_alignment(&self) -> usize {
44        1
45    }
46
47    fn same_as(&self, other: &dyn MMMInputFormat) -> bool {
48        other.downcast_ref::<Self>().is_some_and(|other| self == other)
49    }
50
51    fn mem_size(&self, k: TDim, mn: TDim) -> TDim {
52        self.len(k, mn) * self.dt.size_of()
53    }
54
55    fn extract_at_mn_f16(
56        &self,
57        data: &EagerPackedInput,
58        mn: usize,
59        slice: &mut [f16],
60    ) -> TractResult<()> {
61        ensure!(data.format().same_as(self));
62        ensure!(self.len(data.k(), data.mn()) * self.dt.size_of() == data.packed.len());
63        unsafe {
64            let ptr = data.packed.as_ptr().add(
65                (self.single_panel_len(data.k()) * (mn / self.r) + mn % self.r) * self.dt.size_of(),
66            );
67            for (i, slot) in slice.iter_mut().enumerate() {
68                let ptr = ptr.add(i * self.dt.size_of() * self.r);
69                *slot = if self.dt == f16::datum_type() {
70                    *(ptr as *const f16)
71                } else if self.dt == f32::datum_type() {
72                    f16::from_f32(*(ptr as *const f32))
73                } else {
74                    bail!("Unexpected DT {:?}", self.dt)
75                }
76            }
77        }
78        Ok(())
79    }
80
81    fn extract_at_mn_f32(
82        &self,
83        data: &EagerPackedInput,
84        mn: usize,
85        slice: &mut [f32],
86    ) -> TractResult<()> {
87        ensure!(data.format().same_as(self));
88        ensure!(self.len(data.k(), data.mn()) * self.dt.size_of() == data.packed.len());
89        unsafe {
90            let ptr = data.packed.as_ptr().add(
91                (self.single_panel_len(data.k()) * (mn / self.r) + mn % self.r) * self.dt.size_of(),
92            );
93            for (i, slot) in slice.iter_mut().enumerate() {
94                let ptr = ptr.add(i * self.dt.size_of() * self.r);
95                *slot = if self.dt == f16::datum_type() {
96                    (*(ptr as *const f16)).to_f32()
97                } else if self.dt == f32::datum_type() {
98                    *(ptr as *const f32)
99                } else {
100                    bail!("Unexpected DT {:?}", self.dt)
101                }
102            }
103        }
104        Ok(())
105    }
106}
107
108impl Display for PackedFormat {
109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        write!(f, "Packed{:?}[{}]", self.dt, self.r)
111    }
112}
113
114impl Debug for PackedFormat {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        write!(
117            f,
118            "Packed{:?}[{}]@{}+{}",
119            self.dt, self.r, self.alignment_bytes, self.end_padding_record
120        )
121    }
122}
123
124impl PackedFormat {
125    pub const fn new(dt: DatumType, nr: usize, alignment_bytes: usize) -> PackedFormat {
126        PackedFormat { dt, r: nr, alignment_bytes, end_padding_record: 1 }
127    }
128
129    pub const fn with_end_padding_record(self, end_padding_record: usize) -> Self {
130        PackedFormat { end_padding_record, ..self }
131    }
132
133    #[inline]
134    pub fn align(self, alignment: usize) -> Self {
135        Self { alignment_bytes: alignment, ..self }
136    }
137
138    #[inline]
139    pub fn alignment(&self) -> usize {
140        self.alignment_bytes
141    }
142
143    #[inline]
144    pub fn panel_width(&self) -> usize {
145        self.r
146    }
147
148    #[inline]
149    pub fn len<D: DimLike>(&self, k: D, n: D) -> D {
150        n.divceil(self.r) * self.single_panel_len(k)
151    }
152
153    #[inline]
154    pub fn single_panel_len<D: DimLike>(&self, k: D) -> D {
155        ((k + self.end_padding_record) * self.r).divceil(self.alignment()) * self.alignment()
156    }
157
158    #[inline]
159    pub fn single_panel_layout(&self, k: usize, item_size: usize) -> Layout {
160        assert!(k > 0);
161        Layout::from_size_align(self.single_panel_len(k) * item_size, self.alignment()).unwrap()
162    }
163
164    pub fn pack_tensor(
165        &self,
166        t: &Tensor,
167        k_axis: usize,
168        mn_axis: usize,
169    ) -> TractResult<Box<dyn MMMInputValue>> {
170        ensure!(t.datum_type().is_copy());
171        ensure!(
172            t.datum_type().unquantized() == self.dt.unquantized(),
173            "Attempting to pack for {self} tensor {t:?}"
174        );
175        let k = t.shape()[k_axis];
176        let mn = t.shape()[mn_axis];
177        let packed_len = self.len(k, mn);
178        let panel_len = self.single_panel_len(k);
179        let panel_bytes = panel_len * t.datum_type().size_of();
180        let strides = t.strides();
181        unsafe {
182            let mut packed = Blob::new_for_size_and_align(
183                t.datum_type().size_of() * packed_len,
184                self.alignment_bytes,
185            );
186            if cfg!(debug_assertions) {
187                packed.as_bytes_mut().fill(0u8);
188            }
189            dispatch_copy!(Self::pack_t(t.datum_type())(
190                self,
191                packed.as_mut_ptr() as _,
192                t.as_ptr_unchecked(),
193                mn,
194                strides[k_axis],
195                strides[mn_axis],
196                0..k,
197                0..mn
198            ));
199            Ok(Box::new(EagerPackedInput {
200                fact: PackedOpaqueFact { format: Box::new(self.clone()), mn: mn.to_dim(), k },
201                packed: packed.into(),
202                panel_bytes,
203                mn,
204            }))
205        }
206    }
207
208    pub fn pack_tensor_view(
209        &self,
210        t: &TensorView,
211        k_axis: usize,
212        mn_axis: usize,
213    ) -> TractResult<Box<dyn MMMInputValue>> {
214        ensure!(
215            t.datum_type().unquantized() == self.dt.unquantized(),
216            "Attempting to pack for {self} tensor view {t:?}"
217        );
218        let k = t.shape()[k_axis];
219        let mn = t.shape()[mn_axis];
220        let packed_len = self.len(k, mn);
221        let panel_len = self.single_panel_len(k);
222        let panel_bytes = panel_len * t.datum_type().size_of();
223        let strides = t.strides();
224        unsafe {
225            let mut packed = Blob::new_for_size_and_align(
226                t.datum_type().size_of() * packed_len,
227                self.alignment_bytes,
228            );
229            if cfg!(debug_assertions) {
230                packed.as_bytes_mut().fill(0u8);
231            }
232            dispatch_copy!(Self::pack_t(t.datum_type())(
233                self,
234                packed.as_mut_ptr() as _,
235                t.as_ptr_unchecked(),
236                mn,
237                strides[k_axis],
238                strides[mn_axis],
239                0..k,
240                0..mn
241            ));
242            Ok(Box::new(EagerPackedInput {
243                fact: PackedOpaqueFact { format: Box::new(self.clone()), mn: mn.to_dim(), k },
244                packed: packed.into(),
245                panel_bytes,
246                mn,
247            }))
248        }
249    }
250
251    pub unsafe fn pack<'a, 'b>(
252        &self,
253        pb: impl std::borrow::BorrowMut<TensorView<'a>>,
254        b: impl std::borrow::Borrow<TensorView<'b>>,
255        k_axis: usize,
256        mn_axis: usize,
257    ) {
258        let k = b.borrow().shape()[k_axis];
259        let mn = b.borrow().shape()[mn_axis];
260        self.pack_segment(pb, b, k_axis, mn_axis, 0..k, 0..mn);
261    }
262
263
264    #[allow(clippy::too_many_arguments)]
265    #[rustfmt::skip]
266    pub unsafe fn pack_t<T: Datum + Copy>(
267        &self,
268        pb: *mut T,
269        b: *const T,
270        mn: usize,
271        k_stride: isize,
272        mn_stride: isize,
273        k_range: Range<usize>,
274        mn_range: Range<usize>,
275        ) {
276        if k_range.len() == 0 || mn_range.len() == 0 {
277            return
278        }
279        if self.r == 1 && k_stride == 1 && mn == 1 {
280            pb.copy_from_nonoverlapping(b.add(k_range.start), k_range.len())
281        } else if mn_stride == 1 {
282            let size_of = T::datum_type().size_of();
283            let rbytes = self.r * size_of;
284            let mn_valid_end = mn_range.end.min(mn);
285            let mn_range_bytes = mn_range.start * size_of..mn_valid_end * size_of;
286            let k_stride_bytes = k_stride * size_of as isize;
287            let bb = b as *const u8;
288            let pbb = pb as *mut u8;
289            let panel_len = self.single_panel_len(k_range.len()) * size_of;
290            match rbytes {
291                16 => pack_mn_major::<[u8; 16]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
292                24 => pack_mn_major::<[u8; 24]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
293                32 => pack_mn_major::<[u8; 32]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
294                48 => pack_mn_major::<[u8; 48]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
295                64 => pack_mn_major::<[u8; 64]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
296                _ => {
297                    let mut packer = self.write_with_k_outer(pb, k_range.len(), mn_range.len());
298                    for k in k_range {
299                        for x in mn_range.start..mn_valid_end {
300                            packer.write(*b.offset(x as isize + k_stride * k as isize))
301                        }
302                        for _x in mn_valid_end..mn_range.end {
303                            packer.write(T::default())
304                        }
305                    }
306                }
307            }
308        } else if k_stride == 1 {
309            let mut packer = self.write_with_k_inner(pb, k_range.len(), mn);
310            let mn_valid_end = mn_range.end.min(mn);
311            for x in mn_range.start..mn_valid_end {
312                for k in k_range.clone() {
313                    packer.write(*b.offset(x as isize * mn_stride + k as isize))
314                }
315            }
316            // just ignore invalid mn_range
317        } else {
318            let mut packer = self.write_with_k_outer(pb, k_range.len(), mn);
319            let mn_valid_end = mn_range.end.min(mn);
320            for k in k_range {
321                for x in mn_range.start..mn_valid_end {
322                    packer.write(*b.offset(x as isize * mn_stride + k_stride * k as isize))
323                }
324                for _x in mn_valid_end..mn_range.end {
325                    packer.write(T::default())
326                }
327            }
328        }
329    }
330
331    #[inline]
332    pub unsafe fn pack_segment<'a, 'b>(
333        &self,
334        mut pb: impl std::borrow::BorrowMut<TensorView<'a>>,
335        b: impl std::borrow::Borrow<TensorView<'b>>,
336        k_axis: usize,
337        mn_axis: usize,
338        k_range: Range<usize>,
339        mn_range: Range<usize>,
340    ) {
341        debug_assert!(pb.borrow().len() >= self.len(k_range.len(), mn_range.len()));
342        let pb = pb.borrow_mut();
343        let b = b.borrow();
344        let dt = pb.datum_type();
345        dispatch_copy!(Self::pack_t(dt)(
346            self,
347            pb.as_ptr_mut_unchecked(),
348            b.as_ptr_unchecked(),
349            b.shape()[mn_axis],
350            b.strides()[k_axis],
351            b.strides()[mn_axis],
352            k_range,
353            mn_range
354        ));
355    }
356
357    pub fn write_with_k_outer<'p, T: Copy + Debug>(
358        &self,
359        pb: *mut T,
360        k: usize,
361        mn: usize,
362    ) -> KOutWriter<'p, T> {
363        KOutWriter::new(pb, self.r, self.single_panel_len(k), mn, k)
364    }
365
366    pub fn write_single_panel_with_k_outer<'p, T: Copy + Debug>(
367        &self,
368        pb: *mut T,
369    ) -> KOutSinglePanelWriter<'p, T> {
370        KOutSinglePanelWriter::new(pb)
371    }
372
373    pub fn write_with_k_inner<'p, T: Copy + Debug>(
374        &self,
375        pb: *mut T,
376        k: usize,
377        mn: usize,
378    ) -> KInWriter<'p, T> {
379        let panel_len = self.single_panel_len(k);
380        KInWriter::new(pb, panel_len, self.r, mn, k)
381    }
382}
383
384pub trait PackingWriter<T: Copy> {
385    fn write(&mut self, t: T);
386}
387
388#[derive(Debug)]
389pub struct KOutSinglePanelWriter<'p, T>
390where
391    T: Copy + std::fmt::Debug,
392{
393    ptr: *mut T,
394    _phantom: PhantomData<&'p T>,
395}
396
397impl<'p, T> KOutSinglePanelWriter<'p, T>
398where
399    T: Copy + std::fmt::Debug,
400{
401    pub fn new(ptr: *mut T) -> KOutSinglePanelWriter<'p, T> {
402        KOutSinglePanelWriter { ptr, _phantom: PhantomData }
403    }
404}
405
406impl<T> PackingWriter<T> for KOutSinglePanelWriter<'_, T>
407where
408    T: Copy + std::fmt::Debug,
409{
410    #[inline(always)]
411    fn write(&mut self, t: T) {
412        unsafe {
413            *self.ptr = t;
414            self.ptr = self.ptr.offset(1);
415        }
416    }
417}
418
419#[derive(Debug)]
420pub struct KOutWriter<'p, T>
421where
422    T: Copy + std::fmt::Debug,
423{
424    ptr: *mut T,
425    panels: usize,
426    panel_width: usize,
427    last_panel_width: usize,
428    remain: usize,
429    current_panel: usize,
430    next_panel: isize,
431    next_lane: isize,
432    _phantom: PhantomData<&'p T>,
433}
434
435impl<'p, T> KOutWriter<'p, T>
436where
437    T: Copy + std::fmt::Debug,
438{
439    pub fn new(
440        ptr: *mut T,
441        panel_width: usize,
442        panel_len: usize,
443        mn: usize,
444        _k: usize,
445    ) -> KOutWriter<'p, T> {
446        let panels = mn.divceil(panel_width);
447        let last_panel_width = mn - (panels - 1) * panel_width;
448        KOutWriter {
449            ptr,
450            panels,
451            panel_width,
452            last_panel_width,
453            remain: if panels > 1 { panel_width } else { last_panel_width },
454            current_panel: 0,
455            next_panel: (panel_len - panel_width) as isize,
456            next_lane: (panel_width - last_panel_width) as isize
457                - (panel_len * (panels - 1)) as isize,
458            _phantom: PhantomData,
459        }
460    }
461}
462
463impl<T> PackingWriter<T> for KOutWriter<'_, T>
464where
465    T: Copy + std::fmt::Debug,
466{
467    #[inline(always)]
468    fn write(&mut self, t: T) {
469        unsafe {
470            *self.ptr = t;
471            self.remain -= 1;
472            self.ptr = self.ptr.offset(1);
473            if self.remain == 0 {
474                self.current_panel += 1;
475                if self.current_panel == self.panels {
476                    self.ptr = self.ptr.offset(self.next_lane);
477                    self.current_panel = 0;
478                } else {
479                    self.ptr = self.ptr.offset(self.next_panel);
480                }
481                if self.current_panel == self.panels - 1 {
482                    self.remain = self.last_panel_width;
483                } else {
484                    self.remain = self.panel_width;
485                }
486            }
487        }
488    }
489}
490
491#[derive(Debug)]
492pub struct KInWriter<'p, T>
493where
494    T: Copy + Debug,
495{
496    ptr: *mut T,
497    k: usize,
498    panels: usize,
499    panel_width: usize,
500    last_panel_width: usize,
501    remain_on_k: usize,
502    remain_on_mn: usize,
503    current_panel: usize,
504    next_mn_offset: isize,
505    next_panel_offset: isize,
506    _phantom: PhantomData<&'p T>,
507}
508
509impl<'p, T> KInWriter<'p, T>
510where
511    T: Copy + Debug,
512{
513    pub fn new(
514        ptr: *mut T,
515        panel_len: usize,
516        panel_width: usize,
517        mn: usize,
518        k: usize,
519    ) -> KInWriter<'p, T> {
520        let panels = mn.divceil(panel_width);
521        let last_panel_width = mn - (panels - 1) * panel_width;
522        KInWriter {
523            ptr,
524            k,
525            panels,
526            panel_width,
527            last_panel_width,
528            remain_on_k: k,
529            remain_on_mn: if panels == 1 { last_panel_width } else { panel_width },
530            current_panel: 0,
531            next_mn_offset: 1 - (k * panel_width) as isize,
532            next_panel_offset: panel_len as isize - (k * panel_width + panel_width - 1) as isize,
533            //                 ^ next panel     ^    ^ rewind left ^   ^ rewind up   ^
534            _phantom: PhantomData,
535        }
536    }
537}
538
539impl<T> PackingWriter<T> for KInWriter<'_, T>
540where
541    T: Copy + std::fmt::Debug,
542{
543    #[inline(always)]
544    fn write(&mut self, t: T) {
545        unsafe {
546            *self.ptr = t;
547            self.remain_on_k -= 1;
548            self.ptr = self.ptr.add(self.panel_width);
549            if self.remain_on_k == 0 {
550                self.remain_on_k = self.k;
551                self.remain_on_mn -= 1;
552                if self.remain_on_mn > 0 {
553                    self.ptr = self.ptr.offset(self.next_mn_offset);
554                } else {
555                    self.ptr = self.ptr.offset(self.next_panel_offset);
556                    self.current_panel += 1;
557                    if self.current_panel == self.panels - 1 {
558                        self.remain_on_mn = self.last_panel_width;
559                    } else {
560                        self.remain_on_mn = self.panel_width;
561                    }
562                }
563            }
564        }
565    }
566}
567
568#[inline(never)]
569unsafe fn pack_mn_major<Chunk: Copy>(
570    b: *const u8,
571    packed: *mut u8,
572    panel_len: usize,
573    k_stride_bytes: isize,
574    mn_range_bytes: Range<usize>,
575    k_range: Range<usize>,
576) {
577    let mnr = std::mem::size_of::<Chunk>();
578    let full_panes = mn_range_bytes.len() / mnr;
579    let partial_pane = mn_range_bytes.len() % mnr;
580    for k in 0..k_range.len() {
581        let mut p_row = packed.add(k * mnr);
582        let mut b_row =
583            b.offset((k_range.start + k) as isize * k_stride_bytes + mn_range_bytes.start as isize);
584        for _ in 0..full_panes {
585            p_row.copy_from_nonoverlapping(b_row, mnr);
586            p_row = p_row.add(panel_len);
587            b_row = b_row.add(mnr);
588        }
589        if partial_pane > 0 {
590            p_row.copy_from_nonoverlapping(b_row, partial_pane);
591        }
592    }
593}
594
595pub trait Packing {
596    fn packing(r: usize) -> PackedFormat;
597}
598
599impl<D: Datum> Packing for D {
600    fn packing(r: usize) -> PackedFormat {
601        PackedFormat::new(Self::datum_type(), r, vector_size())
602    }
603}
604
605#[cfg(test)]
606mod test {
607    use std::ops::Range;
608
609    use proptest::prelude::*;
610    use tract_data::internal::num_integer::Integer;
611    use tract_data::internal::tract_ndarray::Zip;
612    use tract_data::internal::*;
613    use tract_ndarray::prelude::*;
614
615    #[derive(Debug)]
616    struct PackProblem {
617        k: usize,
618        mn: usize,
619        is_a: bool,
620        r: usize,
621        k_range: Range<usize>,
622        mn_range: Range<usize>,
623        align_panel: usize,
624    }
625
626    impl PackProblem {
627        fn input(&self) -> Array2<u32> {
628            let shape = if self.is_a { (self.mn, self.k) } else { (self.k, self.mn) };
629            let data = (0..(self.k * self.mn) as u32).collect();
630            Array2::from_shape_vec(shape, data).unwrap()
631        }
632
633        fn packer(&self) -> Array2<u32> {
634            let panels = self.mn_range.len().divceil(self.r);
635            let packer = super::PackedFormat::new(u32::datum_type(), self.r, self.align_panel)
636                .with_end_padding_record(0);
637            let input = self.input().into_tensor();
638            let panel_len = packer.single_panel_len(self.k_range.len());
639            let mut output =
640                Tensor::zero::<u32>(&[packer.len(self.k_range.len(), self.mn_range.len())])
641                    .unwrap();
642            unsafe {
643                packer.pack_segment(
644                    output.view_mut(),
645                    input.view(),
646                    self.is_a as usize,
647                    !self.is_a as usize,
648                    self.k_range.clone(),
649                    self.mn_range.clone(),
650                )
651            };
652            output.into_array::<u32>().unwrap().into_shape_with_order((panels, panel_len)).unwrap()
653        }
654
655        fn reference(&self) -> Array2<u32> {
656            let input = self.input();
657            let panels = self.mn_range.len().divceil(self.r);
658            let len = Integer::next_multiple_of(&(self.k_range.len() * self.r), &self.align_panel);
659            Array2::from_shape_fn([panels, len], |(panel, z)| {
660                let k = z / self.r;
661                let x = z % self.r;
662                let mn = panel * self.r + x + self.mn_range.start;
663                let k = k + self.k_range.start;
664                let coords = if self.is_a { (mn, k) } else { (k, mn) };
665                *input.get(coords).unwrap_or(&0)
666            })
667        }
668
669        fn valid(&self) -> Array2<bool> {
670            let panels = self.mn_range.len().divceil(self.r);
671            let len = Integer::next_multiple_of(&(self.k_range.len() * self.r), &self.align_panel);
672            Array2::from_shape_fn([panels, len], |(panel, z)| {
673                let k = z / self.r;
674                let x = z % self.r;
675                let k = k + self.k_range.start;
676                let mn = panel * self.r + x + self.mn_range.start;
677                k < self.k_range.end.min(self.k) && mn < self.mn_range.end.min(self.mn)
678            })
679        }
680
681        fn check(&self) {
682            let mut packer = self.packer();
683            let mut reference = self.reference();
684            let valid = self.valid();
685            Zip::from(&mut packer).and(&valid).for_each(|p, v| *p = if *v { *p } else { -1 as _ });
686            Zip::from(&mut reference)
687                .and(&valid)
688                .for_each(|p, v| *p = if *v { *p } else { -1 as _ });
689            assert_eq!(packer, reference);
690        }
691    }
692
693    impl Arbitrary for PackProblem {
694        type Parameters = ();
695        type Strategy = BoxedStrategy<PackProblem>;
696        fn arbitrary_with(_args: ()) -> Self::Strategy {
697            (any::<bool>(), 1usize..9, 1usize..20, 1usize..20)
698                .prop_flat_map(|(is_a, r, k, mn)| {
699                    (
700                        Just((is_a, r, k, mn)),
701                        sub_range_strat(0..k),
702                        sub_range_strat(0..mn),
703                        1usize..5,
704                    )
705                })
706                .prop_map(|((is_a, r, k, mn), k_range, mn_range, align_panel)| PackProblem {
707                    k,
708                    mn,
709                    is_a,
710                    r,
711                    k_range,
712                    mn_range,
713                    align_panel,
714                })
715                .boxed()
716        }
717    }
718
719    fn sub_range_strat(range: Range<usize>) -> BoxedStrategy<Range<usize>> {
720        (0..range.len())
721            .prop_flat_map(|cropped| (Just(cropped), 0..=cropped))
722            .prop_map(move |(cropped, left)| range.start + left..range.end - (cropped - left))
723            .boxed()
724    }
725
726    proptest::proptest! {
727        #[test]
728        fn prop(pb in any::<PackProblem>()) {
729            pb.check();
730        }
731
732        #[test]
733        fn subrange_prop(_range in sub_range_strat(0..20)) {
734        }
735
736    }
737
738    #[test]
739    fn simple_b_1() {
740        PackProblem {
741            k: 2,
742            mn: 1,
743            is_a: false,
744            r: 1,
745            k_range: 0..2,
746            mn_range: 0..1,
747            align_panel: 1,
748        }
749        .check();
750    }
751
752    #[test]
753    fn simple_b_2() {
754        PackProblem {
755            k: 2,
756            mn: 2,
757            is_a: false,
758            r: 1,
759            k_range: 0..2,
760            mn_range: 0..2,
761            align_panel: 1,
762        }
763        .check()
764    }
765
766    #[test]
767    fn simple_b_3() {
768        PackProblem {
769            k: 2,
770            mn: 1,
771            is_a: false,
772            r: 4,
773            k_range: 0..2,
774            mn_range: 0..1,
775            align_panel: 1,
776        }
777        .check();
778    }
779
780    #[test]
781    fn simple_b_4() {
782        PackProblem {
783            k: 1,
784            mn: 3,
785            is_a: false,
786            r: 2,
787            k_range: 0..1,
788            mn_range: 0..3,
789            align_panel: 1,
790        }
791        .check();
792    }
793
794    #[test]
795    fn simple_a_1() {
796        PackProblem {
797            k: 2,
798            mn: 2,
799            is_a: true,
800            r: 1,
801            k_range: 0..2,
802            mn_range: 0..2,
803            align_panel: 1,
804        }
805        .check();
806    }
807
808    #[test]
809    fn simple_a_2() {
810        PackProblem {
811            k: 2,
812            mn: 3,
813            is_a: true,
814            r: 2,
815            k_range: 0..2,
816            mn_range: 0..3,
817            align_panel: 1,
818        }
819        .check();
820    }
821
822    #[test]
823    fn range_k_0() {
824        PackProblem {
825            k: 2,
826            mn: 1,
827            is_a: false,
828            r: 1,
829            k_range: 1..2,
830            mn_range: 0..1,
831            align_panel: 1,
832        }
833        .check();
834    }
835
836    #[test]
837    fn range_k_1() {
838        PackProblem {
839            k: 2,
840            mn: 2,
841            is_a: false,
842            r: 1,
843            k_range: 0..2,
844            mn_range: 0..1,
845            align_panel: 1,
846        }
847        .check();
848    }
849
850    #[test]
851    fn range_k_2() {
852        PackProblem {
853            k: 2,
854            mn: 1,
855            is_a: false,
856            r: 6,
857            k_range: 1..2,
858            mn_range: 0..1,
859            align_panel: 1,
860        }
861        .check();
862    }
863
864    #[test]
865    fn range_mn_0() {
866        PackProblem {
867            k: 1,
868            mn: 2,
869            is_a: false,
870            r: 2,
871            k_range: 0..1,
872            mn_range: 0..1,
873            align_panel: 1,
874        }
875        .check();
876    }
877
878    #[test]
879    fn range_b_4() {
880        PackProblem {
881            k: 1,
882            mn: 2,
883            is_a: false,
884            r: 6,
885            k_range: 0..1,
886            mn_range: 1..2,
887            align_panel: 1,
888        }
889        .check();
890    }
891
892    #[test]
893    fn range_b_5() {
894        PackProblem {
895            k: 1,
896            mn: 7,
897            is_a: false,
898            r: 6,
899            k_range: 0..1,
900            mn_range: 1..7,
901            align_panel: 1,
902        }
903        .check();
904    }
905
906    #[test]
907    fn align_a_1() {
908        PackProblem {
909            k: 2,
910            mn: 2,
911            is_a: true,
912            r: 1,
913            k_range: 0..1,
914            mn_range: 0..2,
915            align_panel: 2,
916        }
917        .check();
918    }
919
920    #[test]
921    fn align_b_1() {
922        PackProblem {
923            k: 1,
924            mn: 1,
925            is_a: false,
926            r: 1,
927            k_range: 0..1,
928            mn_range: 0..1,
929            align_panel: 2,
930        }
931        .check();
932    }
933
934    #[test]
935    fn align_b_2() {
936        PackProblem {
937            k: 3,
938            mn: 1,
939            is_a: false,
940            r: 1,
941            k_range: 0..3,
942            mn_range: 0..1,
943            align_panel: 2,
944        }
945        .check();
946    }
947
948    #[test]
949    fn align_b_3() {
950        PackProblem {
951            k: 1,
952            mn: 1,
953            is_a: false,
954            r: 3,
955            k_range: 0..1,
956            mn_range: 0..1,
957            align_panel: 2,
958        }
959        .check();
960    }
961
962    #[test]
963    fn align_b_4() {
964        PackProblem {
965            k: 2,
966            mn: 1,
967            is_a: false,
968            r: 1,
969            k_range: 0..1,
970            mn_range: 0..1,
971            align_panel: 2,
972        }
973        .check();
974    }
975
976    #[test]
977    fn align_b_5() {
978        PackProblem {
979            k: 1,
980            mn: 5,
981            is_a: false,
982            r: 4,
983            k_range: 0..1,
984            mn_range: 0..5,
985            align_panel: 3,
986        }
987        .check();
988    }
989}