1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
use tract_nnef::internal::*;
use tract_nnef::ser::tdim;
use tract_nnef::tract_core::trivial_op_state_freeeze;

pub fn register(registry: &mut Registry) {
    registry.register_primitive(
        "tract_pulse_mask",
        &[
            TypeName::Scalar.tensor().named("input"),
            TypeName::Integer.named("axis"),
            TypeName::Integer.named("begin"),
            TypeName::Integer.named("end"),
            TypeName::Scalar.named("value"),
        ],
        &[("output", TypeName::Scalar.tensor())],
        deser,
    );
    registry.register_dumper(ser)
}

fn ser(ast: &mut IntoAst, node: &TypedNode, op: &PulseMask) -> TractResult<Option<Arc<RValue>>> {
    let wire = ast.mapping[&node.inputs[0]].clone();
    let params = vec![
        ("axis", numeric(op.axis)),
        ("begin", numeric(op.begin)),
        ("end", tdim(&op.end)),
        ("value", numeric(op.value.cast_to_scalar::<f32>())),
    ];
    Ok(Some(invocation("tract_pulse_mask", &[wire], &params)))
}

fn deser(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
    let wire = invocation.named_arg_as(builder, "input")?;
    let axis = invocation.named_arg_as(builder, "axis")?;
    let begin = invocation.named_arg_as(builder, "begin")?;
    let value: Tensor = tensor0(invocation.named_arg_as::<f32>(builder, "value")?);
    let end = builder.allowing_new_symbols(|builder| invocation.named_arg_as(builder, "end"))?;
    let op = PulseMask { axis, begin, end, value };
    builder.wire(op, &[wire])
}

#[derive(Debug, Clone, Default, Hash)]
struct PulseMaskOpState {
    current_pos: usize,
}

impl OpState for PulseMaskOpState {
    fn eval(
        &mut self,
        session: &mut SessionState,
        op: &dyn Op,
        inputs: TVec<TValue>,
    ) -> TractResult<TVec<TValue>> {
        let input = args_1!(inputs).into_tensor();
        let op = op.downcast_ref::<PulseMask>().ok_or_else(|| format_err!("Wrong Op type"))?;
        let tensor = self.pad(session, op, input)?;
        Ok(tvec!(tensor.into_tvalue()))
    }
}

impl PulseMaskOpState {
    fn pad(
        &mut self,
        session: &SessionState,
        op: &PulseMask,
        mut input: Tensor,
    ) -> TractResult<Tensor> {
        let pulse = input.shape()[op.axis];
        let pulse_begin = self.current_pos;
        let pulse_end = self.current_pos + pulse;
        self.current_pos += pulse;
        let end = op.end.eval(&session.resolved_symbols).to_usize().unwrap_or(std::usize::MAX);

        // pulse is entirely in valid input, just forward
        if pulse_begin >= op.begin && pulse_end <= end {
            return Ok(input);
        }

        if pulse_begin < op.begin {
            let fill_up_to = (op.begin - pulse_begin).min(pulse);
            unsafe {
                dispatch_copy_by_size!(crate::pad::fill_slice_constant(input.datum_type())(
                    &mut input,
                    &op.value,
                    op.axis,
                    0..fill_up_to
                ))
            };
        }
        if pulse_end > end {
            let fill_from = pulse - (pulse_end - end).min(pulse);
            unsafe {
                dispatch_copy_by_size!(crate::pad::fill_slice_constant(input.datum_type())(
                    &mut input,
                    &op.value,
                    op.axis,
                    fill_from..pulse
                ))
            }
        }

        Ok(input)
    }
}

#[derive(Debug, Clone, Default, Hash)]
pub struct PulseMask {
    pub axis: usize,
    pub begin: usize,
    pub end: TDim,
    pub value: Tensor,
}

impl Op for PulseMask {
    fn name(&self) -> Cow<str> {
        "PulseMask".into()
    }

    fn info(&self) -> TractResult<Vec<String>> {
        Ok(vec![format!("axis: {} begin: {} end: {}", self.axis, self.begin, self.end,)])
    }

    op_as_typed_op!();
}

impl EvalOp for PulseMask {
    fn is_stateless(&self) -> bool {
        false
    }

    fn state(
        &self,
        _session: &mut SessionState,
        _node_id: usize,
    ) -> TractResult<Option<Box<dyn OpState>>> {
        Ok(Some(Box::<PulseMaskOpState>::default()))
    }
}

impl TypedOp for PulseMask {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        Ok(tvec!(inputs[0].clone()))
    }

    as_op!();
}

trivial_op_state_freeeze!(PulseMaskOpState);