1use crate::internal::*;
2use crate::ops::einsum::block_quant_aware_input_shape;
3use crate::ops::matmul::pack::OptSimpleMatMulPack;
4use ndarray::*;
5use tract_linalg::block_quant::BlockQuantStorage;
6use tract_linalg::mmm::{MMMInputValue, PackedMatrixStorage};
7
8#[derive(Debug, Clone, Hash, PartialEq, Eq)]
9pub struct Gather {
10 pub axis: usize,
11 pub output_type: Option<DatumType>,
12}
13
14impl Op for Gather {
15 fn name(&self) -> StaticName {
16 "Gather".into()
17 }
18
19 op_as_typed_op!();
20}
21
22impl Gather {
23 pub fn new(axis: usize) -> Gather {
24 Gather { axis, output_type: None }
25 }
26
27 pub fn compute_output_shape<D: DimLike>(
28 &self,
29 input_shape: &[D],
30 indices_shape: &[D],
31 ) -> TractResult<TVec<D>> {
32 ensure!(input_shape.len() > self.axis);
33 let mut output_shape: TVec<D> = input_shape[..self.axis].into();
34 output_shape.extend(indices_shape.iter().cloned());
35 output_shape.extend(input_shape[self.axis + 1..].iter().cloned());
36 Ok(output_shape)
37 }
38
39 fn eval_t<T: Datum>(&self, data: TValue, indices: &TValue) -> TractResult<Tensor> {
40 let data_plain = data.try_as_plain()?;
41 let data_view = unsafe { data_plain.to_array_view_unchecked::<T>() };
42 let indices = indices.to_plain_array_view::<i64>()?;
43 let output_shape = &*self.compute_output_shape(data.shape(), indices.shape())?;
44 let mut output = unsafe { Tensor::uninitialized::<T>(output_shape)? };
45 let mut output_plain = output.try_as_plain_mut()?;
46 let mut output_view = output_plain.to_array_view_mut::<T>()?;
47
48 let data_shape = data.shape();
49 let data_axis = self.axis;
50
51 let block_len = data_shape[data_axis + 1..].iter().product::<usize>();
52
53 let can_block_copy = data_shape[..data_axis].iter().all(|&d| d == 1)
54 && output_shape[..data_axis].iter().all(|&d| d == 1)
55 && data_view.is_standard_layout()
56 && output_view.is_standard_layout();
57
58 if can_block_copy {
59 let mut out_offset = 0;
60 let input_slice = data_view.as_slice().unwrap();
61 let output_slice = &mut output_view.as_slice_mut().unwrap();
62 for idx_coords in indices.indexed_iter() {
63 let index = *idx_coords.1;
64 let axis_len = data_shape[data_axis] as i64;
65 let resolved_index = if index < 0 { index + axis_len } else { index };
66 let resolved_index = resolved_index as usize;
67
68 let input_offset = resolved_index * block_len;
69
70 output_slice[out_offset..out_offset + block_len]
71 .clone_from_slice(&input_slice[input_offset..input_offset + block_len]);
72 out_offset += block_len;
73 }
74 } else {
75 let ic_len = self.axis + 1 + output_shape.len() - (self.axis + indices.ndim());
76 let mut icoords = vec![0; ic_len];
77 let axis = self.axis;
78 for coords in tract_ndarray::indices(output_shape) {
79 let ocoords = coords.as_array_view();
80 let ocoords = ocoords.as_slice().unwrap();
81
82 let kcoords = &ocoords[self.axis..][..indices.ndim()];
83 let k = indices[kcoords];
84 let k = if k < 0 { k + data_view.shape()[self.axis] as i64 } else { k } as usize;
85 icoords[0..axis].copy_from_slice(&ocoords[..self.axis]);
86 icoords[self.axis] = k;
87 icoords[self.axis + 1..].clone_from_slice(&ocoords[self.axis + indices.ndim()..]);
88 output_view[ocoords] =
89 data_view.get(&*icoords).cloned().context("Invalid gather")?;
90 }
91 unsafe { output.set_datum_type(data.datum_type()) };
92 }
93 Ok(output)
94 }
95
96 fn eval_bq<F: Datum>(
97 &self,
98 data: &BlockQuantStorage,
99 m: usize,
100 k: usize,
101 indices: &TValue,
102 ) -> TractResult<Tensor> {
103 ensure!(self.axis == 0);
104 let data_shape = &[m, k];
105 let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
106 let mut output = unsafe { Tensor::uninitialized::<F>(output_shape)? };
107 let indices_plain = indices.try_as_plain()?;
108 let indices_slice = indices_plain.as_slice::<i64>()?;
109 let vector_len = k;
110 let blob = data.value();
111
112 let block_len = data.format().block_len();
113 let block_bytes = data.format().block_bytes();
114 if F::datum_type() == f16::datum_type() {
115 let mut output_plain = output.try_as_plain_mut()?;
116 let output_slice = output_plain.as_slice_mut::<f16>()?;
117 for (pos, ix) in indices_slice.iter().enumerate() {
118 let slice = &mut output_slice[pos * vector_len..][..vector_len];
119 for i in (0..vector_len).step_by(block_len) {
120 let offset = k * *ix as usize + i;
121 let block_id = offset / block_len;
122 data.format().dequant_block_f16(
123 &blob[block_id * block_bytes..][..block_bytes],
124 &mut slice[i..i + block_len],
125 );
126 }
127 }
128 } else {
129 let mut output_plain = output.try_as_plain_mut()?;
130 let output_slice = output_plain.as_slice_mut::<f32>()?;
131 for (pos, ix) in indices_slice.iter().enumerate() {
132 let slice = &mut output_slice[pos * vector_len..][..vector_len];
133 for i in (0..vector_len).step_by(block_len) {
134 let offset = k * *ix as usize + i;
135 let block_id = offset / block_len;
136 data.format().dequant_block_f32(
137 &blob[block_id * block_bytes..][..block_bytes],
138 &mut slice[i..i + block_len],
139 );
140 }
141 }
142 }
143 Ok(output)
144 }
145
146 fn eval_input_store<F: Datum>(
147 &self,
148 data: &dyn MMMInputValue,
149 indices: &TValue,
150 ) -> TractResult<Tensor> {
151 ensure!(self.axis == 0);
152 let data_shape = &[data.mn(), data.k()];
153 let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
154 let mut output = unsafe { Tensor::uninitialized::<F>(output_shape)? };
155 let indices_plain = indices.try_as_plain()?;
156 let indices_slice = indices_plain.as_slice::<i64>()?;
157 let vector_len = data_shape[1];
158 if F::datum_type() == f16::datum_type() {
159 let mut output_plain = output.try_as_plain_mut()?;
160 let output_slice = output_plain.as_slice_mut::<f16>()?;
161 for (pos, m) in indices_slice.iter().enumerate() {
162 let slice = &mut output_slice[pos * vector_len..][..vector_len];
163 data.extract_at_mn_f16(*m as usize, slice)?;
164 }
165 } else {
166 let mut output_plain = output.try_as_plain_mut()?;
167 let output_slice = output_plain.as_slice_mut::<f32>()?;
168 for (pos, m) in indices_slice.iter().enumerate() {
169 let slice = &mut output_slice[pos * vector_len..][..vector_len];
170 data.extract_at_mn_f32(*m as usize, slice)?;
171 }
172 }
173 Ok(output)
174 }
175}
176
177impl TypedOp for Gather {
178 as_op!();
179
180 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
181 if let Some(dt) = self.output_type {
182 ensure!(
183 inputs[0].is_exotic() || inputs[0].datum_type == dt,
184 "Inconsistent datum_type in Gather: attribute is {:?}, but inputs[0] is {:?}",
185 dt,
186 inputs[0].datum_type
187 );
188 } else {
189 ensure!(
190 inputs[0].is_plain(),
191 "Gather applied to compressed data requires an explicit datum_type attribute for its output"
192 );
193 }
194 ensure!(inputs[1].datum_type == i64::datum_type());
195 if inputs[0].is_exotic() {
196 let data_shape = block_quant_aware_input_shape(inputs[0])?;
197 Ok(tvec!(
198 self.output_type
199 .unwrap()
200 .fact(&*self.compute_output_shape(&data_shape, &inputs[1].shape)?)
201 ))
202 } else {
203 Ok(tvec!(
204 inputs[0]
205 .datum_type
206 .fact(&*self.compute_output_shape(&inputs[0].shape, &inputs[1].shape)?)
207 ))
208 }
209 }
210
211 fn axes_mapping(
212 &self,
213 inputs: &[&TypedFact],
214 _outputs: &[&TypedFact],
215 ) -> TractResult<AxesMapping> {
216 if !inputs[0].is_plain() {
224 return AxesMapping::disconnected(
225 inputs,
226 &[&inputs[0].datum_type.fact(&[0i64.to_dim()])],
227 );
228 }
229 let data_rank = inputs[0].rank();
230 let indices_rank = inputs[1].rank();
231 let mut axes: TVec<crate::axes::Axis> = tvec!();
232 let mut alphabet = 'a'..;
233 for k in 0..self.axis {
234 axes.push(
235 crate::axes::Axis::new(alphabet.next().unwrap(), 2, 1).input(0, k).output(0, k),
236 );
237 }
238 axes.push(crate::axes::Axis::new(alphabet.next().unwrap(), 2, 1).input(0, self.axis));
239 for k in self.axis + 1..data_rank {
240 let out_pos = k - 1 + indices_rank;
241 axes.push(
242 crate::axes::Axis::new(alphabet.next().unwrap(), 2, 1)
243 .input(0, k)
244 .output(0, out_pos),
245 );
246 }
247 for k in 0..indices_rank {
248 let out_pos = self.axis + k;
249 axes.push(
250 crate::axes::Axis::new(alphabet.next().unwrap(), 2, 1)
251 .input(1, k)
252 .output(0, out_pos),
253 );
254 }
255 AxesMapping::new(2, 1, axes)
256 }
257
258 fn declutter(
259 &self,
260 model: &TypedModel,
261 node: &TypedNode,
262 ) -> TractResult<Option<TypedModelPatch>> {
263 let (input_fact, indices_fact) = args_2!(model.node_input_facts(node.id)?);
264 if let Some(indices) = indices_fact.konst.as_ref()
265 && indices.rank() == 1
266 && indices.len() == 1
267 && input_fact.is_plain()
268 && input_fact.datum_type.is_number()
269 {
270 let mut patch = TypedModelPatch::default();
271 let mut wire = patch.tap_model(model, node.inputs[0])?;
272 let index = indices.cast_to_scalar::<i64>()?;
273 let index = if index < 0 {
274 let data_fact = model.outlet_fact(node.inputs[0])?;
275 data_fact.shape[self.axis].clone() + index.to_dim()
276 } else {
277 index.to_dim()
278 };
279 wire = patch.wire_node(
280 format!("{}.slice", node.name),
281 crate::ops::array::Slice { axis: self.axis, start: index.clone(), end: index + 1 },
282 &[wire],
283 )?[0];
284 patch.shunt_outside(model, node.id.into(), wire)?;
285 return Ok(Some(patch));
286 }
287 if input_fact.konst.is_some() {
288 if let Some(sibling) = model
290 .outlet_successors(node.inputs[0])
291 .iter()
292 .find(|o| o.node != node.id && model.node(o.node).op_is::<OptSimpleMatMulPack>())
293 {
294 let mut patch = TypedModelPatch::default();
295 let mut taps = patch.taps(model, &node.inputs)?;
296 taps[0] = patch.tap_model(model, sibling.node.into())?;
297 let wire = patch.wire_node(&node.name, self.clone(), &taps)?[0];
298 patch.shunt_outside(model, node.id.into(), wire)?;
299 return Ok(Some(patch));
300 }
301 }
302 Ok(None)
303 }
304}
305
306impl EvalOp for Gather {
307 fn is_stateless(&self) -> bool {
308 true
309 }
310
311 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
312 let (data, indices) = args_2!(inputs);
313 let result = if let Some(bqs) = data.storage_as::<BlockQuantStorage>() {
314 let dt = self.output_type.unwrap();
315 let m = data.shape()[data.rank() - 2];
316 let k = *data.shape().last().unwrap();
317 dispatch_floatlike!(Self::eval_bq(dt)(self, bqs, m, k, &indices))?
318 } else if let Some(storage) = data.storage_as::<PackedMatrixStorage>()
319 && storage.batch_shape().is_empty()
320 {
321 let dt = self.output_type.unwrap();
322 let data_val = storage.value();
323 dispatch_floatlike!(Self::eval_input_store(dt)(self, data_val, &indices))?
324 } else {
325 dispatch_datum!(Self::eval_t(data.datum_type())(self, data, &indices))?
326 };
327 Ok(tvec!(result.into_tvalue()))
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 #[test]
336 fn test_should_gather_scalar_index() {
337 let data = Tensor::from(arr1(&[1i64, 2, 3]));
338 let gatherer = Gather::new(0);
339 for idx in 2..3 {
340 let index = Tensor::from(arr0(idx));
341 let outputs =
342 gatherer.eval(tvec![data.clone().into_tvalue(), index.into_tvalue()]).unwrap();
343 let output = &outputs[0];
344 assert_eq!(output.shape().len(), 0);
345 assert_eq!(*output.try_as_plain().unwrap().to_scalar::<i64>().unwrap(), idx + 1);
346 }
347 }
348}