1use std::fmt::{Debug, Display};
2use tract_data::internal::*;
3
4use super::{EagerPackedInput, MMMInputFormat, MMMInputValue};
5use crate::pack::PackedFormat;
6
7type Kernel = unsafe fn(input: *const u8, output: *mut u8, k: usize);
8
9#[allow(clippy::derived_hash_with_manual_eq)]
10#[derive(Hash, Clone)]
11pub struct PanelExtractor {
12 pub name: String,
13 pub from: Box<dyn MMMInputFormat>,
14 pub to: PackedFormat,
15 pub kernel: Kernel,
16 pub supported_predicate: fn() -> bool,
17}
18
19impl Debug for PanelExtractor {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 write!(f, "{} ({:?} -> {:?})", self.name, self.from, self.to)
22 }
23}
24
25impl Display for PanelExtractor {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 write!(f, "{}", self.name)
28 }
29}
30
31impl PartialEq for PanelExtractor {
32 fn eq(&self, other: &Self) -> bool {
33 self.name == other.name && self.from.same_as(&*other.from) && self.to == other.to
34 }
35}
36
37impl PanelExtractor {
38 #[allow(unused_variables)]
39 pub fn is_supported_here(&self) -> bool {
40 (self.supported_predicate)()
41 }
42}
43
44#[derive(Clone, Hash)]
45pub struct PanelExtractInput {
46 pub format: PanelExtractor,
47 pub data: EagerPackedInput,
48}
49
50impl MMMInputValue for PanelExtractInput {
51 fn scratch_panel_buffer_layout(&self) -> Option<std::alloc::Layout> {
52 Some(self.format.to.single_panel_layout(self.data.k(), self.format.to.dt.size_of()))
53 }
54 fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> TractResult<*const u8> {
55 let scratch = buffer.unwrap();
56 unsafe {
57 let source = self.data.packed.as_ptr().add(self.data.panel_bytes * i);
58 (self.format.kernel)(source, scratch, self.data.k());
59 }
60 Ok(scratch)
61 }
62 fn mn(&self) -> usize {
63 self.data.mn()
64 }
65 fn k(&self) -> usize {
66 self.data.k()
67 }
68 fn format(&self) -> &dyn MMMInputFormat {
69 &self.format.to
70 }
71 fn opaque_fact(&self) -> &dyn OpaqueFact {
72 self.data.opaque_fact()
73 }
74 fn same_as(&self, other: &dyn MMMInputValue) -> bool {
75 other
76 .downcast_ref::<Self>()
77 .is_some_and(|o| o.format == self.format && o.data.same_as(&self.data))
78 }
79 fn extract_at_mn_f16(&self, mn: usize, slice: &mut [f16]) -> TractResult<()> {
80 self.data.extract_at_mn_f16(mn, slice)
81 }
82 fn extract_at_mn_f32(&self, mn: usize, slice: &mut [f32]) -> TractResult<()> {
83 self.data.extract_at_mn_f32(mn, slice)
84 }
85}
86
87impl Display for PanelExtractInput {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 write!(f, "PanelExtract({})", self.data)
90 }
91}
92
93impl Debug for PanelExtractInput {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 write!(f, "PanelExtract({})", self.data)
96 }
97}
98
99#[macro_export]
100macro_rules! panel_extractor {
101 ( $func:path as $id:ident($from:expr, $to: expr)
102 $(where($where:expr))?
103 ) => {
104 paste! {
105 lazy_static::lazy_static! {
106 pub static ref $id: $crate::mmm::PanelExtractor = {
107 use $crate::mmm::MMMInputFormat;
108 let (from, to) = ($from, $to);
109 assert!(from.r() == to.r());
110 #[allow(unused_mut)]
111 let mut it = $crate::mmm::PanelExtractor {
112 name: stringify!($id).to_string(),
113 from,
114 to,
115 kernel: $func,
116 supported_predicate: || true
117 };
118 $(
119 it.supported_predicate = $where;
120 )?
121 it
122 };
123 }
124
125 #[cfg(test)]
126 mod [<test_$id>] {
127 use super::$id;
128 #[test]
129 fn repack_0block_1panel() {
130 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 0, 1).unwrap();
131 }
132
133 #[test]
134 fn repack_1block_0panel() {
135 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 1, 0).unwrap();
136 }
137
138 #[test]
139 fn repack_1block_1panel() {
140 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 1, 1).unwrap();
141 }
142
143 #[test]
144 fn repack_2block_1panel() {
145 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 2, 1).unwrap();
146 }
147
148 #[test]
149 fn repack_1block_2panel() {
150 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 1, 2).unwrap();
151 }
152
153 #[test]
154 fn repack_2block_2panel() {
155 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 2, 2).unwrap();
156 }
157 }
158 }
159 };
160}
161
162#[cfg(test)]
163pub mod test {
164 use crate::frame::block_quant::PackedBlockQuantFormat;
165 use tract_data::internal::*;
166 use tract_ndarray::Array2;
167
168 use super::*;
169
170 pub fn test_packing(
171 extractor: &PanelExtractor,
172 blocks: usize,
173 panels: usize,
174 ) -> TractResult<()> {
175 if !extractor.is_supported_here() {
176 return Ok(());
177 }
178 assert!(extractor.from.r() == extractor.to.r());
179 assert!(extractor.to.dt == f32::datum_type() || extractor.to.dt == f16::datum_type());
180 if let Some(from) = extractor.from.downcast_ref::<PackedBlockQuantFormat>() {
181 test_packing_bq(extractor, from, blocks, panels)
182 } else if let Some(from) = extractor.from.downcast_ref() {
183 test_packing_plain(extractor, from, blocks, panels)
184 } else {
185 todo!()
186 }
187 }
188
189 pub fn test_packing_plain(
190 extractor: &PanelExtractor,
191 from: &PackedFormat,
192 blocks: usize,
193 panels: usize,
194 ) -> TractResult<()> {
195 let m = from.r * panels;
196 let k = 8 * blocks; let to = &extractor.to;
198 let weights_orig =
199 Array2::from_shape_fn((m, k), |(m, k)| ((m * 31 + k * 17) % 20) as f32 - 10.)
200 .into_tensor()
201 .cast_to_dt(from.dt)?
202 .into_owned();
203 let packed_orig = from.prepare_tensor(&weights_orig, 1, 0)?;
204 let packed_orig = packed_orig
205 .try_as_dense()?
206 .to_scalar::<Opaque>()?
207 .downcast_ref::<Box<dyn MMMInputValue>>()
208 .unwrap();
209 let packed_orig = packed_orig.downcast_ref::<EagerPackedInput>().unwrap();
210
211 for panel in 0..panels {
212 let orig_panel = &packed_orig.packed[packed_orig.panel_bytes * panel..]
213 [..k * from.r * from.dt.size_of()];
214 let mut reference_panel = Tensor::zero_dt(from.dt, &[k, from.r])?;
215 reference_panel.as_bytes_mut().copy_from_slice(orig_panel);
216 reference_panel = reference_panel.cast_to_dt(to.dt)?.into_owned();
217
218 let mut tested_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
219 unsafe {
220 (extractor.kernel)(
221 orig_panel.as_ptr(),
222 tested_panel.as_bytes_mut().as_mut_ptr(),
223 k,
224 );
225 }
226 compare_panels(&tested_panel, &reference_panel, from.r, k);
227 }
228 Ok(())
229 }
230
231 pub fn test_packing_bq(
232 extractor: &PanelExtractor,
233 from: &PackedBlockQuantFormat,
234 blocks: usize,
235 panels: usize,
236 ) -> TractResult<()> {
237 let m = from.r * panels;
238 let k = from.bq.block_len() * blocks;
239 let to = &extractor.to;
240 let weights_orig =
241 Array2::from_shape_fn((m, k), |(m, k)| ((m * 31 + k * 17) % 20) as f32 - 10.)
242 .into_tensor()
243 .cast_to_dt(to.dt)?
244 .into_owned();
245 let weights = if to.dt == f32::datum_type() {
246 from.bq
247 .dequant_f32(&from.bq.quant_f32(weights_orig.try_as_dense()?.as_slice::<f32>()?)?)?
248 .into_shape(&[m, k])?
249 } else {
250 from.bq
251 .dequant_f16(&from.bq.quant_f16(weights_orig.try_as_dense()?.as_slice::<f16>()?)?)?
252 .into_shape(&[m, k])?
253 };
254 let block_quant = if to.dt == f32::datum_type() {
255 from.bq.quant_f32(weights.try_as_dense()?.as_slice::<f32>()?)?
256 } else {
257 from.bq.quant_f16(weights.try_as_dense()?.as_slice::<f16>()?)?
258 };
259 let packed_block_quant =
260 from.bq.pack(&block_quant, k, from.r, from.zip, from.scales_at_end)?;
261
262 let mut reference_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
263 let mut tested_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
264
265 for panel in 0..packed_block_quant.panels_count() {
266 unsafe {
267 from.bq.extract_packed_panel(
268 &packed_block_quant,
269 to,
270 panel,
271 reference_panel.as_bytes_mut().as_mut_ptr(),
272 )?;
273
274 let source =
275 packed_block_quant.packed.as_ptr().add(packed_block_quant.panel_bytes * panel);
276 (extractor.kernel)(source, tested_panel.as_bytes_mut().as_mut_ptr(), k);
277 }
278 compare_panels(&tested_panel, &reference_panel, from.r, k);
279 }
280 Ok(())
281 }
282
283 fn compare_panels(tested_panel: &Tensor, reference_panel: &Tensor, r: usize, k: usize) {
284 if tested_panel != reference_panel {
285 if reference_panel.datum_type() == f32::datum_type() {
286 crate::frame::mmm::tests::display_error(
287 tested_panel.try_as_dense().unwrap().as_slice::<f32>().unwrap(),
288 reference_panel.try_as_dense().unwrap().as_slice::<f32>().unwrap(),
289 r,
290 k,
291 );
292 } else {
293 crate::frame::mmm::tests::display_error(
294 tested_panel.try_as_dense().unwrap().as_slice::<f16>().unwrap(),
295 reference_panel.try_as_dense().unwrap().as_slice::<f16>().unwrap(),
296 r,
297 k,
298 );
299 }
300 }
301 assert_eq!(tested_panel, reference_panel);
302 }
303}