1use std::fmt::{Debug, Display};
2use tract_data::internal::*;
3
4use super::{EagerPackedInput, MMMInputFormat, MMMInputValue};
5use crate::pack::PackedFormat;
6
7type Kernel = unsafe fn(input: *const u8, output: *mut u8, k: usize);
8
9#[allow(clippy::derived_hash_with_manual_eq)]
10#[derive(Hash, Clone)]
11pub struct PanelExtractor {
12 pub name: String,
13 pub from: Box<dyn MMMInputFormat>,
14 pub to: PackedFormat,
15 pub kernel: Kernel,
16 pub supported_predicate: fn() -> bool,
17}
18
19impl Debug for PanelExtractor {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 write!(f, "{} ({:?} -> {:?})", self.name, self.from, self.to)
22 }
23}
24
25impl Display for PanelExtractor {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 write!(f, "{}", self.name)
28 }
29}
30
31impl PartialEq for PanelExtractor {
32 fn eq(&self, other: &Self) -> bool {
33 self.name == other.name && *self.from == *other.from && self.to == other.to
34 }
35}
36impl Eq for PanelExtractor {}
37
38impl PanelExtractor {
39 #[allow(unused_variables)]
40 pub fn is_supported_here(&self) -> bool {
41 (self.supported_predicate)()
42 }
43}
44
45#[derive(Clone, Hash, PartialEq, Eq)]
46pub struct PanelExtractInput {
47 pub format: PanelExtractor,
48 pub data: EagerPackedInput,
49}
50
51impl MMMInputValue for PanelExtractInput {
52 fn scratch_panel_buffer_layout(&self) -> Option<std::alloc::Layout> {
53 Some(self.format.to.single_panel_layout(self.data.k(), self.format.to.dt.size_of()))
54 }
55 fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> TractResult<*const u8> {
56 let scratch = buffer.unwrap();
57 unsafe {
58 let source = self.data.packed.as_ptr().add(self.data.panel_bytes * i);
59 (self.format.kernel)(source, scratch, self.data.k());
60 }
61 Ok(scratch)
62 }
63 fn mn(&self) -> usize {
64 self.data.mn()
65 }
66 fn k(&self) -> usize {
67 self.data.k()
68 }
69 fn format(&self) -> &dyn MMMInputFormat {
70 &self.format.to
71 }
72 fn exotic_fact(&self) -> &dyn ExoticFact {
73 self.data.exotic_fact()
74 }
75 fn extract_at_mn_f16(&self, mn: usize, slice: &mut [f16]) -> TractResult<()> {
76 self.data.extract_at_mn_f16(mn, slice)
77 }
78 fn extract_at_mn_f32(&self, mn: usize, slice: &mut [f32]) -> TractResult<()> {
79 self.data.extract_at_mn_f32(mn, slice)
80 }
81}
82
83impl Display for PanelExtractInput {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 write!(f, "PanelExtract({})", self.data)
86 }
87}
88
89impl Debug for PanelExtractInput {
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 write!(f, "PanelExtract({})", self.data)
92 }
93}
94
95#[macro_export]
96macro_rules! panel_extractor {
97 ( $func:path as $id:ident($from:expr, $to: expr)
98 $(where($where:expr))?
99 ) => {
100 paste! {
101 lazy_static::lazy_static! {
102 pub static ref $id: $crate::mmm::PanelExtractor = {
103 use $crate::mmm::MMMInputFormat;
104 let (from, to) = ($from, $to);
105 assert!(from.r() == to.r());
106 #[allow(unused_mut)]
107 let mut it = $crate::mmm::PanelExtractor {
108 name: stringify!($id).to_string(),
109 from,
110 to,
111 kernel: $func,
112 supported_predicate: || true
113 };
114 $(
115 it.supported_predicate = $where;
116 )?
117 it
118 };
119 }
120
121 #[cfg(test)]
122 mod [<test_$id>] {
123 use super::$id;
124 #[test]
125 fn repack_0block_1panel() {
126 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 0, 1).unwrap();
127 }
128
129 #[test]
130 fn repack_1block_0panel() {
131 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 1, 0).unwrap();
132 }
133
134 #[test]
135 fn repack_1block_1panel() {
136 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 1, 1).unwrap();
137 }
138
139 #[test]
140 fn repack_2block_1panel() {
141 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 2, 1).unwrap();
142 }
143
144 #[test]
145 fn repack_1block_2panel() {
146 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 1, 2).unwrap();
147 }
148
149 #[test]
150 fn repack_2block_2panel() {
151 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 2, 2).unwrap();
152 }
153 }
154 }
155 };
156}
157
158#[cfg(test)]
159pub mod test {
160 use crate::frame::block_quant::PackedBlockQuantFormat;
161 use crate::mmm::PackedMatrixStorage;
162 use tract_data::internal::*;
163 use tract_ndarray::Array2;
164
165 use super::*;
166
167 pub fn test_packing(
168 extractor: &PanelExtractor,
169 blocks: usize,
170 panels: usize,
171 ) -> TractResult<()> {
172 if !extractor.is_supported_here() {
173 return Ok(());
174 }
175 assert!(extractor.from.r() == extractor.to.r());
176 assert!(extractor.to.dt == f32::datum_type() || extractor.to.dt == f16::datum_type());
177 if let Some(from) = extractor.from.downcast_ref::<PackedBlockQuantFormat>() {
178 test_packing_bq(extractor, from, blocks, panels)
179 } else if let Some(from) = extractor.from.downcast_ref() {
180 test_packing_plain(extractor, from, blocks, panels)
181 } else {
182 todo!()
183 }
184 }
185
186 pub fn test_packing_plain(
187 extractor: &PanelExtractor,
188 from: &PackedFormat,
189 blocks: usize,
190 panels: usize,
191 ) -> TractResult<()> {
192 let m = from.r * panels;
193 let k = 8 * blocks; let to = &extractor.to;
195 let weights_orig =
196 Array2::from_shape_fn((m, k), |(m, k)| ((m * 31 + k * 17) % 20) as f32 - 10.)
197 .into_tensor()
198 .cast_to_dt(from.dt)?
199 .into_owned();
200 let packed_orig = from.prepare_tensor(&weights_orig, 1, 0)?;
201 let packed_orig_storage = packed_orig.try_storage_as::<PackedMatrixStorage>()?;
202 let packed_orig = packed_orig_storage.value().downcast_ref::<EagerPackedInput>().unwrap();
203
204 for panel in 0..panels {
205 let orig_panel = &packed_orig.packed[packed_orig.panel_bytes * panel..]
206 [..k * from.r * from.dt.size_of()];
207 let mut reference_panel = Tensor::zero_dt(from.dt, &[k, from.r])?;
208 reference_panel.as_bytes_mut().copy_from_slice(orig_panel);
209 reference_panel = reference_panel.cast_to_dt(to.dt)?.into_owned();
210
211 let mut tested_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
212 unsafe {
213 (extractor.kernel)(
214 orig_panel.as_ptr(),
215 tested_panel.as_bytes_mut().as_mut_ptr(),
216 k,
217 );
218 }
219 compare_panels(&tested_panel, &reference_panel, from.r, k);
220 }
221 Ok(())
222 }
223
224 pub fn test_packing_bq(
225 extractor: &PanelExtractor,
226 from: &PackedBlockQuantFormat,
227 blocks: usize,
228 panels: usize,
229 ) -> TractResult<()> {
230 let m = from.r * panels;
231 let k = from.bq.block_len() * blocks;
232 let to = &extractor.to;
233 let weights_orig =
234 Array2::from_shape_fn((m, k), |(m, k)| ((m * 31 + k * 17) % 20) as f32 - 10.)
235 .into_tensor()
236 .cast_to_dt(to.dt)?
237 .into_owned();
238 let weights = if to.dt == f32::datum_type() {
239 from.bq
240 .dequant_f32(&from.bq.quant_f32(weights_orig.try_as_plain()?.as_slice::<f32>()?)?)?
241 .into_shape(&[m, k])?
242 } else {
243 from.bq
244 .dequant_f16(&from.bq.quant_f16(weights_orig.try_as_plain()?.as_slice::<f16>()?)?)?
245 .into_shape(&[m, k])?
246 };
247 let block_quant = if to.dt == f32::datum_type() {
248 from.bq.quant_f32(weights.try_as_plain()?.as_slice::<f32>()?)?
249 } else {
250 from.bq.quant_f16(weights.try_as_plain()?.as_slice::<f16>()?)?
251 };
252 let packed_block_quant =
253 from.bq.pack(&block_quant, k, from.r, from.zip, from.scales_at_end)?;
254
255 let mut reference_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
256 let mut tested_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
257
258 for panel in 0..packed_block_quant.panels_count() {
259 unsafe {
260 from.bq.extract_packed_panel(
261 &packed_block_quant,
262 to,
263 panel,
264 reference_panel.as_bytes_mut().as_mut_ptr(),
265 )?;
266
267 let source =
268 packed_block_quant.packed.as_ptr().add(packed_block_quant.panel_bytes * panel);
269 (extractor.kernel)(source, tested_panel.as_bytes_mut().as_mut_ptr(), k);
270 }
271 compare_panels(&tested_panel, &reference_panel, from.r, k);
272 }
273 Ok(())
274 }
275
276 fn compare_panels(tested_panel: &Tensor, reference_panel: &Tensor, r: usize, k: usize) {
277 if tested_panel != reference_panel {
278 if reference_panel.datum_type() == f32::datum_type() {
279 crate::frame::mmm::tests::display_error(
280 tested_panel.try_as_plain().unwrap().as_slice::<f32>().unwrap(),
281 reference_panel.try_as_plain().unwrap().as_slice::<f32>().unwrap(),
282 r,
283 k,
284 );
285 } else {
286 crate::frame::mmm::tests::display_error(
287 tested_panel.try_as_plain().unwrap().as_slice::<f16>().unwrap(),
288 reference_panel.try_as_plain().unwrap().as_slice::<f16>().unwrap(),
289 r,
290 k,
291 );
292 }
293 }
294 assert_eq!(tested_panel, reference_panel);
295 }
296}