Skip to main content

tract_core/ops/
logic.rs

1#![allow(clippy::bool_comparison)]
2#![allow(clippy::unnecessary_cast)]
3
4mod comparison;
5mod ite;
6pub use comparison::Comp;
7pub use ite::IfThenElse;
8
9use ndarray::*;
10
11use crate::broadcast::multi_broadcast;
12use crate::internal::*;
13
14bin_to_super_type!(and, And,
15                   [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = (a as i64 != 0 && b as i64 != 0) as _);
16bin_to_super_type!(or, Or,
17                   [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = (a as i64 != 0 || b as i64 != 0) as _);
18bin_to_super_type!(xor, Xor, /*flip: commute, */ [bool] => |c, &a, &b| *c = a ^ b);
19
20element_wise!(not, Not, [bool] => |_, vs| {
21    vs.iter_mut().for_each(|a| *a = !*a);
22    Ok(())
23});
24
25#[derive(Debug, Clone, new, Default, Hash, PartialEq, Eq)]
26pub struct Iff;
27
28impl Iff {
29    pub unsafe fn eval_t<T: Datum>(
30        cond: &ArrayViewD<bool>,
31        out: &mut Tensor,
32        t: &Tensor,
33        f: &Tensor,
34    ) {
35        unsafe {
36            Zip::from(out.to_array_view_mut_unchecked::<T>())
37                .and_broadcast(cond)
38                .and_broadcast(t.to_array_view_unchecked::<T>())
39                .and_broadcast(f.to_array_view_unchecked::<T>())
40                .for_each(|r, c, t, f| *r = if *c { t.clone() } else { f.clone() })
41        }
42    }
43}
44
45impl Op for Iff {
46    fn name(&self) -> StaticName {
47        "Iff".into()
48    }
49    op_as_typed_op!();
50    impl_op_same_as!();
51}
52
53impl EvalOp for Iff {
54    fn is_stateless(&self) -> bool {
55        true
56    }
57
58    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
59        let (cond, t, f) = args_3!(inputs);
60        anyhow::ensure!(t.datum_type() == f.datum_type());
61        let shape: TVec<usize> = multi_broadcast(&[cond.shape(), t.shape(), f.shape()])?;
62        unsafe {
63            let mut result = Tensor::uninitialized_dt(t.datum_type(), &shape)?;
64            let cond = cond.to_dense_array_view::<bool>()?;
65            dispatch_datum_by_size!(Self::eval_t(t.datum_type())(&cond, &mut result, &t, &f));
66            Ok(tvec!(result.into_tvalue()))
67        }
68    }
69}
70
71impl TypedOp for Iff {
72    as_op!();
73
74    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
75        ensure!(inputs.len() == 3, "Iff expects 3 intputs.");
76        ensure!(inputs[1].datum_type == inputs[2].datum_type);
77        ensure!(inputs[0].datum_type.is::<bool>());
78        ensure!(inputs[0].rank() == inputs[1].rank());
79        ensure!(inputs[0].rank() == inputs[2].rank());
80        let shape = multi_broadcast(&[
81            inputs[0].shape.to_tvec(),
82            inputs[1].shape.to_tvec(),
83            inputs[2].shape.to_tvec(),
84        ])
85        .unwrap();
86        Ok(tvec!(inputs[1].datum_type.fact(shape)))
87    }
88
89    fn axes_mapping(
90        &self,
91        inputs: &[&TypedFact],
92        outputs: &[&TypedFact],
93    ) -> TractResult<AxesMapping> {
94        AxesMapping::natural(inputs, outputs)
95    }
96}
97
98bin_to_super_type!(bitand, BitAnd,
99                   [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = a & b);
100bin_to_super_type!(bitor, BitOr,
101                   [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = a | b);
102bin_to_super_type!(bitxor, BitXor,
103                   [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = a ^ b);
104
105element_wise!(bitnot, BitNot, [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |_, xs| {
106    xs.iter_mut().for_each(|x| *x = !*x);
107    Ok(())
108});