1#![allow(clippy::needless_range_loop)]
2use num_traits::AsPrimitive;
3
4use tract_data::prelude::f16;
5use tract_data::prelude::*;
6
7use super::*;
8use crate::frame::block_quant::{BlockQuant, NibbleReader, PackedBlockQuantFormat, Q4_0};
9use crate::frame::mmm::*;
10use crate::{has_fp16, LADatum, Ops};
11
12macro_rules! scalar {
13 ($ab: expr, $m: expr, $f: expr) => {
14 for i in 0..$ab.len() {
15 for j in 0..$ab[0].len() {
16 $ab[i][j] = $f($m, $ab[i][j])
17 }
18 }
19 };
20}
21
22macro_rules! per_row {
23 ($ab: expr, $m: expr, $f: expr) => {
24 for i in 0..$ab.len() {
25 for j in 0..$ab[0].len() {
26 $ab[i][j] = $f(*$m.add(i), $ab[i][j])
27 }
28 }
29 };
30}
31
32macro_rules! per_col {
33 ($ab: expr, $m: expr, $f: expr) => {
34 for i in 0..$ab.len() {
35 for j in 0..$ab[0].len() {
36 $ab[i][j] = $f(*$m.add(j), $ab[i][j])
37 }
38 }
39 };
40}
41
42unsafe fn add_mat_mul<const MR: usize, const NR: usize, TI, TA, TB>(
43 pa: *const u8,
44 pb: *const u8,
45 k: usize,
46 ab: &mut [[TI; NR]; MR],
47) where
48 TA: LADatum + AsPrimitive<TI>,
49 TB: LADatum + AsPrimitive<TI>,
50 TI: LADatum,
51{
52 unsafe {
53 let a = pa as *const TA;
54 let b = pb as *const TB;
55 for ik in 0..k {
56 let a = std::slice::from_raw_parts(a.add(MR * ik), MR);
57 let b = std::slice::from_raw_parts(b.add(NR * ik), NR);
58 for i in 0..MR {
59 for j in 0..NR {
60 ab[i][j] += a[i].as_() * b[j].as_();
61 }
62 }
63 }
64 }
65}
66
67unsafe fn add_mat_mul_pq40<const MR: usize, const NR: usize, TB, TI>(
68 pa: *const u8,
69 pb: *const u8,
70 k: usize,
71 ab: &mut [[TI; NR]; MR],
72) where
73 TI: LADatum,
74 f16: AsPrimitive<TI>,
75 TB: AsPrimitive<TI>,
76 i8: AsPrimitive<TI>,
77{
78 unsafe {
79 assert!(k % Q4_0.block_len() == 0);
80 let len = (k * MR) / Q4_0.block_len() * Q4_0.block_bytes();
81 let mut pa = NibbleReader::for_slice(std::slice::from_raw_parts(pa, len));
82 let b = pb as *const TB;
83 for bk in 0..k / 32 {
84 let mut scales: [TI; MR] = [TI::zero(); MR];
85 scales.iter_mut().for_each(|x| *x = pa.read_f16().as_());
86 for ik in 0..32 {
87 let mut a: [TI; MR] = [TI::zero(); MR];
88 a.iter_mut().zip(&scales).for_each(|(x, s)| *x = *s * (pa.read_i4() - 8).as_());
89 let b = std::slice::from_raw_parts(b.add(NR * (ik + 32 * bk)), NR);
90 for i in 0..MR {
91 for j in 0..NR {
92 ab[i][j] += a[i] * b[j].as_();
93 }
94 }
95 }
96 }
97 }
98}
99
100unsafe fn add_mat_mul_pq40_scales_at_end<const MR: usize, const NR: usize, TB, TI>(
101 pa: *const u8,
102 pb: *const u8,
103 k: usize,
104 ab: &mut [[TI; NR]; MR],
105) where
106 TI: LADatum,
107 f16: AsPrimitive<TI>,
108 TB: AsPrimitive<TI>,
109 i8: AsPrimitive<TI>,
110{
111 unsafe {
112 assert!(k % Q4_0.block_len() == 0);
113 let len = (k * MR) / Q4_0.block_len() * Q4_0.block_bytes();
114 let mut pa = NibbleReader::for_slice(std::slice::from_raw_parts(pa, len));
115 let b = pb as *const TB;
116 for bk in 0..k / 32 {
117 let mut temp = [[TI::zero(); NR]; MR];
118 for ik in 0..32 {
119 let mut a: [TI; MR] = [TI::zero(); MR];
120 a.iter_mut().for_each(|x| *x = (pa.read_i4() - 8).as_());
121 let b = std::slice::from_raw_parts(b.add(NR * (ik + 32 * bk)), NR);
122 for i in 0..MR {
123 for j in 0..NR {
124 temp[i][j] += a[i] * b[j].as_();
125 }
126 }
127 }
128 for i in 0..MR {
129 let scale = pa.read_f16().as_();
130 for j in 0..NR {
131 ab[i][j] += temp[i][j] * scale;
132 }
133 }
134 }
135 }
136}
137
138unsafe fn add_unicast<const MR: usize, const NR: usize, TI, TO>(
139 ab: &mut [[TI; NR]; MR],
140 other: &OutputStoreKer,
141) where
142 TI: LADatum,
143 TO: LADatum + AsPrimitive<TI>,
144{
145 unsafe {
146 for i in 0usize..MR {
147 for j in 0usize..NR {
148 let value: *const TO = other
149 .ptr
150 .offset(other.row_byte_stride * i as isize + other.col_byte_stride * j as isize)
151 as _;
152 ab[i].as_mut()[j] += (*value).as_();
153 }
154 }
155 }
156}
157
158unsafe fn store_t<const MR: usize, const NR: usize, TC, TI>(
159 tile: &OutputStoreKer,
160 ab: &[[TI; NR]; MR],
161) where
162 TC: Copy,
163{
164 unsafe {
165 for i in 0usize..MR {
166 for j in 0usize..NR {
167 let loc: *mut TC = tile
168 .ptr
169 .offset(tile.row_byte_stride * i as isize + tile.col_byte_stride * j as isize)
170 as _;
171 let val: *const TC = (&ab[i].as_ref()[j]) as *const TI as _;
172 *loc = *val
173 }
174 }
175 }
176}
177
178unsafe fn store_float_t<const MR: usize, const NR: usize, TC, TI>(
179 tile: &OutputStoreKer,
180 ab: &[[TI; NR]; MR],
181) where
182 TC: Copy + 'static,
183 TI: Copy + 'static + AsPrimitive<TC>,
184{
185 unsafe {
186 for i in 0usize..MR {
187 for j in 0usize..NR {
188 let loc: *mut TC = tile
189 .ptr
190 .offset(tile.row_byte_stride * i as isize + tile.col_byte_stride * j as isize)
191 as _;
192 let val = ab[i].as_ref()[j].as_();
193 *loc = val
194 }
195 }
196 }
197}
198
199#[inline(never)]
200unsafe fn kernel<TI, const MR: usize, const NR: usize>(mut pnl: *const FusedKerSpec<TI>) -> isize
201where
202 TI: LADatum + ScaleShiftAndRound + AsPrimitive<TI>,
203 TI: AsPrimitive<f16> + AsPrimitive<f32> + AsPrimitive<f64>,
204 usize: AsPrimitive<TI>,
205 f16: AsPrimitive<TI>,
206 f32: AsPrimitive<TI>,
207 f64: AsPrimitive<TI>,
208 i8: AsPrimitive<TI>,
209 i32: AsPrimitive<TI>,
210{
211 unsafe {
212 let mut ab = [[TI::zero(); NR]; MR];
213 loop {
214 if pnl.is_null() {
215 break;
216 }
217 match *pnl {
218 FusedKerSpec::Done => break,
219 FusedKerSpec::Clear => ab = std::mem::zeroed(),
220 FusedKerSpec::LoadTile(col_major, _row_major) => {
221 for row in 0..MR {
222 for col in 0..NR {
223 ab[row][col] = *col_major.add(col * MR + row);
224 }
225 }
226 }
227 FusedKerSpec::ScalarAdd(a) => scalar!(ab, a, |a, b| a + b),
228 FusedKerSpec::ScalarMul(a) => scalar!(ab, a, |a, b| a * b),
229 FusedKerSpec::ScalarMin(m) => scalar!(ab, m, |a, b| if a < b { a } else { b }),
230 FusedKerSpec::ScalarMax(m) => scalar!(ab, m, |a, b| if a > b { a } else { b }),
231 FusedKerSpec::ScalarSub(m) => scalar!(ab, m, |a, b| a - b),
232 FusedKerSpec::ScalarSubF(m) => scalar!(ab, m, |a, b| b - a),
233 FusedKerSpec::LeakyRelu(m) => {
234 scalar!(ab, m, |a, b| if b > TI::zero() { b } else { a * b })
235 }
236 FusedKerSpec::PerRowMin(m) => per_row!(ab, m, |a, b| if a < b { a } else { b }),
237 FusedKerSpec::PerRowMax(m) => per_row!(ab, m, |a, b| if a > b { a } else { b }),
238 FusedKerSpec::PerRowAdd(m) => per_row!(ab, m, |a, b| a + b),
239 FusedKerSpec::PerRowMul(m) => per_row!(ab, m, |a, b| a * b),
240 FusedKerSpec::PerRowSub(m) => per_row!(ab, m, |a, b| a - b),
241 FusedKerSpec::PerRowSubF(m) => per_row!(ab, m, |a, b| b - a),
242 FusedKerSpec::PerColMin(m) => per_col!(ab, m, |a, b| if a < b { a } else { b }),
243 FusedKerSpec::PerColMax(m) => per_col!(ab, m, |a, b| if a > b { a } else { b }),
244 FusedKerSpec::PerColAdd(m) => per_col!(ab, m, |a, b| a + b),
245 FusedKerSpec::PerColMul(m) => per_col!(ab, m, |a, b| a * b),
246 FusedKerSpec::PerColSub(m) => per_col!(ab, m, |a, b| a - b),
247 FusedKerSpec::PerColSubF(m) => per_col!(ab, m, |a, b| b - a),
248 FusedKerSpec::AddRowColProducts(rows, cols) => {
249 for i in 0..MR {
250 for j in 0..NR {
251 ab[i][j] += *rows.add(i) * *cols.add(j);
252 }
253 }
254 }
255 FusedKerSpec::AddUnicast(other) => {
256 if TI::datum_type().is_float() && other.item_size == 2 {
257 add_unicast::<MR, NR, TI, f16>(&mut ab, &other)
258 } else if TI::datum_type().is_float() && other.item_size == 4 {
259 add_unicast::<MR, NR, TI, f32>(&mut ab, &other)
260 } else if TI::datum_type().is_float() && other.item_size == 8 {
261 add_unicast::<MR, NR, TI, f64>(&mut ab, &other)
262 } else if TI::datum_type() == i32::datum_type() && other.item_size == 1 {
263 add_unicast::<MR, NR, TI, i8>(&mut ab, &other)
264 } else if TI::datum_type() == i32::datum_type() && other.item_size == 4 {
265 add_unicast::<MR, NR, TI, i32>(&mut ab, &other)
266 } else {
267 unimplemented!("Missing AddUnicast type");
268 }
269 }
270 FusedKerSpec::ShiftLeft(shift) => {
271 for i in 0..MR {
272 for j in 0..NR {
273 ab[i][j] = ab[i][j].q_shl(shift);
274 }
275 }
276 }
277 FusedKerSpec::RoundingShiftRight(shift, rp) => {
278 for i in 0..MR {
279 for j in 0..NR {
280 ab[i][j] = ab[i][j].q_shr(shift, rp);
281 }
282 }
283 }
284 FusedKerSpec::QScale(shift, rp, mult) => {
285 for i in 0..MR {
286 for j in 0..NR {
287 ab[i][j] = ab[i][j].q_scale(Scaler::from_fuse_params(shift, rp, mult));
288 }
289 }
290 }
291 FusedKerSpec::AddMatMul { k, pa, pb, packing } => {
292 use std::mem::transmute;
293 if TI::datum_type().is_float() {
294 match packing {
295 0 => add_mat_mul::<MR, NR, TI, TI, TI>(pa, pb, k, &mut ab),
296 1 => add_mat_mul::<MR, NR, TI, f16, f16>(pa, pb, k, &mut ab),
297 2 => add_mat_mul::<MR, NR, TI, f32, f32>(pa, pb, k, &mut ab),
298 3 => add_mat_mul::<MR, NR, TI, f16, f32>(pa, pb, k, &mut ab),
299 4 => add_mat_mul::<MR, NR, TI, f32, f16>(pa, pb, k, &mut ab),
300 5 => add_mat_mul_pq40::<MR, NR, f16, TI>(pa, pb, k, &mut ab),
301 6 => add_mat_mul_pq40_scales_at_end::<MR, NR, f16, TI>(
302 pa, pb, k, &mut ab,
303 ),
304 7 => add_mat_mul_pq40::<MR, NR, f32, TI>(pa, pb, k, &mut ab),
305 _ => unreachable!(),
306 }
307 } else if TI::datum_type() == i32::datum_type() {
308 let ab = transmute::<&mut [[TI; NR]; MR], &mut [[i32; NR]; MR]>(&mut ab);
310 if packing == 0 {
311 add_mat_mul::<MR, NR, i32, i32, i32>(pa, pb, k, ab)
312 } else if packing == 1 {
313 add_mat_mul::<MR, NR, i32, i8, i8>(pa, pb, k, ab)
314 } else {
315 return 1;
316 }
317 } else {
318 return 1;
319 }
320 }
321 FusedKerSpec::Store(tile) => {
322 if TI::datum_type().is_float() {
323 match tile.item_size {
324 2 => store_float_t::<MR, NR, f16, _>(&tile, &ab),
325 4 => store_float_t::<MR, NR, f32, _>(&tile, &ab),
326 8 => store_float_t::<MR, NR, f64, _>(&tile, &ab),
327 _ => unimplemented!(),
328 }
329 } else {
330 match tile.item_size {
331 1 => store_t::<MR, NR, u8, _>(&tile, &ab),
332 2 => store_t::<MR, NR, u16, _>(&tile, &ab),
333 4 => store_t::<MR, NR, u32, _>(&tile, &ab),
334 8 => store_t::<MR, NR, u64, _>(&tile, &ab),
335 _ => unimplemented!(),
336 }
337 }
338 }
339 };
340 pnl = pnl.add(1);
341 }
342 }
343 0
344}
345
346fn pq40_r4() -> PackedBlockQuantFormat {
347 PackedBlockQuantFormat::new(&Q4_0, 4, 0, false)
348}
349
350fn pq40_r4_se() -> PackedBlockQuantFormat {
351 PackedBlockQuantFormat::new(&Q4_0, 4, 0, true)
352}
353
354MMMRustKernel!(kernel::<f16, 4, 4> => generic_f16_4x4<f16>(4,4)
356 packing[1] = f16f16bis => |k| k.with_packing(f16::packing(4), f16::packing(4));
357 packing[2] = f32f32 => |k| k.with_packing(f32::packing(4), f32::packing(4));
358 packing[3] = f16f32 => |k| k.with_packing(f16::packing(4), f32::packing(4));
359 packing[4] = f32f16 => |k| k.with_packing(f32::packing(4), f16::packing(4));
360 packing[5] = q40f16 => |k| k.with_packing(pq40_r4(), f16::packing(4));
361 packing[6] = q40f16se => |k| k.with_packing(pq40_r4_se(), f16::packing(4));
362 packing[7] = q40f32 => |k| k.with_packing(pq40_r4(), f32::packing(4));
363 quality(if has_fp16() { ImplementationQuality::Generic } else { ImplementationQuality::Dreadful })
364 store(f32, f64)
365);
366
367MMMRustKernel! {kernel::<f16, 4, 1> => generic_f16_4x1<f16>(4,1)
368 packing[1] = f16f16bis => |k| k.with_packing(f16::packing(4), f16::packing(1));
369 packing[2] = f32f32 => |k| k.with_packing(f32::packing(4), f32::packing(1));
370 packing[3] = f16f32 => |k| k.with_packing(f16::packing(4), f32::packing(1));
371 packing[4] = f32f16 => |k| k.with_packing(f32::packing(4), f16::packing(1));
372 packing[5] = q40f16 => |k| k.with_packing(pq40_r4(), f16::packing(1));
373 packing[6] = q40f16se => |k| k.with_packing(pq40_r4_se(), f16::packing(1));
374 packing[7] = q40f32 => |k| k.with_packing(pq40_r4(), f32::packing(1));
375 quality(if has_fp16() { ImplementationQuality::Generic } else { ImplementationQuality::Dreadful })
376 store(f32, f64)
377}
378
379MMMRustKernel!(kernel::<f32, 4, 4> => generic_f32_4x4<f32>(4,4)
381 packing[1] = f16f16 => |k| k.with_packing(f16::packing(4), f16::packing(4));
382 packing[2] = f32f32bis => |k| k.with_packing(f32::packing(4), f32::packing(4));
383 packing[3] = f16f32 => |k| k.with_packing(f16::packing(4), f32::packing(4));
384 packing[4] = f32f16 => |k| k.with_packing(f32::packing(4), f16::packing(4));
385 packing[5] = q40f16 => |k| k.with_packing(pq40_r4(), f16::packing(4));
386 packing[6] = q40f16se => |k| k.with_packing(pq40_r4_se(), f16::packing(4));
387 packing[7] = q40f32 => |k| k.with_packing(pq40_r4(), f32::packing(4));
388 quality(ImplementationQuality::Generic)
389 store(f16, f64)
390);
391MMMRustKernel! {kernel::<f32, 4, 1> => generic_f32_4x1<f32>(4,1)
392 packing[1] = f16f16 => |k| k.with_packing(f16::packing(4), f16::packing(1));
393 packing[2] = f32f32bis => |k| k.with_packing(f32::packing(4), f32::packing(1));
394 packing[3] = f16f32 => |k| k.with_packing(f16::packing(4), f32::packing(1));
395 packing[4] = f32f16 => |k| k.with_packing(f32::packing(4), f16::packing(1));
396 packing[5] = q40f16 => |k| k.with_packing(pq40_r4(), f16::packing(1));
397 packing[6] = q40f16se => |k| k.with_packing(pq40_r4_se(), f16::packing(1));
398 packing[7] = q40f32 => |k| k.with_packing(pq40_r4(), f32::packing(1));
399 quality(ImplementationQuality::Generic)
400 store(f16, f64)
401}
402
403MMMRustKernel!(kernel::<f64, 4, 4> => generic_f64_4x4<f64>(4,4)
405 quality(ImplementationQuality::Generic)
406 store(f16, f32));
407MMMRustKernel!(kernel::<f64, 4, 1> => generic_f64_4x1<f64>(4,1)
408 quality(ImplementationQuality::Generic)
409 store(f16, f32));
410
411MMMRustKernel! {kernel::<i32, 4, 4> => generic_i32_4x4<i32>(4,4)
413 packing[1] = i8i8 => |k| k.with_packing(i8::packing(4), i8::packing(4));
414 quality(ImplementationQuality::Generic)
415 store(i8)
416}
417
418MMMRustKernel! {kernel::<i32, 4, 1> => generic_i32_4x1<i32>(4,1)
419 packing[1] = i8i8 => |k| k.with_packing(i8::packing(4), i8::packing(1));
420 quality(ImplementationQuality::Generic)
421 store(i8)
422}
423
424#[cfg(test)]
426MMMRustKernel!(kernel::<f32, 3, 2> => generic_f32_3x2<f32>(3,2) store(f16, f64));
427
428#[cfg(test)]
429MMMRustKernel! {kernel::<i32, 3, 2> => generic_i32_3x2<i32>(3,2)
430 packing[1] = i8i8 => |k| k.with_packing(i8::packing(3), i8::packing(2));
431 store(i8)
432}
433
434pub fn plug(ops: &mut Ops) {
435 ops.mmm_impls.push(generic_f16_4x4.mmm());
436 ops.mmm_impls.push(generic_f16_4x1.mmm());
437 ops.mmm_impls.push(generic_f32_4x4.mmm());
438 ops.mmm_impls.push(generic_f32_4x1.mmm());
439 ops.mmm_impls.push(generic_f64_4x4.mmm());
440 ops.mmm_impls.push(generic_f64_4x1.mmm());
441 ops.mmm_impls.push(generic_i32_4x4.mmm());
442 ops.mmm_impls.push(generic_i32_4x1.mmm());
443}
444
445#[cfg(test)]
446mod test {
447
448 #[test]
449 fn kits() {
450 let mut ops = crate::generic();
451 super::plug(&mut ops);
452 }
453}