tract_linalg/frame/
lut.rs1use std::fmt;
2use std::hash::Hash;
3use std::marker::PhantomData;
4use tract_data::internal::*;
5
6pub trait Lut: fmt::Debug + dyn_clone::DynClone + Send + Sync {
7 fn table(&self) -> &[u8];
8 fn run(&self, buf: &mut [u8]);
9}
10
11dyn_clone::clone_trait_object!(Lut);
12
13#[derive(Debug, Clone, Hash)]
14pub struct LutImpl<K: LutKer> {
15 table: Tensor,
16 _boo: PhantomData<K>,
17}
18
19impl<K: LutKer> LutImpl<K> {
20 pub fn new(table: &[u8]) -> LutImpl<K> {
21 unsafe {
22 LutImpl {
23 table: Tensor::from_raw_aligned::<u8>(
24 &[table.len()],
25 table,
26 K::table_alignment_bytes(),
27 )
28 .unwrap(),
29 _boo: PhantomData,
30 }
31 }
32 }
33}
34
35impl<K: LutKer> Lut for LutImpl<K> {
36 fn table(&self) -> &[u8] {
37 self.table.as_slice().unwrap()
38 }
39
40 fn run(&self, buf: &mut [u8]) {
41 unsafe {
42 let table: *const u8 = self.table.as_ptr_unchecked();
43 let align = K::input_alignment_bytes();
44 let aligned_start = (buf.as_ptr() as usize).next_multiple_of(align);
45 let prefix = (aligned_start - buf.as_ptr() as usize).min(buf.len());
46 for i in 0..(prefix as isize) {
47 let ptr = buf.as_mut_ptr().offset(i);
48 *ptr = *table.offset(*ptr as isize);
49 }
50 let remaining = buf.len() - prefix;
51 if remaining == 0 {
52 return;
53 }
54 let n = K::n();
55 let aligned_len = remaining / n * n;
56 if aligned_len > 0 {
57 K::run(buf.as_mut_ptr().add(prefix), aligned_len, table);
58 }
59 let remaining = buf.len() - aligned_len - prefix;
60 for i in 0..remaining {
61 let ptr = buf.as_mut_ptr().add(i + prefix + aligned_len);
62 *ptr = *table.offset(*ptr as isize);
63 }
64 }
65 }
66}
67
68pub trait LutKer: Clone + fmt::Debug + Send + Sync + Hash {
69 fn name() -> &'static str;
70 fn n() -> usize;
71 fn input_alignment_bytes() -> usize;
72 fn table_alignment_bytes() -> usize;
73 unsafe fn run(buf: *mut u8, len: usize, table: *const u8);
74}
75
76#[cfg(test)]
77#[macro_use]
78pub mod test {
79 use super::*;
80 use proptest::prelude::*;
81
82 #[derive(Debug)]
83 pub struct LutProblem {
84 pub table: Vec<u8>,
85 pub data: Vec<u8>,
86 }
87
88 impl Arbitrary for LutProblem {
89 type Parameters = ();
90 type Strategy = BoxedStrategy<Self>;
91
92 fn arbitrary_with(_p: ()) -> Self::Strategy {
93 proptest::collection::vec(any::<u8>(), 1..256)
94 .prop_flat_map(|table| {
95 let data = proptest::collection::vec(0..table.len() as u8, 0..100);
96 (Just(table), data)
97 })
98 .prop_map(|(table, data)| LutProblem { table, data })
99 .boxed()
100 }
101 }
102
103 impl LutProblem {
104 pub fn reference(&self) -> Vec<u8> {
105 self.data.iter().map(|x| self.table[*x as usize]).collect()
106 }
107
108 pub fn test<K: LutKer>(&self) -> Vec<u8> {
109 let lut = LutImpl::<K>::new(&self.table);
110 let mut data = self.data.clone();
111 lut.run(&mut data);
112 data
113 }
114 }
115
116 #[macro_export]
117 macro_rules! lut_frame_tests {
118 ($cond:expr, $ker:ty) => {
119 mod lut {
120 use proptest::prelude::*;
121 #[allow(unused_imports)]
122 use $crate::frame::lut::test::*;
123
124 proptest::proptest! {
125 #[test]
126 fn lut_prop(pb in any::<LutProblem>()) {
127 if $cond {
128 prop_assert_eq!(pb.test::<$ker>(), pb.reference())
129 }
130 }
131 }
132
133 #[test]
134 fn test_empty() {
135 let pb = LutProblem { table: vec![0], data: vec![] };
136 assert_eq!(pb.test::<$ker>(), pb.reference())
137 }
138 }
139 };
140 }
141}