use std::fmt::Debug;
use std::marker::PhantomData;
use tract_data::TractResult;
use crate::LADatum;
use super::element_wise_helper::{map_reduce_slice_with_alignment, reduce_slice_with_alignment};
macro_rules! reduce_impl_wrap {
    ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $neutral: expr, $run: item, $reduce_two: item) => {
        paste! {
            #[derive(Copy, Clone, Debug)]
            #[allow(non_camel_case_types)]
            pub struct $func;
            impl crate::frame::reduce::ReduceKer<$ti, $params> for $func {
                #[inline(always)]
                fn name() -> &'static str {
                    stringify!($func)
                }
                #[inline(always)]
                fn nr() -> usize {
                    $nr
                }
                #[inline(always)]
                fn alignment_items() -> usize {
                    $alignment_items
                }
                #[inline(always)]
                fn alignment_bytes() -> usize {
                    $alignment_items * std::mem::size_of::<$ti>()
                }
                #[inline(always)]
                fn neutral() -> $ti {
                    $neutral
                }
                $run
                $reduce_two
            }
        }
    };
}
pub trait Reduce<T, Params = ()>: Send + Sync + Debug + dyn_clone::DynClone
where
    Params: Copy + Send + Sync + Debug + 'static + Default,
    T: Copy + Debug + PartialEq + Send + Sync,
{
    fn name(&self) -> &'static str;
    fn run(&self, vec: &[T]) -> TractResult<T> {
        self.run_with_params(vec, Params::default())
    }
    fn run_with_params(&self, vec: &[T], params: Params) -> TractResult<T>;
}
dyn_clone::clone_trait_object!(<T, Params> Reduce<T, Params> where T: Copy, Params: Copy);
#[derive(Debug, Clone, new)]
pub struct ReduceImpl<K, T, Params = ()>
where
    T: LADatum,
    Params: Copy + Send + Sync + Debug + 'static + Default,
    K: ReduceKer<T, Params> + Clone,
{
    phantom: PhantomData<(K, T, Params)>,
}
impl<K, T, Params> Reduce<T, Params> for ReduceImpl<K, T, Params>
where
    T: LADatum,
    Params: Copy + Send + Sync + Debug + 'static + Default,
    K: ReduceKer<T, Params> + Clone,
{
    fn name(&self) -> &'static str {
        K::name()
    }
    fn run_with_params(&self, vec: &[T], params: Params) -> TractResult<T> {
        reduce_slice_with_alignment(
            vec,
            |data| K::run(data, params),
            K::nr(),
            K::alignment_bytes(),
            K::neutral(),
            K::reduce_two,
        )
    }
}
pub trait ReduceKer<T, Params = ()>:
    Send + Sync + Debug + dyn_clone::DynClone + Clone + 'static
where
    Params: Copy + Send + Sync + Debug + 'static + Default,
    T: LADatum,
{
    fn name() -> &'static str;
    fn alignment_bytes() -> usize {
        Self::alignment_items() * T::datum_type().size_of()
    }
    fn alignment_items() -> usize;
    fn nr() -> usize;
    fn neutral() -> T;
    fn reduce_two(a: T, b: T) -> T;
    fn run(vec: &[T], params: Params) -> T;
    fn red() -> Box<dyn Reduce<T, Params>> {
        Box::new(ReduceImpl::<Self, T, Params>::new())
    }
}
#[allow(unused_macros)]
macro_rules! map_reduce_impl_wrap {
    ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $map_neutral: expr, $reduce_neutral: expr, $run: item, $reduce_two: item) => {
        paste! {
            #[derive(Copy, Clone, Debug)]
            #[allow(non_camel_case_types)]
            pub struct $func;
            impl crate::frame::reduce::MapReduceKer<$ti, $params> for $func {
                #[inline(always)]
                fn name() -> &'static str {
                    stringify!($func)
                }
                #[inline(always)]
                fn nr() -> usize {
                    $nr
                }
                #[inline(always)]
                fn alignment_items() -> usize {
                    $alignment_items
                }
                #[inline(always)]
                fn alignment_bytes() -> usize {
                    $alignment_items * std::mem::size_of::<$ti>()
                }
                #[inline(always)]
                fn map_neutral() -> $ti {
                    $map_neutral
                }
                #[inline(always)]
                fn reduce_neutral() -> $ti {
                    $reduce_neutral
                }
                $run
                $reduce_two
            }
        }
    };
}
pub trait MapReduce<T, Params = ()>: Send + Sync + Debug + dyn_clone::DynClone
where
    Params: Copy + Send + Sync + Debug + 'static + Default,
    T: Copy + Debug + PartialEq + Send + Sync,
{
    fn name(&self) -> &'static str;
    fn run(&self, vec: &mut [T]) -> TractResult<T> {
        self.run_with_params(vec, Params::default())
    }
    fn run_with_params(&self, vec: &mut [T], params: Params) -> TractResult<T>;
}
dyn_clone::clone_trait_object!(<T, Params> MapReduce<T, Params> where T: Copy, Params: Copy);
#[derive(Debug, Clone, new)]
pub struct MapReduceImpl<K, T, Params = ()>
where
    T: LADatum,
    Params: Copy + Send + Sync + Debug + 'static + Default,
    K: MapReduceKer<T, Params> + Clone,
{
    phantom: PhantomData<(K, T, Params)>,
}
impl<K, T, Params> MapReduce<T, Params> for MapReduceImpl<K, T, Params>
where
    T: LADatum,
    Params: Copy + Send + Sync + Debug + 'static + Default,
    K: MapReduceKer<T, Params> + Clone,
{
    fn name(&self) -> &'static str {
        K::name()
    }
    fn run_with_params(&self, vec: &mut [T], params: Params) -> TractResult<T> {
        map_reduce_slice_with_alignment(
            vec,
            |data| K::run(data, params),
            K::nr(),
            K::alignment_bytes(),
            K::map_neutral(),
            K::reduce_neutral(),
            K::reduce_two,
        )
    }
}
pub trait MapReduceKer<T, Params = ()>:
    Send + Sync + Debug + dyn_clone::DynClone + Clone + 'static
where
    Params: Copy + Send + Sync + Debug + 'static + Default,
    T: LADatum,
{
    fn name() -> &'static str;
    fn alignment_bytes() -> usize {
        Self::alignment_items() * T::datum_type().size_of()
    }
    fn alignment_items() -> usize;
    fn nr() -> usize;
    fn map_neutral() -> T;
    fn reduce_neutral() -> T;
    fn reduce_two(a: T, b: T) -> T;
    fn run(vec: &mut [T], params: Params) -> T;
    fn red() -> Box<dyn MapReduce<T, Params>> {
        Box::new(MapReduceImpl::<Self, T, Params>::new())
    }
}
#[cfg(test)]
pub mod test {
    use super::*;
    use proptest::test_runner::{TestCaseError, TestCaseResult};
    use tract_data::internal::*;
    use tract_data::itertools::Itertools;
    pub fn test_reduce<K: ReduceKer<T, ()>, T: LADatum>(
        values: &[T],
        neutral: T,
        reference_reduce: impl Fn(T, T) -> T,
    ) -> TestCaseResult {
        test_reduce_params::<K, T, ()>(values, neutral, reference_reduce, ())
    }
    pub fn test_reduce_params<K: ReduceKer<T, Params>, T: LADatum, Params>(
        values: &[T],
        neutral: T,
        reference_reducer: impl Fn(T, T) -> T,
        params: Params,
    ) -> TestCaseResult
    where
        Params: Copy + Send + Sync + Debug + 'static + Default,
    {
        crate::setup_test_logger();
        let op = K::red();
        let expected = values.iter().fold(neutral, |acc, i| reference_reducer(acc, *i));
        let mut found = values;
        let red = op.run_with_params(&mut found, params).unwrap();
        tensor0(red)
            .close_enough(&tensor0(expected), true)
            .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?;
        Ok(())
    }
    pub fn test_map_reduce<K: MapReduceKer<T, ()>, T: LADatum>(
        values: &[T],
        map_neutral: T,
        neutral: T,
        reference_map: impl Fn(T) -> T,
        reference_reduce: impl Fn(T, T) -> T,
    ) -> TestCaseResult {
        test_map_reduce_params::<K, T, ()>(
            values,
            map_neutral,
            neutral,
            reference_map,
            reference_reduce,
            (),
        )
    }
    pub fn test_map_reduce_params<K: MapReduceKer<T, Params>, T: LADatum, Params>(
        values: &[T],
        _neutral: T,
        map_neutral: T,
        reference_map: impl Fn(T) -> T,
        reference_reducer: impl Fn(T, T) -> T,
        params: Params,
    ) -> TestCaseResult
    where
        Params: Copy + Send + Sync + Debug + 'static + Default,
    {
        crate::setup_test_logger();
        let op = K::red();
        let mut found = values.to_vec();
        let expected_values = values.iter().copied().map(reference_map).collect_vec();
        let expected_reduced =
            expected_values.iter().fold(map_neutral, |acc, i| reference_reducer(acc, *i));
        let red = op.run_with_params(&mut found, params).unwrap();
        tensor1(&found)
            .close_enough(&tensor1(&expected_values), Approximation::SuperApproximate)
            .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?;
        tensor0(red)
            .close_enough(&tensor0(expected_reduced), Approximation::SuperApproximate)
            .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?;
        Ok(())
    }
}