Skip to main content

tract_core/ops/array/
gather_nd.rs

1use crate::internal::*;
2use tract_ndarray::prelude::*;
3
4#[derive(Debug, Clone, new, Hash)]
5pub struct GatherNd {
6    pub batch_dims: usize,
7}
8
9impl GatherNd {
10    fn compute_shape<D: DimLike>(
11        &self,
12        data_shape: &[D],
13        indices_shape: &[D],
14    ) -> TractResult<TVec<D>> {
15        let mut shape: TVec<D> = indices_shape.into();
16        let n = shape.pop().unwrap().to_usize()?;
17        shape.extend(data_shape[n + self.batch_dims..].iter().cloned());
18        Ok(shape)
19    }
20
21    unsafe fn eval_t<T: Datum>(
22        &self,
23        output: &mut Tensor,
24        data: &Tensor,
25        indices: &ArrayViewD<i32>,
26    ) -> TractResult<()> {
27        let batch_dims = self.batch_dims;
28        assert_eq!(output.shape()[..batch_dims], data.shape()[..batch_dims]);
29        assert_eq!(output.shape()[..batch_dims], indices.shape()[..batch_dims]);
30        let batch_size = data.shape().iter().take(batch_dims).product();
31        let n = indices.shape()[indices.ndim() - 1];
32
33        let remaining = indices.shape().iter().skip(batch_dims).rev().skip(1).product();
34        let indices_shape_op = tvec!(batch_size, remaining, n);
35        let reshaped_indices: ArrayViewD<i32> =
36            indices.view().into_shape_with_order(&*indices_shape_op).unwrap();
37
38        let mut data_shape_op: TVec<usize> =
39            data.shape().iter().skip(batch_dims).copied().collect();
40        data_shape_op.insert(0, batch_size);
41        let data_dense = data.try_as_dense()?;
42        let reshaped_data = unsafe {
43            data_dense
44                .to_array_view_unchecked::<T>()
45                .into_shape_with_order(&*data_shape_op)
46                .unwrap()
47        };
48
49        let mut output_shape_op: TVec<usize> =
50            data.shape().iter().skip(n + batch_dims).copied().collect();
51        output_shape_op.insert(0, batch_size * remaining);
52        let mut output_dense = output.try_as_dense_mut()?;
53        let mut output = unsafe {
54            output_dense
55                .to_array_view_mut_unchecked::<T>()
56                .into_shape_with_order(&*output_shape_op)
57                .unwrap()
58        };
59
60        for b in 0..batch_size {
61            let mut i = reshaped_data.view();
62            i.index_axis_inplace(Axis(0), b);
63            let mut coords = reshaped_indices.view();
64            coords.index_axis_inplace(Axis(0), b);
65
66            for ix in 0..remaining {
67                let mut coords = coords.view();
68                coords.index_axis_inplace(Axis(0), ix);
69
70                let mut i = i.view();
71                for x in coords {
72                    i.index_axis_inplace(Axis(0), *x as usize);
73                }
74
75                let mut o = output.view_mut();
76                o.index_axis_inplace(Axis(0), b * remaining + ix);
77                o.assign(&i);
78            }
79        }
80        Ok(())
81    }
82}
83
84impl Op for GatherNd {
85    fn name(&self) -> StaticName {
86        "GatherNd".into()
87    }
88
89    op_as_typed_op!();
90}
91
92impl EvalOp for GatherNd {
93    fn is_stateless(&self) -> bool {
94        true
95    }
96
97    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
98        let (data, indices) = args_2!(inputs);
99        let shape = self.compute_shape(data.shape(), indices.shape())?;
100        let indices = indices.cast_to::<i32>()?;
101        let indices = indices.to_dense_array_view::<i32>()?;
102        unsafe {
103            let mut output = Tensor::uninitialized_dt(data.datum_type(), &shape)?;
104            dispatch_datum_by_size!(Self::eval_t(data.datum_type())(
105                self,
106                &mut output,
107                &data,
108                &indices
109            ))?;
110            Ok(tvec!(output.into_tvalue()))
111        }
112    }
113}
114
115impl TypedOp for GatherNd {
116    as_op!();
117
118    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
119        let shape = self.compute_shape(&inputs[0].shape.to_tvec(), &inputs[1].shape.to_tvec())?;
120        Ok(tvec!(inputs[0].datum_type.fact(&shape)))
121    }
122
123    fn declutter(
124        &self,
125        model: &TypedModel,
126        node: &TypedNode,
127    ) -> TractResult<Option<TypedModelPatch>> {
128        if let Some(indices) = &model.outlet_fact(node.inputs[1])?.konst {
129            if indices.rank() == 2 && indices.shape()[0] == 1 {
130                let mut patch = TypedModelPatch::default();
131                let mut wire = patch.tap_model(model, node.inputs[0])?;
132                for (axis, &i) in
133                    indices.cast_to::<i32>()?.try_as_dense()?.as_slice::<i32>()?.iter().enumerate()
134                {
135                    wire = patch.wire_node(
136                        format!("{}-slice-axis-{}", node.name, axis),
137                        crate::ops::array::Slice::new(axis, i as usize, (i + 1) as usize),
138                        &[wire],
139                    )?[0];
140                }
141                for i in (0..indices.shape()[1]).rev() {
142                    wire = patch.wire_node(
143                        format!("{}-remove_axis_{}", node.name, i),
144                        crate::ops::change_axes::AxisOp::Rm(i),
145                        &[wire],
146                    )?[0];
147                }
148                wire = patch.wire_node(
149                    format!("{}-add_axis", node.name),
150                    crate::ops::change_axes::AxisOp::Add(0),
151                    &[wire],
152                )?[0];
153                patch.shunt_outside(model, node.id.into(), wire)?;
154                return Ok(Some(patch));
155            }
156        }
157        Ok(None)
158    }
159}