use tract_data::itertools::Itertools;
use crate::internal::*;
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub struct OneHot {
    pub axis: usize,
    pub dim: usize,
    pub off: Arc<Tensor>,
    pub on: Arc<Tensor>,
}
impl Op for OneHot {
    fn name(&self) -> Cow<str> {
        "Onehot".into()
    }
    op_as_typed_op!();
}
impl TypedOp for OneHot {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        let mut shape = inputs[0].shape.to_tvec();
        shape.insert(self.axis, self.dim.to_dim());
        Ok(tvec!(self.off.datum_type().fact(&*shape)))
    }
    fn axes_mapping(
        &self,
        inputs: &[&TypedFact],
        outputs: &[&TypedFact],
    ) -> TractResult<AxesMapping> {
        let axes = (0..inputs[0].rank())
            .zip('a'..)
            .map(|(i, repr)| {
                Axis::new(repr, inputs.len(), outputs.len())
                    .input(0, i)
                    .output(0, i + (i >= self.axis) as usize)
            })
            .chain(std::iter::once(
                Axis::new('Z', inputs.len(), outputs.len()).output(0, self.axis),
            ))
            .collect_vec();
        AxesMapping::new(inputs.len(), outputs.len(), axes)
    }
    as_op!();
}
impl EvalOp for OneHot {
    fn is_stateless(&self) -> bool {
        true
    }
    fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let input = args_1!(inputs);
        let mut shape: TVec<usize> = input.shape().into();
        shape.insert(self.axis, self.dim);
        unsafe {
            let mut output = self.off.broadcast_scalar_to_shape(&shape)?;
            dispatch_datum_by_size!(Self::eval_t(self.off.datum_type())(
                self,
                &input,
                &mut output
            ))?;
            Ok(tvec!(output.into_tvalue()))
        }
    }
    fn state(
        &self,
        _session: &mut SessionState,
        _node_id: usize,
    ) -> TractResult<Option<Box<dyn OpState>>> {
        Ok(None)
    }
}
impl OneHot {
    unsafe fn eval_t<T: Datum + Clone>(
        &self,
        input: &Tensor,
        output: &mut Tensor,
    ) -> TractResult<()> {
        let on = self.on.to_scalar_unchecked::<T>();
        let mut shape: TVec<usize> = input.shape().into();
        shape.insert(self.axis, self.dim);
        let mut array = output.to_array_view_mut_unchecked::<T>();
        let input = input.cast_to::<i32>()?;
        let input = input.to_array_view::<i32>()?;
        for icoord in tract_ndarray::indices_of(&input) {
            use tract_ndarray::Dimension;
            let mut ocoord: Vec<usize> = icoord.slice().into();
            let coord = input[&icoord];
            let coord = if coord < 0 { coord + self.dim as i32 } else { coord } as usize;
            ocoord.insert(self.axis, coord);
            array[&*ocoord] = on.clone();
        }
        Ok(())
    }
}