vortex_array/compute/
take.rs1use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
9use vortex_scalar::Scalar;
10
11use crate::arrays::{ConstantArray, ConstantVTable};
12use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output};
13use crate::stats::{Precision, Stat, StatsProviderExt, StatsSet};
14use crate::vtable::VTable;
15use crate::{Array, ArrayRef, Canonical, IntoArray};
16
17static TAKE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
18 let compute = ComputeFn::new("take".into(), ArcRef::new_ref(&Take));
19 for kernel in inventory::iter::<TakeKernelRef> {
20 compute.register_kernel(kernel.0.clone());
21 }
22 compute
23});
24
25pub fn take(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
32 if indices.is_empty() {
33 return Ok(Canonical::empty(
34 &array
35 .dtype()
36 .union_nullability(indices.dtype().nullability()),
37 )
38 .into_array());
39 }
40
41 TAKE_FN
42 .invoke(&InvocationArgs {
43 inputs: &[array.into(), indices.into()],
44 options: &(),
45 })?
46 .unwrap_array()
47}
48
49#[doc(hidden)]
50pub struct Take;
51
52impl ComputeFnVTable for Take {
53 fn invoke(
54 &self,
55 args: &InvocationArgs,
56 kernels: &[ArcRef<dyn Kernel>],
57 ) -> VortexResult<Output> {
58 let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
59
60 if indices.all_invalid()? {
66 return Ok(ConstantArray::new(
67 Scalar::null(array.dtype().as_nullable()),
68 indices.len(),
69 )
70 .into_array()
71 .into());
72 }
73
74 let taken_array = take_impl(array, indices, kernels)?;
75
76 if !taken_array.is::<ConstantVTable>() {
79 let derived_stats = derive_take_stats(array);
80
81 let mut stats = taken_array.statistics().to_owned();
84 stats.combine_sets(&derived_stats, array.dtype())?;
85
86 for (stat, val) in stats {
87 taken_array.statistics().set(stat, val)
90 }
91 }
92
93 Ok(taken_array.into())
94 }
95
96 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
97 let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
98
99 if !indices.dtype().is_int() {
100 vortex_bail!(
101 "Take indices must be an integer type, got {}",
102 indices.dtype()
103 );
104 }
105
106 Ok(array
107 .dtype()
108 .union_nullability(indices.dtype().nullability()))
109 }
110
111 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
112 let TakeArgs { indices, .. } = TakeArgs::try_from(args)?;
113 Ok(indices.len())
114 }
115
116 fn is_elementwise(&self) -> bool {
117 false
118 }
119}
120
121fn derive_take_stats(arr: &dyn Array) -> StatsSet {
122 let stats = arr.statistics().to_owned();
123
124 let is_constant = arr.statistics().get_as::<bool>(Stat::IsConstant);
125
126 let mut stats = stats.keep_inexact_stats(&[
127 Stat::Min,
129 Stat::Max,
130 ]);
131
132 if is_constant == Some(Precision::Exact(true)) {
133 stats.set(Stat::IsConstant, Precision::exact(true));
135 }
136
137 stats
138}
139
140fn take_impl(
141 array: &dyn Array,
142 indices: &dyn Array,
143 kernels: &[ArcRef<dyn Kernel>],
144) -> VortexResult<ArrayRef> {
145 let args = InvocationArgs {
146 inputs: &[array.into(), indices.into()],
147 options: &(),
148 };
149
150 for kernel in TAKE_FROM_FN.kernels() {
152 if let Some(output) = kernel.invoke(&args)? {
153 return output.unwrap_array();
154 }
155 }
156 if let Some(output) = indices.invoke(&TAKE_FROM_FN, &args)? {
157 return output.unwrap_array();
158 }
159
160 for kernel in kernels {
162 if let Some(output) = kernel.invoke(&args)? {
163 return output.unwrap_array();
164 }
165 }
166 if let Some(output) = array.invoke(&TAKE_FN, &args)? {
167 return output.unwrap_array();
168 }
169
170 if !array.is_canonical() {
172 log::debug!("No take implementation found for {}", array.encoding_id());
173 let canonical = array.to_canonical()?;
174 return take(canonical.as_ref(), indices);
175 }
176
177 vortex_bail!("No take implementation found for {}", array.encoding_id());
178}
179
180struct TakeArgs<'a> {
181 array: &'a dyn Array,
182 indices: &'a dyn Array,
183}
184
185impl<'a> TryFrom<&InvocationArgs<'a>> for TakeArgs<'a> {
186 type Error = VortexError;
187
188 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
189 if value.inputs.len() != 2 {
190 vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
191 }
192 let array = value.inputs[0]
193 .array()
194 .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
195 let indices = value.inputs[1]
196 .array()
197 .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
198 Ok(Self { array, indices })
199 }
200}
201
202pub trait TakeKernel: VTable {
203 fn take(&self, array: &Self::Array, indices: &dyn Array) -> VortexResult<ArrayRef>;
210}
211
212pub struct TakeKernelRef(pub ArcRef<dyn Kernel>);
214inventory::collect!(TakeKernelRef);
215
216#[derive(Debug)]
217pub struct TakeKernelAdapter<V: VTable>(pub V);
218
219impl<V: VTable + TakeKernel> TakeKernelAdapter<V> {
220 pub const fn lift(&'static self) -> TakeKernelRef {
221 TakeKernelRef(ArcRef::new_ref(self))
222 }
223}
224
225impl<V: VTable + TakeKernel> Kernel for TakeKernelAdapter<V> {
226 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
227 let inputs = TakeArgs::try_from(args)?;
228 let Some(array) = inputs.array.as_opt::<V>() else {
229 return Ok(None);
230 };
231 Ok(Some(V::take(&self.0, array, inputs.indices)?.into()))
232 }
233}
234
235static TAKE_FROM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
236 let compute = ComputeFn::new("take_from".into(), ArcRef::new_ref(&TakeFrom));
237 for kernel in inventory::iter::<TakeFromKernelRef> {
238 compute.register_kernel(kernel.0.clone());
239 }
240 compute
241});
242
243pub struct TakeFrom;
244
245impl ComputeFnVTable for TakeFrom {
246 fn invoke(
247 &self,
248 _args: &InvocationArgs,
249 _kernels: &[ArcRef<dyn Kernel>],
250 ) -> VortexResult<Output> {
251 vortex_bail!(
252 "TakeFrom should not be invoked directly. Its kernels are used to accelerated the Take function"
253 )
254 }
255
256 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
257 Take.return_dtype(args)
258 }
259
260 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
261 Take.return_len(args)
262 }
263
264 fn is_elementwise(&self) -> bool {
265 Take.is_elementwise()
266 }
267}
268
269pub trait TakeFromKernel: VTable {
270 fn take_from(&self, indices: &Self::Array, array: &dyn Array)
273 -> VortexResult<Option<ArrayRef>>;
274}
275
276pub struct TakeFromKernelRef(pub ArcRef<dyn Kernel>);
277inventory::collect!(TakeFromKernelRef);
278
279#[derive(Debug)]
280pub struct TakeFromKernelAdapter<V: VTable>(pub V);
281
282impl<V: VTable + TakeFromKernel> TakeFromKernelAdapter<V> {
283 pub const fn lift(&'static self) -> TakeFromKernelRef {
284 TakeFromKernelRef(ArcRef::new_ref(self))
285 }
286}
287
288impl<V: VTable + TakeFromKernel> Kernel for TakeFromKernelAdapter<V> {
289 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
290 let inputs = TakeArgs::try_from(args)?;
291 let Some(indices) = inputs.indices.as_opt::<V>() else {
292 return Ok(None);
293 };
294 Ok(V::take_from(&self.0, indices, inputs.array)?.map(Output::from))
295 }
296}