tract_linalg/frame/
element_wise.rs

1use std::fmt::Debug;
2use std::marker::PhantomData;
3
4use tract_data::TractResult;
5
6use crate::LADatum;
7
8use super::element_wise_helper::map_slice_with_alignment;
9
10macro_rules! ew_impl_wrap {
11    ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $run: item) => {
12        paste! {
13            #[derive(Copy, Clone, Debug)]
14            #[allow(non_camel_case_types)]
15            pub struct $func;
16
17            impl crate::frame::element_wise::ElementWiseKer<$ti, $params> for $func {
18                #[inline(always)]
19                fn name() -> &'static str {
20                    stringify!($func)
21                }
22                #[inline(always)]
23                fn nr() -> usize {
24                    $nr
25                }
26                #[inline(always)]
27                fn alignment_items() -> usize {
28                    $alignment_items
29                }
30                $run
31            }
32        }
33    };
34}
35
36macro_rules! ew_impl {
37    ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr) => {
38        paste! {
39            mod [<sys_ $func>] {
40                #[allow(unused_imports)]
41                use tract_data::prelude::f16;
42                extern_kernel!(fn $func(ptr: *mut $ti, count: usize) -> ());
43            }
44            ew_impl_wrap!($ti, $func, $nr, $alignment_items, (),
45                #[inline(never)]
46                fn run(buf: &mut [$ti], _params: ()) {
47                    unsafe { [<sys_ $func>]::$func(buf.as_mut_ptr(), buf.len()) }
48                }
49            );
50        }
51    };
52    ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty) => {
53        paste! {
54            mod [<sys_ $func>] {
55                #[allow(unused_imports)]
56                use tract_data::prelude::f16;
57                extern_kernel!(fn $func(ptr: *mut $ti, count: usize, params: $params) -> ());
58            }
59            ew_impl_wrap!($ti, $func, $nr, $alignment_items, $params,
60                #[inline(never)]
61                fn run(buf: &mut [$ti], params: $params) {
62                    unsafe { [<sys_ $func>]::$func(buf.as_mut_ptr(), buf.len(), params) }
63                }
64            );
65        }
66    };
67}
68
69pub trait ElementWise<T, Params = ()>: Send + Sync + Debug + dyn_clone::DynClone
70where
71    Params: Copy + Send + Sync + Debug + 'static + Default,
72    T: Copy + Debug + PartialEq + Send + Sync,
73{
74    fn name(&self) -> &'static str;
75    fn run(&self, vec: &mut [T]) -> TractResult<()> {
76        self.run_with_params(vec, Params::default())
77    }
78    fn run_with_params(&self, vec: &mut [T], params: Params) -> TractResult<()>;
79}
80
81dyn_clone::clone_trait_object!(<T, Params> ElementWise<T, Params> where T: Copy, Params: Copy);
82
83#[derive(Debug, Clone, new)]
84pub struct ElementWiseImpl<K, T, Params = ()>
85where
86    T: LADatum,
87    Params: Copy + Send + Sync + Debug + 'static + Default,
88    K: ElementWiseKer<T, Params> + Clone,
89{
90    phantom: PhantomData<(K, T, Params)>,
91}
92
93impl<K, T, Params> ElementWise<T, Params> for ElementWiseImpl<K, T, Params>
94where
95    T: LADatum,
96    Params: Copy + Send + Sync + Debug + 'static + Default,
97    K: ElementWiseKer<T, Params> + Clone,
98{
99    fn name(&self) -> &'static str {
100        K::name()
101    }
102    fn run_with_params(&self, vec: &mut [T], params: Params) -> TractResult<()> {
103        map_slice_with_alignment(vec, |data| K::run(data, params), K::nr(), K::alignment_bytes())
104    }
105}
106
107pub trait ElementWiseKer<T, Params = ()>:
108    Send + Sync + Debug + dyn_clone::DynClone + Clone + 'static
109where
110    Params: Copy + Send + Sync + Debug + 'static + Default,
111    T: LADatum,
112{
113    fn name() -> &'static str;
114    fn alignment_bytes() -> usize {
115        Self::alignment_items() * T::datum_type().size_of()
116    }
117    fn alignment_items() -> usize;
118    fn nr() -> usize;
119    fn run(vec: &mut [T], params: Params);
120    fn ew() -> Box<dyn ElementWise<T, Params>> {
121        Box::new(ElementWiseImpl::<Self, T, Params>::new())
122    }
123}
124
125#[cfg(test)]
126pub mod test {
127    use crate::{frame::element_wise::*, LADatum};
128    use proptest::test_runner::{TestCaseError, TestCaseResult};
129    use tract_data::internal::*;
130
131    pub fn test_element_wise<K: ElementWiseKer<T, ()>, T: LADatum, F: Fn(T) -> T>(
132        values: &[T],
133        reference: F,
134    ) -> TestCaseResult {
135        test_element_wise_params::<K, T, F, ()>(values, reference, ())
136    }
137
138    pub fn test_element_wise_params<
139        K: ElementWiseKer<T, Params>,
140        T: LADatum,
141        F: Fn(T) -> T,
142        Params,
143    >(
144        values: &[T],
145        reference: F,
146        params: Params,
147    ) -> TestCaseResult
148    where
149        Params: Copy + Send + Sync + Debug + 'static + Default,
150    {
151        crate::setup_test_logger();
152        let op = ElementWiseImpl::<K, T, Params>::new();
153        let mut values = values.to_vec();
154        while values.len() < K::nr() {
155            values.push(T::zero());
156        }
157        let expected = values.iter().copied().map(reference).collect::<Vec<_>>();
158        let mut found = values;
159        op.run_with_params(&mut found, params).unwrap();
160        tensor1(&found)
161            .close_enough(&tensor1(&expected), true)
162            .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?;
163        Ok(())
164    }
165}