tract_core/ops/array/
one_hot.rs

1use tract_data::itertools::Itertools;
2
3use crate::internal::*;
4
5#[derive(Debug, PartialEq, Eq, Clone, Hash)]
6pub struct OneHot {
7    pub axis: usize,
8    pub dim: usize,
9    pub off: Arc<Tensor>,
10    pub on: Arc<Tensor>,
11}
12
13impl Op for OneHot {
14    fn name(&self) -> Cow<str> {
15        "Onehot".into()
16    }
17
18    op_as_typed_op!();
19}
20
21impl TypedOp for OneHot {
22    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
23        let mut shape = inputs[0].shape.to_tvec();
24        shape.insert(self.axis, self.dim.to_dim());
25        Ok(tvec!(self.off.datum_type().fact(&*shape)))
26    }
27
28    fn axes_mapping(
29        &self,
30        inputs: &[&TypedFact],
31        outputs: &[&TypedFact],
32    ) -> TractResult<AxesMapping> {
33        let axes = (0..inputs[0].rank())
34            .zip('a'..)
35            .map(|(i, repr)| {
36                Axis::new(repr, inputs.len(), outputs.len())
37                    .input(0, i)
38                    .output(0, i + (i >= self.axis) as usize)
39            })
40            .chain(std::iter::once(
41                Axis::new('Z', inputs.len(), outputs.len()).output(0, self.axis),
42            ))
43            .collect_vec();
44        AxesMapping::new(inputs.len(), outputs.len(), axes)
45    }
46
47    as_op!();
48}
49
50impl EvalOp for OneHot {
51    fn is_stateless(&self) -> bool {
52        true
53    }
54
55    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
56        let input = args_1!(inputs);
57        let mut shape: TVec<usize> = input.shape().into();
58        shape.insert(self.axis, self.dim);
59        unsafe {
60            let mut output = self.off.broadcast_scalar_to_shape(&shape)?;
61            dispatch_datum_by_size!(Self::eval_t(self.off.datum_type())(
62                self,
63                &input,
64                &mut output
65            ))?;
66            Ok(tvec!(output.into_tvalue()))
67        }
68    }
69}
70
71impl OneHot {
72    unsafe fn eval_t<T: Datum + Clone>(
73        &self,
74        input: &Tensor,
75        output: &mut Tensor,
76    ) -> TractResult<()> {
77        let on = self.on.to_scalar_unchecked::<T>();
78        let mut shape: TVec<usize> = input.shape().into();
79        shape.insert(self.axis, self.dim);
80        let mut array = output.to_array_view_mut_unchecked::<T>();
81        let input = input.cast_to::<i32>()?;
82        let input = input.to_array_view::<i32>()?;
83        for icoord in tract_ndarray::indices_of(&input) {
84            use tract_ndarray::Dimension;
85            let mut ocoord: Vec<usize> = icoord.slice().into();
86            let coord = input[&icoord];
87            let coord = if coord < 0 { coord + self.dim as i32 } else { coord } as usize;
88            ocoord.insert(self.axis, coord);
89            array[&*ocoord] = on.clone();
90        }
91        Ok(())
92    }
93}