use std::fmt::Debug;
use crate::frame::mmm::FusedKerSpec;
use crate::LADatum;
use super::{MatMatMul, MatMatMulImpl, FusedSpec};
pub trait MatMatMulKer<TI>: Copy + Clone + Debug + Send + Sync + 'static
where
    TI: LADatum,
{
    fn name() -> &'static str;
    fn kernel(op: &[FusedKerSpec<TI>]) -> isize;
    fn mr() -> usize;
    fn nr() -> usize;
    fn alignment_bytes_packed_a() -> usize;
    fn end_padding_packed_a() -> usize;
    fn alignment_bytes_packed_b() -> usize;
    fn end_padding_packed_b() -> usize;
    #[allow(unused_variables)]
    fn prefetch(ptr: *const u8, len: usize) {}
    fn mmm() -> Box<dyn MatMatMul> {
        Box::<MatMatMulImpl<Self, TI>>::default()
    }
    #[allow(unused_variables)]
    fn can_fuse(spec: &FusedSpec) -> bool {
        true
    }
}
#[macro_export]
macro_rules! test_mmm_kernel_f16 {
    ($k: ident, $cond: expr) => {
        paste! {
            #[cfg(test)]
            #[allow(non_snake_case)]
            mod [<test_ $k>] {
                mmm_kernel_tests!($cond, $k, f16, f16, f16, f16);
                mmm_frame_tests!($cond, $k, f16, f16, f16, f16);
                mmm_kernel_fuse_tests!($cond, $k, f16, f16);
            }
        }
    };
}
#[macro_export]
macro_rules! test_mmm_kernel_f32 {
    ($k: ident, $cond: expr) => {
        paste! {
            #[cfg(test)]
            #[allow(non_snake_case)]
            mod [<test_ $k>] {
                mmm_kernel_tests!($cond, $k, f32, f32, f32, f32);
                mmm_frame_tests!($cond, $k, f32, f32, f32, f32);
                mmm_kernel_fuse_tests!($cond, $k, f32, f32);
                }
        }
    };
}
#[macro_export]
macro_rules! test_mmm_kernel_f64 {
    ($k: ident, $cond: expr) => {
        paste! {
            #[cfg(test)]
            #[allow(non_snake_case)]
            mod [<test_ $k>] {
                mmm_kernel_tests!($cond, $k, f64, f64, f64, f64);
                mmm_frame_tests!($cond, $k, f64, f64, f64, f64);
                mmm_kernel_fuse_tests!($cond, $k, f64, f64);
                }
        }
    };
}
#[macro_export]
macro_rules! test_mmm_kernel_i32 {
    ($k: ident, $cond: expr) => {
        paste! {
            #[cfg(test)]
            #[allow(non_snake_case)]
            mod [<test_ $k>] {
                mmm_kernel_tests!($cond, $k, i8, i8, i8, i32);
                mmm_kernel_fuse_tests!($cond, $k, i8, i32);
                mmm_frame_tests!($cond, $k, i8, i8, i8, i32);
            }
            #[cfg(test)]
            mod [<test_qi8_ $k>] {
                qmmm_kernel_fuse_tests!($cond, $k, i8, i8, i8, i32);
            }
            #[cfg(test)]
            mod [<test_qi32_ $k>] {
                qmmm_kernel_fuse_tests!($cond, $k, i8, i8, i32, i32);
            }
        }
    };
}
#[cfg(test)]
#[macro_use]
pub mod test {
    use super::*;
    use crate::frame::mmm::OutputStoreKer;
    use num_traits::{AsPrimitive, One, Zero};
    use proptest::collection::vec;
    use proptest::prelude::*;
    use std::fmt;
    use std::marker::PhantomData;
    use tract_data::internal::*;
    #[macro_export]
    macro_rules! mmm_kernel_tests {
        ($cond:expr, $ker:ident, $ta:ty, $tb:ty, $tc:ty, $ti: ty) => {
            mod kernel {
                use super::super::$ker;
                use num_traits::Zero;
                use proptest::prelude::*;
                #[allow(unused_imports)]
                use tract_data::prelude::f16;
                #[allow(unused_imports)]
                use $crate::frame::mmm::kernel::test;
                use $crate::frame::mmm::kernel::test::PackedPackedProblem;
                use $crate::frame::mmm::MatMatMulKer;
                proptest::proptest! {
                    #[test]
                    fn packed_packed_prop(pb in any::<PackedPackedProblem<$ker, $ta, $tb, $tc, $ti>>()) {
                        if $cond {
                            prop_assert_eq!(pb.run(), pb.reference())
                        }
                    }
                }
                #[test]
                fn packed_packed_1() {
                    if $cond {
                        test::packed_packed::<$ker, $ta, $tb, $tc, $ti>(1)
                    }
                }
                #[test]
                fn packed_packed_2() {
                    if $cond {
                        test::packed_packed::<$ker, $ta, $tb, $tc, $ti>(2)
                    }
                }
                #[test]
                fn packed_packed_13() {
                    if $cond {
                        test::packed_packed::<$ker, $ta, $tb, $tc, $ti>(13)
                    }
                }
                #[test]
                fn packed_packed_empty() {
                    if $cond {
                        let pb = PackedPackedProblem::<$ker, $ta, $tb, $tc, $ti>::new(
                            0,
                            vec!(<$ta>::zero(); 0),
                            vec!(<$tb>::zero(); 0),
                            false,
                            false);
                        assert_eq!(pb.run(), pb.reference())
                    }
                }
                #[test]
                fn packed_packed_bug_1() {
                    if $cond {
                        let pb = PackedPackedProblem::<$ker, $ta, $tb, $tc, $ti>::new(
                            1,
                            vec!(<$ta>::zero(); <$ker>::mr()),
                            vec!(<$tb>::zero(); <$ker>::nr()),
                            true,
                            true);
                        assert_eq!(pb.run(), pb.reference())
                    }
                }
                #[test]
                fn packed_vec_k1() {
                    if $cond {
                        test::packed_vec::<$ker, $ta, $tb, $tc, $ti>(1)
                    }
                }
                #[test]
                fn packed_vec_k2() {
                    if $cond {
                        test::packed_vec::<$ker, $ta, $tb, $tc, $ti>(2)
                    }
                }
                #[test]
                fn packed_vec_k4() {
                    if $cond {
                        test::packed_vec::<$ker, $ta, $tb, $tc, $ti>(4)
                    }
                }
                #[test]
                fn packed_vec_k13() {
                    if $cond {
                        test::packed_vec::<$ker, $ta, $tb, $tc, $ti>(13)
                    }
                }
            }
        };
    }
    #[derive(Debug, new)]
    pub struct PackedPackedProblem<K, TA, TB, TC, TI>
    where
        K: MatMatMulKer<TI>,
        TA: 'static + Debug + AsPrimitive<TI>,
        TB: 'static + Debug + AsPrimitive<TI>,
        TC: Copy + PartialEq + 'static + Debug,
        TI: LADatum + fmt::Display + AsPrimitive<TC>,
        usize: AsPrimitive<TA> + AsPrimitive<TB>,
    {
        pub k: usize,
        pub a: Vec<TA>,
        pub b: Vec<TB>,
        pub trans_c: bool,
        pub add_one: bool,
        pub _phantom: PhantomData<(K, TC, TI)>,
    }
    impl<K, TA, TB, TC, TI> Arbitrary for PackedPackedProblem<K, TA, TB, TC, TI>
    where
        K: MatMatMulKer<TI>,
        TA: 'static + Debug + AsPrimitive<TI>,
        TB: 'static + Debug + AsPrimitive<TI>,
        TC: Copy + PartialEq + 'static + Debug,
        TI: LADatum + fmt::Display + AsPrimitive<TC>,
        usize: AsPrimitive<TA> + AsPrimitive<TB>,
    {
        type Parameters = ();
        type Strategy = BoxedStrategy<Self>;
        fn arbitrary_with(_: ()) -> Self::Strategy {
            (0usize..20, any::<bool>(), any::<bool>())
                .prop_flat_map(|(k, trans_c, add_one)| {
                    let m = k * K::mr();
                    let n = k * K::nr();
                    let a = (0usize..10).prop_map(|x| x.as_());
                    let b = (0usize..10).prop_map(|x| x.as_());
                    (Just(k), Just(trans_c), Just(add_one), vec(a, m..=m), vec(b, n..=n))
                })
                .prop_map(|(k, trans_c, add_one, a, b)| Self {
                    k,
                    a,
                    b,
                    trans_c,
                    add_one,
                    _phantom: PhantomData,
                })
                .boxed()
        }
    }
    impl<K, TA, TB, TC, TI> PackedPackedProblem<K, TA, TB, TC, TI>
    where
        K: MatMatMulKer<TI>,
        TA: 'static + Debug + AsPrimitive<TI> + Datum,
        TB: 'static + Debug + AsPrimitive<TI> + Datum,
        TC: Copy + Zero + PartialEq + 'static + Debug,
        TI: LADatum + fmt::Display + AsPrimitive<TC>,
        usize: AsPrimitive<TA> + AsPrimitive<TB>,
    {
        pub fn reference(&self) -> Vec<TC> {
            let init = if self.add_one { TI::one() } else { TI::zero() };
            let mut vi = vec![init; K::mr() * K::nr()];
            let mr = K::mr();
            let nr = K::nr();
            for m in 0..mr {
                for n in 0..nr {
                    for k in 0..self.k {
                        let a: TI = self.a[m + mr * k].as_();
                        let b: TI = self.b[n + nr * k].as_();
                        let offset = if self.trans_c { m + n * mr } else { n + m * nr };
                        vi[offset] += a * b;
                    }
                }
            }
            vi.into_iter().map(|ti| ti.as_()).collect()
        }
        pub fn run(&self) -> Vec<TC> {
            unsafe {
                let a = self
                    .a
                    .iter()
                    .cloned()
                    .chain(vec![0.as_(); K::end_padding_packed_a() * K::mr()])
                    .collect::<Vec<_>>();
                let pa = Tensor::from_slice_align(&a, K::alignment_bytes_packed_a()).unwrap();
                let b = self
                    .b
                    .iter()
                    .cloned()
                    .chain(vec![0.as_(); K::end_padding_packed_b() * K::nr()])
                    .collect::<Vec<_>>();
                let pb = Tensor::from_slice_align(&b, K::alignment_bytes_packed_b()).unwrap();
                let mut v = vec![TC::zero(); K::mr() * K::nr()];
                let c = if self.trans_c {
                    mmm_stride_storage(&mut v, 1, K::mr())
                } else {
                    mmm_stride_storage(&mut v, K::nr(), 1)
                };
                let b_store = pb.as_ptr_unchecked::<TB>() as _;
                let mut non_linear_ops = tvec!(FusedKerSpec::AddMatMul {
                    k: self.k,
                    pa: pa.as_ptr_unchecked::<u8>() as _,
                    pb: b_store,
                    cpu_variant: 0,
                });
                if self.add_one {
                    non_linear_ops.push(FusedKerSpec::ScalarAdd(TI::one()));
                }
                non_linear_ops.push(FusedKerSpec::Store(c));
                non_linear_ops.push(FusedKerSpec::Done);
                non_linear_ops.insert(0, FusedKerSpec::Clear);
                let err = K::kernel(&non_linear_ops);
                assert_eq!(err, 0);
                v
            }
        }
    }
    pub fn packed_packed<K, TA, TB, TC, TI>(k: usize)
    where
        K: MatMatMulKer<TI>,
        TA: Copy + One + Datum + AsPrimitive<TI>,
        TB: Copy + One + Datum + AsPrimitive<TI>,
        TC: Copy + PartialEq + Zero + 'static + Debug,
        TI: LADatum + AsPrimitive<TC>,
        usize: AsPrimitive<TC> + AsPrimitive<TA> + AsPrimitive<TB>,
    {
        let a = vec![TA::one(); K::mr() * k];
        let b = vec![TB::one(); K::nr() * k];
        let pb = PackedPackedProblem::<K, TA, TB, TC, TI>::new(k, a, b, false, false);
        assert_eq!(pb.run(), pb.reference())
    }
    pub fn mmm_stride_storage<T: Copy>(v: &mut [T], rsc: usize, csc: usize) -> OutputStoreKer {
        OutputStoreKer {
            ptr: v.as_mut_ptr() as _,
            row_byte_stride: (std::mem::size_of::<T>() * rsc) as isize,
            col_byte_stride: (std::mem::size_of::<T>() * csc) as isize,
            item_size: std::mem::size_of::<T>(),
        }
    }
    pub fn packed_vec<K, TA, TB, TC, TI>(k: usize)
    where
        K: MatMatMulKer<TI>,
        TA: Copy + One + AsPrimitive<TI> + Debug + Datum,
        TB: Copy + One + AsPrimitive<TI> + Debug + Datum,
        TC: Copy + PartialEq + Zero + 'static + Debug,
        TI: LADatum + AsPrimitive<TC>,
        usize: AsPrimitive<TC>,
    {
        let pa = unsafe {
            Tensor::from_slice_align(
                &vec![TA::one(); K::mr() * (k + K::end_padding_packed_a())],
                K::alignment_bytes_packed_a(),
            )
            .unwrap()
        };
        let b = vec![TB::one(); (k + 1) * K::nr()];
        let mut c: Vec<TC> = vec![TC::zero(); K::mr() * K::nr()];
        let tile = mmm_stride_storage(&mut c, 1, 0);
        let pb = unsafe { Tensor::from_slice_align(&b, K::alignment_bytes_packed_b()).unwrap() };
        let non_linear_ops = tvec!(
            FusedKerSpec::Clear,
            FusedKerSpec::AddMatMul {
                pa: unsafe { pa.as_ptr_unchecked::<u8>() as _ },
                pb: unsafe { pb.as_ptr_unchecked::<u8>() as _ },
                k,
                cpu_variant: 0,
            },
            FusedKerSpec::Store(tile),
            FusedKerSpec::Done
        );
        let err = K::kernel(&non_linear_ops);
        assert_eq!(err, 0);
        let expected = vec![k.as_(); K::mr()];
        assert_eq!(c[..K::mr()], expected);
    }
}