tract_core/ops/array/
one_hot.rs1use tract_data::itertools::Itertools;
2
3use crate::internal::*;
4
5#[derive(Debug, PartialEq, Eq, Clone, Hash)]
6pub struct OneHot {
7 pub axis: usize,
8 pub dim: usize,
9 pub off: Arc<Tensor>,
10 pub on: Arc<Tensor>,
11}
12
13impl Op for OneHot {
14 fn name(&self) -> Cow<str> {
15 "Onehot".into()
16 }
17
18 op_as_typed_op!();
19}
20
21impl TypedOp for OneHot {
22 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
23 let mut shape = inputs[0].shape.to_tvec();
24 shape.insert(self.axis, self.dim.to_dim());
25 Ok(tvec!(self.off.datum_type().fact(&*shape)))
26 }
27
28 fn axes_mapping(
29 &self,
30 inputs: &[&TypedFact],
31 outputs: &[&TypedFact],
32 ) -> TractResult<AxesMapping> {
33 let axes = (0..inputs[0].rank())
34 .zip('a'..)
35 .map(|(i, repr)| {
36 Axis::new(repr, inputs.len(), outputs.len())
37 .input(0, i)
38 .output(0, i + (i >= self.axis) as usize)
39 })
40 .chain(std::iter::once(
41 Axis::new('Z', inputs.len(), outputs.len()).output(0, self.axis),
42 ))
43 .collect_vec();
44 AxesMapping::new(inputs.len(), outputs.len(), axes)
45 }
46
47 as_op!();
48}
49
50impl EvalOp for OneHot {
51 fn is_stateless(&self) -> bool {
52 true
53 }
54
55 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
56 let input = args_1!(inputs);
57 let mut shape: TVec<usize> = input.shape().into();
58 shape.insert(self.axis, self.dim);
59 unsafe {
60 let mut output = self.off.broadcast_scalar_to_shape(&shape)?;
61 dispatch_datum_by_size!(Self::eval_t(self.off.datum_type())(
62 self,
63 &input,
64 &mut output
65 ))?;
66 Ok(tvec!(output.into_tvalue()))
67 }
68 }
69}
70
71impl OneHot {
72 unsafe fn eval_t<T: Datum + Clone>(
73 &self,
74 input: &Tensor,
75 output: &mut Tensor,
76 ) -> TractResult<()> {
77 let on = self.on.to_scalar_unchecked::<T>();
78 let mut shape: TVec<usize> = input.shape().into();
79 shape.insert(self.axis, self.dim);
80 let mut array = output.to_array_view_mut_unchecked::<T>();
81 let input = input.cast_to::<i32>()?;
82 let input = input.to_array_view::<i32>()?;
83 for icoord in tract_ndarray::indices_of(&input) {
84 use tract_ndarray::Dimension;
85 let mut ocoord: Vec<usize> = icoord.slice().into();
86 let coord = input[&icoord];
87 let coord = if coord < 0 { coord + self.dim as i32 } else { coord } as usize;
88 ocoord.insert(self.axis, coord);
89 array[&*ocoord] = on.clone();
90 }
91 Ok(())
92 }
93}