tract_core/ops/array/
gather.rs1use crate::internal::*;
2use crate::ops::einsum::block_quant_aware_input_shape;
3use ndarray::*;
4use tract_linalg::block_quant::BlockQuantValue;
5use tract_linalg::mmm::MMMInputValue;
6
7#[derive(Debug, Clone, Hash, PartialEq)]
8pub struct Gather {
9 pub axis: usize,
10 pub output_type: Option<DatumType>,
11}
12
13impl Op for Gather {
14 fn name(&self) -> Cow<str> {
15 "Gather".into()
16 }
17
18 op_as_typed_op!();
19 impl_op_same_as!();
20}
21
22impl Gather {
23 pub fn new(axis: usize) -> Gather {
24 Gather { axis, output_type: None }
25 }
26
27 pub fn compute_output_shape<D: DimLike>(
28 &self,
29 input_shape: &[D],
30 indices_shape: &[D],
31 ) -> TractResult<TVec<D>> {
32 ensure!(input_shape.len() > self.axis);
33 let mut output_shape: TVec<D> = input_shape[..self.axis].into();
34 output_shape.extend(indices_shape.iter().cloned());
35 output_shape.extend(input_shape[self.axis + 1..].iter().cloned());
36 Ok(output_shape)
37 }
38
39 fn eval_t<T: Datum>(&self, data: TValue, indices: &TValue) -> TractResult<Tensor> {
40 let data_view = unsafe { data.to_array_view_unchecked::<T>() }; let indices = indices.to_array_view::<i64>()?;
42 let output_shape = &*self.compute_output_shape(data.shape(), indices.shape())?;
43 let mut output = unsafe { Tensor::uninitialized::<T>(output_shape)? };
44 let mut output_view = output.to_array_view_mut::<T>()?;
45 for coords in tract_ndarray::indices(output_shape) {
46 let ocoords = coords.as_array_view();
47 let ocoords = ocoords.as_slice().unwrap();
48 let mut icoords: TVec<usize> = ocoords[0..self.axis].into();
49 let kcoords = &ocoords[self.axis..][..indices.ndim()];
50 let k = indices[kcoords];
51 let k = if k < 0 { k + data_view.shape()[self.axis] as i64 } else { k } as usize;
52 icoords.push(k);
53 icoords.extend(ocoords[self.axis + indices.ndim()..].iter().copied());
54 output_view[ocoords] = data_view.get(&*icoords).context("Invalid gather")?.clone();
55 }
56 unsafe { output.set_datum_type(data.datum_type()) };
57 Ok(output)
58 }
59
60 fn eval_bq<F: Datum>(&self, data: &BlockQuantValue, indices: &TValue) -> TractResult<Tensor> {
61 ensure!(self.axis == 0);
62 ensure!(data.fact.shape().len() == 2);
63 let data_shape = &data.fact.shape();
64 let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
65 let mut output = unsafe { Tensor::uninitialized::<F>(output_shape)? };
66 let indices_slice = indices.as_slice::<i64>()?;
67 let vector_len = data_shape[1];
68 if F::datum_type() == f16::datum_type() {
69 let output_slice = output.as_slice_mut::<f16>()?;
70 for (pos, ix) in indices_slice.iter().enumerate() {
71 let slice = &mut output_slice[pos * vector_len..][..vector_len];
72 for (i, slot) in slice.iter_mut().enumerate() {
73 let offset = data_shape[1] * *ix as usize + i;
74 *slot = data.fact.format.extract_at_offset_f16(&data.value, offset)
75 }
76 }
77 } else {
78 let output_slice = output.as_slice_mut::<f32>()?;
79 for (pos, ix) in indices_slice.iter().enumerate() {
80 let slice = &mut output_slice[pos * vector_len..][..vector_len];
81 for (i, slot) in slice.iter_mut().enumerate() {
82 let offset = data_shape[1] * *ix as usize + i;
83 *slot = data.fact.format.extract_at_offset_f32(&data.value, offset)
84 }
85 }
86 }
87 Ok(output)
88 }
89
90 fn eval_input_store<F: Datum>(
91 &self,
92 data: &dyn MMMInputValue,
93 indices: &TValue,
94 ) -> TractResult<Tensor> {
95 ensure!(self.axis == 0);
96 let data_shape = &[data.mn(), data.k()];
97 let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
98 let mut output = unsafe { Tensor::uninitialized::<F>(output_shape)? };
99 let indices_slice = indices.as_slice::<i64>()?;
100 let vector_len = data_shape[1];
101 if F::datum_type() == f16::datum_type() {
102 let output_slice = output.as_slice_mut::<f16>()?;
103 for (pos, m) in indices_slice.iter().enumerate() {
104 let slice = &mut output_slice[pos * vector_len..][..vector_len];
105 data.extract_at_mn_f16(*m as usize, slice)?;
106 }
107 } else {
108 let output_slice = output.as_slice_mut::<f32>()?;
109 for (pos, m) in indices_slice.iter().enumerate() {
110 let slice = &mut output_slice[pos * vector_len..][..vector_len];
111 data.extract_at_mn_f32(*m as usize, slice)?;
112 }
113 }
114 Ok(output)
115 }
116}
117
118impl TypedOp for Gather {
119 as_op!();
120
121 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
122 if let Some(dt) = self.output_type {
123 ensure!(
124 inputs[0].datum_type.is_opaque() || inputs[0].datum_type == dt,
125 "Inconsistent datum_type in Gather: attribute is {:?}, but inputs[0] is {:?}",
126 dt,
127 inputs[0].datum_type
128 );
129 } else {
130 ensure!(!inputs[0].datum_type.is_opaque(),
131 "Gather applied to compressed data requires an explicit datum_type attribute for its output");
132 }
133 ensure!(inputs[1].datum_type == i64::datum_type());
134 if inputs[0].datum_type.is_opaque() {
135 let data_shape = block_quant_aware_input_shape(inputs[0])?;
136 Ok(tvec!(self
137 .output_type
138 .unwrap()
139 .fact(&*self.compute_output_shape(&data_shape, &inputs[1].shape)?)))
140 } else {
141 Ok(tvec!(inputs[0]
142 .datum_type
143 .fact(&*self.compute_output_shape(&inputs[0].shape, &inputs[1].shape)?)))
144 }
145 }
146
147 fn declutter(
148 &self,
149 model: &TypedModel,
150 node: &TypedNode,
151 ) -> TractResult<Option<TypedModelPatch>> {
152 let (input_fact, indices_fact) = args_2!(model.node_input_facts(node.id)?);
153 if let Some(indices) = indices_fact.konst.as_ref() {
154 if indices.rank() == 1 && indices.len() == 1 && input_fact.datum_type.is_number() {
155 let mut patch = TypedModelPatch::default();
156 let mut wire = patch.tap_model(model, node.inputs[0])?;
157 let index = indices.cast_to_scalar::<i64>()?;
158 let index = if index < 0 {
159 let data_fact = model.outlet_fact(node.inputs[0])?;
160 data_fact.shape[self.axis].clone() + index.to_dim()
161 } else {
162 index.to_dim()
163 };
164 wire = patch.wire_node(
165 format!("{}.slice", node.name),
166 crate::ops::array::Slice {
167 axis: self.axis,
168 start: index.clone(),
169 end: index + 1,
170 },
171 &[wire],
172 )?[0];
173 patch.shunt_outside(model, node.id.into(), wire)?;
174 return Ok(Some(patch));
175 }
176 }
177 Ok(None)
178 }
179}
180
181impl EvalOp for Gather {
182 fn is_stateless(&self) -> bool {
183 true
184 }
185
186 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
187 let (data, indices) = args_2!(inputs);
188 let result = if let Ok(opaque) = data.to_scalar::<Opaque>() {
189 let dt = self.output_type.unwrap();
190 if let Some(data) = opaque.downcast_ref::<BlockQuantValue>() {
191 dispatch_floatlike!(Self::eval_bq(dt)(self, data, &indices))?
192 } else if let Some(data) = opaque.downcast_ref::<Box<dyn MMMInputValue>>() {
193 dispatch_floatlike!(Self::eval_input_store(dt)(self, &**data, &indices))?
194 } else {
195 bail!("Can't use Gather on {:?} input", data);
196 }
197 } else {
198 dispatch_datum_by_size!(Self::eval_t(data.datum_type())(self, data, &indices))?
199 };
200 Ok(tvec!(result.into_tvalue()))
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 #[test]
209 fn test_should_gather_scalar_index() {
210 let data = Tensor::from(arr1(&[1i64, 2, 3]));
211 let gatherer = Gather::new(0);
212 for idx in 2..3 {
213 let index = Tensor::from(arr0(idx));
214 let outputs =
215 gatherer.eval(tvec![data.clone().into_tvalue(), index.into_tvalue()]).unwrap();
216 let output = &outputs[0];
217 assert_eq!(output.shape().len(), 0);
218 assert_eq!(*output.to_scalar::<i64>().unwrap(), idx + 1);
219 }
220 }
221}