Skip to main content

tract_core/ops/array/
gather.rs

1use crate::internal::*;
2use crate::ops::einsum::block_quant_aware_input_shape;
3use crate::ops::matmul::pack::OptSimpleMatMulPack;
4use ndarray::*;
5use tract_linalg::block_quant::BlockQuantFact;
6use tract_linalg::mmm::MMMInputValue;
7
8#[derive(Debug, Clone, Hash, PartialEq)]
9pub struct Gather {
10    pub axis: usize,
11    pub output_type: Option<DatumType>,
12}
13
14impl Op for Gather {
15    fn name(&self) -> StaticName {
16        "Gather".into()
17    }
18
19    op_as_typed_op!();
20    impl_op_same_as!();
21}
22
23impl Gather {
24    pub fn new(axis: usize) -> Gather {
25        Gather { axis, output_type: None }
26    }
27
28    pub fn compute_output_shape<D: DimLike>(
29        &self,
30        input_shape: &[D],
31        indices_shape: &[D],
32    ) -> TractResult<TVec<D>> {
33        ensure!(input_shape.len() > self.axis);
34        let mut output_shape: TVec<D> = input_shape[..self.axis].into();
35        output_shape.extend(indices_shape.iter().cloned());
36        output_shape.extend(input_shape[self.axis + 1..].iter().cloned());
37        Ok(output_shape)
38    }
39
40    fn eval_t<T: Datum>(&self, data: TValue, indices: &TValue) -> TractResult<Tensor> {
41        let data_dense = data.try_as_dense()?;
42        let data_view = unsafe { data_dense.to_array_view_unchecked::<T>() };
43        let indices = indices.to_dense_array_view::<i64>()?;
44        let output_shape = &*self.compute_output_shape(data.shape(), indices.shape())?;
45        let mut output = unsafe { Tensor::uninitialized::<T>(output_shape)? };
46        let mut output_dense = output.try_as_dense_mut()?;
47        let mut output_view = output_dense.to_array_view_mut::<T>()?;
48
49        let data_shape = data.shape();
50        let data_axis = self.axis;
51
52        let block_len = data_shape[data_axis + 1..].iter().product::<usize>();
53
54        let can_block_copy = data_shape[..data_axis].iter().all(|&d| d == 1)
55            && output_shape[..data_axis].iter().all(|&d| d == 1)
56            && data_view.is_standard_layout()
57            && output_view.is_standard_layout();
58
59        if can_block_copy {
60            let mut out_offset = 0;
61            let input_slice = data_view.as_slice().unwrap();
62            let output_slice = &mut output_view.as_slice_mut().unwrap();
63            for idx_coords in indices.indexed_iter() {
64                let index = *idx_coords.1;
65                let axis_len = data_shape[data_axis] as i64;
66                let resolved_index = if index < 0 { index + axis_len } else { index };
67                let resolved_index = resolved_index as usize;
68
69                let input_offset = resolved_index * block_len;
70
71                output_slice[out_offset..out_offset + block_len]
72                    .clone_from_slice(&input_slice[input_offset..input_offset + block_len]);
73                out_offset += block_len;
74            }
75        } else {
76            let ic_len = self.axis + 1 + output_shape.len() - (self.axis + indices.ndim());
77            let mut icoords = vec![0; ic_len];
78            let axis = self.axis;
79            for coords in tract_ndarray::indices(output_shape) {
80                let ocoords = coords.as_array_view();
81                let ocoords = ocoords.as_slice().unwrap();
82
83                let kcoords = &ocoords[self.axis..][..indices.ndim()];
84                let k = indices[kcoords];
85                let k = if k < 0 { k + data_view.shape()[self.axis] as i64 } else { k } as usize;
86                icoords[0..axis].copy_from_slice(&ocoords[..self.axis]);
87                icoords[self.axis] = k;
88                icoords[self.axis + 1..].clone_from_slice(&ocoords[self.axis + indices.ndim()..]);
89                output_view[ocoords] =
90                    data_view.get(&*icoords).cloned().context("Invalid gather")?;
91            }
92            unsafe { output.set_datum_type(data.datum_type()) };
93        }
94        Ok(output)
95    }
96
97    fn eval_bq<F: Datum>(&self, data: &BlobWithFact, indices: &TValue) -> TractResult<Tensor> {
98        let bqf = data.fact.downcast_ref::<BlockQuantFact>().context("Expected BlockQuantFact")?;
99        ensure!(self.axis == 0);
100        ensure!(bqf.shape().len() == 2);
101        let data_shape = &bqf.shape();
102        let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
103        let mut output = unsafe { Tensor::uninitialized::<F>(output_shape)? };
104        let indices_dense = indices.try_as_dense()?;
105        let indices_slice = indices_dense.as_slice::<i64>()?;
106        let vector_len = data_shape[1];
107
108        let block_len = bqf.format.block_len();
109        let block_bytes = bqf.format.block_bytes();
110        if F::datum_type() == f16::datum_type() {
111            let mut output_dense = output.try_as_dense_mut()?;
112            let output_slice = output_dense.as_slice_mut::<f16>()?;
113            for (pos, ix) in indices_slice.iter().enumerate() {
114                let slice = &mut output_slice[pos * vector_len..][..vector_len];
115                for i in (0..vector_len).step_by(block_len) {
116                    let offset = data_shape[1] * *ix as usize + i;
117                    let block_id = offset / block_len;
118                    bqf.format.dequant_block_f16(
119                        &data.value[block_id * block_bytes..][..block_bytes],
120                        &mut slice[i..i + block_len],
121                    );
122                }
123            }
124        } else {
125            let mut output_dense = output.try_as_dense_mut()?;
126            let output_slice = output_dense.as_slice_mut::<f32>()?;
127            for (pos, ix) in indices_slice.iter().enumerate() {
128                let slice = &mut output_slice[pos * vector_len..][..vector_len];
129                for i in (0..vector_len).step_by(block_len) {
130                    let offset = data_shape[1] * *ix as usize + i;
131                    let block_id = offset / block_len;
132                    bqf.format.dequant_block_f32(
133                        &data.value[block_id * block_bytes..][..block_bytes],
134                        &mut slice[i..i + block_len],
135                    );
136                }
137            }
138        }
139        Ok(output)
140    }
141
142    fn eval_input_store<F: Datum>(
143        &self,
144        data: &dyn MMMInputValue,
145        indices: &TValue,
146    ) -> TractResult<Tensor> {
147        ensure!(self.axis == 0);
148        let data_shape = &[data.mn(), data.k()];
149        let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
150        let mut output = unsafe { Tensor::uninitialized::<F>(output_shape)? };
151        let indices_dense = indices.try_as_dense()?;
152        let indices_slice = indices_dense.as_slice::<i64>()?;
153        let vector_len = data_shape[1];
154        if F::datum_type() == f16::datum_type() {
155            let mut output_dense = output.try_as_dense_mut()?;
156            let output_slice = output_dense.as_slice_mut::<f16>()?;
157            for (pos, m) in indices_slice.iter().enumerate() {
158                let slice = &mut output_slice[pos * vector_len..][..vector_len];
159                data.extract_at_mn_f16(*m as usize, slice)?;
160            }
161        } else {
162            let mut output_dense = output.try_as_dense_mut()?;
163            let output_slice = output_dense.as_slice_mut::<f32>()?;
164            for (pos, m) in indices_slice.iter().enumerate() {
165                let slice = &mut output_slice[pos * vector_len..][..vector_len];
166                data.extract_at_mn_f32(*m as usize, slice)?;
167            }
168        }
169        Ok(output)
170    }
171}
172
173impl TypedOp for Gather {
174    as_op!();
175
176    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
177        if let Some(dt) = self.output_type {
178            ensure!(
179                inputs[0].datum_type.is_opaque() || inputs[0].datum_type == dt,
180                "Inconsistent datum_type in Gather: attribute is {:?}, but inputs[0] is {:?}",
181                dt,
182                inputs[0].datum_type
183            );
184        } else {
185            ensure!(
186                !inputs[0].datum_type.is_opaque(),
187                "Gather applied to compressed data requires an explicit datum_type attribute for its output"
188            );
189        }
190        ensure!(inputs[1].datum_type == i64::datum_type());
191        if inputs[0].datum_type.is_opaque() {
192            let data_shape = block_quant_aware_input_shape(inputs[0])?;
193            Ok(tvec!(
194                self.output_type
195                    .unwrap()
196                    .fact(&*self.compute_output_shape(&data_shape, &inputs[1].shape)?)
197            ))
198        } else {
199            Ok(tvec!(
200                inputs[0]
201                    .datum_type
202                    .fact(&*self.compute_output_shape(&inputs[0].shape, &inputs[1].shape)?)
203            ))
204        }
205    }
206
207    fn declutter(
208        &self,
209        model: &TypedModel,
210        node: &TypedNode,
211    ) -> TractResult<Option<TypedModelPatch>> {
212        let (input_fact, indices_fact) = args_2!(model.node_input_facts(node.id)?);
213        if let Some(indices) = indices_fact.konst.as_ref() {
214            if indices.rank() == 1 && indices.len() == 1 && input_fact.datum_type.is_number() {
215                let mut patch = TypedModelPatch::default();
216                let mut wire = patch.tap_model(model, node.inputs[0])?;
217                let index = indices.cast_to_scalar::<i64>()?;
218                let index = if index < 0 {
219                    let data_fact = model.outlet_fact(node.inputs[0])?;
220                    data_fact.shape[self.axis].clone() + index.to_dim()
221                } else {
222                    index.to_dim()
223                };
224                wire = patch.wire_node(
225                    format!("{}.slice", node.name),
226                    crate::ops::array::Slice {
227                        axis: self.axis,
228                        start: index.clone(),
229                        end: index + 1,
230                    },
231                    &[wire],
232                )?[0];
233                patch.shunt_outside(model, node.id.into(), wire)?;
234                return Ok(Some(patch));
235            }
236        }
237        if input_fact.konst.is_some() {
238            // look for a OptSimpleMatMulPack *sibling*
239            if let Some(sibling) = model
240                .outlet_successors(node.inputs[0])
241                .iter()
242                .find(|o| o.node != node.id && model.node(o.node).op_is::<OptSimpleMatMulPack>())
243            {
244                let mut patch = TypedModelPatch::default();
245                let mut taps = patch.taps(model, &node.inputs)?;
246                taps[0] = patch.tap_model(model, sibling.node.into())?;
247                let wire = patch.wire_node(&node.name, self.clone(), &taps)?[0];
248                patch.shunt_outside(model, node.id.into(), wire)?;
249                return Ok(Some(patch));
250            }
251        }
252        Ok(None)
253    }
254}
255
256impl EvalOp for Gather {
257    fn is_stateless(&self) -> bool {
258        true
259    }
260
261    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
262        let (data, indices) = args_2!(inputs);
263        let data_dense = data.try_as_dense();
264        let result = if let Some(opaque) =
265            data_dense.as_ref().ok().and_then(|d| d.to_scalar::<Opaque>().ok())
266        {
267            let dt = self.output_type.unwrap();
268            if let Some(data) = opaque.downcast_ref::<BlobWithFact>() {
269                dispatch_floatlike!(Self::eval_bq(dt)(self, data, &indices))?
270            } else if let Some(data) = opaque.downcast_ref::<Box<dyn MMMInputValue>>() {
271                dispatch_floatlike!(Self::eval_input_store(dt)(self, &**data, &indices))?
272            } else {
273                bail!("Can't use Gather on {:?} input", data);
274            }
275        } else {
276            dispatch_datum!(Self::eval_t(data.datum_type())(self, data, &indices))?
277        };
278        Ok(tvec!(result.into_tvalue()))
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    #[test]
287    fn test_should_gather_scalar_index() {
288        let data = Tensor::from(arr1(&[1i64, 2, 3]));
289        let gatherer = Gather::new(0);
290        for idx in 2..3 {
291            let index = Tensor::from(arr0(idx));
292            let outputs =
293                gatherer.eval(tvec![data.clone().into_tvalue(), index.into_tvalue()]).unwrap();
294            let output = &outputs[0];
295            assert_eq!(output.shape().len(), 0);
296            assert_eq!(*output.try_as_dense().unwrap().to_scalar::<i64>().unwrap(), idx + 1);
297        }
298    }
299}