tract_core/ops/array/
gather_nd.rs1use crate::internal::*;
2use tract_ndarray::prelude::*;
3
4#[derive(Debug, Clone, new, Hash)]
5pub struct GatherNd {
6 pub batch_dims: usize,
7}
8
9impl GatherNd {
10 fn compute_shape<D: DimLike>(
11 &self,
12 data_shape: &[D],
13 indices_shape: &[D],
14 ) -> TractResult<TVec<D>> {
15 let mut shape: TVec<D> = indices_shape.into();
16 let n = shape.pop().unwrap().to_usize()?;
17 shape.extend(data_shape[n + self.batch_dims..].iter().cloned());
18 Ok(shape)
19 }
20
21 unsafe fn eval_t<T: Datum>(
22 &self,
23 output: &mut Tensor,
24 data: &Tensor,
25 indices: &ArrayViewD<i32>,
26 ) {
27 let batch_dims = self.batch_dims;
28 assert_eq!(output.shape()[..batch_dims], data.shape()[..batch_dims]);
29 assert_eq!(output.shape()[..batch_dims], indices.shape()[..batch_dims]);
30 let batch_size = data.shape().iter().take(batch_dims).product();
31 let n = indices.shape()[indices.ndim() - 1];
32
33 let remaining = indices.shape().iter().skip(batch_dims).rev().skip(1).product();
34 let indices_shape_op = tvec!(batch_size, remaining, n);
35 let reshaped_indices: ArrayViewD<i32> =
36 indices.view().into_shape_with_order(&*indices_shape_op).unwrap();
37
38 let mut data_shape_op: TVec<usize> =
39 data.shape().iter().skip(batch_dims).copied().collect();
40 data_shape_op.insert(0, batch_size);
41 let reshaped_data = unsafe {
42 data.to_array_view_unchecked::<T>().into_shape_with_order(&*data_shape_op).unwrap()
43 };
44
45 let mut output_shape_op: TVec<usize> =
46 data.shape().iter().skip(n + batch_dims).copied().collect();
47 output_shape_op.insert(0, batch_size * remaining);
48 let mut output = unsafe {
49 output
50 .to_array_view_mut_unchecked::<T>()
51 .into_shape_with_order(&*output_shape_op)
52 .unwrap()
53 };
54
55 for b in 0..batch_size {
56 let mut i = reshaped_data.view();
57 i.index_axis_inplace(Axis(0), b);
58 let mut coords = reshaped_indices.view();
59 coords.index_axis_inplace(Axis(0), b);
60
61 for ix in 0..remaining {
62 let mut coords = coords.view();
63 coords.index_axis_inplace(Axis(0), ix);
64
65 let mut i = i.view();
66 for x in coords {
67 i.index_axis_inplace(Axis(0), *x as usize);
68 }
69
70 let mut o = output.view_mut();
71 o.index_axis_inplace(Axis(0), b * remaining + ix);
72 o.assign(&i);
73 }
74 }
75 }
76}
77
78impl Op for GatherNd {
79 fn name(&self) -> StaticName {
80 "GatherNd".into()
81 }
82
83 op_as_typed_op!();
84}
85
86impl EvalOp for GatherNd {
87 fn is_stateless(&self) -> bool {
88 true
89 }
90
91 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
92 let (data, indices) = args_2!(inputs);
93 let shape = self.compute_shape(data.shape(), indices.shape())?;
94 let indices = indices.cast_to::<i32>()?;
95 let indices = indices.to_array_view::<i32>()?;
96 unsafe {
97 let mut output = Tensor::uninitialized_dt(data.datum_type(), &shape)?;
98 dispatch_datum_by_size!(Self::eval_t(data.datum_type())(
99 self,
100 &mut output,
101 &data,
102 &indices
103 ));
104 Ok(tvec!(output.into_tvalue()))
105 }
106 }
107}
108
109impl TypedOp for GatherNd {
110 as_op!();
111
112 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
113 let shape = self.compute_shape(&inputs[0].shape.to_tvec(), &inputs[1].shape.to_tvec())?;
114 Ok(tvec!(inputs[0].datum_type.fact(&shape)))
115 }
116
117 fn declutter(
118 &self,
119 model: &TypedModel,
120 node: &TypedNode,
121 ) -> TractResult<Option<TypedModelPatch>> {
122 if let Some(indices) = &model.outlet_fact(node.inputs[1])?.konst {
123 if indices.rank() == 2 && indices.shape()[0] == 1 {
124 let mut patch = TypedModelPatch::default();
125 let mut wire = patch.tap_model(model, node.inputs[0])?;
126 for (axis, &i) in indices.cast_to::<i32>()?.as_slice::<i32>()?.iter().enumerate() {
127 wire = patch.wire_node(
128 format!("{}-slice-axis-{}", node.name, axis),
129 crate::ops::array::Slice::new(axis, i as usize, (i + 1) as usize),
130 &[wire],
131 )?[0];
132 }
133 for i in (0..indices.shape()[1]).rev() {
134 wire = patch.wire_node(
135 format!("{}-remove_axis_{}", node.name, i),
136 crate::ops::change_axes::AxisOp::Rm(i),
137 &[wire],
138 )?[0];
139 }
140 wire = patch.wire_node(
141 format!("{}-add_axis", node.name),
142 crate::ops::change_axes::AxisOp::Add(0),
143 &[wire],
144 )?[0];
145 patch.shunt_outside(model, node.id.into(), wire)?;
146 return Ok(Some(patch));
147 }
148 }
149 Ok(None)
150 }
151}