use tract_nnef::internal::*;
use tract_nnef::ser::tdim;
use tract_nnef::tract_core::trivial_op_state_freeeze;
pub fn register(registry: &mut Registry) {
    registry.register_primitive(
        "tract_pulse_mask",
        &[
            TypeName::Scalar.tensor().named("input"),
            TypeName::Integer.named("axis"),
            TypeName::Integer.named("begin"),
            TypeName::Integer.named("end"),
            TypeName::Scalar.named("value"),
        ],
        &[("output", TypeName::Scalar.tensor())],
        deser,
    );
    registry.register_dumper(ser)
}
fn ser(ast: &mut IntoAst, node: &TypedNode, op: &PulseMask) -> TractResult<Option<Arc<RValue>>> {
    let wire = ast.mapping[&node.inputs[0]].clone();
    let params = vec![
        ("axis", numeric(op.axis)),
        ("begin", numeric(op.begin)),
        ("end", tdim(&op.end)),
        ("value", numeric(op.value.cast_to_scalar::<f32>())),
    ];
    Ok(Some(invocation("tract_pulse_mask", &[wire], ¶ms)))
}
fn deser(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
    let wire = invocation.named_arg_as(builder, "input")?;
    let axis = invocation.named_arg_as(builder, "axis")?;
    let begin = invocation.named_arg_as(builder, "begin")?;
    let value: Tensor = tensor0(invocation.named_arg_as::<f32>(builder, "value")?);
    let end = builder.allowing_new_symbols(|builder| invocation.named_arg_as(builder, "end"))?;
    let op = PulseMask { axis, begin, end, value };
    builder.wire(op, &[wire])
}
#[derive(Debug, Clone, Default, Hash)]
struct PulseMaskOpState {
    current_pos: usize,
}
impl OpState for PulseMaskOpState {
    fn eval(
        &mut self,
        session: &mut SessionState,
        op: &dyn Op,
        inputs: TVec<TValue>,
    ) -> TractResult<TVec<TValue>> {
        let input = args_1!(inputs).into_tensor();
        let op = op.downcast_ref::<PulseMask>().ok_or_else(|| format_err!("Wrong Op type"))?;
        let tensor = self.pad(session, op, input)?;
        Ok(tvec!(tensor.into_tvalue()))
    }
}
impl PulseMaskOpState {
    fn pad(
        &mut self,
        session: &SessionState,
        op: &PulseMask,
        mut input: Tensor,
    ) -> TractResult<Tensor> {
        let pulse = input.shape()[op.axis];
        let pulse_begin = self.current_pos;
        let pulse_end = self.current_pos + pulse;
        self.current_pos += pulse;
        let end = op.end.eval(&session.resolved_symbols).to_usize().unwrap_or(std::usize::MAX);
        if pulse_begin >= op.begin && pulse_end <= end {
            return Ok(input);
        }
        if pulse_begin < op.begin {
            let fill_up_to = (op.begin - pulse_begin).min(pulse);
            unsafe {
                dispatch_copy_by_size!(crate::pad::fill_slice_constant(input.datum_type())(
                    &mut input,
                    &op.value,
                    op.axis,
                    0..fill_up_to
                ))
            };
        }
        if pulse_end > end {
            let fill_from = pulse - (pulse_end - end).min(pulse);
            unsafe {
                dispatch_copy_by_size!(crate::pad::fill_slice_constant(input.datum_type())(
                    &mut input,
                    &op.value,
                    op.axis,
                    fill_from..pulse
                ))
            }
        }
        Ok(input)
    }
}
#[derive(Debug, Clone, Default, Hash)]
pub struct PulseMask {
    pub axis: usize,
    pub begin: usize,
    pub end: TDim,
    pub value: Tensor,
}
impl Op for PulseMask {
    fn name(&self) -> Cow<str> {
        "PulseMask".into()
    }
    fn info(&self) -> TractResult<Vec<String>> {
        Ok(vec![format!("axis: {} begin: {} end: {}", self.axis, self.begin, self.end,)])
    }
    op_as_typed_op!();
}
impl EvalOp for PulseMask {
    fn is_stateless(&self) -> bool {
        false
    }
    fn state(
        &self,
        _session: &mut SessionState,
        _node_id: usize,
    ) -> TractResult<Option<Box<dyn OpState>>> {
        Ok(Some(Box::<PulseMaskOpState>::default()))
    }
}
impl TypedOp for PulseMask {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        Ok(tvec!(inputs[0].clone()))
    }
    as_op!();
}
trivial_op_state_freeeze!(PulseMaskOpState);