tract_core/ops/array/
trilu.rs

1use crate::internal::*;
2
3#[derive(Debug, Clone)]
4pub struct Trilu {
5    pub upper: bool,
6}
7
8impl Op for Trilu {
9    fn name(&self) -> Cow<str> {
10        "Trilu".into()
11    }
12
13    op_as_typed_op!();
14}
15
16impl EvalOp for Trilu {
17    fn is_stateless(&self) -> bool {
18        true
19    }
20
21    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
22        let (input, k) = args_2!(inputs);
23        let mut input = input.into_tensor();
24        let k = *k.to_scalar::<i64>()?;
25        fn eval_t<T: Datum>(tensor: &mut Tensor, upper: bool, k: i64) -> TractResult<()> {
26            let mut view = tensor.to_array_view_mut::<T>()?;
27            for coords in tract_ndarray::indices(view.shape()) {
28                let row = coords[view.ndim() - 2] as i64;
29                let col = coords[view.ndim() - 1] as i64;
30                if upper {
31                    if col < row + k {
32                        view[coords] = T::default();
33                    }
34                } else if col > row + k {
35                    view[coords] = T::default();
36                }
37            }
38            Ok(())
39        }
40        dispatch_datum!(eval_t(input.datum_type())(&mut input, self.upper, k))?;
41        Ok(tvec!(input.into_tvalue()))
42    }
43}
44
45impl TypedOp for Trilu {
46    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
47        Ok(tvec!(inputs[0].without_value()))
48    }
49
50    as_op!();
51}