tract_linalg/frame/
by_scalar.rs1use std::fmt::Debug;
2use std::marker::PhantomData;
3
4use crate::element_wise::{ElementWise, ElementWiseKer};
5use crate::element_wise_helper::map_slice_with_alignment;
6use crate::{LADatum, LinalgFn};
7use tract_data::internal::*;
8
9#[derive(Debug, Clone, new)]
12pub struct ByScalarImpl<K, T>
13where
14 T: LADatum,
15 K: ByScalarKer<T> + Clone,
16{
17 phantom: PhantomData<(K, T)>,
18}
19
20impl<K, T> ElementWise<T, T> for ByScalarImpl<K, T>
21where
22 T: LADatum,
23 K: ByScalarKer<T> + Clone,
24{
25 fn name(&self) -> &'static str {
26 K::name()
27 }
28 fn run_with_params(&self, vec: &mut [T], params: T) -> TractResult<()> {
29 map_slice_with_alignment(vec, |data| K::run(data, params), K::nr(), K::alignment_bytes())
30 }
31}
32
33pub trait ByScalarKer<T>: ElementWiseKer<T, T>
34where
35 T: LADatum,
36{
37 fn bin() -> Box<LinalgFn> {
38 Box::new(|a: &mut TensorView, b: &TensorView| {
39 let a_slice = a.as_slice_mut()?;
40 let b = b.as_slice()?[0];
41 (Self::ew()).run_with_params(a_slice, b)
42 })
43 }
44}
45
46macro_rules! by_scalar_impl_wrap {
47 ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $run: item) => {
48 paste! {
49 ew_impl_wrap!($ti, $func, $nr, $alignment_items, $ti, $run);
50
51 impl crate::frame::by_scalar::ByScalarKer<$ti> for $func {}
52 }
53 };
54}
55
56#[cfg(test)]
57#[macro_use]
58pub mod test {
59 use crate::frame::element_wise::ElementWiseKer;
60 use crate::LADatum;
61 use num_traits::{AsPrimitive, Float};
62 use proptest::test_runner::TestCaseResult;
63
64 #[macro_export]
65 macro_rules! by_scalar_frame_tests {
66 ($cond:expr, $t: ty, $ker:ty, $func:expr) => {
67 paste::paste! {
68 proptest::proptest! {
69 #[test]
70 fn [<prop_ $ker:snake>](xs in proptest::collection::vec(-25f32..25.0, 0..100), scalar in -25f32..25f32) {
71 if $cond {
72 $crate::frame::by_scalar::test::test_by_scalar::<$ker, $t>(&*xs, scalar, $func).unwrap()
73 }
74 }
75 }
76 }
77 };
78 }
79
80 pub fn test_by_scalar<K: ElementWiseKer<T, T>, T: LADatum + Float>(
81 values: &[f32],
82 scalar: f32,
83 func: impl Fn(T, T) -> T,
84 ) -> TestCaseResult
85 where
86 f32: AsPrimitive<T>,
87 {
88 crate::setup_test_logger();
89 let values: Vec<T> = values.iter().copied().map(|x| x.as_()).collect();
90 crate::frame::element_wise::test::test_element_wise_params::<K, T, _, T>(
91 &values,
92 |a| (func)(a, scalar.as_()),
93 scalar.as_(),
94 )
95 }
96}