1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
use ndarray::*; use crate::broadcast::multi_broadcast; use crate::internal::*; use super::binary::commute; bin_to_super_type!(and, And, flip: commute, [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = (a as i64 != 0 && b as i64 != 0) as _); bin_to_super_type!(or, Or, flip: commute, [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = (a as i64 != 0 || b as i64 != 0) as _); bin_to_super_type!(xor, Xor, flip: commute, [bool] => |c, &a, &b| *c = a ^ b); bin_to_bool!(equals, Equals, flip: commute, [bool, u8, u16, u32, u64, i8, i16, i32, i64, f32, f64, TDim] => |c, a, b | *c = a == b ); bin_to_bool!(not_equals, NotEquals, flip: commute, [bool, u8, u16, u32, u64, i8, i16, i32, i64, f32, f64, TDim] => |c, a, b | *c = a != b ); bin_to_bool!(lesser, Lesser, [bool, u8, u16, u32, u64, i8, i16, i32, i64, f32, f64] => |c, &a, &b | *c = a < b); bin_to_bool!(lesser_equal, LesserEqual, [bool, u8, u16, u32, u64, i8, i16, i32, i64, f32, f64] => |c, &a, &b | *c = a <= b); bin_to_bool!(greater, Greater, [bool, u8, u16, u32, u64, i8, i16, i32, i64, f32, f64] => |c, &a, &b | *c = a > b); bin_to_bool!(greater_equal, GreaterEqual, [bool, u8, u16, u32, u64, i8, i16, i32, i64, f32, f64] => |c, &a, &b | *c = a >= b); element_wise!(not, Not, [bool] => |_, vs| { vs.iter_mut().for_each(|a| *a = !*a); Ok(()) }); #[derive(Debug, Clone, new, Default, Hash)] pub struct Iff; impl_dyn_hash!(Iff); impl Iff { pub unsafe fn eval_t<T: Datum>( cond: &ArrayViewD<bool>, out: &mut Tensor, t: &Tensor, f: &Tensor, ) { Zip::from(out.to_array_view_mut_unchecked::<T>()) .and_broadcast(cond) .and_broadcast(t.to_array_view_unchecked::<T>()) .and_broadcast(f.to_array_view_unchecked::<T>()) .for_each(|r, c, t, f| *r = if *c { t.clone() } else { f.clone() }) } } impl Op for Iff { fn name(&self) -> Cow<str> { "Iff".into() } op_core_mir!(); op_as_typed_op!(); } impl EvalOp for Iff { fn is_stateless(&self) -> bool { true } fn eval(&self, mut inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> { let (cond, t, f) = args_3!(inputs); let shape: TVec<usize> = multi_broadcast(&[cond.shape(), t.shape(), f.shape()]) .ok_or_else(|| { format_err!( "Incompatible shapes {:?}, {:?} and {:?}", cond.shape(), t.shape(), f.shape() ) })?; unsafe { let mut result = Tensor::uninitialized_dt(t.datum_type(), &*shape)?; let cond = cond.to_array_view::<bool>()?; dispatch_datum_by_size!(Self::eval_t(t.datum_type())(&cond, &mut result, &t, &f)); Ok(tvec!(result.into_arc_tensor())) } } } impl TypedOp for Iff { as_op!(); fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> { if inputs[1].datum_type != inputs[2].datum_type { bail!("Then and else tensors type mismatch ({:?} and {:?}).", inputs[1], inputs[2]); } let shape = multi_broadcast(&[ inputs[0].shape.to_tvec(), inputs[1].shape.to_tvec(), inputs[2].shape.to_tvec(), ]) .unwrap(); Ok(tvec!(TypedFact::dt_shape(inputs[1].datum_type, shape))) } }