tract_linalg/frame/
by_scalar.rs1use std::{fmt::Debug, marker::PhantomData};
2
3use tract_data::{TractResult, internal::TensorView};
4
5use crate::{LADatum, element_wise::ElementWiseKer, LinalgFn};
6
7use super::{ElementWise, element_wise_helper::map_slice_with_alignment};
8
9
10#[derive(Debug, Clone, new)]
13pub struct ByScalarImpl<K, T>
14where
15 T: LADatum,
16 K: ByScalarKer<T> + Clone,
17{
18 phantom: PhantomData<(K, T)>,
19}
20
21impl<K, T> ElementWise<T, T> for ByScalarImpl<K, T>
22where
23 T: LADatum,
24 K: ByScalarKer<T> + Clone,
25{
26 fn name(&self) -> &'static str {
27 K::name()
28 }
29 fn run_with_params(&self, vec: &mut [T], params: T) -> TractResult<()> {
30 map_slice_with_alignment(vec, |data| K::run(data, params), K::nr(), K::alignment_bytes())
31 }
32}
33
34
35pub trait ByScalarKer<T>: ElementWiseKer<T, T>
36where
37 T: LADatum
38{
39 fn bin() -> Box<LinalgFn> {
40 Box::new(|a: &mut TensorView, b: &TensorView| {
41 let a_slice = a.as_slice_mut()?;
42 let b = b.as_slice()?[0];
43 (Self::ew()).run_with_params(a_slice, b)
44 })
45 }
46}
47
48macro_rules! by_scalar_impl_wrap {
49 ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $run: item) => {
50 paste! {
51 ew_impl_wrap!($ti, $func, $nr, $alignment_items, $ti, $run);
52
53 impl crate::frame::by_scalar::ByScalarKer<$ti> for $func {}
54 }
55 };
56}
57
58
59#[cfg(test)]
60#[macro_use]
61pub mod test {
62 use crate::frame::element_wise::ElementWiseKer;
63 use crate::LADatum;
64 use num_traits::{AsPrimitive, Float};
65 use proptest::test_runner::TestCaseResult;
66
67 #[macro_export]
68 macro_rules! by_scalar_frame_tests {
69 ($cond:expr, $t: ty, $ker:ty, $func:expr) => {
70 paste::paste! {
71 proptest::proptest! {
72 #[test]
73 fn [<prop_ $ker:snake>](xs in proptest::collection::vec(-25f32..25.0, 0..100), scalar in -25f32..25f32) {
74 if $cond {
75 $crate::frame::by_scalar::test::test_by_scalar::<$ker, $t>(&*xs, scalar, $func).unwrap()
76 }
77 }
78 }
79 }
80 };
81 }
82
83 pub fn test_by_scalar<K: ElementWiseKer<T, T>, T: LADatum + Float>(
84 values: &[f32],
85 scalar: f32,
86 func: impl Fn(T, T) -> T,
87 ) -> TestCaseResult
88 where
89 f32: AsPrimitive<T>,
90 {
91 crate::setup_test_logger();
92 let values: Vec<T> = values.iter().copied().map(|x| x.as_()).collect();
93 crate::frame::element_wise::test::test_element_wise_params::<K, T, _, T>(
94 &values,
95 |a| (func)(a, scalar.as_()),
96 scalar.as_(),
97 )
98 }
99}