tract_core/ops/array/
trilu.rs1use crate::internal::*;
2
3#[derive(Debug, Clone)]
4pub struct Trilu {
5 pub upper: bool,
6}
7
8impl Op for Trilu {
9 fn name(&self) -> Cow<str> {
10 "Trilu".into()
11 }
12
13 op_as_typed_op!();
14}
15
16impl EvalOp for Trilu {
17 fn is_stateless(&self) -> bool {
18 true
19 }
20
21 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
22 let (input, k) = args_2!(inputs);
23 let mut input = input.into_tensor();
24 let k = *k.to_scalar::<i64>()?;
25 fn eval_t<T: Datum>(tensor: &mut Tensor, upper: bool, k: i64) -> TractResult<()> {
26 let mut view = tensor.to_array_view_mut::<T>()?;
27 for coords in tract_ndarray::indices(view.shape()) {
28 let row = coords[view.ndim() - 2] as i64;
29 let col = coords[view.ndim() - 1] as i64;
30 if upper {
31 if col < row + k {
32 view[coords] = T::default();
33 }
34 } else if col > row + k {
35 view[coords] = T::default();
36 }
37 }
38 Ok(())
39 }
40 dispatch_datum!(eval_t(input.datum_type())(&mut input, self.upper, k))?;
41 Ok(tvec!(input.into_tvalue()))
42 }
43}
44
45impl TypedOp for Trilu {
46 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
47 Ok(tvec!(inputs[0].without_value()))
48 }
49
50 as_op!();
51}