1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
use crate::mmm::MMMInputValue;

use super::*;
use crate::frame::PackedFormat;

#[derive(Clone, Hash)]
pub struct RepackingPackedBlockQuantValue {
    pub value: EagerPackedInput,
    pub pack: PackedFormat,
}

impl Display for RepackingPackedBlockQuantValue {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{:?} repacked to {:?}", self.value, self.pack)
    }
}

impl Debug for RepackingPackedBlockQuantValue {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        <Self as Display>::fmt(self, f)
    }
}

impl MMMInputValue for RepackingPackedBlockQuantValue {
    fn scratch_panel_buffer_layout(&self) -> Option<Layout> {
        Some(self.pack.single_panel_layout(self.value.k, 4))
    }
    fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> TractResult<*const u8> {
        let buffer = buffer.context("Scratch panel expected")?;
        let pbqf = self.value.format.downcast_ref::<PackedBlockQuantFormat>().unwrap();
        unsafe {
            pbqf.bq.repack_panel(&self.value, &self.pack, i, buffer)?;
        }
        Ok(buffer)
    }
    fn mn(&self) -> usize {
        self.value.mn
    }
    fn r(&self) -> usize {
        self.pack.r
    }
    fn k(&self) -> usize {
        self.value.k
    }
}