Skip to main content

tract_linalg/frame/mmm/
panel_extract.rs

1use std::fmt::{Debug, Display};
2use tract_data::internal::*;
3
4use super::{EagerPackedInput, MMMInputFormat, MMMInputValue};
5use crate::pack::PackedFormat;
6
7type Kernel = unsafe fn(input: *const u8, output: *mut u8, k: usize);
8
9#[allow(clippy::derived_hash_with_manual_eq)]
10#[derive(Hash, Clone)]
11pub struct PanelExtractor {
12    pub name: String,
13    pub from: Box<dyn MMMInputFormat>,
14    pub to: PackedFormat,
15    pub kernel: Kernel,
16    pub supported_predicate: fn() -> bool,
17}
18
19impl Debug for PanelExtractor {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        write!(f, "{} ({:?} -> {:?})", self.name, self.from, self.to)
22    }
23}
24
25impl Display for PanelExtractor {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        write!(f, "{}", self.name)
28    }
29}
30
31impl PartialEq for PanelExtractor {
32    fn eq(&self, other: &Self) -> bool {
33        self.name == other.name && self.from.same_as(&*other.from) && self.to == other.to
34    }
35}
36
37impl PanelExtractor {
38    #[allow(unused_variables)]
39    pub fn is_supported_here(&self) -> bool {
40        (self.supported_predicate)()
41    }
42}
43
44#[derive(Clone, Hash)]
45pub struct PanelExtractInput {
46    pub format: PanelExtractor,
47    pub data: EagerPackedInput,
48}
49
50impl MMMInputValue for PanelExtractInput {
51    fn scratch_panel_buffer_layout(&self) -> Option<std::alloc::Layout> {
52        Some(self.format.to.single_panel_layout(self.data.k(), self.format.to.dt.size_of()))
53    }
54    fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> TractResult<*const u8> {
55        let scratch = buffer.unwrap();
56        unsafe {
57            let source = self.data.packed.as_ptr().add(self.data.panel_bytes * i);
58            (self.format.kernel)(source, scratch, self.data.k());
59        }
60        Ok(scratch)
61    }
62    fn mn(&self) -> usize {
63        self.data.mn()
64    }
65    fn k(&self) -> usize {
66        self.data.k()
67    }
68    fn format(&self) -> &dyn MMMInputFormat {
69        &self.format.to
70    }
71    fn opaque_fact(&self) -> &dyn OpaqueFact {
72        self.data.opaque_fact()
73    }
74    fn same_as(&self, other: &dyn MMMInputValue) -> bool {
75        other
76            .downcast_ref::<Self>()
77            .is_some_and(|o| o.format == self.format && o.data.same_as(&self.data))
78    }
79    fn extract_at_mn_f16(&self, mn: usize, slice: &mut [f16]) -> TractResult<()> {
80        self.data.extract_at_mn_f16(mn, slice)
81    }
82    fn extract_at_mn_f32(&self, mn: usize, slice: &mut [f32]) -> TractResult<()> {
83        self.data.extract_at_mn_f32(mn, slice)
84    }
85}
86
87impl Display for PanelExtractInput {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        write!(f, "PanelExtract({})", self.data)
90    }
91}
92
93impl Debug for PanelExtractInput {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        write!(f, "PanelExtract({})", self.data)
96    }
97}
98
99#[macro_export]
100macro_rules! panel_extractor {
101    ( $func:path as $id:ident($from:expr, $to: expr)
102            $(where($where:expr))?
103     ) => {
104        paste! {
105            lazy_static::lazy_static! {
106                pub static ref $id: $crate::mmm::PanelExtractor = {
107                    use $crate::mmm::MMMInputFormat;
108                    let (from, to) = ($from, $to);
109                    assert!(from.r() == to.r());
110                    #[allow(unused_mut)]
111                    let mut it = $crate::mmm::PanelExtractor {
112                        name: stringify!($id).to_string(),
113                        from,
114                        to,
115                        kernel: $func,
116                        supported_predicate: || true
117                    };
118                    $(
119                        it.supported_predicate = $where;
120                    )?
121                    it
122                };
123            }
124
125            #[cfg(test)]
126            mod [<test_$id>] {
127                use super::$id;
128                #[test]
129                fn repack_0block_1panel() {
130                    $crate::frame::mmm::panel_extract::test::test_packing(&$id, 0, 1).unwrap();
131                }
132
133                #[test]
134                fn repack_1block_0panel() {
135                    $crate::frame::mmm::panel_extract::test::test_packing(&$id, 1, 0).unwrap();
136                }
137
138                #[test]
139                fn repack_1block_1panel() {
140                    $crate::frame::mmm::panel_extract::test::test_packing(&$id, 1, 1).unwrap();
141                }
142
143                #[test]
144                fn repack_2block_1panel() {
145                    $crate::frame::mmm::panel_extract::test::test_packing(&$id, 2, 1).unwrap();
146                }
147
148                #[test]
149                fn repack_1block_2panel() {
150                    $crate::frame::mmm::panel_extract::test::test_packing(&$id, 1, 2).unwrap();
151                }
152
153                #[test]
154                fn repack_2block_2panel() {
155                    $crate::frame::mmm::panel_extract::test::test_packing(&$id, 2, 2).unwrap();
156                }
157            }
158        }
159    };
160}
161
162#[cfg(test)]
163pub mod test {
164    use crate::frame::block_quant::PackedBlockQuantFormat;
165    use tract_data::internal::*;
166    use tract_ndarray::Array2;
167
168    use super::*;
169
170    pub fn test_packing(
171        extractor: &PanelExtractor,
172        blocks: usize,
173        panels: usize,
174    ) -> TractResult<()> {
175        if !extractor.is_supported_here() {
176            return Ok(());
177        }
178        assert!(extractor.from.r() == extractor.to.r());
179        assert!(extractor.to.dt == f32::datum_type() || extractor.to.dt == f16::datum_type());
180        if let Some(from) = extractor.from.downcast_ref::<PackedBlockQuantFormat>() {
181            test_packing_bq(extractor, from, blocks, panels)
182        } else if let Some(from) = extractor.from.downcast_ref() {
183            test_packing_plain(extractor, from, blocks, panels)
184        } else {
185            todo!()
186        }
187    }
188
189    pub fn test_packing_plain(
190        extractor: &PanelExtractor,
191        from: &PackedFormat,
192        blocks: usize,
193        panels: usize,
194    ) -> TractResult<()> {
195        let m = from.r * panels;
196        let k = 8 * blocks; // 8 is arbitrary
197        let to = &extractor.to;
198        let weights_orig =
199            Array2::from_shape_fn((m, k), |(m, k)| ((m * 31 + k * 17) % 20) as f32 - 10.)
200                .into_tensor()
201                .cast_to_dt(from.dt)?
202                .into_owned();
203        let packed_orig = from.prepare_tensor(&weights_orig, 1, 0)?;
204        let packed_orig =
205            packed_orig.to_scalar::<Opaque>()?.downcast_ref::<Box<dyn MMMInputValue>>().unwrap();
206        let packed_orig = packed_orig.downcast_ref::<EagerPackedInput>().unwrap();
207
208        for panel in 0..panels {
209            let orig_panel = &packed_orig.packed[packed_orig.panel_bytes * panel..]
210                [..k * from.r * from.dt.size_of()];
211            let mut reference_panel = Tensor::zero_dt(from.dt, &[k, from.r])?;
212            reference_panel.as_bytes_mut().copy_from_slice(orig_panel);
213            reference_panel = reference_panel.cast_to_dt(to.dt)?.into_owned();
214
215            let mut tested_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
216            unsafe {
217                (extractor.kernel)(
218                    orig_panel.as_ptr(),
219                    tested_panel.as_bytes_mut().as_mut_ptr(),
220                    k,
221                );
222            }
223            compare_panels(&tested_panel, &reference_panel, from.r, k);
224        }
225        Ok(())
226    }
227
228    pub fn test_packing_bq(
229        extractor: &PanelExtractor,
230        from: &PackedBlockQuantFormat,
231        blocks: usize,
232        panels: usize,
233    ) -> TractResult<()> {
234        let m = from.r * panels;
235        let k = from.bq.block_len() * blocks;
236        let to = &extractor.to;
237        let weights_orig =
238            Array2::from_shape_fn((m, k), |(m, k)| ((m * 31 + k * 17) % 20) as f32 - 10.)
239                .into_tensor()
240                .cast_to_dt(to.dt)?
241                .into_owned();
242        let weights = if to.dt == f32::datum_type() {
243            from.bq
244                .dequant_f32(&from.bq.quant_f32(weights_orig.as_slice::<f32>()?)?)?
245                .into_shape(&[m, k])?
246        } else {
247            from.bq
248                .dequant_f16(&from.bq.quant_f16(weights_orig.as_slice::<f16>()?)?)?
249                .into_shape(&[m, k])?
250        };
251        let block_quant = if to.dt == f32::datum_type() {
252            from.bq.quant_f32(weights.as_slice::<f32>()?)?
253        } else {
254            from.bq.quant_f16(weights.as_slice::<f16>()?)?
255        };
256        let packed_block_quant =
257            from.bq.pack(&block_quant, k, from.r, from.zip, from.scales_at_end)?;
258
259        let mut reference_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
260        let mut tested_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
261
262        for panel in 0..packed_block_quant.panels_count() {
263            unsafe {
264                from.bq.extract_packed_panel(
265                    &packed_block_quant,
266                    to,
267                    panel,
268                    reference_panel.as_bytes_mut().as_mut_ptr(),
269                )?;
270
271                let source =
272                    packed_block_quant.packed.as_ptr().add(packed_block_quant.panel_bytes * panel);
273                (extractor.kernel)(source, tested_panel.as_bytes_mut().as_mut_ptr(), k);
274            }
275            compare_panels(&tested_panel, &reference_panel, from.r, k);
276        }
277        Ok(())
278    }
279
280    fn compare_panels(tested_panel: &Tensor, reference_panel: &Tensor, r: usize, k: usize) {
281        if tested_panel != reference_panel {
282            if reference_panel.datum_type() == f32::datum_type() {
283                crate::frame::mmm::tests::display_error(
284                    tested_panel.as_slice::<f32>().unwrap(),
285                    reference_panel.as_slice::<f32>().unwrap(),
286                    r,
287                    k,
288                );
289            } else {
290                crate::frame::mmm::tests::display_error(
291                    tested_panel.as_slice::<f16>().unwrap(),
292                    reference_panel.as_slice::<f16>().unwrap(),
293                    r,
294                    k,
295                );
296            }
297        }
298        assert_eq!(tested_panel, reference_panel);
299    }
300}