tract_linalg/frame/
by_scalar.rs

1use std::fmt::Debug;
2use std::marker::PhantomData;
3
4use crate::element_wise::{ElementWise, ElementWiseKer};
5use crate::element_wise_helper::map_slice_with_alignment;
6use crate::{LADatum, LinalgFn};
7use tract_data::internal::*;
8
9/// Generic implementation struct that unify all by scalar kernels.
10/// A by scalar operation is an ElementWise operation with a scalar paramerer.
11#[derive(Debug, Clone, new)]
12pub struct ByScalarImpl<K, T>
13where
14    T: LADatum,
15    K: ByScalarKer<T> + Clone,
16{
17    phantom: PhantomData<(K, T)>,
18}
19
20impl<K, T> ElementWise<T, T> for ByScalarImpl<K, T>
21where
22    T: LADatum,
23    K: ByScalarKer<T> + Clone,
24{
25    fn name(&self) -> &'static str {
26        K::name()
27    }
28    fn run_with_params(&self, vec: &mut [T], params: T) -> TractResult<()> {
29        map_slice_with_alignment(vec, |data| K::run(data, params), K::nr(), K::alignment_bytes())
30    }
31}
32
33pub trait ByScalarKer<T>: ElementWiseKer<T, T>
34where
35    T: LADatum,
36{
37    fn bin() -> Box<LinalgFn> {
38        Box::new(|a: &mut TensorView, b: &TensorView| {
39            let a_slice = a.as_slice_mut()?;
40            let b = b.as_slice()?[0];
41            (Self::ew()).run_with_params(a_slice, b)
42        })
43    }
44}
45
46macro_rules! by_scalar_impl_wrap {
47    ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $run: item) => {
48        paste! {
49            ew_impl_wrap!($ti, $func, $nr, $alignment_items, $ti, $run);
50
51            impl crate::frame::by_scalar::ByScalarKer<$ti> for $func {}
52        }
53    };
54}
55
56#[cfg(test)]
57#[macro_use]
58pub mod test {
59    use crate::frame::element_wise::ElementWiseKer;
60    use crate::LADatum;
61    use num_traits::{AsPrimitive, Float};
62    use proptest::test_runner::TestCaseResult;
63
64    #[macro_export]
65    macro_rules! by_scalar_frame_tests {
66        ($cond:expr, $t: ty, $ker:ty, $func:expr) => {
67            paste::paste! {
68                proptest::proptest! {
69                    #[test]
70                    fn [<prop_ $ker:snake>](xs in proptest::collection::vec(-25f32..25.0, 0..100), scalar in -25f32..25f32) {
71                        if $cond {
72                            $crate::frame::by_scalar::test::test_by_scalar::<$ker, $t>(&*xs, scalar, $func).unwrap()
73                        }
74                    }
75                }
76            }
77        };
78    }
79
80    pub fn test_by_scalar<K: ElementWiseKer<T, T>, T: LADatum + Float>(
81        values: &[f32],
82        scalar: f32,
83        func: impl Fn(T, T) -> T,
84    ) -> TestCaseResult
85    where
86        f32: AsPrimitive<T>,
87    {
88        crate::setup_test_logger();
89        let values: Vec<T> = values.iter().copied().map(|x| x.as_()).collect();
90        crate::frame::element_wise::test::test_element_wise_params::<K, T, _, T>(
91            &values,
92            |a| (func)(a, scalar.as_()),
93            scalar.as_(),
94        )
95    }
96}