tract_core/ops/array/
gather.rs

1use crate::internal::*;
2use crate::ops::einsum::block_quant_aware_input_shape;
3use ndarray::*;
4use tract_linalg::block_quant::BlockQuantValue;
5use tract_linalg::mmm::MMMInputValue;
6
7#[derive(Debug, Clone, Hash, PartialEq)]
8pub struct Gather {
9    pub axis: usize,
10    pub output_type: Option<DatumType>,
11}
12
13impl Op for Gather {
14    fn name(&self) -> Cow<str> {
15        "Gather".into()
16    }
17
18    op_as_typed_op!();
19    impl_op_same_as!();
20}
21
22impl Gather {
23    pub fn new(axis: usize) -> Gather {
24        Gather { axis, output_type: None }
25    }
26
27    pub fn compute_output_shape<D: DimLike>(
28        &self,
29        input_shape: &[D],
30        indices_shape: &[D],
31    ) -> TractResult<TVec<D>> {
32        ensure!(input_shape.len() > self.axis);
33        let mut output_shape: TVec<D> = input_shape[..self.axis].into();
34        output_shape.extend(indices_shape.iter().cloned());
35        output_shape.extend(input_shape[self.axis + 1..].iter().cloned());
36        Ok(output_shape)
37    }
38
39    fn eval_t<T: Datum>(&self, data: TValue, indices: &TValue) -> TractResult<Tensor> {
40        let data_view = unsafe { data.to_array_view_unchecked::<T>() }; // copy only
41        let indices = indices.to_array_view::<i64>()?;
42        let output_shape = &*self.compute_output_shape(data.shape(), indices.shape())?;
43        let mut output = unsafe { Tensor::uninitialized::<T>(output_shape)? };
44        let mut output_view = output.to_array_view_mut::<T>()?;
45        for coords in tract_ndarray::indices(output_shape) {
46            let ocoords = coords.as_array_view();
47            let ocoords = ocoords.as_slice().unwrap();
48            let mut icoords: TVec<usize> = ocoords[0..self.axis].into();
49            let kcoords = &ocoords[self.axis..][..indices.ndim()];
50            let k = indices[kcoords];
51            let k = if k < 0 { k + data_view.shape()[self.axis] as i64 } else { k } as usize;
52            icoords.push(k);
53            icoords.extend(ocoords[self.axis + indices.ndim()..].iter().copied());
54            output_view[ocoords] = data_view.get(&*icoords).context("Invalid gather")?.clone();
55        }
56        unsafe { output.set_datum_type(data.datum_type()) };
57        Ok(output)
58    }
59
60    fn eval_bq<F: Datum>(&self, data: &BlockQuantValue, indices: &TValue) -> TractResult<Tensor> {
61        ensure!(self.axis == 0);
62        ensure!(data.fact.shape().len() == 2);
63        let data_shape = &data.fact.shape();
64        let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
65        let mut output = unsafe { Tensor::uninitialized::<F>(output_shape)? };
66        let indices_slice = indices.as_slice::<i64>()?;
67        let vector_len = data_shape[1];
68        if F::datum_type() == f16::datum_type() {
69            let output_slice = output.as_slice_mut::<f16>()?;
70            for (pos, ix) in indices_slice.iter().enumerate() {
71                let slice = &mut output_slice[pos * vector_len..][..vector_len];
72                for (i, slot) in slice.iter_mut().enumerate() {
73                    let offset = data_shape[1] * *ix as usize + i;
74                    *slot = data.fact.format.extract_at_offset_f16(&data.value, offset)
75                }
76            }
77        } else {
78            let output_slice = output.as_slice_mut::<f32>()?;
79            for (pos, ix) in indices_slice.iter().enumerate() {
80                let slice = &mut output_slice[pos * vector_len..][..vector_len];
81                for (i, slot) in slice.iter_mut().enumerate() {
82                    let offset = data_shape[1] * *ix as usize + i;
83                    *slot = data.fact.format.extract_at_offset_f32(&data.value, offset)
84                }
85            }
86        }
87        Ok(output)
88    }
89
90    fn eval_input_store<F: Datum>(
91        &self,
92        data: &dyn MMMInputValue,
93        indices: &TValue,
94    ) -> TractResult<Tensor> {
95        ensure!(self.axis == 0);
96        let data_shape = &[data.mn(), data.k()];
97        let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
98        let mut output = unsafe { Tensor::uninitialized::<F>(output_shape)? };
99        let indices_slice = indices.as_slice::<i64>()?;
100        let vector_len = data_shape[1];
101        if F::datum_type() == f16::datum_type() {
102            let output_slice = output.as_slice_mut::<f16>()?;
103            for (pos, m) in indices_slice.iter().enumerate() {
104                let slice = &mut output_slice[pos * vector_len..][..vector_len];
105                data.extract_at_mn_f16(*m as usize, slice)?;
106            }
107        } else {
108            let output_slice = output.as_slice_mut::<f32>()?;
109            for (pos, m) in indices_slice.iter().enumerate() {
110                let slice = &mut output_slice[pos * vector_len..][..vector_len];
111                data.extract_at_mn_f32(*m as usize, slice)?;
112            }
113        }
114        Ok(output)
115    }
116}
117
118impl TypedOp for Gather {
119    as_op!();
120
121    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
122        if let Some(dt) = self.output_type {
123            ensure!(
124                inputs[0].datum_type.is_opaque() || inputs[0].datum_type == dt,
125                "Inconsistent datum_type in Gather: attribute is {:?}, but inputs[0] is {:?}",
126                dt,
127                inputs[0].datum_type
128            );
129        } else {
130            ensure!(!inputs[0].datum_type.is_opaque(),
131                "Gather applied to compressed data requires an explicit datum_type attribute for its output");
132        }
133        ensure!(inputs[1].datum_type == i64::datum_type());
134        if inputs[0].datum_type.is_opaque() {
135            let data_shape = block_quant_aware_input_shape(inputs[0])?;
136            Ok(tvec!(self
137                .output_type
138                .unwrap()
139                .fact(&*self.compute_output_shape(&data_shape, &inputs[1].shape)?)))
140        } else {
141            Ok(tvec!(inputs[0]
142                .datum_type
143                .fact(&*self.compute_output_shape(&inputs[0].shape, &inputs[1].shape)?)))
144        }
145    }
146
147    fn declutter(
148        &self,
149        model: &TypedModel,
150        node: &TypedNode,
151    ) -> TractResult<Option<TypedModelPatch>> {
152        let (input_fact, indices_fact) = args_2!(model.node_input_facts(node.id)?);
153        if let Some(indices) = indices_fact.konst.as_ref() {
154            if indices.rank() == 1 && indices.len() == 1 && input_fact.datum_type.is_number() {
155                let mut patch = TypedModelPatch::default();
156                let mut wire = patch.tap_model(model, node.inputs[0])?;
157                let index = indices.cast_to_scalar::<i64>()?;
158                let index = if index < 0 {
159                    let data_fact = model.outlet_fact(node.inputs[0])?;
160                    data_fact.shape[self.axis].clone() + index.to_dim()
161                } else {
162                    index.to_dim()
163                };
164                wire = patch.wire_node(
165                    format!("{}.slice", node.name),
166                    crate::ops::array::Slice {
167                        axis: self.axis,
168                        start: index.clone(),
169                        end: index + 1,
170                    },
171                    &[wire],
172                )?[0];
173                patch.shunt_outside(model, node.id.into(), wire)?;
174                return Ok(Some(patch));
175            }
176        }
177        Ok(None)
178    }
179}
180
181impl EvalOp for Gather {
182    fn is_stateless(&self) -> bool {
183        true
184    }
185
186    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
187        let (data, indices) = args_2!(inputs);
188        let result = if let Ok(opaque) = data.to_scalar::<Opaque>() {
189            let dt = self.output_type.unwrap();
190            if let Some(data) = opaque.downcast_ref::<BlockQuantValue>() {
191                dispatch_floatlike!(Self::eval_bq(dt)(self, data, &indices))?
192            } else if let Some(data) = opaque.downcast_ref::<Box<dyn MMMInputValue>>() {
193                dispatch_floatlike!(Self::eval_input_store(dt)(self, &**data, &indices))?
194            } else {
195                bail!("Can't use Gather on {:?} input", data);
196            }
197        } else {
198            dispatch_datum_by_size!(Self::eval_t(data.datum_type())(self, data, &indices))?
199        };
200        Ok(tvec!(result.into_tvalue()))
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    #[test]
209    fn test_should_gather_scalar_index() {
210        let data = Tensor::from(arr1(&[1i64, 2, 3]));
211        let gatherer = Gather::new(0);
212        for idx in 2..3 {
213            let index = Tensor::from(arr0(idx));
214            let outputs =
215                gatherer.eval(tvec![data.clone().into_tvalue(), index.into_tvalue()]).unwrap();
216            let output = &outputs[0];
217            assert_eq!(output.shape().len(), 0);
218            assert_eq!(*output.to_scalar::<i64>().unwrap(), idx + 1);
219        }
220    }
221}