1#![allow(clippy::bool_comparison)]
2#![allow(clippy::unnecessary_cast)]
3
4mod comparison;
5mod ite;
6pub use ite::IfThenElse;
7pub use comparison::Comp;
8
9use ndarray::*;
10
11use crate::broadcast::multi_broadcast;
12use crate::internal::*;
13
14bin_to_super_type!(and, And,
15 [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = (a as i64 != 0 && b as i64 != 0) as _);
16bin_to_super_type!(or, Or,
17 [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = (a as i64 != 0 || b as i64 != 0) as _);
18bin_to_super_type!(xor, Xor, [bool] => |c, &a, &b| *c = a ^ b);
19
20element_wise!(not, Not, [bool] => |_, vs| {
21 vs.iter_mut().for_each(|a| *a = !*a);
22 Ok(())
23});
24
25#[derive(Debug, Clone, new, Default, Hash)]
26pub struct Iff;
27
28impl Iff {
29 pub unsafe fn eval_t<T: Datum>(
30 cond: &ArrayViewD<bool>,
31 out: &mut Tensor,
32 t: &Tensor,
33 f: &Tensor,
34 ) {
35 Zip::from(out.to_array_view_mut_unchecked::<T>())
36 .and_broadcast(cond)
37 .and_broadcast(t.to_array_view_unchecked::<T>())
38 .and_broadcast(f.to_array_view_unchecked::<T>())
39 .for_each(|r, c, t, f| *r = if *c { t.clone() } else { f.clone() })
40 }
41}
42
43impl Op for Iff {
44 fn name(&self) -> Cow<str> {
45 "Iff".into()
46 }
47 op_as_typed_op!();
48}
49
50impl EvalOp for Iff {
51 fn is_stateless(&self) -> bool {
52 true
53 }
54
55 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
56 let (cond, t, f) = args_3!(inputs);
57 anyhow::ensure!(t.datum_type() == f.datum_type());
58 let shape: TVec<usize> = multi_broadcast(&[cond.shape(), t.shape(), f.shape()])?;
59 unsafe {
60 let mut result = Tensor::uninitialized_dt(t.datum_type(), &shape)?;
61 let cond = cond.to_array_view::<bool>()?;
62 dispatch_datum_by_size!(Self::eval_t(t.datum_type())(&cond, &mut result, &t, &f));
63 Ok(tvec!(result.into_tvalue()))
64 }
65 }
66}
67
68impl TypedOp for Iff {
69 as_op!();
70
71 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
72 anyhow::ensure!(inputs.len() == 3, "Iff expects 3 intputs.");
73 if inputs[1].datum_type != inputs[2].datum_type {
74 bail!("Then and else tensors type mismatch ({:?} and {:?}).", inputs[1], inputs[2]);
75 }
76 if inputs[0].rank() != inputs[1].rank() || inputs[0].rank() != inputs[2].rank() {
77 bail!("Inconsistent ranks, {:?}", inputs);
78 }
79 let shape = multi_broadcast(&[
80 inputs[0].shape.to_tvec(),
81 inputs[1].shape.to_tvec(),
82 inputs[2].shape.to_tvec(),
83 ])
84 .unwrap();
85 Ok(tvec!(inputs[1].datum_type.fact(shape)))
86 }
87
88 fn axes_mapping(
89 &self,
90 inputs: &[&TypedFact],
91 outputs: &[&TypedFact],
92 ) -> TractResult<AxesMapping> {
93 AxesMapping::natural(inputs, outputs)
94 }
95}
96
97bin_to_super_type!(bitand, BitAnd,
98 [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = a & b);
99bin_to_super_type!(bitor, BitOr,
100 [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = a | b);
101bin_to_super_type!(bitxor, BitXor,
102 [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = a ^ b);
103
104element_wise!(bitnot, BitNot, [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |_, xs| {
105 xs.iter_mut().for_each(|x| *x = !*x);
106 Ok(())
107});