1use std::fmt::{Debug, Display};
2use tract_data::internal::*;
3
4use super::{EagerPackedInput, MMMInputFormat, MMMInputValue};
5use crate::pack::PackedFormat;
6
7type Kernel = unsafe fn(input: *const u8, output: *mut u8, k: usize);
8
9#[allow(clippy::derived_hash_with_manual_eq)]
10#[derive(Hash, Clone)]
11pub struct PanelExtractor {
12 pub name: String,
13 pub from: Box<dyn MMMInputFormat>,
14 pub to: PackedFormat,
15 pub kernel: Kernel,
16 pub supported_predicate: fn() -> bool,
17}
18
19impl Debug for PanelExtractor {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 write!(f, "{} ({:?} -> {:?})", self.name, self.from, self.to)
22 }
23}
24
25impl Display for PanelExtractor {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 write!(f, "{}", self.name)
28 }
29}
30
31impl PartialEq for PanelExtractor {
32 fn eq(&self, other: &Self) -> bool {
33 self.name == other.name && self.from.same_as(&*other.from) && self.to == other.to
34 }
35}
36
37impl PanelExtractor {
38 #[allow(unused_variables)]
39 pub fn is_supported_here(&self) -> bool {
40 (self.supported_predicate)()
41 }
42}
43
44#[derive(Clone, Hash)]
45pub struct PanelExtractInput {
46 pub format: PanelExtractor,
47 pub data: EagerPackedInput,
48}
49
50impl MMMInputValue for PanelExtractInput {
51 fn scratch_panel_buffer_layout(&self) -> Option<std::alloc::Layout> {
52 Some(self.format.to.single_panel_layout(self.data.k(), self.format.to.dt.size_of()))
53 }
54 fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> TractResult<*const u8> {
55 let scratch = buffer.unwrap();
56 unsafe {
57 let source = self.data.packed.as_ptr().add(self.data.panel_bytes * i);
58 (self.format.kernel)(source, scratch, self.data.k());
59 }
60 Ok(scratch)
61 }
62 fn mn(&self) -> usize {
63 self.data.mn()
64 }
65 fn k(&self) -> usize {
66 self.data.k()
67 }
68 fn format(&self) -> &dyn MMMInputFormat {
69 &self.format.to
70 }
71 fn opaque_fact(&self) -> &dyn OpaqueFact {
72 self.data.opaque_fact()
73 }
74 fn same_as(&self, other: &dyn MMMInputValue) -> bool {
75 other
76 .downcast_ref::<Self>()
77 .is_some_and(|o| o.format == self.format && o.data.same_as(&self.data))
78 }
79 fn extract_at_mn_f16(&self, mn: usize, slice: &mut [f16]) -> TractResult<()> {
80 self.data.extract_at_mn_f16(mn, slice)
81 }
82 fn extract_at_mn_f32(&self, mn: usize, slice: &mut [f32]) -> TractResult<()> {
83 self.data.extract_at_mn_f32(mn, slice)
84 }
85}
86
87impl Display for PanelExtractInput {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 write!(f, "PanelExtract({})", self.data)
90 }
91}
92
93impl Debug for PanelExtractInput {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 write!(f, "PanelExtract({})", self.data)
96 }
97}
98
99#[macro_export]
100macro_rules! panel_extractor {
101 ( $func:path as $id:ident($from:expr, $to: expr)
102 $(where($where:expr))?
103 ) => {
104 paste! {
105 lazy_static::lazy_static! {
106 pub static ref $id: $crate::mmm::PanelExtractor = {
107 use $crate::mmm::MMMInputFormat;
108 let (from, to) = ($from, $to);
109 assert!(from.r() == to.r());
110 #[allow(unused_mut)]
111 let mut it = $crate::mmm::PanelExtractor {
112 name: stringify!($id).to_string(),
113 from,
114 to,
115 kernel: $func,
116 supported_predicate: || true
117 };
118 $(
119 it.supported_predicate = $where;
120 )?
121 it
122 };
123 }
124
125 #[cfg(test)]
126 mod [<test_$id>] {
127 use super::$id;
128 #[test]
129 fn repack_0block_1panel() {
130 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 0, 1).unwrap();
131 }
132
133 #[test]
134 fn repack_1block_0panel() {
135 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 1, 0).unwrap();
136 }
137
138 #[test]
139 fn repack_1block_1panel() {
140 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 1, 1).unwrap();
141 }
142
143 #[test]
144 fn repack_2block_1panel() {
145 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 2, 1).unwrap();
146 }
147
148 #[test]
149 fn repack_1block_2panel() {
150 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 1, 2).unwrap();
151 }
152
153 #[test]
154 fn repack_2block_2panel() {
155 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 2, 2).unwrap();
156 }
157 }
158 }
159 };
160}
161
162#[cfg(test)]
163pub mod test {
164 use crate::frame::block_quant::PackedBlockQuantFormat;
165 use tract_data::internal::*;
166 use tract_ndarray::Array2;
167
168 use super::*;
169
170 pub fn test_packing(
171 extractor: &PanelExtractor,
172 blocks: usize,
173 panels: usize,
174 ) -> TractResult<()> {
175 if !extractor.is_supported_here() {
176 return Ok(());
177 }
178 assert!(extractor.from.r() == extractor.to.r());
179 assert!(extractor.to.dt == f32::datum_type() || extractor.to.dt == f16::datum_type());
180 if let Some(from) = extractor.from.downcast_ref::<PackedBlockQuantFormat>() {
181 test_packing_bq(extractor, from, blocks, panels)
182 } else if let Some(from) = extractor.from.downcast_ref() {
183 test_packing_plain(extractor, from, blocks, panels)
184 } else {
185 todo!()
186 }
187 }
188
189 pub fn test_packing_plain(
190 extractor: &PanelExtractor,
191 from: &PackedFormat,
192 blocks: usize,
193 panels: usize,
194 ) -> TractResult<()> {
195 let m = from.r * panels;
196 let k = 8 * blocks; let to = &extractor.to;
198 let weights_orig =
199 Array2::from_shape_fn((m, k), |(m, k)| ((m * 31 + k * 17) % 20) as f32 - 10.)
200 .into_tensor()
201 .cast_to_dt(from.dt)?
202 .into_owned();
203 let packed_orig = from.prepare_tensor(&weights_orig, 1, 0)?;
204 let packed_orig =
205 packed_orig.to_scalar::<Opaque>()?.downcast_ref::<Box<dyn MMMInputValue>>().unwrap();
206 let packed_orig = packed_orig.downcast_ref::<EagerPackedInput>().unwrap();
207
208 for panel in 0..panels {
209 let orig_panel = &packed_orig.packed[packed_orig.panel_bytes * panel..]
210 [..k * from.r * from.dt.size_of()];
211 let mut reference_panel = Tensor::zero_dt(from.dt, &[k, from.r])?;
212 reference_panel.as_bytes_mut().copy_from_slice(orig_panel);
213 reference_panel = reference_panel.cast_to_dt(to.dt)?.into_owned();
214
215 let mut tested_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
216 unsafe {
217 (extractor.kernel)(
218 orig_panel.as_ptr(),
219 tested_panel.as_bytes_mut().as_mut_ptr(),
220 k,
221 );
222 }
223 compare_panels(&tested_panel, &reference_panel, from.r, k);
224 }
225 Ok(())
226 }
227
228 pub fn test_packing_bq(
229 extractor: &PanelExtractor,
230 from: &PackedBlockQuantFormat,
231 blocks: usize,
232 panels: usize,
233 ) -> TractResult<()> {
234 let m = from.r * panels;
235 let k = from.bq.block_len() * blocks;
236 let to = &extractor.to;
237 let weights_orig =
238 Array2::from_shape_fn((m, k), |(m, k)| ((m * 31 + k * 17) % 20) as f32 - 10.)
239 .into_tensor()
240 .cast_to_dt(to.dt)?
241 .into_owned();
242 let weights = if to.dt == f32::datum_type() {
243 from.bq
244 .dequant_f32(&from.bq.quant_f32(weights_orig.as_slice::<f32>()?)?)?
245 .into_shape(&[m, k])?
246 } else {
247 from.bq
248 .dequant_f16(&from.bq.quant_f16(weights_orig.as_slice::<f16>()?)?)?
249 .into_shape(&[m, k])?
250 };
251 let block_quant = if to.dt == f32::datum_type() {
252 from.bq.quant_f32(weights.as_slice::<f32>()?)?
253 } else {
254 from.bq.quant_f16(weights.as_slice::<f16>()?)?
255 };
256 let packed_block_quant =
257 from.bq.pack(&block_quant, k, from.r, from.zip, from.scales_at_end)?;
258
259 let mut reference_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
260 let mut tested_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
261
262 for panel in 0..packed_block_quant.panels_count() {
263 unsafe {
264 from.bq.extract_packed_panel(
265 &packed_block_quant,
266 to,
267 panel,
268 reference_panel.as_bytes_mut().as_mut_ptr(),
269 )?;
270
271 let source =
272 packed_block_quant.packed.as_ptr().add(packed_block_quant.panel_bytes * panel);
273 (extractor.kernel)(source, tested_panel.as_bytes_mut().as_mut_ptr(), k);
274 }
275 compare_panels(&tested_panel, &reference_panel, from.r, k);
276 }
277 Ok(())
278 }
279
280 fn compare_panels(tested_panel: &Tensor, reference_panel: &Tensor, r: usize, k: usize) {
281 if tested_panel != reference_panel {
282 if reference_panel.datum_type() == f32::datum_type() {
283 crate::frame::mmm::tests::display_error(
284 tested_panel.as_slice::<f32>().unwrap(),
285 reference_panel.as_slice::<f32>().unwrap(),
286 r,
287 k,
288 );
289 } else {
290 crate::frame::mmm::tests::display_error(
291 tested_panel.as_slice::<f16>().unwrap(),
292 reference_panel.as_slice::<f16>().unwrap(),
293 r,
294 k,
295 );
296 }
297 }
298 assert_eq!(tested_panel, reference_panel);
299 }
300}