tract_linalg/generic/
mmm.rs

1#![allow(clippy::needless_range_loop)]
2use num_traits::AsPrimitive;
3
4use tract_data::prelude::f16;
5use tract_data::prelude::*;
6
7use super::*;
8use crate::frame::block_quant::{BlockQuant, NibbleReader, PackedBlockQuantFormat, Q4_0};
9use crate::frame::mmm::*;
10use crate::{has_fp16, LADatum, Ops};
11
12macro_rules! scalar {
13    ($ab: expr, $m: expr, $f: expr) => {
14        for i in 0..$ab.len() {
15            for j in 0..$ab[0].len() {
16                $ab[i][j] = $f($m, $ab[i][j])
17            }
18        }
19    };
20}
21
22macro_rules! per_row {
23    ($ab: expr, $m: expr, $f: expr) => {
24        for i in 0..$ab.len() {
25            for j in 0..$ab[0].len() {
26                $ab[i][j] = $f(*$m.add(i), $ab[i][j])
27            }
28        }
29    };
30}
31
32macro_rules! per_col {
33    ($ab: expr, $m: expr, $f: expr) => {
34        for i in 0..$ab.len() {
35            for j in 0..$ab[0].len() {
36                $ab[i][j] = $f(*$m.add(j), $ab[i][j])
37            }
38        }
39    };
40}
41
42unsafe fn add_mat_mul<const MR: usize, const NR: usize, TI, TA, TB>(
43    pa: *const u8,
44    pb: *const u8,
45    k: usize,
46    ab: &mut [[TI; NR]; MR],
47) where
48    TA: LADatum + AsPrimitive<TI>,
49    TB: LADatum + AsPrimitive<TI>,
50    TI: LADatum,
51{
52    unsafe {
53        let a = pa as *const TA;
54        let b = pb as *const TB;
55        for ik in 0..k {
56            let a = std::slice::from_raw_parts(a.add(MR * ik), MR);
57            let b = std::slice::from_raw_parts(b.add(NR * ik), NR);
58            for i in 0..MR {
59                for j in 0..NR {
60                    ab[i][j] += a[i].as_() * b[j].as_();
61                }
62            }
63        }
64    }
65}
66
67unsafe fn add_mat_mul_pq40<const MR: usize, const NR: usize, TB, TI>(
68    pa: *const u8,
69    pb: *const u8,
70    k: usize,
71    ab: &mut [[TI; NR]; MR],
72) where
73    TI: LADatum,
74    f16: AsPrimitive<TI>,
75    TB: AsPrimitive<TI>,
76    i8: AsPrimitive<TI>,
77{
78    unsafe {
79        assert!(k % Q4_0.block_len() == 0);
80        let len = (k * MR) / Q4_0.block_len() * Q4_0.block_bytes();
81        let mut pa = NibbleReader::for_slice(std::slice::from_raw_parts(pa, len));
82        let b = pb as *const TB;
83        for bk in 0..k / 32 {
84            let mut scales: [TI; MR] = [TI::zero(); MR];
85            scales.iter_mut().for_each(|x| *x = pa.read_f16().as_());
86            for ik in 0..32 {
87                let mut a: [TI; MR] = [TI::zero(); MR];
88                a.iter_mut().zip(&scales).for_each(|(x, s)| *x = *s * (pa.read_i4() - 8).as_());
89                let b = std::slice::from_raw_parts(b.add(NR * (ik + 32 * bk)), NR);
90                for i in 0..MR {
91                    for j in 0..NR {
92                        ab[i][j] += a[i] * b[j].as_();
93                    }
94                }
95            }
96        }
97    }
98}
99
100unsafe fn add_mat_mul_pq40_scales_at_end<const MR: usize, const NR: usize, TB, TI>(
101    pa: *const u8,
102    pb: *const u8,
103    k: usize,
104    ab: &mut [[TI; NR]; MR],
105) where
106    TI: LADatum,
107    f16: AsPrimitive<TI>,
108    TB: AsPrimitive<TI>,
109    i8: AsPrimitive<TI>,
110{
111    unsafe {
112        assert!(k % Q4_0.block_len() == 0);
113        let len = (k * MR) / Q4_0.block_len() * Q4_0.block_bytes();
114        let mut pa = NibbleReader::for_slice(std::slice::from_raw_parts(pa, len));
115        let b = pb as *const TB;
116        for bk in 0..k / 32 {
117            let mut temp = [[TI::zero(); NR]; MR];
118            for ik in 0..32 {
119                let mut a: [TI; MR] = [TI::zero(); MR];
120                a.iter_mut().for_each(|x| *x = (pa.read_i4() - 8).as_());
121                let b = std::slice::from_raw_parts(b.add(NR * (ik + 32 * bk)), NR);
122                for i in 0..MR {
123                    for j in 0..NR {
124                        temp[i][j] += a[i] * b[j].as_();
125                    }
126                }
127            }
128            for i in 0..MR {
129                let scale = pa.read_f16().as_();
130                for j in 0..NR {
131                    ab[i][j] += temp[i][j] * scale;
132                }
133            }
134        }
135    }
136}
137
138unsafe fn add_unicast<const MR: usize, const NR: usize, TI, TO>(
139    ab: &mut [[TI; NR]; MR],
140    other: &OutputStoreKer,
141) where
142    TI: LADatum,
143    TO: LADatum + AsPrimitive<TI>,
144{
145    unsafe {
146        for i in 0usize..MR {
147            for j in 0usize..NR {
148                let value: *const TO = other
149                    .ptr
150                    .offset(other.row_byte_stride * i as isize + other.col_byte_stride * j as isize)
151                    as _;
152                ab[i].as_mut()[j] += (*value).as_();
153            }
154        }
155    }
156}
157
158unsafe fn store_t<const MR: usize, const NR: usize, TC, TI>(
159    tile: &OutputStoreKer,
160    ab: &[[TI; NR]; MR],
161) where
162    TC: Copy,
163{
164    unsafe {
165        for i in 0usize..MR {
166            for j in 0usize..NR {
167                let loc: *mut TC = tile
168                    .ptr
169                    .offset(tile.row_byte_stride * i as isize + tile.col_byte_stride * j as isize)
170                    as _;
171                let val: *const TC = (&ab[i].as_ref()[j]) as *const TI as _;
172                *loc = *val
173            }
174        }
175    }
176}
177
178unsafe fn store_float_t<const MR: usize, const NR: usize, TC, TI>(
179    tile: &OutputStoreKer,
180    ab: &[[TI; NR]; MR],
181) where
182    TC: Copy + 'static,
183    TI: Copy + 'static + AsPrimitive<TC>,
184{
185    unsafe {
186        for i in 0usize..MR {
187            for j in 0usize..NR {
188                let loc: *mut TC = tile
189                    .ptr
190                    .offset(tile.row_byte_stride * i as isize + tile.col_byte_stride * j as isize)
191                    as _;
192                let val = ab[i].as_ref()[j].as_();
193                *loc = val
194            }
195        }
196    }
197}
198
199#[inline(never)]
200unsafe fn kernel<TI, const MR: usize, const NR: usize>(mut pnl: *const FusedKerSpec<TI>) -> isize
201where
202    TI: LADatum + ScaleShiftAndRound + AsPrimitive<TI>,
203    TI: AsPrimitive<f16> + AsPrimitive<f32> + AsPrimitive<f64>,
204    usize: AsPrimitive<TI>,
205    f16: AsPrimitive<TI>,
206    f32: AsPrimitive<TI>,
207    f64: AsPrimitive<TI>,
208    i8: AsPrimitive<TI>,
209    i32: AsPrimitive<TI>,
210{
211    unsafe {
212        let mut ab = [[TI::zero(); NR]; MR];
213        loop {
214            if pnl.is_null() {
215                break;
216            }
217            match *pnl {
218                FusedKerSpec::Done => break,
219                FusedKerSpec::Clear => ab = std::mem::zeroed(),
220                FusedKerSpec::LoadTile(col_major, _row_major) => {
221                    for row in 0..MR {
222                        for col in 0..NR {
223                            ab[row][col] = *col_major.add(col * MR + row);
224                        }
225                    }
226                }
227                FusedKerSpec::ScalarAdd(a) => scalar!(ab, a, |a, b| a + b),
228                FusedKerSpec::ScalarMul(a) => scalar!(ab, a, |a, b| a * b),
229                FusedKerSpec::ScalarMin(m) => scalar!(ab, m, |a, b| if a < b { a } else { b }),
230                FusedKerSpec::ScalarMax(m) => scalar!(ab, m, |a, b| if a > b { a } else { b }),
231                FusedKerSpec::ScalarSub(m) => scalar!(ab, m, |a, b| a - b),
232                FusedKerSpec::ScalarSubF(m) => scalar!(ab, m, |a, b| b - a),
233                FusedKerSpec::LeakyRelu(m) => {
234                    scalar!(ab, m, |a, b| if b > TI::zero() { b } else { a * b })
235                }
236                FusedKerSpec::PerRowMin(m) => per_row!(ab, m, |a, b| if a < b { a } else { b }),
237                FusedKerSpec::PerRowMax(m) => per_row!(ab, m, |a, b| if a > b { a } else { b }),
238                FusedKerSpec::PerRowAdd(m) => per_row!(ab, m, |a, b| a + b),
239                FusedKerSpec::PerRowMul(m) => per_row!(ab, m, |a, b| a * b),
240                FusedKerSpec::PerRowSub(m) => per_row!(ab, m, |a, b| a - b),
241                FusedKerSpec::PerRowSubF(m) => per_row!(ab, m, |a, b| b - a),
242                FusedKerSpec::PerColMin(m) => per_col!(ab, m, |a, b| if a < b { a } else { b }),
243                FusedKerSpec::PerColMax(m) => per_col!(ab, m, |a, b| if a > b { a } else { b }),
244                FusedKerSpec::PerColAdd(m) => per_col!(ab, m, |a, b| a + b),
245                FusedKerSpec::PerColMul(m) => per_col!(ab, m, |a, b| a * b),
246                FusedKerSpec::PerColSub(m) => per_col!(ab, m, |a, b| a - b),
247                FusedKerSpec::PerColSubF(m) => per_col!(ab, m, |a, b| b - a),
248                FusedKerSpec::AddRowColProducts(rows, cols) => {
249                    for i in 0..MR {
250                        for j in 0..NR {
251                            ab[i][j] += *rows.add(i) * *cols.add(j);
252                        }
253                    }
254                }
255                FusedKerSpec::AddUnicast(other) => {
256                    if TI::datum_type().is_float() && other.item_size == 2 {
257                        add_unicast::<MR, NR, TI, f16>(&mut ab, &other)
258                    } else if TI::datum_type().is_float() && other.item_size == 4 {
259                        add_unicast::<MR, NR, TI, f32>(&mut ab, &other)
260                    } else if TI::datum_type().is_float() && other.item_size == 8 {
261                        add_unicast::<MR, NR, TI, f64>(&mut ab, &other)
262                    } else if TI::datum_type() == i32::datum_type() && other.item_size == 1 {
263                        add_unicast::<MR, NR, TI, i8>(&mut ab, &other)
264                    } else if TI::datum_type() == i32::datum_type() && other.item_size == 4 {
265                        add_unicast::<MR, NR, TI, i32>(&mut ab, &other)
266                    } else {
267                        unimplemented!("Missing AddUnicast type");
268                    }
269                }
270                FusedKerSpec::ShiftLeft(shift) => {
271                    for i in 0..MR {
272                        for j in 0..NR {
273                            ab[i][j] = ab[i][j].q_shl(shift);
274                        }
275                    }
276                }
277                FusedKerSpec::RoundingShiftRight(shift, rp) => {
278                    for i in 0..MR {
279                        for j in 0..NR {
280                            ab[i][j] = ab[i][j].q_shr(shift, rp);
281                        }
282                    }
283                }
284                FusedKerSpec::QScale(shift, rp, mult) => {
285                    for i in 0..MR {
286                        for j in 0..NR {
287                            ab[i][j] = ab[i][j].q_scale(Scaler::from_fuse_params(shift, rp, mult));
288                        }
289                    }
290                }
291                FusedKerSpec::AddMatMul { k, pa, pb, packing } => {
292                    use std::mem::transmute;
293                    if TI::datum_type().is_float() {
294                        match packing {
295                            0 => add_mat_mul::<MR, NR, TI, TI, TI>(pa, pb, k, &mut ab),
296                            1 => add_mat_mul::<MR, NR, TI, f16, f16>(pa, pb, k, &mut ab),
297                            2 => add_mat_mul::<MR, NR, TI, f32, f32>(pa, pb, k, &mut ab),
298                            3 => add_mat_mul::<MR, NR, TI, f16, f32>(pa, pb, k, &mut ab),
299                            4 => add_mat_mul::<MR, NR, TI, f32, f16>(pa, pb, k, &mut ab),
300                            5 => add_mat_mul_pq40::<MR, NR, f16, TI>(pa, pb, k, &mut ab),
301                            6 => add_mat_mul_pq40_scales_at_end::<MR, NR, f16, TI>(
302                                pa, pb, k, &mut ab,
303                            ),
304                            7 => add_mat_mul_pq40::<MR, NR, f32, TI>(pa, pb, k, &mut ab),
305                            _ => unreachable!(),
306                        }
307                    } else if TI::datum_type() == i32::datum_type() {
308                        // transmute to allow using explicitly i3 in add_mat_mul generic params
309                        let ab = transmute::<&mut [[TI; NR]; MR], &mut [[i32; NR]; MR]>(&mut ab);
310                        if packing == 0 {
311                            add_mat_mul::<MR, NR, i32, i32, i32>(pa, pb, k, ab)
312                        } else if packing == 1 {
313                            add_mat_mul::<MR, NR, i32, i8, i8>(pa, pb, k, ab)
314                        } else {
315                            return 1;
316                        }
317                    } else {
318                        return 1;
319                    }
320                }
321                FusedKerSpec::Store(tile) => {
322                    if TI::datum_type().is_float() {
323                        match tile.item_size {
324                            2 => store_float_t::<MR, NR, f16, _>(&tile, &ab),
325                            4 => store_float_t::<MR, NR, f32, _>(&tile, &ab),
326                            8 => store_float_t::<MR, NR, f64, _>(&tile, &ab),
327                            _ => unimplemented!(),
328                        }
329                    } else {
330                        match tile.item_size {
331                            1 => store_t::<MR, NR, u8, _>(&tile, &ab),
332                            2 => store_t::<MR, NR, u16, _>(&tile, &ab),
333                            4 => store_t::<MR, NR, u32, _>(&tile, &ab),
334                            8 => store_t::<MR, NR, u64, _>(&tile, &ab),
335                            _ => unimplemented!(),
336                        }
337                    }
338                }
339            };
340            pnl = pnl.add(1);
341        }
342    }
343    0
344}
345
346fn pq40_r4() -> PackedBlockQuantFormat {
347    PackedBlockQuantFormat::new(&Q4_0, 4, 0, false)
348}
349
350fn pq40_r4_se() -> PackedBlockQuantFormat {
351    PackedBlockQuantFormat::new(&Q4_0, 4, 0, true)
352}
353
354// f16 kernels
355MMMRustKernel!(kernel::<f16, 4, 4> => generic_f16_4x4<f16>(4,4)
356    packing[1] = f16f16bis => |k| k.with_packing(f16::packing(4), f16::packing(4));
357    packing[2] = f32f32 => |k| k.with_packing(f32::packing(4), f32::packing(4));
358    packing[3] = f16f32 => |k| k.with_packing(f16::packing(4), f32::packing(4));
359    packing[4] = f32f16 => |k| k.with_packing(f32::packing(4), f16::packing(4));
360    packing[5] = q40f16 => |k| k.with_packing(pq40_r4(), f16::packing(4));
361    packing[6] = q40f16se => |k| k.with_packing(pq40_r4_se(), f16::packing(4));
362    packing[7] = q40f32 => |k| k.with_packing(pq40_r4(), f32::packing(4));
363    quality(if has_fp16() { ImplementationQuality::Generic } else { ImplementationQuality::Dreadful })
364    store(f32, f64)
365);
366
367MMMRustKernel! {kernel::<f16, 4, 1> => generic_f16_4x1<f16>(4,1)
368    packing[1] = f16f16bis => |k| k.with_packing(f16::packing(4), f16::packing(1));
369    packing[2] = f32f32 => |k| k.with_packing(f32::packing(4), f32::packing(1));
370    packing[3] = f16f32 => |k| k.with_packing(f16::packing(4), f32::packing(1));
371    packing[4] = f32f16 => |k| k.with_packing(f32::packing(4), f16::packing(1));
372    packing[5] = q40f16 => |k| k.with_packing(pq40_r4(), f16::packing(1));
373    packing[6] = q40f16se => |k| k.with_packing(pq40_r4_se(), f16::packing(1));
374    packing[7] = q40f32 => |k| k.with_packing(pq40_r4(), f32::packing(1));
375    quality(if has_fp16() { ImplementationQuality::Generic } else { ImplementationQuality::Dreadful })
376    store(f32, f64)
377}
378
379// f32 kernels
380MMMRustKernel!(kernel::<f32, 4, 4> => generic_f32_4x4<f32>(4,4)
381    packing[1] = f16f16 => |k| k.with_packing(f16::packing(4), f16::packing(4));
382    packing[2] = f32f32bis => |k| k.with_packing(f32::packing(4), f32::packing(4));
383    packing[3] = f16f32 => |k| k.with_packing(f16::packing(4), f32::packing(4));
384    packing[4] = f32f16 => |k| k.with_packing(f32::packing(4), f16::packing(4));
385    packing[5] = q40f16 => |k| k.with_packing(pq40_r4(), f16::packing(4));
386    packing[6] = q40f16se => |k| k.with_packing(pq40_r4_se(), f16::packing(4));
387    packing[7] = q40f32 => |k| k.with_packing(pq40_r4(), f32::packing(4));
388    quality(ImplementationQuality::Generic)
389    store(f16, f64)
390);
391MMMRustKernel! {kernel::<f32, 4, 1> => generic_f32_4x1<f32>(4,1)
392    packing[1] = f16f16 => |k| k.with_packing(f16::packing(4), f16::packing(1));
393    packing[2] = f32f32bis => |k| k.with_packing(f32::packing(4), f32::packing(1));
394    packing[3] = f16f32 => |k| k.with_packing(f16::packing(4), f32::packing(1));
395    packing[4] = f32f16 => |k| k.with_packing(f32::packing(4), f16::packing(1));
396    packing[5] = q40f16 => |k| k.with_packing(pq40_r4(), f16::packing(1));
397    packing[6] = q40f16se => |k| k.with_packing(pq40_r4_se(), f16::packing(1));
398    packing[7] = q40f32 => |k| k.with_packing(pq40_r4(), f32::packing(1));
399    quality(ImplementationQuality::Generic)
400    store(f16, f64)
401}
402
403// f64 kernels
404MMMRustKernel!(kernel::<f64, 4, 4> => generic_f64_4x4<f64>(4,4)
405    quality(ImplementationQuality::Generic)
406    store(f16, f32));
407MMMRustKernel!(kernel::<f64, 4, 1> => generic_f64_4x1<f64>(4,1)
408    quality(ImplementationQuality::Generic)
409    store(f16, f32));
410
411// I32 kernels
412MMMRustKernel! {kernel::<i32, 4, 4> => generic_i32_4x4<i32>(4,4)
413    packing[1] = i8i8 => |k| k.with_packing(i8::packing(4), i8::packing(4));
414    quality(ImplementationQuality::Generic)
415    store(i8)
416}
417
418MMMRustKernel! {kernel::<i32, 4, 1> => generic_i32_4x1<i32>(4,1)
419    packing[1] = i8i8 => |k| k.with_packing(i8::packing(4), i8::packing(1));
420    quality(ImplementationQuality::Generic)
421    store(i8)
422}
423
424// extra tests kernels
425#[cfg(test)]
426MMMRustKernel!(kernel::<f32, 3, 2> => generic_f32_3x2<f32>(3,2) store(f16, f64));
427
428#[cfg(test)]
429MMMRustKernel! {kernel::<i32, 3, 2> => generic_i32_3x2<i32>(3,2)
430    packing[1] = i8i8 => |k| k.with_packing(i8::packing(3), i8::packing(2));
431    store(i8)
432}
433
434pub fn plug(ops: &mut Ops) {
435    ops.mmm_impls.push(generic_f16_4x4.mmm());
436    ops.mmm_impls.push(generic_f16_4x1.mmm());
437    ops.mmm_impls.push(generic_f32_4x4.mmm());
438    ops.mmm_impls.push(generic_f32_4x1.mmm());
439    ops.mmm_impls.push(generic_f64_4x4.mmm());
440    ops.mmm_impls.push(generic_f64_4x1.mmm());
441    ops.mmm_impls.push(generic_i32_4x4.mmm());
442    ops.mmm_impls.push(generic_i32_4x1.mmm());
443}
444
445#[cfg(test)]
446mod test {
447
448    #[test]
449    fn kits() {
450        let mut ops = crate::generic();
451        super::plug(&mut ops);
452    }
453}