tract_linalg/frame/
lut.rs1use std::fmt;
2use std::hash::Hash;
3use std::marker::PhantomData;
4use tract_data::internal::*;
5
6pub trait Lut: fmt::Debug + dyn_clone::DynClone + Send + Sync {
7 fn table(&self) -> &[u8];
8 fn run(&self, buf: &mut [u8]);
9}
10
11dyn_clone::clone_trait_object!(Lut);
12
13impl PartialEq for dyn Lut {
14 fn eq(&self, other: &Self) -> bool {
15 self.table() == other.table()
16 }
17}
18impl Eq for dyn Lut {}
19
20#[derive(Debug, Clone, Hash)]
21pub struct LutImpl<K: LutKer> {
22 table: Tensor,
23 _boo: PhantomData<K>,
24}
25
26impl<K: LutKer> LutImpl<K> {
27 pub fn new(table: &[u8]) -> LutImpl<K> {
28 unsafe {
29 LutImpl {
30 table: Tensor::from_raw_aligned::<u8>(
31 &[table.len()],
32 table,
33 K::table_alignment_bytes(),
34 )
35 .unwrap(),
36 _boo: PhantomData,
37 }
38 }
39 }
40}
41
42impl<K: LutKer> Lut for LutImpl<K> {
43 fn table(&self) -> &[u8] {
44 self.table.try_as_plain().unwrap().as_slice().unwrap()
45 }
46
47 fn run(&self, buf: &mut [u8]) {
48 unsafe {
49 let table: *const u8 = self.table.as_ptr_unchecked();
50 let align = K::input_alignment_bytes();
51 let aligned_start = (buf.as_ptr() as usize).next_multiple_of(align);
52 let prefix = (aligned_start - buf.as_ptr() as usize).min(buf.len());
53 for i in 0..(prefix as isize) {
54 let ptr = buf.as_mut_ptr().offset(i);
55 *ptr = *table.offset(*ptr as isize);
56 }
57 let remaining = buf.len() - prefix;
58 if remaining == 0 {
59 return;
60 }
61 let n = K::n();
62 let aligned_len = remaining / n * n;
63 if aligned_len > 0 {
64 K::run(buf.as_mut_ptr().add(prefix), aligned_len, table);
65 }
66 let remaining = buf.len() - aligned_len - prefix;
67 for i in 0..remaining {
68 let ptr = buf.as_mut_ptr().add(i + prefix + aligned_len);
69 *ptr = *table.offset(*ptr as isize);
70 }
71 }
72 }
73}
74
75pub trait LutKer: Clone + fmt::Debug + Send + Sync + Hash {
76 fn name() -> &'static str;
77 fn n() -> usize;
78 fn input_alignment_bytes() -> usize;
79 fn table_alignment_bytes() -> usize;
80 unsafe fn run(buf: *mut u8, len: usize, table: *const u8);
81}
82
83#[cfg(test)]
84#[macro_use]
85pub mod test {
86 use super::*;
87 use proptest::prelude::*;
88
89 #[derive(Debug)]
90 pub struct LutProblem {
91 pub table: Vec<u8>,
92 pub data: Vec<u8>,
93 }
94
95 impl Arbitrary for LutProblem {
96 type Parameters = ();
97 type Strategy = BoxedStrategy<Self>;
98
99 fn arbitrary_with(_p: ()) -> Self::Strategy {
100 proptest::collection::vec(any::<u8>(), 1..256)
101 .prop_flat_map(|table| {
102 let data = proptest::collection::vec(0..table.len() as u8, 0..100);
103 (Just(table), data)
104 })
105 .prop_map(|(table, data)| LutProblem { table, data })
106 .boxed()
107 }
108 }
109
110 impl LutProblem {
111 pub fn reference(&self) -> Vec<u8> {
112 self.data.iter().map(|x| self.table[*x as usize]).collect()
113 }
114
115 pub fn test<K: LutKer>(&self) -> Vec<u8> {
116 let lut = LutImpl::<K>::new(&self.table);
117 let mut data = self.data.clone();
118 lut.run(&mut data);
119 data
120 }
121 }
122
123 #[macro_export]
124 macro_rules! lut_frame_tests {
125 ($cond:expr, $ker:ty) => {
126 mod lut {
127 use proptest::prelude::*;
128 #[allow(unused_imports)]
129 use $crate::frame::lut::test::*;
130
131 proptest::proptest! {
132 #[test]
133 fn lut_prop(pb in any::<LutProblem>()) {
134 if $cond {
135 prop_assert_eq!(pb.test::<$ker>(), pb.reference())
136 }
137 }
138 }
139
140 #[test]
141 fn test_empty() {
142 let pb = LutProblem { table: vec![0], data: vec![] };
143 assert_eq!(pb.test::<$ker>(), pb.reference())
144 }
145 }
146 };
147 }
148}