tract_core/ops/array/
trilu.rs1use crate::internal::*;
2
3#[derive(Debug, Clone)]
4pub struct Trilu {
5 pub upper: bool,
6}
7
8impl Op for Trilu {
9 fn name(&self) -> StaticName {
10 "Trilu".into()
11 }
12
13 op_as_typed_op!();
14}
15
16impl EvalOp for Trilu {
17 fn is_stateless(&self) -> bool {
18 true
19 }
20
21 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
22 let (input, k) = args_2!(inputs);
23 let mut input = input.into_tensor();
24 let k = *k.try_as_dense()?.to_scalar::<i64>()?;
25 fn eval_t<T: Datum>(tensor: &mut Tensor, upper: bool, k: i64) -> TractResult<()> {
26 let mut tensor_dense = tensor.try_as_dense_mut()?;
27 let mut view = tensor_dense.to_array_view_mut::<T>()?;
28 for coords in tract_ndarray::indices(view.shape()) {
29 let row = coords[view.ndim() - 2] as i64;
30 let col = coords[view.ndim() - 1] as i64;
31 if upper {
32 if col < row + k {
33 view[coords] = T::default();
34 }
35 } else if col > row + k {
36 view[coords] = T::default();
37 }
38 }
39 Ok(())
40 }
41 dispatch_datum!(eval_t(input.datum_type())(&mut input, self.upper, k))?;
42 Ok(tvec!(input.into_tvalue()))
43 }
44}
45
46impl TypedOp for Trilu {
47 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
48 Ok(tvec!(inputs[0].without_value()))
49 }
50
51 as_op!();
52}