Skip to main content

tract_linalg/frame/mmm/
panel_extract.rs

1use std::fmt::{Debug, Display};
2use tract_data::internal::*;
3
4use super::{EagerPackedInput, MMMInputFormat, MMMInputValue};
5use crate::pack::PackedFormat;
6
7type Kernel = unsafe fn(input: *const u8, output: *mut u8, k: usize);
8
9#[allow(clippy::derived_hash_with_manual_eq)]
10#[derive(Hash, Clone)]
11pub struct PanelExtractor {
12    pub name: String,
13    pub from: Box<dyn MMMInputFormat>,
14    pub to: PackedFormat,
15    pub kernel: Kernel,
16    pub supported_predicate: fn() -> bool,
17}
18
19impl Debug for PanelExtractor {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        write!(f, "{} ({:?} -> {:?})", self.name, self.from, self.to)
22    }
23}
24
25impl Display for PanelExtractor {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        write!(f, "{}", self.name)
28    }
29}
30
31impl PartialEq for PanelExtractor {
32    fn eq(&self, other: &Self) -> bool {
33        self.name == other.name && *self.from == *other.from && self.to == other.to
34    }
35}
36impl Eq for PanelExtractor {}
37
38impl PanelExtractor {
39    #[allow(unused_variables)]
40    pub fn is_supported_here(&self) -> bool {
41        (self.supported_predicate)()
42    }
43}
44
45#[derive(Clone, Hash, PartialEq, Eq)]
46pub struct PanelExtractInput {
47    pub format: PanelExtractor,
48    pub data: EagerPackedInput,
49}
50
51impl MMMInputValue for PanelExtractInput {
52    fn scratch_panel_buffer_layout(&self) -> Option<std::alloc::Layout> {
53        Some(self.format.to.single_panel_layout(self.data.k(), self.format.to.dt.size_of()))
54    }
55    fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> TractResult<*const u8> {
56        let scratch = buffer.unwrap();
57        unsafe {
58            let source = self.data.packed.as_ptr().add(self.data.panel_bytes * i);
59            (self.format.kernel)(source, scratch, self.data.k());
60        }
61        Ok(scratch)
62    }
63    fn mn(&self) -> usize {
64        self.data.mn()
65    }
66    fn k(&self) -> usize {
67        self.data.k()
68    }
69    fn format(&self) -> &dyn MMMInputFormat {
70        &self.format.to
71    }
72    fn exotic_fact(&self) -> &dyn ExoticFact {
73        self.data.exotic_fact()
74    }
75    fn extract_at_mn_f16(&self, mn: usize, slice: &mut [f16]) -> TractResult<()> {
76        self.data.extract_at_mn_f16(mn, slice)
77    }
78    fn extract_at_mn_f32(&self, mn: usize, slice: &mut [f32]) -> TractResult<()> {
79        self.data.extract_at_mn_f32(mn, slice)
80    }
81}
82
83impl Display for PanelExtractInput {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        write!(f, "PanelExtract({})", self.data)
86    }
87}
88
89impl Debug for PanelExtractInput {
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        write!(f, "PanelExtract({})", self.data)
92    }
93}
94
95#[macro_export]
96macro_rules! panel_extractor {
97    ( $func:path as $id:ident($from:expr, $to: expr)
98            $(where($where:expr))?
99     ) => {
100        paste! {
101            lazy_static::lazy_static! {
102                pub static ref $id: $crate::mmm::PanelExtractor = {
103                    use $crate::mmm::MMMInputFormat;
104                    let (from, to) = ($from, $to);
105                    assert!(from.r() == to.r());
106                    #[allow(unused_mut)]
107                    let mut it = $crate::mmm::PanelExtractor {
108                        name: stringify!($id).to_string(),
109                        from,
110                        to,
111                        kernel: $func,
112                        supported_predicate: || true
113                    };
114                    $(
115                        it.supported_predicate = $where;
116                    )?
117                    it
118                };
119            }
120
121            #[cfg(test)]
122            mod [<test_$id>] {
123                use super::$id;
124                #[test]
125                fn repack_0block_1panel() {
126                    $crate::frame::mmm::panel_extract::test::test_packing(&$id, 0, 1).unwrap();
127                }
128
129                #[test]
130                fn repack_1block_0panel() {
131                    $crate::frame::mmm::panel_extract::test::test_packing(&$id, 1, 0).unwrap();
132                }
133
134                #[test]
135                fn repack_1block_1panel() {
136                    $crate::frame::mmm::panel_extract::test::test_packing(&$id, 1, 1).unwrap();
137                }
138
139                #[test]
140                fn repack_2block_1panel() {
141                    $crate::frame::mmm::panel_extract::test::test_packing(&$id, 2, 1).unwrap();
142                }
143
144                #[test]
145                fn repack_1block_2panel() {
146                    $crate::frame::mmm::panel_extract::test::test_packing(&$id, 1, 2).unwrap();
147                }
148
149                #[test]
150                fn repack_2block_2panel() {
151                    $crate::frame::mmm::panel_extract::test::test_packing(&$id, 2, 2).unwrap();
152                }
153            }
154        }
155    };
156}
157
158#[cfg(test)]
159pub mod test {
160    use crate::frame::block_quant::PackedBlockQuantFormat;
161    use crate::mmm::PackedMatrixStorage;
162    use tract_data::internal::*;
163    use tract_ndarray::Array2;
164
165    use super::*;
166
167    pub fn test_packing(
168        extractor: &PanelExtractor,
169        blocks: usize,
170        panels: usize,
171    ) -> TractResult<()> {
172        if !extractor.is_supported_here() {
173            return Ok(());
174        }
175        assert!(extractor.from.r() == extractor.to.r());
176        assert!(extractor.to.dt == f32::datum_type() || extractor.to.dt == f16::datum_type());
177        if let Some(from) = extractor.from.downcast_ref::<PackedBlockQuantFormat>() {
178            test_packing_bq(extractor, from, blocks, panels)
179        } else if let Some(from) = extractor.from.downcast_ref() {
180            test_packing_plain(extractor, from, blocks, panels)
181        } else {
182            todo!()
183        }
184    }
185
186    pub fn test_packing_plain(
187        extractor: &PanelExtractor,
188        from: &PackedFormat,
189        blocks: usize,
190        panels: usize,
191    ) -> TractResult<()> {
192        let m = from.r * panels;
193        let k = 8 * blocks; // 8 is arbitrary
194        let to = &extractor.to;
195        let weights_orig =
196            Array2::from_shape_fn((m, k), |(m, k)| ((m * 31 + k * 17) % 20) as f32 - 10.)
197                .into_tensor()
198                .cast_to_dt(from.dt)?
199                .into_owned();
200        let packed_orig = from.prepare_tensor(&weights_orig, 1, 0)?;
201        let packed_orig_storage = packed_orig.try_storage_as::<PackedMatrixStorage>()?;
202        let packed_orig = packed_orig_storage.value().downcast_ref::<EagerPackedInput>().unwrap();
203
204        for panel in 0..panels {
205            let orig_panel = &packed_orig.packed[packed_orig.panel_bytes * panel..]
206                [..k * from.r * from.dt.size_of()];
207            let mut reference_panel = Tensor::zero_dt(from.dt, &[k, from.r])?;
208            reference_panel.as_bytes_mut().copy_from_slice(orig_panel);
209            reference_panel = reference_panel.cast_to_dt(to.dt)?.into_owned();
210
211            let mut tested_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
212            unsafe {
213                (extractor.kernel)(
214                    orig_panel.as_ptr(),
215                    tested_panel.as_bytes_mut().as_mut_ptr(),
216                    k,
217                );
218            }
219            compare_panels(&tested_panel, &reference_panel, from.r, k);
220        }
221        Ok(())
222    }
223
224    pub fn test_packing_bq(
225        extractor: &PanelExtractor,
226        from: &PackedBlockQuantFormat,
227        blocks: usize,
228        panels: usize,
229    ) -> TractResult<()> {
230        let m = from.r * panels;
231        let k = from.bq.block_len() * blocks;
232        let to = &extractor.to;
233        let weights_orig =
234            Array2::from_shape_fn((m, k), |(m, k)| ((m * 31 + k * 17) % 20) as f32 - 10.)
235                .into_tensor()
236                .cast_to_dt(to.dt)?
237                .into_owned();
238        let weights = if to.dt == f32::datum_type() {
239            from.bq
240                .dequant_f32(&from.bq.quant_f32(weights_orig.try_as_plain()?.as_slice::<f32>()?)?)?
241                .into_shape(&[m, k])?
242        } else {
243            from.bq
244                .dequant_f16(&from.bq.quant_f16(weights_orig.try_as_plain()?.as_slice::<f16>()?)?)?
245                .into_shape(&[m, k])?
246        };
247        let block_quant = if to.dt == f32::datum_type() {
248            from.bq.quant_f32(weights.try_as_plain()?.as_slice::<f32>()?)?
249        } else {
250            from.bq.quant_f16(weights.try_as_plain()?.as_slice::<f16>()?)?
251        };
252        let packed_block_quant =
253            from.bq.pack(&block_quant, k, from.r, from.zip, from.scales_at_end)?;
254
255        let mut reference_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
256        let mut tested_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
257
258        for panel in 0..packed_block_quant.panels_count() {
259            unsafe {
260                from.bq.extract_packed_panel(
261                    &packed_block_quant,
262                    to,
263                    panel,
264                    reference_panel.as_bytes_mut().as_mut_ptr(),
265                )?;
266
267                let source =
268                    packed_block_quant.packed.as_ptr().add(packed_block_quant.panel_bytes * panel);
269                (extractor.kernel)(source, tested_panel.as_bytes_mut().as_mut_ptr(), k);
270            }
271            compare_panels(&tested_panel, &reference_panel, from.r, k);
272        }
273        Ok(())
274    }
275
276    fn compare_panels(tested_panel: &Tensor, reference_panel: &Tensor, r: usize, k: usize) {
277        if tested_panel != reference_panel {
278            if reference_panel.datum_type() == f32::datum_type() {
279                crate::frame::mmm::tests::display_error(
280                    tested_panel.try_as_plain().unwrap().as_slice::<f32>().unwrap(),
281                    reference_panel.try_as_plain().unwrap().as_slice::<f32>().unwrap(),
282                    r,
283                    k,
284                );
285            } else {
286                crate::frame::mmm::tests::display_error(
287                    tested_panel.try_as_plain().unwrap().as_slice::<f16>().unwrap(),
288                    reference_panel.try_as_plain().unwrap().as_slice::<f16>().unwrap(),
289                    r,
290                    k,
291                );
292            }
293        }
294        assert_eq!(tested_panel, reference_panel);
295    }
296}