Skip to main content

tract_core/ops/array/
trilu.rs

1use crate::internal::*;
2
3#[derive(Debug, Clone)]
4pub struct Trilu {
5    pub upper: bool,
6}
7
8impl Op for Trilu {
9    fn name(&self) -> StaticName {
10        "Trilu".into()
11    }
12
13    op_as_typed_op!();
14}
15
16impl EvalOp for Trilu {
17    fn is_stateless(&self) -> bool {
18        true
19    }
20
21    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
22        let (input, k) = args_2!(inputs);
23        let mut input = input.into_tensor();
24        let k = *k.try_as_dense()?.to_scalar::<i64>()?;
25        fn eval_t<T: Datum>(tensor: &mut Tensor, upper: bool, k: i64) -> TractResult<()> {
26            let mut tensor_dense = tensor.try_as_dense_mut()?;
27            let mut view = tensor_dense.to_array_view_mut::<T>()?;
28            for coords in tract_ndarray::indices(view.shape()) {
29                let row = coords[view.ndim() - 2] as i64;
30                let col = coords[view.ndim() - 1] as i64;
31                if upper {
32                    if col < row + k {
33                        view[coords] = T::default();
34                    }
35                } else if col > row + k {
36                    view[coords] = T::default();
37                }
38            }
39            Ok(())
40        }
41        dispatch_datum!(eval_t(input.datum_type())(&mut input, self.upper, k))?;
42        Ok(tvec!(input.into_tvalue()))
43    }
44}
45
46impl TypedOp for Trilu {
47    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
48        Ok(tvec!(inputs[0].without_value()))
49    }
50
51    as_op!();
52}