1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
use crate::internal::*; use tract_ndarray::prelude::*; #[derive(Debug, Clone, new, Hash)] pub struct GatherNd { pub batch_dims: usize, } impl_dyn_hash!(GatherNd); impl GatherNd { fn compute_shape<D: DimLike>( &self, data_shape: &[D], indices_shape: &[D], ) -> TractResult<TVec<D>> { let mut shape: TVec<D> = indices_shape.into(); let n = shape.pop().unwrap().to_usize()?; shape.extend(data_shape[n + self.batch_dims as usize..].iter().cloned()); Ok(shape) } unsafe fn eval_t<T: Datum>( &self, output: &mut Tensor, data: &Tensor, indices: &ArrayViewD<i32>, ) { let batch_dims = self.batch_dims as usize; assert_eq!(output.shape()[..batch_dims], data.shape()[..batch_dims]); assert_eq!(output.shape()[..batch_dims], indices.shape()[..batch_dims]); let batch_size = data.shape().iter().take(batch_dims).product(); let n = indices.shape()[indices.ndim() - 1]; let remaining = indices.shape().iter().skip(batch_dims).rev().skip(1).product(); let indices_shape_op = tvec!(batch_size, remaining, n); let reshaped_indices: ArrayViewD<i32> = indices.view().into_shape(&*indices_shape_op).unwrap(); let mut data_shape_op: TVec<usize> = data.shape().iter().skip(batch_dims).copied().collect(); data_shape_op.insert(0, batch_size); let reshaped_data = data.to_array_view_unchecked::<T>().into_shape(&*data_shape_op).unwrap(); let mut output_shape_op: TVec<usize> = data.shape().iter().skip(n + batch_dims).copied().collect(); output_shape_op.insert(0, batch_size * remaining); let mut output = output.to_array_view_mut_unchecked::<T>().into_shape(&*output_shape_op).unwrap(); for b in 0..batch_size { let mut i = reshaped_data.view(); i.index_axis_inplace(Axis(0), b); let mut coords = reshaped_indices.view(); coords.index_axis_inplace(Axis(0), b); for ix in 0..remaining { let mut coords = coords.view(); coords.index_axis_inplace(Axis(0), ix); let mut i = i.view(); for x in coords { i.index_axis_inplace(Axis(0), *x as usize); } let mut o = output.view_mut(); o.index_axis_inplace(Axis(0), b * remaining + ix); o.assign(&i); } } } } impl Op for GatherNd { fn name(&self) -> Cow<str> { "GatherNd".into() } op_core!(); op_as_typed_op!(); } impl EvalOp for GatherNd { fn is_stateless(&self) -> bool { true } fn eval(&self, mut inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> { let (data, indices) = args_2!(inputs); let shape = self.compute_shape(&data.shape(), &indices.shape())?; let indices = indices.cast_to::<i32>()?; let indices = indices.to_array_view::<i32>()?; unsafe { let mut output = Tensor::uninitialized_dt(data.datum_type(), &*shape)?; dispatch_datum_by_size!(Self::eval_t(data.datum_type())( self, &mut output, &data, &indices )); Ok(tvec!(output.into_arc_tensor())) } } } impl TypedOp for GatherNd { as_op!(); fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> { let shape = self.compute_shape(&inputs[0].shape.to_tvec(), &inputs[1].shape.to_tvec())?; Ok(tvec!(TypedFact::dt_shape(inputs[0].datum_type, &shape))) } fn declutter( &self, model: &TypedModel, node: &TypedNode, ) -> TractResult<Option<TypedModelPatch>> { if let Some(indices) = &model.outlet_fact(node.inputs[1])?.konst { if indices.rank() == 2 && indices.shape()[0] == 1 { let mut patch = TypedModelPatch::default(); let mut wire = patch.tap_model(model, node.inputs[0])?; for (axis, &i) in indices.cast_to::<i32>()?.as_slice::<i32>()?.iter().enumerate() { wire = patch.wire_node( format!("{}-slice-axis-{}", node.name, axis), crate::ops::array::Slice::new(axis, i as usize, (i + 1) as usize), &[wire], )?[0]; } for i in (0..indices.shape()[1]).rev() { wire = patch.wire_node( format!("{}-remove_axis_{}", node.name, i), crate::ops::change_axes::AxisOp::Rm(i), &[wire], )?[0]; } wire = patch.wire_node( format!("{}-add_axis", node.name), crate::ops::change_axes::AxisOp::Add(0), &[wire], )?[0]; patch.shunt_outside(model, node.id.into(), wire)?; return Ok(Some(patch)); } } Ok(None) } }