Skip to main content

tract_core/ops/array/
gather.rs

1use crate::internal::*;
2use crate::ops::einsum::block_quant_aware_input_shape;
3use crate::ops::matmul::pack::OptSimpleMatMulPack;
4use ndarray::*;
5use tract_linalg::block_quant::BlockQuantValue;
6use tract_linalg::mmm::MMMInputValue;
7
8#[derive(Debug, Clone, Hash, PartialEq)]
9pub struct Gather {
10    pub axis: usize,
11    pub output_type: Option<DatumType>,
12}
13
14impl Op for Gather {
15    fn name(&self) -> Cow<str> {
16        "Gather".into()
17    }
18
19    op_as_typed_op!();
20    impl_op_same_as!();
21}
22
23impl Gather {
24    pub fn new(axis: usize) -> Gather {
25        Gather { axis, output_type: None }
26    }
27
28    pub fn compute_output_shape<D: DimLike>(
29        &self,
30        input_shape: &[D],
31        indices_shape: &[D],
32    ) -> TractResult<TVec<D>> {
33        ensure!(input_shape.len() > self.axis);
34        let mut output_shape: TVec<D> = input_shape[..self.axis].into();
35        output_shape.extend(indices_shape.iter().cloned());
36        output_shape.extend(input_shape[self.axis + 1..].iter().cloned());
37        Ok(output_shape)
38    }
39
40    fn eval_t<T: Datum>(&self, data: TValue, indices: &TValue) -> TractResult<Tensor> {
41        let data_view = unsafe { data.to_array_view_unchecked::<T>() }; // copy only
42        let indices = indices.to_array_view::<i64>()?;
43        let output_shape = &*self.compute_output_shape(data.shape(), indices.shape())?;
44        let mut output = unsafe { Tensor::uninitialized::<T>(output_shape)? };
45        let mut output_view = output.to_array_view_mut::<T>()?;
46        for coords in tract_ndarray::indices(output_shape) {
47            let ocoords = coords.as_array_view();
48            let ocoords = ocoords.as_slice().unwrap();
49            let mut icoords: TVec<usize> = ocoords[0..self.axis].into();
50            let kcoords = &ocoords[self.axis..][..indices.ndim()];
51            let k = indices[kcoords];
52            let k = if k < 0 { k + data_view.shape()[self.axis] as i64 } else { k } as usize;
53            icoords.push(k);
54            icoords.extend(ocoords[self.axis + indices.ndim()..].iter().copied());
55            output_view[ocoords] = data_view.get(&*icoords).context("Invalid gather")?.clone();
56        }
57        unsafe { output.set_datum_type(data.datum_type()) };
58        Ok(output)
59    }
60
61    fn eval_bq<F: Datum>(&self, data: &BlockQuantValue, indices: &TValue) -> TractResult<Tensor> {
62        ensure!(self.axis == 0);
63        ensure!(data.fact.shape().len() == 2);
64        let data_shape = &data.fact.shape();
65        let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
66        let mut output = unsafe { Tensor::uninitialized::<F>(output_shape)? };
67        let indices_slice = indices.as_slice::<i64>()?;
68        let vector_len = data_shape[1];
69
70        let block_len = data.fact.format.block_len();
71        let block_bytes = data.fact.format.block_bytes();
72        if F::datum_type() == f16::datum_type() {
73            let output_slice = output.as_slice_mut::<f16>()?;
74            for (pos, ix) in indices_slice.iter().enumerate() {
75                let slice = &mut output_slice[pos * vector_len..][..vector_len];
76                for i in (0..vector_len).step_by(block_len) {
77                    let offset = data_shape[1] * *ix as usize + i;
78                    let block_id = offset / block_len;
79                    data.fact.format.dequant_block_f16(
80                        &data.value[block_id * block_bytes..][..block_bytes],
81                        &mut slice[i..i + block_len],
82                    );
83                }
84            }
85        } else {
86            let output_slice = output.as_slice_mut::<f32>()?;
87            for (pos, ix) in indices_slice.iter().enumerate() {
88                let slice = &mut output_slice[pos * vector_len..][..vector_len];
89                for i in (0..vector_len).step_by(block_len) {
90                    let offset = data_shape[1] * *ix as usize + i;
91                    let block_id = offset / block_len;
92                    data.fact.format.dequant_block_f32(
93                        &data.value[block_id * block_bytes..][..block_bytes],
94                        &mut slice[i..i + block_len],
95                    );
96                }
97            }
98        }
99        Ok(output)
100    }
101
102    fn eval_input_store<F: Datum>(
103        &self,
104        data: &dyn MMMInputValue,
105        indices: &TValue,
106    ) -> TractResult<Tensor> {
107        ensure!(self.axis == 0);
108        let data_shape = &[data.mn(), data.k()];
109        let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
110        let mut output = unsafe { Tensor::uninitialized::<F>(output_shape)? };
111        let indices_slice = indices.as_slice::<i64>()?;
112        let vector_len = data_shape[1];
113        if F::datum_type() == f16::datum_type() {
114            let output_slice = output.as_slice_mut::<f16>()?;
115            for (pos, m) in indices_slice.iter().enumerate() {
116                let slice = &mut output_slice[pos * vector_len..][..vector_len];
117                data.extract_at_mn_f16(*m as usize, slice)?;
118            }
119        } else {
120            let output_slice = output.as_slice_mut::<f32>()?;
121            for (pos, m) in indices_slice.iter().enumerate() {
122                let slice = &mut output_slice[pos * vector_len..][..vector_len];
123                data.extract_at_mn_f32(*m as usize, slice)?;
124            }
125        }
126        Ok(output)
127    }
128}
129
130impl TypedOp for Gather {
131    as_op!();
132
133    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
134        if let Some(dt) = self.output_type {
135            ensure!(
136                inputs[0].datum_type.is_opaque() || inputs[0].datum_type == dt,
137                "Inconsistent datum_type in Gather: attribute is {:?}, but inputs[0] is {:?}",
138                dt,
139                inputs[0].datum_type
140            );
141        } else {
142            ensure!(!inputs[0].datum_type.is_opaque(),
143                "Gather applied to compressed data requires an explicit datum_type attribute for its output");
144        }
145        ensure!(inputs[1].datum_type == i64::datum_type());
146        if inputs[0].datum_type.is_opaque() {
147            let data_shape = block_quant_aware_input_shape(inputs[0])?;
148            Ok(tvec!(self
149                .output_type
150                .unwrap()
151                .fact(&*self.compute_output_shape(&data_shape, &inputs[1].shape)?)))
152        } else {
153            Ok(tvec!(inputs[0]
154                .datum_type
155                .fact(&*self.compute_output_shape(&inputs[0].shape, &inputs[1].shape)?)))
156        }
157    }
158
159    fn declutter(
160        &self,
161        model: &TypedModel,
162        node: &TypedNode,
163    ) -> TractResult<Option<TypedModelPatch>> {
164        let (input_fact, indices_fact) = args_2!(model.node_input_facts(node.id)?);
165        if let Some(indices) = indices_fact.konst.as_ref() {
166            if indices.rank() == 1 && indices.len() == 1 && input_fact.datum_type.is_number() {
167                let mut patch = TypedModelPatch::default();
168                let mut wire = patch.tap_model(model, node.inputs[0])?;
169                let index = indices.cast_to_scalar::<i64>()?;
170                let index = if index < 0 {
171                    let data_fact = model.outlet_fact(node.inputs[0])?;
172                    data_fact.shape[self.axis].clone() + index.to_dim()
173                } else {
174                    index.to_dim()
175                };
176                wire = patch.wire_node(
177                    format!("{}.slice", node.name),
178                    crate::ops::array::Slice {
179                        axis: self.axis,
180                        start: index.clone(),
181                        end: index + 1,
182                    },
183                    &[wire],
184                )?[0];
185                patch.shunt_outside(model, node.id.into(), wire)?;
186                return Ok(Some(patch));
187            }
188        }
189        if input_fact.konst.is_some() {
190            // look for a OptSimpleMatMulPack *sibling*
191            if let Some(sibling) = model
192                .outlet_successors(node.inputs[0])
193                .iter()
194                .find(|o| o.node != node.id && model.node(o.node).op_is::<OptSimpleMatMulPack>())
195            {
196                let mut patch = TypedModelPatch::default();
197                let mut taps = patch.taps(model, &node.inputs)?;
198                taps[0] = patch.tap_model(model, sibling.node.into())?;
199                let wire = patch.wire_node(&node.name, self.clone(), &taps)?[0];
200                patch.shunt_outside(model, node.id.into(), wire)?;
201                return Ok(Some(patch));
202            }
203        }
204        Ok(None)
205    }
206}
207
208impl EvalOp for Gather {
209    fn is_stateless(&self) -> bool {
210        true
211    }
212
213    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
214        let (data, indices) = args_2!(inputs);
215        let result = if let Ok(opaque) = data.to_scalar::<Opaque>() {
216            let dt = self.output_type.unwrap();
217            if let Some(data) = opaque.downcast_ref::<BlockQuantValue>() {
218                dispatch_floatlike!(Self::eval_bq(dt)(self, data, &indices))?
219            } else if let Some(data) = opaque.downcast_ref::<Box<dyn MMMInputValue>>() {
220                dispatch_floatlike!(Self::eval_input_store(dt)(self, &**data, &indices))?
221            } else {
222                bail!("Can't use Gather on {:?} input", data);
223            }
224        } else {
225            dispatch_datum_by_size!(Self::eval_t(data.datum_type())(self, data, &indices))?
226        };
227        Ok(tvec!(result.into_tvalue()))
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    #[test]
236    fn test_should_gather_scalar_index() {
237        let data = Tensor::from(arr1(&[1i64, 2, 3]));
238        let gatherer = Gather::new(0);
239        for idx in 2..3 {
240            let index = Tensor::from(arr0(idx));
241            let outputs =
242                gatherer.eval(tvec![data.clone().into_tvalue(), index.into_tvalue()]).unwrap();
243            let output = &outputs[0];
244            assert_eq!(output.shape().len(), 0);
245            assert_eq!(*output.to_scalar::<i64>().unwrap(), idx + 1);
246        }
247    }
248}