1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
use ndarray::*;

use crate::broadcast::multi_broadcast;
use crate::internal::*;

use super::binary::commute;

bin_to_super_type!(and, And, flip: commute,
                   [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = (a as i64 != 0 && b as i64 != 0) as _);
bin_to_super_type!(or, Or, flip: commute,
                   [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = (a as i64 != 0 || b as i64 != 0) as _);
bin_to_super_type!(xor, Xor, flip: commute, [bool] => |c, &a, &b| *c = a ^ b);
bin_to_bool!(equals, Equals, flip: commute,
 [bool, u8, u16, u32, u64, i8, i16, i32, i64, f32, f64, TDim] => |c, a, b | *c = a == b
);
bin_to_bool!(not_equals, NotEquals, flip: commute,
 [bool, u8, u16, u32, u64, i8, i16, i32, i64, f32, f64, TDim] => |c, a, b | *c = a != b
);

bin_to_bool!(lesser, Lesser, [bool, u8, u16, u32, u64, i8, i16, i32, i64, f32, f64] => |c, &a, &b | *c = a < b);
bin_to_bool!(lesser_equal, LesserEqual, [bool, u8, u16, u32, u64, i8, i16, i32, i64, f32, f64] => |c, &a, &b | *c = a <= b);
bin_to_bool!(greater, Greater, [bool, u8, u16, u32, u64, i8, i16, i32, i64, f32, f64] => |c, &a, &b | *c = a > b);
bin_to_bool!(greater_equal, GreaterEqual, [bool, u8, u16, u32, u64, i8, i16, i32, i64, f32, f64] => |c, &a, &b | *c = a >= b);

element_wise!(not, Not, [bool] => |_, vs| {
    vs.iter_mut().for_each(|a| *a = !*a);
    Ok(())
});

#[derive(Debug, Clone, new, Default, Hash)]
pub struct Iff;

impl_dyn_hash!(Iff);

impl Iff {
    pub unsafe fn eval_t<T: Datum>(
        cond: &ArrayViewD<bool>,
        out: &mut Tensor,
        t: &Tensor,
        f: &Tensor,
    ) {
        Zip::from(out.to_array_view_mut_unchecked::<T>())
            .and_broadcast(cond)
            .and_broadcast(t.to_array_view_unchecked::<T>())
            .and_broadcast(f.to_array_view_unchecked::<T>())
            .for_each(|r, c, t, f| *r = if *c { t.clone() } else { f.clone() })
    }
}

impl Op for Iff {
    fn name(&self) -> Cow<str> {
        "Iff".into()
    }
    op_core_mir!();
    op_as_typed_op!();
}

impl EvalOp for Iff {
    fn is_stateless(&self) -> bool {
        true
    }

    fn eval(&self, mut inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> {
        let (cond, t, f) = args_3!(inputs);
        let shape: TVec<usize> = multi_broadcast(&[cond.shape(), t.shape(), f.shape()])
            .ok_or_else(|| {
                format_err!(
                    "Incompatible shapes {:?}, {:?} and {:?}",
                    cond.shape(),
                    t.shape(),
                    f.shape()
                )
            })?;
        unsafe {
            let mut result = Tensor::uninitialized_dt(t.datum_type(), &*shape)?;
            let cond = cond.to_array_view::<bool>()?;
            dispatch_datum_by_size!(Self::eval_t(t.datum_type())(&cond, &mut result, &t, &f));
            Ok(tvec!(result.into_arc_tensor()))
        }
    }
}

impl TypedOp for Iff {
    as_op!();

    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        anyhow::ensure!(inputs.len() == 3, "Iff expects 3 intputs.");
        if inputs[1].datum_type != inputs[2].datum_type {
            bail!("Then and else tensors type mismatch ({:?} and {:?}).", inputs[1], inputs[2]);
        }
        if inputs[0].rank() != inputs[1].rank() || inputs[0].rank() != inputs[2].rank() {
            bail!("Inconsistent ranks, {:?}", inputs);
        }
        let shape = multi_broadcast(&[
            inputs[0].shape.to_tvec(),
            inputs[1].shape.to_tvec(),
            inputs[2].shape.to_tvec(),
        ])
        .unwrap();
        Ok(tvec!(TypedFact::dt_shape(inputs[1].datum_type, shape)))
    }

    fn invariants(
        &self,
        inputs: &[&TypedFact],
        _outputs: &[&TypedFact],
    ) -> TractResult<Invariants> {
        let a = &inputs[0];
        let b = &inputs[1];
        let c = &inputs[2];
        assert!(a.rank() == b.rank() && b.rank() == c.rank());
        let rank = a.rank();
        Ok((0..rank)
            .into_iter()
            .map(|axis| AxisInfo {
                inputs: tvec!(Some(axis), Some(axis), Some(axis)),
                outputs: tvec!(Some(axis)),
                period: 1,
                disposable: true,
            })
            .collect())
    }
}