tract_linalg/frame/mmm/
panel_extract.rs

1use std::fmt::{Debug, Display};
2use tract_data::internal::*;
3
4use super::{EagerPackedInput, MMMInputFormat, MMMInputValue};
5use crate::pack::PackedFormat;
6
7type Kernel = unsafe fn(input: *const u8, output: *mut u8, k: usize);
8
9#[allow(clippy::derived_hash_with_manual_eq)]
10#[derive(Hash, Clone)]
11pub struct PanelExtractor {
12    pub name: String,
13    pub from: Box<dyn MMMInputFormat>,
14    pub to: PackedFormat,
15    pub kernel: Kernel,
16    pub supported_predicate: fn() -> bool,
17}
18
19impl Debug for PanelExtractor {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        write!(f, "{} ({:?} -> {:?})", self.name, self.from, self.to)
22    }
23}
24
25impl Display for PanelExtractor {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        write!(f, "{}", self.name)
28    }
29}
30
31impl PartialEq for PanelExtractor {
32    fn eq(&self, other: &Self) -> bool {
33        self.name == other.name && self.from.same_as(&*other.from) && self.to == other.to
34    }
35}
36
37impl PanelExtractor {
38    #[allow(unused_variables)]
39    pub fn is_supported_here(&self) -> bool {
40        (self.supported_predicate)()
41    }
42}
43
44#[derive(Clone, Hash)]
45pub struct PanelExtractInput {
46    pub format: PanelExtractor,
47    pub data: EagerPackedInput,
48}
49
50impl MMMInputValue for PanelExtractInput {
51    fn scratch_panel_buffer_layout(&self) -> Option<std::alloc::Layout> {
52        Some(self.format.to.single_panel_layout(self.data.k(), self.format.to.dt.size_of()))
53    }
54    fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> TractResult<*const u8> {
55        let scratch = buffer.unwrap();
56        unsafe {
57            let source = self.data.packed.as_ptr().add(self.data.panel_bytes * i);
58            (self.format.kernel)(source, scratch, self.data.k());
59        }
60        Ok(scratch)
61    }
62    fn mn(&self) -> usize {
63        self.data.mn()
64    }
65    fn k(&self) -> usize {
66        self.data.k()
67    }
68    fn format(&self) -> &dyn MMMInputFormat {
69        &self.format.to
70    }
71    fn opaque_fact(&self) -> &dyn OpaqueFact {
72        self.data.opaque_fact()
73    }
74    fn same_as(&self, other: &dyn MMMInputValue) -> bool {
75        other
76            .downcast_ref::<Self>()
77            .is_some_and(|o| o.format == self.format && o.data.same_as(&self.data))
78    }
79    fn extract_at_mn_f16(&self, mn: usize, slice: &mut [f16]) -> TractResult<()> {
80        self.data.extract_at_mn_f16(mn, slice)
81    }
82    fn extract_at_mn_f32(&self, mn: usize, slice: &mut [f32]) -> TractResult<()> {
83        self.data.extract_at_mn_f32(mn, slice)
84    }
85}
86
87impl Display for PanelExtractInput {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        write!(f, "PanelExtract({})", self.data)
90    }
91}
92
93impl Debug for PanelExtractInput {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        write!(f, "PanelExtract({})", self.data)
96    }
97}
98
99#[macro_export]
100macro_rules! panel_extractor {
101    ( $func:path as $id:ident($from:expr, $to: expr)
102            $(where($where:expr))?
103     ) => {
104        paste! {
105            lazy_static::lazy_static! {
106                pub static ref $id: $crate::mmm::PanelExtractor = {
107                    use $crate::mmm::MMMInputFormat;
108                    let (from, to) = ($from, $to);
109                    assert!(from.r() == to.r());
110                    #[allow(unused_mut)]
111                    let mut it = $crate::mmm::PanelExtractor {
112                        name: stringify!($id).to_string(),
113                        from,
114                        to,
115                        kernel: $func,
116                        supported_predicate: || true
117                    };
118                    $(
119                        it.supported_predicate = $where;
120                    )?
121                    it
122                };
123            }
124
125            #[cfg(test)]
126            mod [<test_$id>] {
127                use super::$id;
128                #[test]
129                fn repack_1block_1panel() {
130                    $crate::frame::mmm::panel_extract::test::test_packing(&$id, 1, 1).unwrap();
131                }
132
133                #[test]
134                fn repack_2block_1panel() {
135                    $crate::frame::mmm::panel_extract::test::test_packing(&$id, 2, 1).unwrap();
136                }
137
138                #[test]
139                fn repack_1block_2panel() {
140                    $crate::frame::mmm::panel_extract::test::test_packing(&$id, 1, 2).unwrap();
141                }
142
143                #[test]
144                fn repack_2block_2panel() {
145                    $crate::frame::mmm::panel_extract::test::test_packing(&$id, 2, 2).unwrap();
146                }
147            }
148        }
149    };
150}
151
152#[cfg(test)]
153pub mod test {
154    use crate::frame::block_quant::PackedBlockQuantFormat;
155    use tract_data::internal::*;
156    use tract_ndarray::Array2;
157
158    use super::*;
159
160    pub fn test_packing(
161        extractor: &PanelExtractor,
162        blocks: usize,
163        panels: usize,
164    ) -> TractResult<()> {
165        if !extractor.is_supported_here() {
166            return Ok(());
167        }
168        assert!(extractor.from.r() == extractor.to.r());
169        assert!(extractor.to.dt == f32::datum_type() || extractor.to.dt == f16::datum_type());
170        if let Some(from) = extractor.from.downcast_ref::<PackedBlockQuantFormat>() {
171            test_packing_bq(extractor, from, blocks, panels)
172        } else if let Some(from) = extractor.from.downcast_ref() {
173            test_packing_plain(extractor, from, blocks, panels)
174        } else {
175            todo!()
176        }
177    }
178
179    pub fn test_packing_plain(
180        extractor: &PanelExtractor,
181        from: &PackedFormat,
182        blocks: usize,
183        panels: usize,
184    ) -> TractResult<()> {
185        let m = from.r * panels;
186        let k = 8 * blocks; // 8 is arbitrary
187        let to = &extractor.to;
188        let weights_orig =
189            Array2::from_shape_fn((m, k), |(m, k)| ((m * 31 + k * 17) % 20) as f32 - 10.)
190                .into_tensor()
191                .cast_to_dt(from.dt)?
192                .into_owned();
193        let packed_orig = from.prepare_tensor(&weights_orig, 1, 0)?;
194        let packed_orig =
195            packed_orig.to_scalar::<Opaque>()?.downcast_ref::<Box<dyn MMMInputValue>>().unwrap();
196        let packed_orig = packed_orig.downcast_ref::<EagerPackedInput>().unwrap();
197
198        for panel in 0..panels {
199            let orig_panel = &packed_orig.packed[packed_orig.panel_bytes * panel..]
200                [..k * from.r * from.dt.size_of()];
201            let mut reference_panel = Tensor::zero_dt(from.dt, &[k, from.r])?;
202            reference_panel.as_bytes_mut().copy_from_slice(orig_panel);
203            reference_panel = reference_panel.cast_to_dt(to.dt)?.into_owned();
204
205            let mut tested_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
206            unsafe {
207                (extractor.kernel)(
208                    orig_panel.as_ptr(),
209                    tested_panel.as_bytes_mut().as_mut_ptr(),
210                    k,
211                );
212            }
213            compare_panels(&tested_panel, &reference_panel, from.r, k);
214        }
215        Ok(())
216    }
217
218    pub fn test_packing_bq(
219        extractor: &PanelExtractor,
220        from: &PackedBlockQuantFormat,
221        blocks: usize,
222        panels: usize,
223    ) -> TractResult<()> {
224        let m = from.r * panels;
225        let k = from.bq.block_len() * blocks;
226        let to = &extractor.to;
227        let weights_orig =
228            Array2::from_shape_fn((m, k), |(m, k)| ((m * 31 + k * 17) % 20) as f32 - 10.)
229                .into_tensor()
230                .cast_to_dt(to.dt)?
231                .into_owned();
232        let weights = if to.dt == f32::datum_type() {
233            from.bq
234                .dequant_f32(&from.bq.quant_f32(weights_orig.as_slice::<f32>()?)?)?
235                .into_shape(&[m, k])?
236        } else {
237            from.bq
238                .dequant_f16(&from.bq.quant_f16(weights_orig.as_slice::<f16>()?)?)?
239                .into_shape(&[m, k])?
240        };
241        let block_quant = if to.dt == f32::datum_type() {
242            from.bq.quant_f32(weights.as_slice::<f32>()?)?
243        } else {
244            from.bq.quant_f16(weights.as_slice::<f16>()?)?
245        };
246        let packed_block_quant =
247            from.bq.pack(&block_quant, k, from.r, from.zip, from.scales_at_end)?;
248
249        let mut reference_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
250        let mut tested_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
251
252        for panel in 0..packed_block_quant.panels_count() {
253            unsafe {
254                from.bq.extract_packed_panel(
255                    &packed_block_quant,
256                    to,
257                    panel,
258                    reference_panel.as_bytes_mut().as_mut_ptr(),
259                )?;
260
261                let source =
262                    packed_block_quant.packed.as_ptr().add(packed_block_quant.panel_bytes * panel);
263                (extractor.kernel)(source, tested_panel.as_bytes_mut().as_mut_ptr(), k);
264            }
265            compare_panels(&tested_panel, &reference_panel, from.r, k);
266        }
267        Ok(())
268    }
269
270    fn compare_panels(tested_panel: &Tensor, reference_panel: &Tensor, r: usize, k: usize) {
271        if tested_panel != reference_panel {
272            if reference_panel.datum_type() == f32::datum_type() {
273                crate::frame::mmm::tests::display_error(
274                    tested_panel.as_slice::<f32>().unwrap(),
275                    reference_panel.as_slice::<f32>().unwrap(),
276                    r,
277                    k,
278                );
279            } else {
280                crate::frame::mmm::tests::display_error(
281                    tested_panel.as_slice::<f16>().unwrap(),
282                    reference_panel.as_slice::<f16>().unwrap(),
283                    r,
284                    k,
285                );
286            }
287        }
288        assert_eq!(tested_panel, reference_panel);
289    }
290}