tract_core/ops/array/
gather.rs1use crate::internal::*;
2use crate::ops::einsum::block_quant_aware_input_shape;
3use crate::ops::matmul::pack::OptSimpleMatMulPack;
4use ndarray::*;
5use tract_linalg::block_quant::BlockQuantStorage;
6use tract_linalg::mmm::{MMMInputValue, PackedMatrixStorage};
7
8#[derive(Debug, Clone, Hash, PartialEq, Eq)]
9pub struct Gather {
10 pub axis: usize,
11 pub output_type: Option<DatumType>,
12}
13
14impl Op for Gather {
15 fn name(&self) -> StaticName {
16 "Gather".into()
17 }
18
19 op_as_typed_op!();
20}
21
22impl Gather {
23 pub fn new(axis: usize) -> Gather {
24 Gather { axis, output_type: None }
25 }
26
27 pub fn compute_output_shape<D: DimLike>(
28 &self,
29 input_shape: &[D],
30 indices_shape: &[D],
31 ) -> TractResult<TVec<D>> {
32 ensure!(input_shape.len() > self.axis);
33 let mut output_shape: TVec<D> = input_shape[..self.axis].into();
34 output_shape.extend(indices_shape.iter().cloned());
35 output_shape.extend(input_shape[self.axis + 1..].iter().cloned());
36 Ok(output_shape)
37 }
38
39 fn eval_t<T: Datum>(&self, data: TValue, indices: &TValue) -> TractResult<Tensor> {
40 let data_plain = data.try_as_plain()?;
41 let data_view = unsafe { data_plain.to_array_view_unchecked::<T>() };
42 let indices = indices.to_plain_array_view::<i64>()?;
43 let output_shape = &*self.compute_output_shape(data.shape(), indices.shape())?;
44 let mut output = unsafe { Tensor::uninitialized::<T>(output_shape)? };
45 let mut output_plain = output.try_as_plain_mut()?;
46 let mut output_view = output_plain.to_array_view_mut::<T>()?;
47
48 let data_shape = data.shape();
49 let data_axis = self.axis;
50
51 let block_len = data_shape[data_axis + 1..].iter().product::<usize>();
52
53 let can_block_copy = data_shape[..data_axis].iter().all(|&d| d == 1)
54 && output_shape[..data_axis].iter().all(|&d| d == 1)
55 && data_view.is_standard_layout()
56 && output_view.is_standard_layout();
57
58 if can_block_copy {
59 let mut out_offset = 0;
60 let input_slice = data_view.as_slice().unwrap();
61 let output_slice = &mut output_view.as_slice_mut().unwrap();
62 for idx_coords in indices.indexed_iter() {
63 let index = *idx_coords.1;
64 let axis_len = data_shape[data_axis] as i64;
65 let resolved_index = if index < 0 { index + axis_len } else { index };
66 let resolved_index = resolved_index as usize;
67
68 let input_offset = resolved_index * block_len;
69
70 output_slice[out_offset..out_offset + block_len]
71 .clone_from_slice(&input_slice[input_offset..input_offset + block_len]);
72 out_offset += block_len;
73 }
74 } else {
75 let ic_len = self.axis + 1 + output_shape.len() - (self.axis + indices.ndim());
76 let mut icoords = vec![0; ic_len];
77 let axis = self.axis;
78 for coords in tract_ndarray::indices(output_shape) {
79 let ocoords = coords.as_array_view();
80 let ocoords = ocoords.as_slice().unwrap();
81
82 let kcoords = &ocoords[self.axis..][..indices.ndim()];
83 let k = indices[kcoords];
84 let k = if k < 0 { k + data_view.shape()[self.axis] as i64 } else { k } as usize;
85 icoords[0..axis].copy_from_slice(&ocoords[..self.axis]);
86 icoords[self.axis] = k;
87 icoords[self.axis + 1..].clone_from_slice(&ocoords[self.axis + indices.ndim()..]);
88 output_view[ocoords] =
89 data_view.get(&*icoords).cloned().context("Invalid gather")?;
90 }
91 unsafe { output.set_datum_type(data.datum_type()) };
92 }
93 Ok(output)
94 }
95
96 fn eval_bq<F: Datum>(
97 &self,
98 data: &BlockQuantStorage,
99 m: usize,
100 k: usize,
101 indices: &TValue,
102 ) -> TractResult<Tensor> {
103 ensure!(self.axis == 0);
104 let data_shape = &[m, k];
105 let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
106 let mut output = unsafe { Tensor::uninitialized::<F>(output_shape)? };
107 let indices_plain = indices.try_as_plain()?;
108 let indices_slice = indices_plain.as_slice::<i64>()?;
109 let vector_len = k;
110 let blob = data.value();
111
112 let block_len = data.format().block_len();
113 let block_bytes = data.format().block_bytes();
114 if F::datum_type() == f16::datum_type() {
115 let mut output_plain = output.try_as_plain_mut()?;
116 let output_slice = output_plain.as_slice_mut::<f16>()?;
117 for (pos, ix) in indices_slice.iter().enumerate() {
118 let slice = &mut output_slice[pos * vector_len..][..vector_len];
119 for i in (0..vector_len).step_by(block_len) {
120 let offset = k * *ix as usize + i;
121 let block_id = offset / block_len;
122 data.format().dequant_block_f16(
123 &blob[block_id * block_bytes..][..block_bytes],
124 &mut slice[i..i + block_len],
125 );
126 }
127 }
128 } else {
129 let mut output_plain = output.try_as_plain_mut()?;
130 let output_slice = output_plain.as_slice_mut::<f32>()?;
131 for (pos, ix) in indices_slice.iter().enumerate() {
132 let slice = &mut output_slice[pos * vector_len..][..vector_len];
133 for i in (0..vector_len).step_by(block_len) {
134 let offset = k * *ix as usize + i;
135 let block_id = offset / block_len;
136 data.format().dequant_block_f32(
137 &blob[block_id * block_bytes..][..block_bytes],
138 &mut slice[i..i + block_len],
139 );
140 }
141 }
142 }
143 Ok(output)
144 }
145
146 fn eval_input_store<F: Datum>(
147 &self,
148 data: &dyn MMMInputValue,
149 indices: &TValue,
150 ) -> TractResult<Tensor> {
151 ensure!(self.axis == 0);
152 let data_shape = &[data.mn(), data.k()];
153 let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
154 let mut output = unsafe { Tensor::uninitialized::<F>(output_shape)? };
155 let indices_plain = indices.try_as_plain()?;
156 let indices_slice = indices_plain.as_slice::<i64>()?;
157 let vector_len = data_shape[1];
158 if F::datum_type() == f16::datum_type() {
159 let mut output_plain = output.try_as_plain_mut()?;
160 let output_slice = output_plain.as_slice_mut::<f16>()?;
161 for (pos, m) in indices_slice.iter().enumerate() {
162 let slice = &mut output_slice[pos * vector_len..][..vector_len];
163 data.extract_at_mn_f16(*m as usize, slice)?;
164 }
165 } else {
166 let mut output_plain = output.try_as_plain_mut()?;
167 let output_slice = output_plain.as_slice_mut::<f32>()?;
168 for (pos, m) in indices_slice.iter().enumerate() {
169 let slice = &mut output_slice[pos * vector_len..][..vector_len];
170 data.extract_at_mn_f32(*m as usize, slice)?;
171 }
172 }
173 Ok(output)
174 }
175}
176
177impl TypedOp for Gather {
178 as_op!();
179
180 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
181 if let Some(dt) = self.output_type {
182 ensure!(
183 inputs[0].is_exotic() || inputs[0].datum_type == dt,
184 "Inconsistent datum_type in Gather: attribute is {:?}, but inputs[0] is {:?}",
185 dt,
186 inputs[0].datum_type
187 );
188 } else {
189 ensure!(
190 inputs[0].is_plain(),
191 "Gather applied to compressed data requires an explicit datum_type attribute for its output"
192 );
193 }
194 ensure!(inputs[1].datum_type == i64::datum_type());
195 if inputs[0].is_exotic() {
196 let data_shape = block_quant_aware_input_shape(inputs[0])?;
197 Ok(tvec!(
198 self.output_type
199 .unwrap()
200 .fact(&*self.compute_output_shape(&data_shape, &inputs[1].shape)?)
201 ))
202 } else {
203 Ok(tvec!(
204 inputs[0]
205 .datum_type
206 .fact(&*self.compute_output_shape(&inputs[0].shape, &inputs[1].shape)?)
207 ))
208 }
209 }
210
211 fn declutter(
212 &self,
213 model: &TypedModel,
214 node: &TypedNode,
215 ) -> TractResult<Option<TypedModelPatch>> {
216 let (input_fact, indices_fact) = args_2!(model.node_input_facts(node.id)?);
217 if let Some(indices) = indices_fact.konst.as_ref()
218 && indices.rank() == 1
219 && indices.len() == 1
220 && input_fact.is_plain()
221 && input_fact.datum_type.is_number()
222 {
223 let mut patch = TypedModelPatch::default();
224 let mut wire = patch.tap_model(model, node.inputs[0])?;
225 let index = indices.cast_to_scalar::<i64>()?;
226 let index = if index < 0 {
227 let data_fact = model.outlet_fact(node.inputs[0])?;
228 data_fact.shape[self.axis].clone() + index.to_dim()
229 } else {
230 index.to_dim()
231 };
232 wire = patch.wire_node(
233 format!("{}.slice", node.name),
234 crate::ops::array::Slice { axis: self.axis, start: index.clone(), end: index + 1 },
235 &[wire],
236 )?[0];
237 patch.shunt_outside(model, node.id.into(), wire)?;
238 return Ok(Some(patch));
239 }
240 if input_fact.konst.is_some() {
241 if let Some(sibling) = model
243 .outlet_successors(node.inputs[0])
244 .iter()
245 .find(|o| o.node != node.id && model.node(o.node).op_is::<OptSimpleMatMulPack>())
246 {
247 let mut patch = TypedModelPatch::default();
248 let mut taps = patch.taps(model, &node.inputs)?;
249 taps[0] = patch.tap_model(model, sibling.node.into())?;
250 let wire = patch.wire_node(&node.name, self.clone(), &taps)?[0];
251 patch.shunt_outside(model, node.id.into(), wire)?;
252 return Ok(Some(patch));
253 }
254 }
255 Ok(None)
256 }
257}
258
259impl EvalOp for Gather {
260 fn is_stateless(&self) -> bool {
261 true
262 }
263
264 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
265 let (data, indices) = args_2!(inputs);
266 let result = if let Some(bqs) = data.storage_as::<BlockQuantStorage>() {
267 let dt = self.output_type.unwrap();
268 let m = data.shape()[data.rank() - 2];
269 let k = *data.shape().last().unwrap();
270 dispatch_floatlike!(Self::eval_bq(dt)(self, bqs, m, k, &indices))?
271 } else if let Some(storage) = data.storage_as::<PackedMatrixStorage>()
272 && storage.batch_shape().is_empty()
273 {
274 let dt = self.output_type.unwrap();
275 let data_val = storage.value();
276 dispatch_floatlike!(Self::eval_input_store(dt)(self, data_val, &indices))?
277 } else {
278 dispatch_datum!(Self::eval_t(data.datum_type())(self, data, &indices))?
279 };
280 Ok(tvec!(result.into_tvalue()))
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[test]
289 fn test_should_gather_scalar_index() {
290 let data = Tensor::from(arr1(&[1i64, 2, 3]));
291 let gatherer = Gather::new(0);
292 for idx in 2..3 {
293 let index = Tensor::from(arr0(idx));
294 let outputs =
295 gatherer.eval(tvec![data.clone().into_tvalue(), index.into_tvalue()]).unwrap();
296 let output = &outputs[0];
297 assert_eq!(output.shape().len(), 0);
298 assert_eq!(*output.try_as_plain().unwrap().to_scalar::<i64>().unwrap(), idx + 1);
299 }
300 }
301}