1use std::fmt::{Debug, Display};
2use tract_data::internal::*;
3
4use super::{EagerPackedInput, MMMInputFormat, MMMInputValue};
5use crate::pack::PackedFormat;
6
7type Kernel = unsafe fn(input: *const u8, output: *mut u8, k: usize);
8
9#[allow(clippy::derived_hash_with_manual_eq)]
10#[derive(Hash, Clone)]
11pub struct PanelExtractor {
12 pub name: String,
13 pub from: Box<dyn MMMInputFormat>,
14 pub to: PackedFormat,
15 pub kernel: Kernel,
16 pub supported_predicate: fn() -> bool,
17}
18
19impl Debug for PanelExtractor {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 write!(f, "{} ({:?} -> {:?})", self.name, self.from, self.to)
22 }
23}
24
25impl Display for PanelExtractor {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 write!(f, "{}", self.name)
28 }
29}
30
31impl PartialEq for PanelExtractor {
32 fn eq(&self, other: &Self) -> bool {
33 self.name == other.name && self.from.same_as(&*other.from) && self.to == other.to
34 }
35}
36
37impl PanelExtractor {
38 #[allow(unused_variables)]
39 pub fn is_supported_here(&self) -> bool {
40 (self.supported_predicate)()
41 }
42}
43
44#[derive(Clone, Hash)]
45pub struct PanelExtractInput {
46 pub format: PanelExtractor,
47 pub data: EagerPackedInput,
48}
49
50impl MMMInputValue for PanelExtractInput {
51 fn scratch_panel_buffer_layout(&self) -> Option<std::alloc::Layout> {
52 Some(self.format.to.single_panel_layout(self.data.k(), self.format.to.dt.size_of()))
53 }
54 fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> TractResult<*const u8> {
55 let scratch = buffer.unwrap();
56 unsafe {
57 let source = self.data.packed.as_ptr().add(self.data.panel_bytes * i);
58 (self.format.kernel)(source, scratch, self.data.k());
59 }
60 Ok(scratch)
61 }
62 fn mn(&self) -> usize {
63 self.data.mn()
64 }
65 fn k(&self) -> usize {
66 self.data.k()
67 }
68 fn format(&self) -> &dyn MMMInputFormat {
69 &self.format.to
70 }
71 fn opaque_fact(&self) -> &dyn OpaqueFact {
72 self.data.opaque_fact()
73 }
74 fn same_as(&self, other: &dyn MMMInputValue) -> bool {
75 other
76 .downcast_ref::<Self>()
77 .is_some_and(|o| o.format == self.format && o.data.same_as(&self.data))
78 }
79 fn extract_at_mn_f16(&self, mn: usize, slice: &mut [f16]) -> TractResult<()> {
80 self.data.extract_at_mn_f16(mn, slice)
81 }
82 fn extract_at_mn_f32(&self, mn: usize, slice: &mut [f32]) -> TractResult<()> {
83 self.data.extract_at_mn_f32(mn, slice)
84 }
85}
86
87impl Display for PanelExtractInput {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 write!(f, "PanelExtract({})", self.data)
90 }
91}
92
93impl Debug for PanelExtractInput {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 write!(f, "PanelExtract({})", self.data)
96 }
97}
98
99#[macro_export]
100macro_rules! panel_extractor {
101 ( $func:path as $id:ident($from:expr, $to: expr)
102 $(where($where:expr))?
103 ) => {
104 paste! {
105 lazy_static::lazy_static! {
106 pub static ref $id: $crate::mmm::PanelExtractor = {
107 use $crate::mmm::MMMInputFormat;
108 let (from, to) = ($from, $to);
109 assert!(from.r() == to.r());
110 #[allow(unused_mut)]
111 let mut it = $crate::mmm::PanelExtractor {
112 name: stringify!($id).to_string(),
113 from,
114 to,
115 kernel: $func,
116 supported_predicate: || true
117 };
118 $(
119 it.supported_predicate = $where;
120 )?
121 it
122 };
123 }
124
125 #[cfg(test)]
126 mod [<test_$id>] {
127 use super::$id;
128 #[test]
129 fn repack_1block_1panel() {
130 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 1, 1).unwrap();
131 }
132
133 #[test]
134 fn repack_2block_1panel() {
135 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 2, 1).unwrap();
136 }
137
138 #[test]
139 fn repack_1block_2panel() {
140 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 1, 2).unwrap();
141 }
142
143 #[test]
144 fn repack_2block_2panel() {
145 $crate::frame::mmm::panel_extract::test::test_packing(&$id, 2, 2).unwrap();
146 }
147 }
148 }
149 };
150}
151
152#[cfg(test)]
153pub mod test {
154 use crate::frame::block_quant::PackedBlockQuantFormat;
155 use tract_data::internal::*;
156 use tract_ndarray::Array2;
157
158 use super::*;
159
160 pub fn test_packing(
161 extractor: &PanelExtractor,
162 blocks: usize,
163 panels: usize,
164 ) -> TractResult<()> {
165 if !extractor.is_supported_here() {
166 return Ok(());
167 }
168 assert!(extractor.from.r() == extractor.to.r());
169 assert!(extractor.to.dt == f32::datum_type() || extractor.to.dt == f16::datum_type());
170 if let Some(from) = extractor.from.downcast_ref::<PackedBlockQuantFormat>() {
171 test_packing_bq(extractor, from, blocks, panels)
172 } else if let Some(from) = extractor.from.downcast_ref() {
173 test_packing_plain(extractor, from, blocks, panels)
174 } else {
175 todo!()
176 }
177 }
178
179 pub fn test_packing_plain(
180 extractor: &PanelExtractor,
181 from: &PackedFormat,
182 blocks: usize,
183 panels: usize,
184 ) -> TractResult<()> {
185 let m = from.r * panels;
186 let k = 8 * blocks; let to = &extractor.to;
188 let weights_orig =
189 Array2::from_shape_fn((m, k), |(m, k)| ((m * 31 + k * 17) % 20) as f32 - 10.)
190 .into_tensor()
191 .cast_to_dt(from.dt)?
192 .into_owned();
193 let packed_orig = from.prepare_tensor(&weights_orig, 1, 0)?;
194 let packed_orig =
195 packed_orig.to_scalar::<Opaque>()?.downcast_ref::<Box<dyn MMMInputValue>>().unwrap();
196 let packed_orig = packed_orig.downcast_ref::<EagerPackedInput>().unwrap();
197
198 for panel in 0..panels {
199 let orig_panel = &packed_orig.packed[packed_orig.panel_bytes * panel..]
200 [..k * from.r * from.dt.size_of()];
201 let mut reference_panel = Tensor::zero_dt(from.dt, &[k, from.r])?;
202 reference_panel.as_bytes_mut().copy_from_slice(orig_panel);
203 reference_panel = reference_panel.cast_to_dt(to.dt)?.into_owned();
204
205 let mut tested_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
206 unsafe {
207 (extractor.kernel)(
208 orig_panel.as_ptr(),
209 tested_panel.as_bytes_mut().as_mut_ptr(),
210 k,
211 );
212 }
213 compare_panels(&tested_panel, &reference_panel, from.r, k);
214 }
215 Ok(())
216 }
217
218 pub fn test_packing_bq(
219 extractor: &PanelExtractor,
220 from: &PackedBlockQuantFormat,
221 blocks: usize,
222 panels: usize,
223 ) -> TractResult<()> {
224 let m = from.r * panels;
225 let k = from.bq.block_len() * blocks;
226 let to = &extractor.to;
227 let weights_orig =
228 Array2::from_shape_fn((m, k), |(m, k)| ((m * 31 + k * 17) % 20) as f32 - 10.)
229 .into_tensor()
230 .cast_to_dt(to.dt)?
231 .into_owned();
232 let weights = if to.dt == f32::datum_type() {
233 from.bq
234 .dequant_f32(&from.bq.quant_f32(weights_orig.as_slice::<f32>()?)?)?
235 .into_shape(&[m, k])?
236 } else {
237 from.bq
238 .dequant_f16(&from.bq.quant_f16(weights_orig.as_slice::<f16>()?)?)?
239 .into_shape(&[m, k])?
240 };
241 let block_quant = if to.dt == f32::datum_type() {
242 from.bq.quant_f32(weights.as_slice::<f32>()?)?
243 } else {
244 from.bq.quant_f16(weights.as_slice::<f16>()?)?
245 };
246 let packed_block_quant =
247 from.bq.pack(&block_quant, k, from.r, from.zip, from.scales_at_end)?;
248
249 let mut reference_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
250 let mut tested_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
251
252 for panel in 0..packed_block_quant.panels_count() {
253 unsafe {
254 from.bq.extract_packed_panel(
255 &packed_block_quant,
256 to,
257 panel,
258 reference_panel.as_bytes_mut().as_mut_ptr(),
259 )?;
260
261 let source =
262 packed_block_quant.packed.as_ptr().add(packed_block_quant.panel_bytes * panel);
263 (extractor.kernel)(source, tested_panel.as_bytes_mut().as_mut_ptr(), k);
264 }
265 compare_panels(&tested_panel, &reference_panel, from.r, k);
266 }
267 Ok(())
268 }
269
270 fn compare_panels(tested_panel: &Tensor, reference_panel: &Tensor, r: usize, k: usize) {
271 if tested_panel != reference_panel {
272 if reference_panel.datum_type() == f32::datum_type() {
273 crate::frame::mmm::tests::display_error(
274 tested_panel.as_slice::<f32>().unwrap(),
275 reference_panel.as_slice::<f32>().unwrap(),
276 r,
277 k,
278 );
279 } else {
280 crate::frame::mmm::tests::display_error(
281 tested_panel.as_slice::<f16>().unwrap(),
282 reference_panel.as_slice::<f16>().unwrap(),
283 r,
284 k,
285 );
286 }
287 }
288 assert_eq!(tested_panel, reference_panel);
289 }
290}