tract_core/ops/array/
gather.rs1use crate::internal::*;
2use crate::ops::einsum::block_quant_aware_input_shape;
3use crate::ops::matmul::pack::OptSimpleMatMulPack;
4use ndarray::*;
5use tract_linalg::block_quant::BlockQuantFact;
6use tract_linalg::mmm::MMMInputValue;
7
8#[derive(Debug, Clone, Hash, PartialEq)]
9pub struct Gather {
10 pub axis: usize,
11 pub output_type: Option<DatumType>,
12}
13
14impl Op for Gather {
15 fn name(&self) -> StaticName {
16 "Gather".into()
17 }
18
19 op_as_typed_op!();
20 impl_op_same_as!();
21}
22
23impl Gather {
24 pub fn new(axis: usize) -> Gather {
25 Gather { axis, output_type: None }
26 }
27
28 pub fn compute_output_shape<D: DimLike>(
29 &self,
30 input_shape: &[D],
31 indices_shape: &[D],
32 ) -> TractResult<TVec<D>> {
33 ensure!(input_shape.len() > self.axis);
34 let mut output_shape: TVec<D> = input_shape[..self.axis].into();
35 output_shape.extend(indices_shape.iter().cloned());
36 output_shape.extend(input_shape[self.axis + 1..].iter().cloned());
37 Ok(output_shape)
38 }
39
40 fn eval_t<T: Datum>(&self, data: TValue, indices: &TValue) -> TractResult<Tensor> {
41 let data_view = unsafe { data.to_array_view_unchecked::<T>() };
42 let indices = indices.to_array_view::<i64>()?;
43 let output_shape = &*self.compute_output_shape(data.shape(), indices.shape())?;
44 let mut output = unsafe { Tensor::uninitialized::<T>(output_shape)? };
45 let mut output_view = output.to_array_view_mut::<T>()?;
46
47 let data_shape = data.shape();
48 let data_axis = self.axis;
49
50 let block_len = data_shape[data_axis + 1..].iter().product::<usize>();
51
52 let can_block_copy = data_shape[..data_axis].iter().all(|&d| d == 1)
53 && output_shape[..data_axis].iter().all(|&d| d == 1)
54 && data_view.is_standard_layout()
55 && output_view.is_standard_layout();
56
57 if can_block_copy {
58 let mut out_offset = 0;
59 let input_slice = data_view.as_slice().unwrap();
60 let output_slice = &mut output_view.as_slice_mut().unwrap();
61 for idx_coords in indices.indexed_iter() {
62 let index = *idx_coords.1;
63 let axis_len = data_shape[data_axis] as i64;
64 let resolved_index = if index < 0 { index + axis_len } else { index };
65 let resolved_index = resolved_index as usize;
66
67 let input_offset = resolved_index * block_len;
68
69 output_slice[out_offset..out_offset + block_len]
70 .clone_from_slice(&input_slice[input_offset..input_offset + block_len]);
71 out_offset += block_len;
72 }
73 } else {
74 let ic_len = self.axis + 1 + output_shape.len() - (self.axis + indices.ndim());
75 let mut icoords = vec![0; ic_len];
76 let axis = self.axis;
77 for coords in tract_ndarray::indices(output_shape) {
78 let ocoords = coords.as_array_view();
79 let ocoords = ocoords.as_slice().unwrap();
80
81 let kcoords = &ocoords[self.axis..][..indices.ndim()];
82 let k = indices[kcoords];
83 let k = if k < 0 { k + data_view.shape()[self.axis] as i64 } else { k } as usize;
84 icoords[0..axis].copy_from_slice(&ocoords[..self.axis]);
85 icoords[self.axis] = k;
86 icoords[self.axis + 1..].clone_from_slice(&ocoords[self.axis + indices.ndim()..]);
87 output_view[ocoords] =
88 data_view.get(&*icoords).cloned().context("Invalid gather")?;
89 }
90 unsafe { output.set_datum_type(data.datum_type()) };
91 }
92 Ok(output)
93 }
94
95 fn eval_bq<F: Datum>(&self, data: &BlobWithFact, indices: &TValue) -> TractResult<Tensor> {
96 let bqf = data.fact.downcast_ref::<BlockQuantFact>().context("Expected BlockQuantFact")?;
97 ensure!(self.axis == 0);
98 ensure!(bqf.shape().len() == 2);
99 let data_shape = &bqf.shape();
100 let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
101 let mut output = unsafe { Tensor::uninitialized::<F>(output_shape)? };
102 let indices_slice = indices.as_slice::<i64>()?;
103 let vector_len = data_shape[1];
104
105 let block_len = bqf.format.block_len();
106 let block_bytes = bqf.format.block_bytes();
107 if F::datum_type() == f16::datum_type() {
108 let output_slice = output.as_slice_mut::<f16>()?;
109 for (pos, ix) in indices_slice.iter().enumerate() {
110 let slice = &mut output_slice[pos * vector_len..][..vector_len];
111 for i in (0..vector_len).step_by(block_len) {
112 let offset = data_shape[1] * *ix as usize + i;
113 let block_id = offset / block_len;
114 bqf.format.dequant_block_f16(
115 &data.value[block_id * block_bytes..][..block_bytes],
116 &mut slice[i..i + block_len],
117 );
118 }
119 }
120 } else {
121 let output_slice = output.as_slice_mut::<f32>()?;
122 for (pos, ix) in indices_slice.iter().enumerate() {
123 let slice = &mut output_slice[pos * vector_len..][..vector_len];
124 for i in (0..vector_len).step_by(block_len) {
125 let offset = data_shape[1] * *ix as usize + i;
126 let block_id = offset / block_len;
127 bqf.format.dequant_block_f32(
128 &data.value[block_id * block_bytes..][..block_bytes],
129 &mut slice[i..i + block_len],
130 );
131 }
132 }
133 }
134 Ok(output)
135 }
136
137 fn eval_input_store<F: Datum>(
138 &self,
139 data: &dyn MMMInputValue,
140 indices: &TValue,
141 ) -> TractResult<Tensor> {
142 ensure!(self.axis == 0);
143 let data_shape = &[data.mn(), data.k()];
144 let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
145 let mut output = unsafe { Tensor::uninitialized::<F>(output_shape)? };
146 let indices_slice = indices.as_slice::<i64>()?;
147 let vector_len = data_shape[1];
148 if F::datum_type() == f16::datum_type() {
149 let output_slice = output.as_slice_mut::<f16>()?;
150 for (pos, m) in indices_slice.iter().enumerate() {
151 let slice = &mut output_slice[pos * vector_len..][..vector_len];
152 data.extract_at_mn_f16(*m as usize, slice)?;
153 }
154 } else {
155 let output_slice = output.as_slice_mut::<f32>()?;
156 for (pos, m) in indices_slice.iter().enumerate() {
157 let slice = &mut output_slice[pos * vector_len..][..vector_len];
158 data.extract_at_mn_f32(*m as usize, slice)?;
159 }
160 }
161 Ok(output)
162 }
163}
164
165impl TypedOp for Gather {
166 as_op!();
167
168 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
169 if let Some(dt) = self.output_type {
170 ensure!(
171 inputs[0].datum_type.is_opaque() || inputs[0].datum_type == dt,
172 "Inconsistent datum_type in Gather: attribute is {:?}, but inputs[0] is {:?}",
173 dt,
174 inputs[0].datum_type
175 );
176 } else {
177 ensure!(
178 !inputs[0].datum_type.is_opaque(),
179 "Gather applied to compressed data requires an explicit datum_type attribute for its output"
180 );
181 }
182 ensure!(inputs[1].datum_type == i64::datum_type());
183 if inputs[0].datum_type.is_opaque() {
184 let data_shape = block_quant_aware_input_shape(inputs[0])?;
185 Ok(tvec!(
186 self.output_type
187 .unwrap()
188 .fact(&*self.compute_output_shape(&data_shape, &inputs[1].shape)?)
189 ))
190 } else {
191 Ok(tvec!(
192 inputs[0]
193 .datum_type
194 .fact(&*self.compute_output_shape(&inputs[0].shape, &inputs[1].shape)?)
195 ))
196 }
197 }
198
199 fn declutter(
200 &self,
201 model: &TypedModel,
202 node: &TypedNode,
203 ) -> TractResult<Option<TypedModelPatch>> {
204 let (input_fact, indices_fact) = args_2!(model.node_input_facts(node.id)?);
205 if let Some(indices) = indices_fact.konst.as_ref() {
206 if indices.rank() == 1 && indices.len() == 1 && input_fact.datum_type.is_number() {
207 let mut patch = TypedModelPatch::default();
208 let mut wire = patch.tap_model(model, node.inputs[0])?;
209 let index = indices.cast_to_scalar::<i64>()?;
210 let index = if index < 0 {
211 let data_fact = model.outlet_fact(node.inputs[0])?;
212 data_fact.shape[self.axis].clone() + index.to_dim()
213 } else {
214 index.to_dim()
215 };
216 wire = patch.wire_node(
217 format!("{}.slice", node.name),
218 crate::ops::array::Slice {
219 axis: self.axis,
220 start: index.clone(),
221 end: index + 1,
222 },
223 &[wire],
224 )?[0];
225 patch.shunt_outside(model, node.id.into(), wire)?;
226 return Ok(Some(patch));
227 }
228 }
229 if input_fact.konst.is_some() {
230 if let Some(sibling) = model
232 .outlet_successors(node.inputs[0])
233 .iter()
234 .find(|o| o.node != node.id && model.node(o.node).op_is::<OptSimpleMatMulPack>())
235 {
236 let mut patch = TypedModelPatch::default();
237 let mut taps = patch.taps(model, &node.inputs)?;
238 taps[0] = patch.tap_model(model, sibling.node.into())?;
239 let wire = patch.wire_node(&node.name, self.clone(), &taps)?[0];
240 patch.shunt_outside(model, node.id.into(), wire)?;
241 return Ok(Some(patch));
242 }
243 }
244 Ok(None)
245 }
246}
247
248impl EvalOp for Gather {
249 fn is_stateless(&self) -> bool {
250 true
251 }
252
253 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
254 let (data, indices) = args_2!(inputs);
255 let result = if let Ok(opaque) = data.to_scalar::<Opaque>() {
256 let dt = self.output_type.unwrap();
257 if let Some(data) = opaque.downcast_ref::<BlobWithFact>() {
258 dispatch_floatlike!(Self::eval_bq(dt)(self, data, &indices))?
259 } else if let Some(data) = opaque.downcast_ref::<Box<dyn MMMInputValue>>() {
260 dispatch_floatlike!(Self::eval_input_store(dt)(self, &**data, &indices))?
261 } else {
262 bail!("Can't use Gather on {:?} input", data);
263 }
264 } else {
265 dispatch_datum!(Self::eval_t(data.datum_type())(self, data, &indices))?
266 };
267 Ok(tvec!(result.into_tvalue()))
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn test_should_gather_scalar_index() {
277 let data = Tensor::from(arr1(&[1i64, 2, 3]));
278 let gatherer = Gather::new(0);
279 for idx in 2..3 {
280 let index = Tensor::from(arr0(idx));
281 let outputs =
282 gatherer.eval(tvec![data.clone().into_tvalue(), index.into_tvalue()]).unwrap();
283 let output = &outputs[0];
284 assert_eq!(output.shape().len(), 0);
285 assert_eq!(*output.to_scalar::<i64>().unwrap(), idx + 1);
286 }
287 }
288}