Skip to main content

tract_core/ops/array/
gather.rs

1use crate::internal::*;
2use crate::ops::einsum::block_quant_aware_input_shape;
3use crate::ops::matmul::pack::OptSimpleMatMulPack;
4use ndarray::*;
5use tract_linalg::block_quant::BlockQuantStorage;
6use tract_linalg::mmm::{MMMInputValue, PackedMatrixStorage};
7
8#[derive(Debug, Clone, Hash, PartialEq, Eq)]
9pub struct Gather {
10    pub axis: usize,
11    pub output_type: Option<DatumType>,
12}
13
14impl Op for Gather {
15    fn name(&self) -> StaticName {
16        "Gather".into()
17    }
18
19    op_as_typed_op!();
20}
21
22impl Gather {
23    pub fn new(axis: usize) -> Gather {
24        Gather { axis, output_type: None }
25    }
26
27    pub fn compute_output_shape<D: DimLike>(
28        &self,
29        input_shape: &[D],
30        indices_shape: &[D],
31    ) -> TractResult<TVec<D>> {
32        ensure!(input_shape.len() > self.axis);
33        let mut output_shape: TVec<D> = input_shape[..self.axis].into();
34        output_shape.extend(indices_shape.iter().cloned());
35        output_shape.extend(input_shape[self.axis + 1..].iter().cloned());
36        Ok(output_shape)
37    }
38
39    fn eval_t<T: Datum>(&self, data: TValue, indices: &TValue) -> TractResult<Tensor> {
40        let data_plain = data.try_as_plain()?;
41        let data_view = unsafe { data_plain.to_array_view_unchecked::<T>() };
42        let indices = indices.to_plain_array_view::<i64>()?;
43        let output_shape = &*self.compute_output_shape(data.shape(), indices.shape())?;
44        let mut output = unsafe { Tensor::uninitialized::<T>(output_shape)? };
45        let mut output_plain = output.try_as_plain_mut()?;
46        let mut output_view = output_plain.to_array_view_mut::<T>()?;
47
48        let data_shape = data.shape();
49        let data_axis = self.axis;
50
51        let block_len = data_shape[data_axis + 1..].iter().product::<usize>();
52
53        let can_block_copy = data_shape[..data_axis].iter().all(|&d| d == 1)
54            && output_shape[..data_axis].iter().all(|&d| d == 1)
55            && data_view.is_standard_layout()
56            && output_view.is_standard_layout();
57
58        if can_block_copy {
59            let mut out_offset = 0;
60            let input_slice = data_view.as_slice().unwrap();
61            let output_slice = &mut output_view.as_slice_mut().unwrap();
62            for idx_coords in indices.indexed_iter() {
63                let index = *idx_coords.1;
64                let axis_len = data_shape[data_axis] as i64;
65                let resolved_index = if index < 0 { index + axis_len } else { index };
66                let resolved_index = resolved_index as usize;
67
68                let input_offset = resolved_index * block_len;
69
70                output_slice[out_offset..out_offset + block_len]
71                    .clone_from_slice(&input_slice[input_offset..input_offset + block_len]);
72                out_offset += block_len;
73            }
74        } else {
75            let ic_len = self.axis + 1 + output_shape.len() - (self.axis + indices.ndim());
76            let mut icoords = vec![0; ic_len];
77            let axis = self.axis;
78            for coords in tract_ndarray::indices(output_shape) {
79                let ocoords = coords.as_array_view();
80                let ocoords = ocoords.as_slice().unwrap();
81
82                let kcoords = &ocoords[self.axis..][..indices.ndim()];
83                let k = indices[kcoords];
84                let k = if k < 0 { k + data_view.shape()[self.axis] as i64 } else { k } as usize;
85                icoords[0..axis].copy_from_slice(&ocoords[..self.axis]);
86                icoords[self.axis] = k;
87                icoords[self.axis + 1..].clone_from_slice(&ocoords[self.axis + indices.ndim()..]);
88                output_view[ocoords] =
89                    data_view.get(&*icoords).cloned().context("Invalid gather")?;
90            }
91            unsafe { output.set_datum_type(data.datum_type()) };
92        }
93        Ok(output)
94    }
95
96    fn eval_bq<F: Datum>(
97        &self,
98        data: &BlockQuantStorage,
99        m: usize,
100        k: usize,
101        indices: &TValue,
102    ) -> TractResult<Tensor> {
103        ensure!(self.axis == 0);
104        let data_shape = &[m, k];
105        let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
106        let mut output = unsafe { Tensor::uninitialized::<F>(output_shape)? };
107        let indices_plain = indices.try_as_plain()?;
108        let indices_slice = indices_plain.as_slice::<i64>()?;
109        let vector_len = k;
110        let blob = data.value();
111
112        let block_len = data.format().block_len();
113        let block_bytes = data.format().block_bytes();
114        if F::datum_type() == f16::datum_type() {
115            let mut output_plain = output.try_as_plain_mut()?;
116            let output_slice = output_plain.as_slice_mut::<f16>()?;
117            for (pos, ix) in indices_slice.iter().enumerate() {
118                let slice = &mut output_slice[pos * vector_len..][..vector_len];
119                for i in (0..vector_len).step_by(block_len) {
120                    let offset = k * *ix as usize + i;
121                    let block_id = offset / block_len;
122                    data.format().dequant_block_f16(
123                        &blob[block_id * block_bytes..][..block_bytes],
124                        &mut slice[i..i + block_len],
125                    );
126                }
127            }
128        } else {
129            let mut output_plain = output.try_as_plain_mut()?;
130            let output_slice = output_plain.as_slice_mut::<f32>()?;
131            for (pos, ix) in indices_slice.iter().enumerate() {
132                let slice = &mut output_slice[pos * vector_len..][..vector_len];
133                for i in (0..vector_len).step_by(block_len) {
134                    let offset = k * *ix as usize + i;
135                    let block_id = offset / block_len;
136                    data.format().dequant_block_f32(
137                        &blob[block_id * block_bytes..][..block_bytes],
138                        &mut slice[i..i + block_len],
139                    );
140                }
141            }
142        }
143        Ok(output)
144    }
145
146    fn eval_input_store<F: Datum>(
147        &self,
148        data: &dyn MMMInputValue,
149        indices: &TValue,
150    ) -> TractResult<Tensor> {
151        ensure!(self.axis == 0);
152        let data_shape = &[data.mn(), data.k()];
153        let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
154        let mut output = unsafe { Tensor::uninitialized::<F>(output_shape)? };
155        let indices_plain = indices.try_as_plain()?;
156        let indices_slice = indices_plain.as_slice::<i64>()?;
157        let vector_len = data_shape[1];
158        if F::datum_type() == f16::datum_type() {
159            let mut output_plain = output.try_as_plain_mut()?;
160            let output_slice = output_plain.as_slice_mut::<f16>()?;
161            for (pos, m) in indices_slice.iter().enumerate() {
162                let slice = &mut output_slice[pos * vector_len..][..vector_len];
163                data.extract_at_mn_f16(*m as usize, slice)?;
164            }
165        } else {
166            let mut output_plain = output.try_as_plain_mut()?;
167            let output_slice = output_plain.as_slice_mut::<f32>()?;
168            for (pos, m) in indices_slice.iter().enumerate() {
169                let slice = &mut output_slice[pos * vector_len..][..vector_len];
170                data.extract_at_mn_f32(*m as usize, slice)?;
171            }
172        }
173        Ok(output)
174    }
175}
176
177impl TypedOp for Gather {
178    as_op!();
179
180    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
181        if let Some(dt) = self.output_type {
182            ensure!(
183                inputs[0].is_exotic() || inputs[0].datum_type == dt,
184                "Inconsistent datum_type in Gather: attribute is {:?}, but inputs[0] is {:?}",
185                dt,
186                inputs[0].datum_type
187            );
188        } else {
189            ensure!(
190                inputs[0].is_plain(),
191                "Gather applied to compressed data requires an explicit datum_type attribute for its output"
192            );
193        }
194        ensure!(inputs[1].datum_type == i64::datum_type());
195        if inputs[0].is_exotic() {
196            let data_shape = block_quant_aware_input_shape(inputs[0])?;
197            Ok(tvec!(
198                self.output_type
199                    .unwrap()
200                    .fact(&*self.compute_output_shape(&data_shape, &inputs[1].shape)?)
201            ))
202        } else {
203            Ok(tvec!(
204                inputs[0]
205                    .datum_type
206                    .fact(&*self.compute_output_shape(&inputs[0].shape, &inputs[1].shape)?)
207            ))
208        }
209    }
210
211    fn declutter(
212        &self,
213        model: &TypedModel,
214        node: &TypedNode,
215    ) -> TractResult<Option<TypedModelPatch>> {
216        let (input_fact, indices_fact) = args_2!(model.node_input_facts(node.id)?);
217        if let Some(indices) = indices_fact.konst.as_ref()
218            && indices.rank() == 1
219            && indices.len() == 1
220            && input_fact.is_plain()
221            && input_fact.datum_type.is_number()
222        {
223            let mut patch = TypedModelPatch::default();
224            let mut wire = patch.tap_model(model, node.inputs[0])?;
225            let index = indices.cast_to_scalar::<i64>()?;
226            let index = if index < 0 {
227                let data_fact = model.outlet_fact(node.inputs[0])?;
228                data_fact.shape[self.axis].clone() + index.to_dim()
229            } else {
230                index.to_dim()
231            };
232            wire = patch.wire_node(
233                format!("{}.slice", node.name),
234                crate::ops::array::Slice { axis: self.axis, start: index.clone(), end: index + 1 },
235                &[wire],
236            )?[0];
237            patch.shunt_outside(model, node.id.into(), wire)?;
238            return Ok(Some(patch));
239        }
240        if input_fact.konst.is_some() {
241            // look for a OptSimpleMatMulPack *sibling*
242            if let Some(sibling) = model
243                .outlet_successors(node.inputs[0])
244                .iter()
245                .find(|o| o.node != node.id && model.node(o.node).op_is::<OptSimpleMatMulPack>())
246            {
247                let mut patch = TypedModelPatch::default();
248                let mut taps = patch.taps(model, &node.inputs)?;
249                taps[0] = patch.tap_model(model, sibling.node.into())?;
250                let wire = patch.wire_node(&node.name, self.clone(), &taps)?[0];
251                patch.shunt_outside(model, node.id.into(), wire)?;
252                return Ok(Some(patch));
253            }
254        }
255        Ok(None)
256    }
257}
258
259impl EvalOp for Gather {
260    fn is_stateless(&self) -> bool {
261        true
262    }
263
264    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
265        let (data, indices) = args_2!(inputs);
266        let result = if let Some(bqs) = data.storage_as::<BlockQuantStorage>() {
267            let dt = self.output_type.unwrap();
268            let m = data.shape()[data.rank() - 2];
269            let k = *data.shape().last().unwrap();
270            dispatch_floatlike!(Self::eval_bq(dt)(self, bqs, m, k, &indices))?
271        } else if let Some(storage) = data.storage_as::<PackedMatrixStorage>()
272            && storage.batch_shape().is_empty()
273        {
274            let dt = self.output_type.unwrap();
275            let data_val = storage.value();
276            dispatch_floatlike!(Self::eval_input_store(dt)(self, data_val, &indices))?
277        } else {
278            dispatch_datum!(Self::eval_t(data.datum_type())(self, data, &indices))?
279        };
280        Ok(tvec!(result.into_tvalue()))
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[test]
289    fn test_should_gather_scalar_index() {
290        let data = Tensor::from(arr1(&[1i64, 2, 3]));
291        let gatherer = Gather::new(0);
292        for idx in 2..3 {
293            let index = Tensor::from(arr0(idx));
294            let outputs =
295                gatherer.eval(tvec![data.clone().into_tvalue(), index.into_tvalue()]).unwrap();
296            let output = &outputs[0];
297            assert_eq!(output.shape().len(), 0);
298            assert_eq!(*output.try_as_plain().unwrap().to_scalar::<i64>().unwrap(), idx + 1);
299        }
300    }
301}