tract_core/ops/array/
gather.rs

1use crate::internal::*;
2use crate::ops::einsum::block_quant_aware_input_shape;
3use ndarray::*;
4use tract_linalg::frame::block_quant::BlockQuantValue;
5
6#[derive(Debug, Clone, new, Hash)]
7pub struct Gather {
8    pub axis: usize,
9}
10
11impl Op for Gather {
12    fn name(&self) -> Cow<str> {
13        "Gather".into()
14    }
15
16    op_as_typed_op!();
17}
18
19impl Gather {
20    pub fn compute_output_shape<D: DimLike>(
21        &self,
22        input_shape: &[D],
23        indices_shape: &[D],
24    ) -> TractResult<TVec<D>> {
25        ensure!(input_shape.len() > self.axis);
26        let mut output_shape: TVec<D> = input_shape[..self.axis].into();
27        output_shape.extend(indices_shape.iter().cloned());
28        output_shape.extend(input_shape[self.axis + 1..].iter().cloned());
29        Ok(output_shape)
30    }
31
32    fn eval_t<T: Datum>(&self, data: TValue, indices: &TValue) -> TractResult<Tensor> {
33        let data_view = unsafe { data.to_array_view_unchecked::<T>() }; // copy only
34        let indices = indices.to_array_view::<i64>()?;
35        let output_shape = &*self.compute_output_shape(data.shape(), indices.shape())?;
36        let mut output = unsafe { Tensor::uninitialized::<T>(output_shape)? };
37        let mut output_view = output.to_array_view_mut::<T>()?;
38        for coords in tract_ndarray::indices(output_shape) {
39            let ocoords = coords.as_array_view();
40            let ocoords = ocoords.as_slice().unwrap();
41            let mut icoords: TVec<usize> = ocoords[0..self.axis].into();
42            let kcoords = &ocoords[self.axis..][..indices.ndim()];
43            let k = indices[kcoords];
44            let k = if k < 0 { k + data_view.shape()[self.axis] as i64 } else { k } as usize;
45            icoords.push(k);
46            icoords.extend(ocoords[self.axis + indices.ndim()..].iter().copied());
47            output_view[ocoords] = data_view.get(&*icoords).context("Invalid gather")?.clone();
48        }
49        unsafe { output.set_datum_type(data.datum_type()) };
50        Ok(output)
51    }
52
53    fn eval_bq_to_f16(&self, data: &BlockQuantValue, indices: &TValue) -> TractResult<Tensor> {
54        ensure!(self.axis == 0);
55        ensure!(data.fact.shape.len() == 2);
56        let data_shape = &data.fact.shape;
57        let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
58        let mut output = unsafe { Tensor::uninitialized::<f16>(output_shape)? };
59        let indices_slice = indices.as_slice::<i64>()?;
60        let vector_len = data_shape[1];
61        let output_slice = output.as_slice_mut::<f16>()?;
62        for (pos, ix) in indices_slice.iter().enumerate() {
63            let slice = &mut output_slice[pos * vector_len..][..vector_len];
64            for (i, slot) in slice.iter_mut().enumerate() {
65                let offset = data_shape[1] * *ix as usize + i;
66                *slot = data.fact.format.extract_at_offset_f16(&data.value, offset)
67            }
68        }
69        Ok(output)
70    }
71}
72
73impl TypedOp for Gather {
74    as_op!();
75
76    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
77        ensure!(inputs[1].datum_type == i64::datum_type());
78        if inputs[0].datum_type.is_opaque() {
79            let data_shape = block_quant_aware_input_shape(inputs[0])?;
80            Ok(tvec!(f16::fact(&*self.compute_output_shape(&data_shape, &inputs[1].shape)?)))
81        } else {
82            Ok(tvec!(inputs[0]
83                .datum_type
84                .fact(&*self.compute_output_shape(&inputs[0].shape, &inputs[1].shape)?)))
85        }
86    }
87
88    fn declutter(
89        &self,
90        model: &TypedModel,
91        node: &TypedNode,
92    ) -> TractResult<Option<TypedModelPatch>> {
93        let (input_fact, indices_fact) = args_2!(model.node_input_facts(node.id)?);
94        if let Some(indices) = indices_fact.konst.as_ref() {
95            if indices.rank() == 1 && indices.len() == 1 && input_fact.datum_type.is_number() {
96                let mut patch = TypedModelPatch::default();
97                let mut wire = patch.tap_model(model, node.inputs[0])?;
98                let index = indices.cast_to_scalar::<i64>()?;
99                let index = if index < 0 {
100                    let data_fact = model.outlet_fact(node.inputs[0])?;
101                    data_fact.shape[self.axis].clone() + index.to_dim()
102                } else {
103                    index.to_dim()
104                };
105                wire = patch.wire_node(
106                    format!("{}.slice", node.name),
107                    crate::ops::array::Slice {
108                        axis: self.axis,
109                        start: index.clone(),
110                        end: index + 1,
111                    },
112                    &[wire],
113                )?[0];
114                patch.shunt_outside(model, node.id.into(), wire)?;
115                return Ok(Some(patch));
116            }
117        }
118        Ok(None)
119    }
120}
121
122impl EvalOp for Gather {
123    fn is_stateless(&self) -> bool {
124        true
125    }
126
127    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
128        let (data, indices) = args_2!(inputs);
129        let result = if data.datum_type().is_opaque() {
130            let data = data
131                .to_scalar::<Opaque>()?
132                .downcast_ref::<BlockQuantValue>()
133                .context("Expected a BlockQuantValue")?;
134            self.eval_bq_to_f16(data, &indices)?
135        } else {
136            dispatch_datum_by_size!(Self::eval_t(data.datum_type())(self, data, &indices))?
137        };
138        Ok(tvec!(result.into_tvalue()))
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[test]
147    fn test_should_gather_scalar_index() {
148        let data = Tensor::from(arr1(&[1i64, 2, 3]));
149        let gatherer = Gather::new(0);
150        for idx in 2..3 {
151            let index = Tensor::from(arr0(idx));
152            let outputs =
153                gatherer.eval(tvec![data.clone().into_tvalue(), index.into_tvalue()]).unwrap();
154            let output = &outputs[0];
155            assert_eq!(output.shape().len(), 0);
156            assert_eq!(*output.to_scalar::<i64>().unwrap(), idx + 1);
157        }
158    }
159}