tract_linalg/frame/
by_scalar.rs

1use std::{fmt::Debug, marker::PhantomData};
2
3use tract_data::{TractResult, internal::TensorView};
4
5use crate::{LADatum, element_wise::ElementWiseKer, LinalgFn};
6
7use super::{ElementWise, element_wise_helper::map_slice_with_alignment};
8
9
10/// Generic implementation struct that unify all by scalar kernels.
11/// A by scalar operation is an ElementWise operation with a scalar paramerer.
12#[derive(Debug, Clone, new)]
13pub struct ByScalarImpl<K, T>
14where
15    T: LADatum,
16    K: ByScalarKer<T> + Clone,
17{
18    phantom: PhantomData<(K, T)>,
19}
20
21impl<K, T> ElementWise<T, T> for ByScalarImpl<K, T>
22where
23    T: LADatum,
24    K: ByScalarKer<T> + Clone,
25{
26    fn name(&self) -> &'static str {
27        K::name()
28    }
29    fn run_with_params(&self, vec: &mut [T], params: T) -> TractResult<()> {
30        map_slice_with_alignment(vec, |data| K::run(data, params), K::nr(), K::alignment_bytes())
31    }
32}
33
34
35pub trait ByScalarKer<T>: ElementWiseKer<T, T>
36where
37    T: LADatum
38{
39    fn bin() -> Box<LinalgFn> {
40        Box::new(|a: &mut TensorView, b: &TensorView| {
41            let a_slice = a.as_slice_mut()?;
42            let b = b.as_slice()?[0];
43            (Self::ew()).run_with_params(a_slice, b)
44        })
45    }
46}
47
48macro_rules! by_scalar_impl_wrap {
49    ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $run: item) => {
50        paste! {
51            ew_impl_wrap!($ti, $func, $nr, $alignment_items, $ti, $run);
52
53            impl crate::frame::by_scalar::ByScalarKer<$ti> for $func {}
54        }
55    };
56}
57
58
59#[cfg(test)]
60#[macro_use]
61pub mod test {
62    use crate::frame::element_wise::ElementWiseKer;
63    use crate::LADatum;
64    use num_traits::{AsPrimitive, Float};
65    use proptest::test_runner::TestCaseResult;
66
67    #[macro_export]
68    macro_rules! by_scalar_frame_tests {
69        ($cond:expr, $t: ty, $ker:ty, $func:expr) => {
70            paste::paste! {
71                proptest::proptest! {
72                    #[test]
73                    fn [<prop_ $ker:snake>](xs in proptest::collection::vec(-25f32..25.0, 0..100), scalar in -25f32..25f32) {
74                        if $cond {
75                            $crate::frame::by_scalar::test::test_by_scalar::<$ker, $t>(&*xs, scalar, $func).unwrap()
76                        }
77                    }
78                }
79            }
80        };
81    }
82
83    pub fn test_by_scalar<K: ElementWiseKer<T, T>, T: LADatum + Float>(
84        values: &[f32],
85        scalar: f32,
86        func: impl Fn(T, T) -> T,
87    ) -> TestCaseResult
88    where
89        f32: AsPrimitive<T>,
90    {
91        crate::setup_test_logger();
92        let values: Vec<T> = values.iter().copied().map(|x| x.as_()).collect();
93        crate::frame::element_wise::test::test_element_wise_params::<K, T, _, T>(
94            &values,
95            |a| (func)(a, scalar.as_()),
96            scalar.as_(),
97        )
98    }
99}