use std::fmt;
use std::hash::Hash;
use std::marker::PhantomData;
use tract_data::internal::*;
pub trait Lut: fmt::Debug + dyn_clone::DynClone + Send + Sync {
    fn table(&self) -> &[u8];
    fn run(&self, buf: &mut [u8]);
}
dyn_clone::clone_trait_object!(Lut);
#[derive(Debug, Clone, Hash)]
pub struct LutImpl<K: LutKer> {
    table: Tensor,
    _boo: PhantomData<K>,
}
impl<K: LutKer> LutImpl<K> {
    pub fn new(table: &[u8]) -> LutImpl<K> {
        unsafe {
            LutImpl {
                table: Tensor::from_raw_aligned::<u8>(
                    &[table.len()],
                    table,
                    K::table_alignment_bytes(),
                )
                .unwrap(),
                _boo: PhantomData,
            }
        }
    }
}
impl<K: LutKer> Lut for LutImpl<K> {
    fn table(&self) -> &[u8] {
        self.table.as_slice().unwrap()
    }
    fn run(&self, buf: &mut [u8]) {
        unsafe {
            let table: *const u8 = self.table.as_ptr_unchecked();
            let align = K::input_alignment_bytes();
            let aligned_start = (buf.as_ptr() as usize + align - 1) / align * align;
            let prefix = (aligned_start - buf.as_ptr() as usize).min(buf.len());
            for i in 0..(prefix as isize) {
                let ptr = buf.as_mut_ptr().offset(i);
                *ptr = *table.offset(*ptr as isize);
            }
            let remaining = buf.len() - prefix;
            if remaining == 0 {
                return;
            }
            let n = K::n();
            let aligned_len = remaining / n * n;
            if aligned_len > 0 {
                K::run(buf.as_mut_ptr().add(prefix), aligned_len, table);
            }
            let remaining = buf.len() - aligned_len - prefix;
            for i in 0..remaining {
                let ptr = buf.as_mut_ptr().add(i + prefix + aligned_len);
                *ptr = *table.offset(*ptr as isize);
            }
        }
    }
}
pub trait LutKer: Clone + fmt::Debug + Send + Sync + Hash {
    fn name() -> &'static str;
    fn n() -> usize;
    fn input_alignment_bytes() -> usize;
    fn table_alignment_bytes() -> usize;
    unsafe fn run(buf: *mut u8, len: usize, table: *const u8);
}
#[cfg(test)]
#[macro_use]
pub mod test {
    use super::*;
    use proptest::prelude::*;
    #[derive(Debug)]
    pub struct LutProblem {
        pub table: Vec<u8>,
        pub data: Vec<u8>,
    }
    impl Arbitrary for LutProblem {
        type Parameters = ();
        type Strategy = BoxedStrategy<Self>;
        fn arbitrary_with(_p: ()) -> Self::Strategy {
            proptest::collection::vec(any::<u8>(), 1..256)
                .prop_flat_map(|table| {
                    let data = proptest::collection::vec(0..table.len() as u8, 0..100);
                    (Just(table), data)
                })
                .prop_map(|(table, data)| LutProblem { table, data })
                .boxed()
        }
    }
    impl LutProblem {
        pub fn reference(&self) -> Vec<u8> {
            self.data.iter().map(|x| self.table[*x as usize]).collect()
        }
        pub fn test<K: LutKer>(&self) -> Vec<u8> {
            let lut = LutImpl::<K>::new(&self.table);
            let mut data = self.data.clone();
            lut.run(&mut data);
            data
        }
    }
    #[macro_export]
    macro_rules! lut_frame_tests {
        ($cond:expr, $ker:ty) => {
            mod lut {
                use proptest::prelude::*;
                #[allow(unused_imports)]
                use $crate::frame::lut::test::*;
                proptest::proptest! {
                    #[test]
                    fn lut_prop(pb in any::<LutProblem>()) {
                        if $cond {
                            prop_assert_eq!(pb.test::<$ker>(), pb.reference())
                        }
                    }
                }
                #[test]
                fn test_empty() {
                    let pb = LutProblem { table: vec![0], data: vec![] };
                    assert_eq!(pb.test::<$ker>(), pb.reference())
                }
            }
        };
    }
}