tract_linalg/frame/
unicast.rs

1use std::fmt::Debug;
2use std::marker::PhantomData;
3
4use tract_data::internal::TensorView;
5use tract_data::TractResult;
6
7use crate::frame::element_wise_helper::TempBuffer;
8use crate::{LADatum, LinalgFn};
9
10macro_rules! unicast_impl_wrap {
11    ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $run: item) => {
12        paste! {
13            #[derive(Copy, Clone, Debug)]
14            #[allow(non_camel_case_types)]
15            pub struct $func;
16
17            impl crate::frame::unicast::UnicastKer<$ti> for $func {
18                #[inline(always)]
19                fn name() -> &'static str {
20                    stringify!($func)
21                }
22                #[inline(always)]
23                fn nr() -> usize {
24                    $nr
25                }
26                #[inline(always)]
27                fn alignment_items() -> usize {
28                    $alignment_items
29                }
30                $run
31            }
32        }
33    };
34}
35
36pub trait Unicast<T>: Send + Sync + Debug + dyn_clone::DynClone
37where
38    T: Copy + Debug + PartialEq + Send + Sync,
39{
40    fn name(&self) -> &'static str;
41    fn run(&self, a: &mut [T], b: &[T]) -> TractResult<()>;
42}
43
44dyn_clone::clone_trait_object!(<T> Unicast<T> where T: Copy);
45
46#[derive(Debug, Clone, new)]
47pub struct UnicastImpl<K, T>
48where
49    T: LADatum,
50    K: UnicastKer<T> + Clone,
51{
52    phantom: PhantomData<(K, T)>,
53}
54
55impl<K, T> UnicastImpl<K, T>
56where
57    T: LADatum,
58    K: UnicastKer<T> + Clone,
59{
60}
61impl<K, T> Unicast<T> for UnicastImpl<K, T>
62where
63    T: LADatum,
64    K: UnicastKer<T> + Clone,
65{
66    fn name(&self) -> &'static str {
67        K::name()
68    }
69    fn run(&self, a: &mut [T], b: &[T]) -> TractResult<()> {
70        unicast_with_alignment(a, b, |a, b| K::run(a, b), K::nr(), K::alignment_bytes())
71    }
72}
73
74pub trait UnicastKer<T>: Send + Sync + Debug + dyn_clone::DynClone + Clone + 'static
75where
76    T: LADatum,
77{
78    fn name() -> &'static str;
79    fn alignment_bytes() -> usize {
80        Self::alignment_items() * T::datum_type().size_of()
81    }
82    fn alignment_items() -> usize;
83    fn nr() -> usize;
84    fn run(a: &mut [T], b: &[T]);
85    fn bin() -> Box<LinalgFn> {
86        Box::new(|a: &mut TensorView, b: &TensorView| {
87            let a_slice = a.as_slice_mut()?;
88            let b_slice = b.as_slice()?;
89            UnicastImpl::<Self, T>::new().run(a_slice, b_slice)
90        })
91    }
92}
93
94std::thread_local! {
95    static TMP: std::cell::RefCell<(TempBuffer, TempBuffer)> = std::cell::RefCell::new((TempBuffer::default(), TempBuffer::default()));
96}
97
98pub(crate) fn unicast_with_alignment<T>(
99    a: &mut [T],
100    b: &[T],
101    f: impl Fn(&mut [T], &[T]),
102    nr: usize,
103    alignment_bytes: usize,
104) -> TractResult<()>
105where
106    T: LADatum,
107{
108    if a.is_empty() {
109        return Ok(());
110    }
111    unsafe {
112        TMP.with(|buffers| {
113            let mut buffers = buffers.borrow_mut();
114            buffers.0.ensure(nr * T::datum_type().size_of(), alignment_bytes);
115            buffers.1.ensure(nr * T::datum_type().size_of(), alignment_bytes);
116            let tmp_a = std::slice::from_raw_parts_mut(buffers.0.buffer as *mut T, nr);
117            let tmp_b = std::slice::from_raw_parts_mut(buffers.1.buffer as *mut T, nr);
118            let mut compute_via_temp_buffer = |a: &mut [T], b: &[T]| {
119                tmp_a[..a.len()].copy_from_slice(a);
120                tmp_b[..b.len()].copy_from_slice(b);
121                f(tmp_a, tmp_b);
122                a.copy_from_slice(&tmp_a[..a.len()])
123            };
124
125            let mut num_element_processed = 0;
126            let a_prefix_len = a.as_ptr().align_offset(alignment_bytes).min(a.len());
127            let b_prefix_len = b.as_ptr().align_offset(alignment_bytes).min(b.len());
128            assert!(
129                a_prefix_len == b_prefix_len,
130                "Both inputs should be of the same alignement, got {a_prefix_len:?}, {b_prefix_len:?}"
131            );
132            let mut applied_prefix_len = 0;
133            if a_prefix_len > 0 {
134                // Incomplete tile needs to be created to process unaligned data.
135                let sub_a = &mut a[..a_prefix_len];
136                let sub_b = &b[..a_prefix_len];
137                compute_via_temp_buffer(sub_a, sub_b);
138                num_element_processed += a_prefix_len;
139                applied_prefix_len = a_prefix_len;
140            }
141
142            let num_complete_tiles = (a.len() - applied_prefix_len) / nr;
143            if num_complete_tiles > 0 {
144                // Process all tiles that are complete.
145                let sub_a = &mut a[applied_prefix_len..][..(num_complete_tiles * nr)];
146                let sub_b = &b[applied_prefix_len..][..(num_complete_tiles * nr)];
147                f(sub_a, sub_b);
148                num_element_processed += num_complete_tiles * nr;
149            }
150
151            if num_element_processed < a.len() {
152                // Incomplete tile needs to be created to process remaining elements.
153                compute_via_temp_buffer(
154                    &mut a[num_element_processed..],
155                    &b[num_element_processed..],
156                );
157            }
158        })
159    }
160    Ok(())
161}
162
163#[cfg(test)]
164#[macro_use]
165pub mod test {
166    use super::*;
167    use crate::LADatum;
168    use proptest::test_runner::{TestCaseError, TestCaseResult};
169    use tract_data::internal::*;
170    use tract_num_traits::{AsPrimitive, Float};
171
172    pub fn test_unicast<K: UnicastKer<T>, T: LADatum>(
173        a: &mut [T],
174        b: &[T],
175        reference: impl Fn(T, T) -> T,
176    ) -> TestCaseResult {
177        crate::setup_test_logger();
178        let op = UnicastImpl::<K, T>::new();
179        let expected = a.iter().zip(b.iter()).map(|(a, b)| (reference)(*a, *b)).collect::<Vec<_>>();
180        op.run(a, b).unwrap();
181        tensor1(a)
182            .close_enough(&tensor1(&expected), true)
183            .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?;
184        Ok(())
185    }
186
187    pub fn test_unicast_t<K: UnicastKer<T>, T: LADatum + Float>(
188        a: &[f32],
189        b: &[f32],
190        func: impl Fn(T, T) -> T,
191    ) -> TestCaseResult
192    where
193        f32: AsPrimitive<T>,
194    {
195        crate::setup_test_logger();
196        let vec_a: Vec<T> = a.iter().copied().map(|x| x.as_()).collect();
197        // We allocate a tensor to ensure allocation is done with alignement
198        let mut a = unsafe { Tensor::from_slice_align(vec_a.as_slice(), vector_size()).unwrap() };
199        let vec_b: Vec<T> = b.iter().copied().map(|x| x.as_()).collect();
200        // We allocate a tensor to ensure allocation is done with alignement
201        let b = unsafe { Tensor::from_slice_align(vec_b.as_slice(), vector_size()).unwrap() };
202        crate::frame::unicast::test::test_unicast::<K, _>(
203            a.as_slice_mut::<T>().unwrap(),
204            b.as_slice::<T>().unwrap(),
205            func,
206        )
207    }
208
209    #[macro_export]
210    macro_rules! unicast_frame_tests {
211        ($cond:expr, $t: ty, $ker:ty, $func:expr) => {
212            paste::paste! {
213                proptest::proptest! {
214                    #[test]
215                    fn [<prop_ $ker:snake>](
216                        (a, b) in (0..100_usize).prop_flat_map(|len| (vec![-25f32..25.0; len], vec![-25f32..25.0; len]))
217                    ) {
218                        if $cond {
219                            $crate::frame::unicast::test::test_unicast_t::<$ker, $t>(&*a, &*b, $func).unwrap()
220                        }
221                    }
222                }
223
224                #[test]
225                fn [<empty_ $ker:snake>]() {
226                    if $cond {
227                        $crate::frame::unicast::test::test_unicast_t::<$ker, $t>(&[], &[], $func).unwrap()
228                    }
229                }
230            }
231        };
232    }
233}