vortex_array/compute/
take.rs1use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
9use vortex_scalar::Scalar;
10
11use crate::arrays::ConstantArray;
12use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output};
13use crate::stats::{Precision, Stat, StatsProviderExt, StatsSet};
14use crate::vtable::VTable;
15use crate::{Array, ArrayRef, Canonical, IntoArray};
16
17pub fn take(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
18 if indices.is_empty() {
19 return Ok(Canonical::empty(
20 &array
21 .dtype()
22 .union_nullability(indices.dtype().nullability()),
23 )
24 .into_array());
25 }
26
27 TAKE_FN
28 .invoke(&InvocationArgs {
29 inputs: &[array.into(), indices.into()],
30 options: &(),
31 })?
32 .unwrap_array()
33}
34
35pub static TAKE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
36 let compute = ComputeFn::new("take".into(), ArcRef::new_ref(&Take));
37 for kernel in inventory::iter::<TakeKernelRef> {
38 compute.register_kernel(kernel.0.clone());
39 }
40 compute
41});
42
43pub struct Take;
44
45impl ComputeFnVTable for Take {
46 fn invoke(
47 &self,
48 args: &InvocationArgs,
49 kernels: &[ArcRef<dyn Kernel>],
50 ) -> VortexResult<Output> {
51 let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
52
53 if indices.all_invalid()? {
59 return Ok(ConstantArray::new(
60 Scalar::null(array.dtype().as_nullable()),
61 indices.len(),
62 )
63 .into_array()
64 .into());
65 }
66
67 let derived_stats = (!array.is_constant()).then(|| derive_take_stats(array));
70
71 let taken = take_impl(array, indices, kernels)?;
72
73 if let Some(derived_stats) = derived_stats {
74 let mut stats = taken.statistics().to_owned();
75 stats.combine_sets(&derived_stats, array.dtype())?;
76 for (stat, val) in stats.into_iter() {
77 taken.statistics().set(stat, val)
78 }
79 }
80
81 Ok(taken.into())
82 }
83
84 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
85 let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
86
87 if !indices.dtype().is_int() {
88 vortex_bail!(
89 "Take indices must be an integer type, got {}",
90 indices.dtype()
91 );
92 }
93
94 Ok(array
95 .dtype()
96 .union_nullability(indices.dtype().nullability()))
97 }
98
99 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
100 let TakeArgs { indices, .. } = TakeArgs::try_from(args)?;
101 Ok(indices.len())
102 }
103
104 fn is_elementwise(&self) -> bool {
105 false
106 }
107}
108
109fn derive_take_stats(arr: &dyn Array) -> StatsSet {
110 let stats = arr.statistics().to_owned();
111
112 let is_constant = stats.get_as::<bool>(Stat::IsConstant);
113
114 let mut stats = stats.keep_inexact_stats(&[
115 Stat::Min,
117 Stat::Max,
118 ]);
119
120 if is_constant == Some(Precision::Exact(true)) {
121 stats.set(Stat::IsConstant, Precision::exact(true));
123 }
124
125 stats
126}
127
128fn take_impl(
129 array: &dyn Array,
130 indices: &dyn Array,
131 kernels: &[ArcRef<dyn Kernel>],
132) -> VortexResult<ArrayRef> {
133 let args = InvocationArgs {
134 inputs: &[array.into(), indices.into()],
135 options: &(),
136 };
137
138 for kernel in TAKE_FROM_FN.kernels() {
140 if let Some(output) = kernel.invoke(&args)? {
141 return output.unwrap_array();
142 }
143 }
144 if let Some(output) = indices.invoke(&TAKE_FROM_FN, &args)? {
145 return output.unwrap_array();
146 }
147
148 for kernel in kernels {
150 if let Some(output) = kernel.invoke(&args)? {
151 return output.unwrap_array();
152 }
153 }
154 if let Some(output) = array.invoke(&TAKE_FN, &args)? {
155 return output.unwrap_array();
156 }
157
158 if !array.is_canonical() {
160 log::debug!("No take implementation found for {}", array.encoding_id());
161 let canonical = array.to_canonical()?;
162 return take(canonical.as_ref(), indices);
163 }
164
165 vortex_bail!("No take implementation found for {}", array.encoding_id());
166}
167
168struct TakeArgs<'a> {
169 array: &'a dyn Array,
170 indices: &'a dyn Array,
171}
172
173impl<'a> TryFrom<&InvocationArgs<'a>> for TakeArgs<'a> {
174 type Error = VortexError;
175
176 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
177 if value.inputs.len() != 2 {
178 vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
179 }
180 let array = value.inputs[0]
181 .array()
182 .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
183 let indices = value.inputs[1]
184 .array()
185 .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
186 Ok(Self { array, indices })
187 }
188}
189
190pub trait TakeKernel: VTable {
191 fn take(&self, array: &Self::Array, indices: &dyn Array) -> VortexResult<ArrayRef>;
198}
199
200pub struct TakeKernelRef(pub ArcRef<dyn Kernel>);
202inventory::collect!(TakeKernelRef);
203
204#[derive(Debug)]
205pub struct TakeKernelAdapter<V: VTable>(pub V);
206
207impl<V: VTable + TakeKernel> TakeKernelAdapter<V> {
208 pub const fn lift(&'static self) -> TakeKernelRef {
209 TakeKernelRef(ArcRef::new_ref(self))
210 }
211}
212
213impl<V: VTable + TakeKernel> Kernel for TakeKernelAdapter<V> {
214 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
215 let inputs = TakeArgs::try_from(args)?;
216 let Some(array) = inputs.array.as_opt::<V>() else {
217 return Ok(None);
218 };
219 Ok(Some(V::take(&self.0, array, inputs.indices)?.into()))
220 }
221}
222
223pub static TAKE_FROM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
224 let compute = ComputeFn::new("take_from".into(), ArcRef::new_ref(&TakeFrom));
225 for kernel in inventory::iter::<TakeFromKernelRef> {
226 compute.register_kernel(kernel.0.clone());
227 }
228 compute
229});
230
231pub struct TakeFrom;
232
233impl ComputeFnVTable for TakeFrom {
234 fn invoke(
235 &self,
236 _args: &InvocationArgs,
237 _kernels: &[ArcRef<dyn Kernel>],
238 ) -> VortexResult<Output> {
239 vortex_bail!(
240 "TakeFrom should not be invoked directly. Its kernels are used to accelerated the Take function"
241 )
242 }
243
244 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
245 Take.return_dtype(args)
246 }
247
248 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
249 Take.return_len(args)
250 }
251
252 fn is_elementwise(&self) -> bool {
253 Take.is_elementwise()
254 }
255}
256
257pub trait TakeFromKernel: VTable {
258 fn take_from(&self, indices: &Self::Array, array: &dyn Array)
261 -> VortexResult<Option<ArrayRef>>;
262}
263
264pub struct TakeFromKernelRef(pub ArcRef<dyn Kernel>);
265inventory::collect!(TakeFromKernelRef);
266
267#[derive(Debug)]
268pub struct TakeFromKernelAdapter<V: VTable>(pub V);
269
270impl<V: VTable + TakeFromKernel> TakeFromKernelAdapter<V> {
271 pub const fn lift(&'static self) -> TakeFromKernelRef {
272 TakeFromKernelRef(ArcRef::new_ref(self))
273 }
274}
275
276impl<V: VTable + TakeFromKernel> Kernel for TakeFromKernelAdapter<V> {
277 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
278 let inputs = TakeArgs::try_from(args)?;
279 let Some(indices) = inputs.indices.as_opt::<V>() else {
280 return Ok(None);
281 };
282 Ok(V::take_from(&self.0, indices, inputs.array)?.map(Output::from))
283 }
284}