Skip to main content

tract_linalg/frame/
lut.rs

1use std::fmt;
2use std::hash::Hash;
3use std::marker::PhantomData;
4use tract_data::internal::*;
5
6pub trait Lut: fmt::Debug + dyn_clone::DynClone + Send + Sync {
7    fn table(&self) -> &[u8];
8    fn run(&self, buf: &mut [u8]);
9}
10
11dyn_clone::clone_trait_object!(Lut);
12
13impl PartialEq for dyn Lut {
14    fn eq(&self, other: &Self) -> bool {
15        self.table() == other.table()
16    }
17}
18impl Eq for dyn Lut {}
19
20#[derive(Debug, Clone, Hash)]
21pub struct LutImpl<K: LutKer> {
22    table: Tensor,
23    _boo: PhantomData<K>,
24}
25
26impl<K: LutKer> LutImpl<K> {
27    pub fn new(table: &[u8]) -> LutImpl<K> {
28        unsafe {
29            LutImpl {
30                table: Tensor::from_raw_aligned::<u8>(
31                    &[table.len()],
32                    table,
33                    K::table_alignment_bytes(),
34                )
35                .unwrap(),
36                _boo: PhantomData,
37            }
38        }
39    }
40}
41
42impl<K: LutKer> Lut for LutImpl<K> {
43    fn table(&self) -> &[u8] {
44        self.table.try_as_plain().unwrap().as_slice().unwrap()
45    }
46
47    fn run(&self, buf: &mut [u8]) {
48        unsafe {
49            let table: *const u8 = self.table.as_ptr_unchecked();
50            let align = K::input_alignment_bytes();
51            let aligned_start = (buf.as_ptr() as usize).next_multiple_of(align);
52            let prefix = (aligned_start - buf.as_ptr() as usize).min(buf.len());
53            for i in 0..(prefix as isize) {
54                let ptr = buf.as_mut_ptr().offset(i);
55                *ptr = *table.offset(*ptr as isize);
56            }
57            let remaining = buf.len() - prefix;
58            if remaining == 0 {
59                return;
60            }
61            let n = K::n();
62            let aligned_len = remaining / n * n;
63            if aligned_len > 0 {
64                K::run(buf.as_mut_ptr().add(prefix), aligned_len, table);
65            }
66            let remaining = buf.len() - aligned_len - prefix;
67            for i in 0..remaining {
68                let ptr = buf.as_mut_ptr().add(i + prefix + aligned_len);
69                *ptr = *table.offset(*ptr as isize);
70            }
71        }
72    }
73}
74
75pub trait LutKer: Clone + fmt::Debug + Send + Sync + Hash {
76    fn name() -> &'static str;
77    fn n() -> usize;
78    fn input_alignment_bytes() -> usize;
79    fn table_alignment_bytes() -> usize;
80    unsafe fn run(buf: *mut u8, len: usize, table: *const u8);
81}
82
83#[cfg(test)]
84#[macro_use]
85pub mod test {
86    use super::*;
87    use proptest::prelude::*;
88
89    #[derive(Debug)]
90    pub struct LutProblem {
91        pub table: Vec<u8>,
92        pub data: Vec<u8>,
93    }
94
95    impl Arbitrary for LutProblem {
96        type Parameters = ();
97        type Strategy = BoxedStrategy<Self>;
98
99        fn arbitrary_with(_p: ()) -> Self::Strategy {
100            proptest::collection::vec(any::<u8>(), 1..256)
101                .prop_flat_map(|table| {
102                    let data = proptest::collection::vec(0..table.len() as u8, 0..100);
103                    (Just(table), data)
104                })
105                .prop_map(|(table, data)| LutProblem { table, data })
106                .boxed()
107        }
108    }
109
110    impl LutProblem {
111        pub fn reference(&self) -> Vec<u8> {
112            self.data.iter().map(|x| self.table[*x as usize]).collect()
113        }
114
115        pub fn test<K: LutKer>(&self) -> Vec<u8> {
116            let lut = LutImpl::<K>::new(&self.table);
117            let mut data = self.data.clone();
118            lut.run(&mut data);
119            data
120        }
121    }
122
123    #[macro_export]
124    macro_rules! lut_frame_tests {
125        ($cond:expr, $ker:ty) => {
126            mod lut {
127                use proptest::prelude::*;
128                #[allow(unused_imports)]
129                use $crate::frame::lut::test::*;
130
131                proptest::proptest! {
132                    #[test]
133                    fn lut_prop(pb in any::<LutProblem>()) {
134                        if $cond {
135                            prop_assert_eq!(pb.test::<$ker>(), pb.reference())
136                        }
137                    }
138                }
139
140                #[test]
141                fn test_empty() {
142                    let pb = LutProblem { table: vec![0], data: vec![] };
143                    assert_eq!(pb.test::<$ker>(), pb.reference())
144                }
145            }
146        };
147    }
148}