use std::fmt::Debug;
use super::{MMMInput, OutputStore, OutputStoreKer};
use tract_data::internal::*;
#[repr(usize)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum RoundingPolicy {
Native,
Zero,
Away,
MinusInf,
PlusInf,
Even,
Odd,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum BinOp {
Min,
Max,
Add,
Mul,
Sub,
SubF,
}
impl BinOp {
pub fn flip(&self) -> BinOp {
use BinOp::*;
match self {
Sub => SubF,
SubF => Sub,
sym => *sym,
}
}
}
#[derive(Clone, Debug)]
pub enum FusedSpec<'t> {
BinScalar(&'t Tensor, BinOp),
BinPerRow(TensorView<'t>, BinOp),
BinPerCol(TensorView<'t>, BinOp),
AddRowColProducts(&'t Tensor, &'t Tensor),
AddUnicast(OutputStore),
LeakyRelu(&'t Tensor),
QScale(isize, RoundingPolicy, i32),
RoundingShiftRight(usize, RoundingPolicy),
ShiftLeft(usize),
Store(OutputStore),
AddMatMul { a: &'t dyn MMMInput, b: &'t dyn MMMInput },
}
impl<'t> FusedSpec<'t> {
pub fn prefer_col_outer(&self) -> bool {
false
}
}
#[repr(C, usize)]
#[derive(PartialEq, Eq, Copy, Clone, Debug)]
#[rustfmt::skip]
pub enum FusedKerSpec<TI: Copy> {
Done, Clear, ScalarMin(TI), ScalarMax(TI), ScalarAdd(TI), ScalarMul(TI), ScalarSub(TI), ScalarSubF(TI), LeakyRelu(TI), PerRowMin(*const TI), PerRowMax(*const TI), PerRowAdd(*const TI), PerRowMul(*const TI), PerRowSub(*const TI), PerRowSubF(*const TI), PerColMin(*const TI), PerColMax(*const TI), PerColAdd(*const TI), PerColMul(*const TI), PerColSub(*const TI), PerColSubF(*const TI), QScale(isize, RoundingPolicy, i32), RoundingShiftRight(usize, RoundingPolicy), ShiftLeft(usize), AddUnicast(OutputStoreKer), AddRowColProducts(*const TI, *const TI), Store(OutputStoreKer), AddMatMul { k: usize, pa: *const u8, pb: *const u8, cpu_variant: usize },
}
#[cfg(test)]
#[macro_use]
pub mod test {
use crate::frame::mmm::storage::*;
use crate::frame::mmm::*;
use crate::generic::{ScaleShiftAndRound, Scaler};
use num_traits::{AsPrimitive, Bounded};
use proptest::prelude::*;
use tract_data::internal::*;
#[test]
fn check_non_linear_enum_size() {
assert_eq!(std::mem::size_of::<RoundingPolicy>(), std::mem::size_of::<usize>());
assert_eq!(
std::mem::size_of::<super::FusedKerSpec<f32>>(),
std::mem::size_of::<usize>() + std::mem::size_of::<OutputStoreKer>()
);
assert_eq!(
std::mem::size_of::<super::FusedKerSpec<f32>>(),
5 * std::mem::size_of::<usize>()
);
}
#[macro_export]
macro_rules! mmm_kernel_fuse_tests {
($cond:expr, $ker:ident, $tc:ty, $ti: ty) => {
mod fuse {
use super::super::$ker;
use num_traits::Zero;
#[allow(unused_imports)]
use tract_data::prelude::f16;
use tract_data::prelude::tensor0;
#[allow(unused_imports)]
use $crate::frame::mmm::fuse::test;
use $crate::frame::mmm::fuse::test::tile;
use $crate::frame::mmm::fuse::FusedKerSpec;
use $crate::frame::mmm::{FusedSpec, MatMatMulKer};
#[test]
fn return_zeros() {
if $cond {
test::return_zeros::<$ker, $tc, $ti>()
}
}
#[test]
fn store_non_contiguous() {
if $cond {
test::store_non_contiguous::<$ker, $tc, $ti>()
}
}
proptest::proptest! {
#[test]
fn return_c_prop(c in tile::<$ker, $tc, $ti>()) {
if $cond {
test::return_c::<$ker, $tc, $ti>(&c)
}
}
}
fn fmin<T: PartialOrd>(a: T, b: T) -> T {
if a < b {
a
} else {
b
}
}
fn fmax<T: PartialOrd>(a: T, b: T) -> T {
if a > b {
a
} else {
b
}
}
macro_rules! bin {
($FKS:ident, $geo:expr, $f:expr, $extra_cond:expr) => {
paste! {
#[test]
fn [<$FKS:snake>]() {
if $cond && $extra_cond {
test::$geo::<$ker, $tc, $ti>(FusedKerSpec::$FKS, $f);
}
}
}
};
}
bin!(PerColMin, per_col, fmin, true);
bin!(PerColMax, per_col, fmax, true);
bin!(PerColAdd, per_col, |a, b| a + b, true);
bin!(PerColMul, per_col, |a, b| a * b, true);
bin!(PerColSub, per_col, |a, b| a - b, true);
bin!(PerColSubF, per_col, |a, b| b - a, true);
bin!(PerRowMin, per_row, fmin, true);
bin!(PerRowMax, per_row, fmax, true);
bin!(PerRowAdd, per_row, |a, b| a + b, true);
bin!(PerRowMul, per_row, |a, b| a * b, true);
bin!(PerRowSub, per_row, |a, b| a - b, true);
bin!(PerRowSubF, per_row, |a, b| b - a, true);
bin!(ScalarMin, scalar, fmin, true);
bin!(ScalarMax, scalar, fmax, true);
bin!(ScalarAdd, scalar, |a, b| a + b, true);
bin!(ScalarMul, scalar, |a, b| a * b, true);
bin!(ScalarSub, scalar, |a, b| a - b, true);
bin!(ScalarSubF, scalar, |a, b| b - a, true);
bin!(
LeakyRelu,
scalar,
|a, b| if b > <$ti>::zero() { b } else { a * b },
<$ker as MatMatMulKer<$ti>>::can_fuse(&FusedSpec::LeakyRelu(&tensor0(
<$ti>::zero()
)))
);
#[test]
fn return_c_add_row_col_product() {
if $cond {
test::return_c_add_row_col_product::<$ker, $tc, $ti>()
}
}
#[test]
fn return_c_plus_d() {
if $cond {
test::return_c_plus_d::<$ker, $tc, $ti>()
}
}
#[test]
fn return_c_clear() {
if $cond {
test::return_c_clear::<$ker, $tc, $ti>()
}
}
}
};
}
#[macro_export]
macro_rules! qmmm_kernel_fuse_tests {
($cond:expr, $ker:ident, $ta:ty, $tb:ty, $tc:ty, $ti: ty) => {
mod fuseq {
use $crate::frame::mmm::fuse::RoundingPolicy;
#[allow(unused_imports)]
use $crate::frame::mmm::fuse::test;
use $crate::frame::mmm::fuse::test::QScaleProblem;
use $crate::frame::mmm::kernel::MatMatMulKer;
use $crate::generic::Scaler;
use proptest::prelude::*;
use super::super::$ker;
macro_rules! test_q_scale {
($policy: ident) => {
paste! {
#[test]
fn [<return_q_scale_halfpos_ $policy:lower>]() {
if $cond {
let len = (<$ker>::mr() * <$ker>::nr()) as i64;
let v = (0..len).map(|i| (i - len / 2) as $tc).collect();
QScaleProblem::<$ker, $tc, $ti>::new(v, Scaler::new(0.5f32, RoundingPolicy::$policy)).run()
}
}
#[test]
fn [<return_q_scale_halfneg_ $policy:lower>]() {
if $cond {
let len = (<$ker>::mr() * <$ker>::nr()) as i64;
let v = (0..len).map(|i| (i - len / 2) as $tc).collect();
QScaleProblem::<$ker, $tc, $ti>::new(v, Scaler::new(-0.5f32, RoundingPolicy::$policy)).run()
}
}
#[test]
fn [<return_q_scale_pot_ $policy:lower>]() {
if $cond {
let len = (<$ker>::mr() * <$ker>::nr()) as i64;
let v = (0..len).map(|i| (i - len / 2) as $tc).collect();
QScaleProblem::<$ker, $tc, $ti>::new(v, Scaler::new(0.25f32, RoundingPolicy::$policy)).run()
}
}
#[test]
fn [<return_q_scale_nonpot_ $policy:lower>]() {
if $cond {
let len = (<$ker>::mr() * <$ker>::nr()) as i64;
let v = (0..len).map(|i| (i - len / 2) as $tc).collect();
QScaleProblem::<$ker, $tc, $ti>::new(v, Scaler::new(1f32 / 5., RoundingPolicy::$policy)).run()
}
}
#[test]
fn [<return_q_scale_bigpot_ $policy:lower>]() {
if $cond {
let len = (<$ker>::mr() * <$ker>::nr()) as i64;
let v = (0..len).map(|i| (i - len / 2) as $tc).collect();
QScaleProblem::<$ker, $tc, $ti>::new(v, Scaler::new(4f32, RoundingPolicy::$policy)).run()
}
}
#[test]
fn [<return_q_scale_bignonpot_ $policy:lower>]() {
if $cond {
let len = (<$ker>::mr() * <$ker>::nr()) as i64;
let v = (0..len).map(|i| (i - len / 2) as $tc).collect();
QScaleProblem::<$ker, $tc, $ti>::new(v, Scaler::new(14., RoundingPolicy::$policy)).run()
}
}
}
}
}
test_q_scale!(Zero);
test_q_scale!(Away);
test_q_scale!(MinusInf);
test_q_scale!(PlusInf);
test_q_scale!(Even);
test_q_scale!(Odd);
proptest::proptest! {
#[test]
fn return_q_scale_prop(pb in any::<QScaleProblem<$ker, $tc, $ti>>()) {
if $cond {
pb.run()
}
}
}
#[test]
fn return_c_scale_bigpot() {
if $cond {
test::return_c_scale_bigpot::<$ker, $tc, $ti>()
}
}
}
};
}
pub fn mmm_stride_storage<T: Copy>(v: &[T], rsc: usize) -> OutputStoreKer {
OutputStoreKer {
ptr: v.as_ptr() as _,
row_byte_stride: (std::mem::size_of::<T>() * rsc) as isize,
col_byte_stride: std::mem::size_of::<T>() as isize,
item_size: std::mem::size_of::<T>(),
}
}
use crate::LADatum;
pub fn return_zeros<K, TC, TI>()
where
K: MatMatMulKer<TI>,
TC: LADatum,
TI: LADatum + Bounded + PartialEq,
{
let v = vec![TC::max_value(); K::mr() * K::nr()];
let c = mmm_stride_storage(&v, K::nr());
let non_linear = tvec![FusedKerSpec::Clear, FusedKerSpec::Store(c), FusedKerSpec::Done];
let err = K::kernel(&non_linear);
assert_eq!(err, 0);
let expected = vec![TC::zero(); v.len()];
assert_eq!(v, expected);
}
pub fn store_non_contiguous<K, TC, TI>()
where
K: MatMatMulKer<TI>,
TC: LADatum,
TI: LADatum + Bounded + PartialEq,
{
let v = vec![TC::max_value(); K::mr() * 5 * K::nr() * 3];
let c = OutputStoreKer {
ptr: v.as_ptr() as _,
row_byte_stride: (std::mem::size_of::<TC>() * 3 * K::nr() * 5) as isize,
col_byte_stride: std::mem::size_of::<TC>() as isize * 3,
item_size: std::mem::size_of::<TC>(),
};
let non_linear = tvec![FusedKerSpec::Clear, FusedKerSpec::Store(c), FusedKerSpec::Done];
let err = K::kernel(&non_linear);
assert_eq!(err, 0);
let mut expected = vec![TC::max_value(); v.len()];
for c in 0..K::nr() {
for r in 0..K::mr() {
expected[c * 3 + r * 3 * 5 * K::nr()] = TC::zero();
}
}
assert_eq!(v, expected);
}
pub fn fused_ops<K, TC, TI, E>(c: &[TC], ops: &[FusedKerSpec<TI>], expect: E)
where
K: MatMatMulKer<TI>,
TC: Datum + AsPrimitive<TI>,
TI: LADatum + AsPrimitive<TC>,
E: Fn(usize, usize, TI) -> TI,
{
assert!(c.len() == K::mr() * K::nr());
let v = c.to_vec();
let c = mmm_stride_storage(&v, K::nr());
let mut ops = ops.to_vec();
ops.insert(0, FusedKerSpec::AddUnicast(c));
ops.insert(0, FusedKerSpec::Clear);
ops.push(FusedKerSpec::Store(c));
ops.push(FusedKerSpec::Done);
let expected = (0..v.len())
.map(|ix| expect(ix / K::nr(), ix % K::nr(), v[ix].as_()).as_())
.collect::<Vec<TC>>();
let err = K::kernel(&ops);
assert_eq!(err, 0);
if v != expected {
println!("found, expected:");
for m in 0..K::mr() {
for n in 0..K::nr() {
use nu_ansi_term::Color::*;
let f = v[m * K::nr() + n];
let e = expected[m * K::nr() + n];
let color = if f != e { Red } else { Green };
print!("{} ", color.paint(format!("{:4}", f)));
}
print!(" ");
for n in 0..K::nr() {
print!("{:4} ", expected[m * K::nr() + n]);
}
println!();
}
}
assert_eq!(v, expected);
}
pub fn return_c<K, TC, TI>(v: &[TC])
where
K: MatMatMulKer<TI>,
TC: LADatum + AsPrimitive<TI>,
TI: LADatum + AsPrimitive<TC>,
usize: AsPrimitive<TC> + AsPrimitive<TI>,
{
fused_ops::<K, TC, TI, _>(v, &[], |_, _, c| c + 1.as_() - 1.as_())
}
pub fn return_c_plus_d<K, TC, TI>()
where
K: MatMatMulKer<TI>,
TC: LADatum + AsPrimitive<TI>,
TI: LADatum + AsPrimitive<TC>,
usize: AsPrimitive<TC> + AsPrimitive<TI>,
{
let len = K::mr() * K::nr();
let v: Vec<TC> = (0..len).map(|f| f.as_()).collect();
let d: Vec<TI> = (0..len).map(|f| ((3 * f) % 7).as_()).collect();
fused_ops::<K, TC, TI, _>(
&v,
&[FusedKerSpec::AddUnicast(mmm_stride_storage(&d, K::nr()))],
|row, col, c| c + d[row * K::nr() + col],
);
}
pub fn per_col<K, TC, TI>(op: impl Fn(*const TI) -> FusedKerSpec<TI>, f: impl Fn(TI, TI) -> TI)
where
K: MatMatMulKer<TI>,
TC: LADatum + AsPrimitive<TI>,
TI: LADatum + AsPrimitive<TC>,
usize: AsPrimitive<TC> + AsPrimitive<TI>,
{
let len = K::mr() * K::nr();
let v: Vec<TC> = (0..len).map(|f| f.as_()).collect();
let bias: Vec<TI> = (0..K::nr()).map(|f| (f + 1).as_()).collect();
fused_ops::<K, TC, TI, _>(&v, &[op(bias.as_ptr())], |_, col, c| f(bias[col], c))
}
pub fn per_row<K, TC, TI>(op: impl Fn(*const TI) -> FusedKerSpec<TI>, f: impl Fn(TI, TI) -> TI)
where
K: MatMatMulKer<TI>,
TC: LADatum + AsPrimitive<TI>,
TI: LADatum + AsPrimitive<TC>,
usize: AsPrimitive<TC> + AsPrimitive<TI>,
{
let len = K::mr() * K::nr();
let v: Vec<TC> = (0..len).map(|f| f.as_()).collect();
let bias: Vec<TI> = (0..K::mr()).map(|f| (f + 1).as_()).collect();
fused_ops::<K, TC, TI, _>(&v, &[op(bias.as_ptr())], |row, _, c| f(bias[row], c))
}
pub fn scalar<K, TC, TI>(op: impl Fn(TI) -> FusedKerSpec<TI>, f: impl Fn(TI, TI) -> TI)
where
K: MatMatMulKer<TI>,
TC: LADatum + AsPrimitive<TI>,
TI: LADatum + AsPrimitive<TC>,
isize: AsPrimitive<TC> + AsPrimitive<TI>,
{
let len = K::mr() * K::nr();
let v: Vec<TC> = (0..len as isize).map(|f| (f - len as isize / 2).as_()).collect();
let five: TI = 5.as_();
fused_ops::<K, TC, TI, _>(&v, &[op(five)], |_, _, c| f(five, c))
}
pub fn return_c_add_row_col_product<K, TC, TI>()
where
K: MatMatMulKer<TI>,
TC: LADatum + AsPrimitive<TI>,
TI: LADatum + AsPrimitive<TC>,
usize: AsPrimitive<TC> + AsPrimitive<TI>,
{
let len = K::mr() * K::nr();
let v: Vec<TC> = (0..len).map(|f| (f + 1).as_()).collect();
let rows: Vec<TI> = (0..K::mr()).map(|f| (f + 3).as_()).collect();
let cols: Vec<TI> = (0..K::nr()).map(|f| (f + 2).as_()).collect();
fused_ops::<K, TC, TI, _>(
&v,
&[FusedKerSpec::AddRowColProducts(rows.as_ptr(), cols.as_ptr())],
|row, col, c| c + cols[col] * rows[row],
)
}
pub fn return_c_clear<K, TC, TI>()
where
K: MatMatMulKer<TI>,
TC: LADatum + AsPrimitive<TI>,
TI: LADatum + AsPrimitive<TC>,
usize: AsPrimitive<TC> + AsPrimitive<TI>,
{
let len = K::mr() * K::nr();
let v: Vec<TC> = (0..len).map(|f| f.as_()).collect();
fused_ops::<K, TC, TI, _>(&v, &[FusedKerSpec::Clear], |_, _, _| 0.as_())
}
pub fn return_c_scale_bigpot<K, TC, TI>()
where
K: MatMatMulKer<TI>,
TC: LADatum + AsPrimitive<TI>,
TI: LADatum + AsPrimitive<TC> + ScaleShiftAndRound,
isize: AsPrimitive<TC> + AsPrimitive<TI>,
{
let len = K::mr() * K::nr();
let v: Vec<TC> = (-(len as isize) / 2..).take(len).map(|f| f.as_()).collect();
fused_ops::<K, TC, TI, _>(&v, &[FusedKerSpec::ShiftLeft(1)], |_, _, c| c.q_shl(1))
}
#[derive(Debug, new)]
pub struct QScaleProblem<K, TC, TI>
where
K: MatMatMulKer<TI>,
TC: LADatum,
TI: LADatum + AsPrimitive<TC>,
i64: AsPrimitive<TC>,
{
pub c: Vec<TC>,
pub scaler: Scaler,
pub boo: std::marker::PhantomData<(K, TC, TI)>,
}
impl<K, TC, TI> Arbitrary for QScaleProblem<K, TC, TI>
where
K: MatMatMulKer<TI>,
TC: LADatum + Arbitrary,
TI: LADatum + AsPrimitive<TC>,
i64: AsPrimitive<TC>,
{
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_p: ()) -> Self::Strategy {
use RoundingPolicy::*;
let len = K::mr() * K::nr();
(
proptest::collection::vec((-20i64..20).prop_map(|i| i.as_()), len..=len),
-5i32..5,
prop_oneof!(Just(1f32), 0f32..1f32),
proptest::prop_oneof![
Just(Zero),
Just(Away),
Just(PlusInf),
Just(MinusInf),
Just(Odd),
Just(Even)
],
)
.prop_map(|(c, scale_pot, scale_mult, policy)| QScaleProblem {
c,
scaler: Scaler::new(scale_mult * 2f32.powi(scale_pot), policy),
boo: std::marker::PhantomData,
})
.boxed()
}
}
impl<K, TC, TI> QScaleProblem<K, TC, TI>
where
K: MatMatMulKer<TI>,
TC: LADatum + AsPrimitive<TI>,
TI: LADatum + AsPrimitive<TC> + ScaleShiftAndRound + AsPrimitive<i64>,
usize: AsPrimitive<TC> + AsPrimitive<TI>,
i64: AsPrimitive<TC>,
{
pub fn run(&self) {
if let FusedSpec::QScale(shift, policy, mult) = self.scaler.as_fused_spec() {
fused_ops::<K, TC, TI, _>(
&self.c,
&[FusedKerSpec::QScale(shift, policy, mult)],
|_, _, c| c.q_scale(self.scaler),
)
} else if let FusedSpec::RoundingShiftRight(shift, policy) = self.scaler.as_fused_spec()
{
fused_ops::<K, TC, TI, _>(
&self.c,
&[FusedKerSpec::RoundingShiftRight(shift, policy)],
|_, _, c| c.q_shr(shift, policy),
)
} else if let FusedSpec::ShiftLeft(shift) = self.scaler.as_fused_spec() {
fused_ops::<K, TC, TI, _>(&self.c, &[FusedKerSpec::ShiftLeft(shift)], |_, _, c| {
c.q_shl(shift)
})
} else {
unreachable!()
}
}
}
pub fn tile<K, TC, TI>() -> BoxedStrategy<Vec<TC>>
where
K: MatMatMulKer<TI>,
TC: LADatum,
TI: LADatum + AsPrimitive<TC>,
i8: AsPrimitive<TC>,
{
let len = K::mr() * K::nr();
proptest::collection::vec(any::<i8>().prop_map(|c| c.as_()), len..=len).boxed()
}
}