tract_core/ops/array/
gather_nd.rs1use crate::internal::*;
2use tract_ndarray::prelude::*;
3
4#[derive(Debug, Clone, new, Hash)]
5pub struct GatherNd {
6 pub batch_dims: usize,
7}
8
9
10
11impl GatherNd {
12 fn compute_shape<D: DimLike>(
13 &self,
14 data_shape: &[D],
15 indices_shape: &[D],
16 ) -> TractResult<TVec<D>> {
17 let mut shape: TVec<D> = indices_shape.into();
18 let n = shape.pop().unwrap().to_usize()?;
19 shape.extend(data_shape[n + self.batch_dims..].iter().cloned());
20 Ok(shape)
21 }
22
23 unsafe fn eval_t<T: Datum>(
24 &self,
25 output: &mut Tensor,
26 data: &Tensor,
27 indices: &ArrayViewD<i32>,
28 ) {
29 let batch_dims = self.batch_dims;
30 assert_eq!(output.shape()[..batch_dims], data.shape()[..batch_dims]);
31 assert_eq!(output.shape()[..batch_dims], indices.shape()[..batch_dims]);
32 let batch_size = data.shape().iter().take(batch_dims).product();
33 let n = indices.shape()[indices.ndim() - 1];
34
35 let remaining = indices.shape().iter().skip(batch_dims).rev().skip(1).product();
36 let indices_shape_op = tvec!(batch_size, remaining, n);
37 let reshaped_indices: ArrayViewD<i32> =
38 indices.view().into_shape_with_order(&*indices_shape_op).unwrap();
39
40 let mut data_shape_op: TVec<usize> =
41 data.shape().iter().skip(batch_dims).copied().collect();
42 data_shape_op.insert(0, batch_size);
43 let reshaped_data =
44 data.to_array_view_unchecked::<T>().into_shape_with_order(&*data_shape_op).unwrap();
45
46 let mut output_shape_op: TVec<usize> =
47 data.shape().iter().skip(n + batch_dims).copied().collect();
48 output_shape_op.insert(0, batch_size * remaining);
49 let mut output =
50 output.to_array_view_mut_unchecked::<T>().into_shape_with_order(&*output_shape_op).unwrap();
51
52 for b in 0..batch_size {
53 let mut i = reshaped_data.view();
54 i.index_axis_inplace(Axis(0), b);
55 let mut coords = reshaped_indices.view();
56 coords.index_axis_inplace(Axis(0), b);
57
58 for ix in 0..remaining {
59 let mut coords = coords.view();
60 coords.index_axis_inplace(Axis(0), ix);
61
62 let mut i = i.view();
63 for x in coords {
64 i.index_axis_inplace(Axis(0), *x as usize);
65 }
66
67 let mut o = output.view_mut();
68 o.index_axis_inplace(Axis(0), b * remaining + ix);
69 o.assign(&i);
70 }
71 }
72 }
73}
74
75impl Op for GatherNd {
76 fn name(&self) -> Cow<str> {
77 "GatherNd".into()
78 }
79
80 op_as_typed_op!();
81}
82
83impl EvalOp for GatherNd {
84 fn is_stateless(&self) -> bool {
85 true
86 }
87
88 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
89 let (data, indices) = args_2!(inputs);
90 let shape = self.compute_shape(data.shape(), indices.shape())?;
91 let indices = indices.cast_to::<i32>()?;
92 let indices = indices.to_array_view::<i32>()?;
93 unsafe {
94 let mut output = Tensor::uninitialized_dt(data.datum_type(), &shape)?;
95 dispatch_datum_by_size!(Self::eval_t(data.datum_type())(
96 self,
97 &mut output,
98 &data,
99 &indices
100 ));
101 Ok(tvec!(output.into_tvalue()))
102 }
103 }
104}
105
106impl TypedOp for GatherNd {
107 as_op!();
108
109 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
110 let shape = self.compute_shape(&inputs[0].shape.to_tvec(), &inputs[1].shape.to_tvec())?;
111 Ok(tvec!(inputs[0].datum_type.fact(&shape)))
112 }
113
114 fn declutter(
115 &self,
116 model: &TypedModel,
117 node: &TypedNode,
118 ) -> TractResult<Option<TypedModelPatch>> {
119 if let Some(indices) = &model.outlet_fact(node.inputs[1])?.konst {
120 if indices.rank() == 2 && indices.shape()[0] == 1 {
121 let mut patch = TypedModelPatch::default();
122 let mut wire = patch.tap_model(model, node.inputs[0])?;
123 for (axis, &i) in indices.cast_to::<i32>()?.as_slice::<i32>()?.iter().enumerate() {
124 wire = patch.wire_node(
125 format!("{}-slice-axis-{}", node.name, axis),
126 crate::ops::array::Slice::new(axis, i as usize, (i + 1) as usize),
127 &[wire],
128 )?[0];
129 }
130 for i in (0..indices.shape()[1]).rev() {
131 wire = patch.wire_node(
132 format!("{}-remove_axis_{}", node.name, i),
133 crate::ops::change_axes::AxisOp::Rm(i),
134 &[wire],
135 )?[0];
136 }
137 wire = patch.wire_node(
138 format!("{}-add_axis", node.name),
139 crate::ops::change_axes::AxisOp::Add(0),
140 &[wire],
141 )?[0];
142 patch.shunt_outside(model, node.id.into(), wire)?;
143 return Ok(Some(patch));
144 }
145 }
146 Ok(None)
147 }
148}