tract_linalg/frame/
lut.rs

1use std::fmt;
2use std::hash::Hash;
3use std::marker::PhantomData;
4use tract_data::internal::*;
5
6pub trait Lut: fmt::Debug + dyn_clone::DynClone + Send + Sync {
7    fn table(&self) -> &[u8];
8    fn run(&self, buf: &mut [u8]);
9}
10
11dyn_clone::clone_trait_object!(Lut);
12
13#[derive(Debug, Clone, Hash)]
14pub struct LutImpl<K: LutKer> {
15    table: Tensor,
16    _boo: PhantomData<K>,
17}
18
19impl<K: LutKer> LutImpl<K> {
20    pub fn new(table: &[u8]) -> LutImpl<K> {
21        unsafe {
22            LutImpl {
23                table: Tensor::from_raw_aligned::<u8>(
24                    &[table.len()],
25                    table,
26                    K::table_alignment_bytes(),
27                )
28                .unwrap(),
29                _boo: PhantomData,
30            }
31        }
32    }
33}
34
35impl<K: LutKer> Lut for LutImpl<K> {
36    fn table(&self) -> &[u8] {
37        self.table.as_slice().unwrap()
38    }
39
40    fn run(&self, buf: &mut [u8]) {
41        unsafe {
42            let table: *const u8 = self.table.as_ptr_unchecked();
43            let align = K::input_alignment_bytes();
44            let aligned_start = (buf.as_ptr() as usize).next_multiple_of(align);
45            let prefix = (aligned_start - buf.as_ptr() as usize).min(buf.len());
46            for i in 0..(prefix as isize) {
47                let ptr = buf.as_mut_ptr().offset(i);
48                *ptr = *table.offset(*ptr as isize);
49            }
50            let remaining = buf.len() - prefix;
51            if remaining == 0 {
52                return;
53            }
54            let n = K::n();
55            let aligned_len = remaining / n * n;
56            if aligned_len > 0 {
57                K::run(buf.as_mut_ptr().add(prefix), aligned_len, table);
58            }
59            let remaining = buf.len() - aligned_len - prefix;
60            for i in 0..remaining {
61                let ptr = buf.as_mut_ptr().add(i + prefix + aligned_len);
62                *ptr = *table.offset(*ptr as isize);
63            }
64        }
65    }
66}
67
68pub trait LutKer: Clone + fmt::Debug + Send + Sync + Hash {
69    fn name() -> &'static str;
70    fn n() -> usize;
71    fn input_alignment_bytes() -> usize;
72    fn table_alignment_bytes() -> usize;
73    unsafe fn run(buf: *mut u8, len: usize, table: *const u8);
74}
75
76#[cfg(test)]
77#[macro_use]
78pub mod test {
79    use super::*;
80    use proptest::prelude::*;
81
82    #[derive(Debug)]
83    pub struct LutProblem {
84        pub table: Vec<u8>,
85        pub data: Vec<u8>,
86    }
87
88    impl Arbitrary for LutProblem {
89        type Parameters = ();
90        type Strategy = BoxedStrategy<Self>;
91
92        fn arbitrary_with(_p: ()) -> Self::Strategy {
93            proptest::collection::vec(any::<u8>(), 1..256)
94                .prop_flat_map(|table| {
95                    let data = proptest::collection::vec(0..table.len() as u8, 0..100);
96                    (Just(table), data)
97                })
98                .prop_map(|(table, data)| LutProblem { table, data })
99                .boxed()
100        }
101    }
102
103    impl LutProblem {
104        pub fn reference(&self) -> Vec<u8> {
105            self.data.iter().map(|x| self.table[*x as usize]).collect()
106        }
107
108        pub fn test<K: LutKer>(&self) -> Vec<u8> {
109            let lut = LutImpl::<K>::new(&self.table);
110            let mut data = self.data.clone();
111            lut.run(&mut data);
112            data
113        }
114    }
115
116    #[macro_export]
117    macro_rules! lut_frame_tests {
118        ($cond:expr, $ker:ty) => {
119            mod lut {
120                use proptest::prelude::*;
121                #[allow(unused_imports)]
122                use $crate::frame::lut::test::*;
123
124                proptest::proptest! {
125                    #[test]
126                    fn lut_prop(pb in any::<LutProblem>()) {
127                        if $cond {
128                            prop_assert_eq!(pb.test::<$ker>(), pb.reference())
129                        }
130                    }
131                }
132
133                #[test]
134                fn test_empty() {
135                    let pb = LutProblem { table: vec![0], data: vec![] };
136                    assert_eq!(pb.test::<$ker>(), pb.reference())
137                }
138            }
139        };
140    }
141}