Skip to main content

tract_core/ops/array/
gather_nd.rs

1use crate::internal::*;
2use tract_ndarray::prelude::*;
3
4#[derive(Debug, Clone, new, Hash)]
5pub struct GatherNd {
6    pub batch_dims: usize,
7}
8
9impl GatherNd {
10    fn compute_shape<D: DimLike>(
11        &self,
12        data_shape: &[D],
13        indices_shape: &[D],
14    ) -> TractResult<TVec<D>> {
15        let mut shape: TVec<D> = indices_shape.into();
16        let n = shape.pop().unwrap().to_usize()?;
17        shape.extend(data_shape[n + self.batch_dims..].iter().cloned());
18        Ok(shape)
19    }
20
21    unsafe fn eval_t<T: Datum>(
22        &self,
23        output: &mut Tensor,
24        data: &Tensor,
25        indices: &ArrayViewD<i32>,
26    ) {
27        let batch_dims = self.batch_dims;
28        assert_eq!(output.shape()[..batch_dims], data.shape()[..batch_dims]);
29        assert_eq!(output.shape()[..batch_dims], indices.shape()[..batch_dims]);
30        let batch_size = data.shape().iter().take(batch_dims).product();
31        let n = indices.shape()[indices.ndim() - 1];
32
33        let remaining = indices.shape().iter().skip(batch_dims).rev().skip(1).product();
34        let indices_shape_op = tvec!(batch_size, remaining, n);
35        let reshaped_indices: ArrayViewD<i32> =
36            indices.view().into_shape_with_order(&*indices_shape_op).unwrap();
37
38        let mut data_shape_op: TVec<usize> =
39            data.shape().iter().skip(batch_dims).copied().collect();
40        data_shape_op.insert(0, batch_size);
41        let reshaped_data = unsafe {
42            data.to_array_view_unchecked::<T>().into_shape_with_order(&*data_shape_op).unwrap()
43        };
44
45        let mut output_shape_op: TVec<usize> =
46            data.shape().iter().skip(n + batch_dims).copied().collect();
47        output_shape_op.insert(0, batch_size * remaining);
48        let mut output = unsafe {
49            output
50                .to_array_view_mut_unchecked::<T>()
51                .into_shape_with_order(&*output_shape_op)
52                .unwrap()
53        };
54
55        for b in 0..batch_size {
56            let mut i = reshaped_data.view();
57            i.index_axis_inplace(Axis(0), b);
58            let mut coords = reshaped_indices.view();
59            coords.index_axis_inplace(Axis(0), b);
60
61            for ix in 0..remaining {
62                let mut coords = coords.view();
63                coords.index_axis_inplace(Axis(0), ix);
64
65                let mut i = i.view();
66                for x in coords {
67                    i.index_axis_inplace(Axis(0), *x as usize);
68                }
69
70                let mut o = output.view_mut();
71                o.index_axis_inplace(Axis(0), b * remaining + ix);
72                o.assign(&i);
73            }
74        }
75    }
76}
77
78impl Op for GatherNd {
79    fn name(&self) -> StaticName {
80        "GatherNd".into()
81    }
82
83    op_as_typed_op!();
84}
85
86impl EvalOp for GatherNd {
87    fn is_stateless(&self) -> bool {
88        true
89    }
90
91    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
92        let (data, indices) = args_2!(inputs);
93        let shape = self.compute_shape(data.shape(), indices.shape())?;
94        let indices = indices.cast_to::<i32>()?;
95        let indices = indices.to_array_view::<i32>()?;
96        unsafe {
97            let mut output = Tensor::uninitialized_dt(data.datum_type(), &shape)?;
98            dispatch_datum_by_size!(Self::eval_t(data.datum_type())(
99                self,
100                &mut output,
101                &data,
102                &indices
103            ));
104            Ok(tvec!(output.into_tvalue()))
105        }
106    }
107}
108
109impl TypedOp for GatherNd {
110    as_op!();
111
112    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
113        let shape = self.compute_shape(&inputs[0].shape.to_tvec(), &inputs[1].shape.to_tvec())?;
114        Ok(tvec!(inputs[0].datum_type.fact(&shape)))
115    }
116
117    fn declutter(
118        &self,
119        model: &TypedModel,
120        node: &TypedNode,
121    ) -> TractResult<Option<TypedModelPatch>> {
122        if let Some(indices) = &model.outlet_fact(node.inputs[1])?.konst {
123            if indices.rank() == 2 && indices.shape()[0] == 1 {
124                let mut patch = TypedModelPatch::default();
125                let mut wire = patch.tap_model(model, node.inputs[0])?;
126                for (axis, &i) in indices.cast_to::<i32>()?.as_slice::<i32>()?.iter().enumerate() {
127                    wire = patch.wire_node(
128                        format!("{}-slice-axis-{}", node.name, axis),
129                        crate::ops::array::Slice::new(axis, i as usize, (i + 1) as usize),
130                        &[wire],
131                    )?[0];
132                }
133                for i in (0..indices.shape()[1]).rev() {
134                    wire = patch.wire_node(
135                        format!("{}-remove_axis_{}", node.name, i),
136                        crate::ops::change_axes::AxisOp::Rm(i),
137                        &[wire],
138                    )?[0];
139                }
140                wire = patch.wire_node(
141                    format!("{}-add_axis", node.name),
142                    crate::ops::change_axes::AxisOp::Add(0),
143                    &[wire],
144                )?[0];
145                patch.shunt_outside(model, node.id.into(), wire)?;
146                return Ok(Some(patch));
147            }
148        }
149        Ok(None)
150    }
151}