1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
use crate::internal::*;

#[derive(Debug, Clone)]
pub struct Trilu {
    pub upper: bool,
}

impl Op for Trilu {
    fn name(&self) -> Cow<str> {
        "Trilu".into()
    }

    op_as_typed_op!();
}

impl EvalOp for Trilu {
    fn is_stateless(&self) -> bool {
        true
    }

    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let (input, k) = args_2!(inputs);
        let mut input = input.into_tensor();
        let k = *k.to_scalar::<i64>()?;
        fn eval_t<T: Datum>(tensor: &mut Tensor, upper: bool, k: i64) -> TractResult<()> {
            let mut view = tensor.to_array_view_mut::<T>()?;
            for coords in tract_ndarray::indices(view.shape()) {
                let row = coords[view.ndim() - 2] as i64;
                let col = coords[view.ndim() - 1] as i64;
                if upper {
                    if col < row + k {
                        view[coords] = T::default();
                    }
                } else if col > row + k {
                    view[coords] = T::default();
                }
            }
            Ok(())
        }
        dispatch_datum!(eval_t(input.datum_type())(&mut input, self.upper, k))?;
        Ok(tvec!(input.into_tvalue()))
    }
}

impl TypedOp for Trilu {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        Ok(tvec!(inputs[0].without_value()))
    }

    as_op!();
}