vortex_array/compute/
take.rs1use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
9use vortex_scalar::Scalar;
10
11use crate::arrays::{ConstantArray, ConstantVTable};
12use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output};
13use crate::stats::{Precision, Stat, StatsProviderExt, StatsSet};
14use crate::vtable::VTable;
15use crate::{Array, ArrayRef, Canonical, IntoArray};
16
17static TAKE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
18 let compute = ComputeFn::new("take".into(), ArcRef::new_ref(&Take));
19 for kernel in inventory::iter::<TakeKernelRef> {
20 compute.register_kernel(kernel.0.clone());
21 }
22 compute
23});
24
25pub(crate) fn warm_up_vtable() -> usize {
26 TAKE_FN.kernels().len() + TAKE_FROM_FN.kernels().len()
27}
28
29pub fn take(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
36 if indices.is_empty() {
37 return Ok(Canonical::empty(
38 &array
39 .dtype()
40 .union_nullability(indices.dtype().nullability()),
41 )
42 .into_array());
43 }
44
45 TAKE_FN
46 .invoke(&InvocationArgs {
47 inputs: &[array.into(), indices.into()],
48 options: &(),
49 })?
50 .unwrap_array()
51}
52
53#[doc(hidden)]
54pub struct Take;
55
56impl ComputeFnVTable for Take {
57 fn invoke(
58 &self,
59 args: &InvocationArgs,
60 kernels: &[ArcRef<dyn Kernel>],
61 ) -> VortexResult<Output> {
62 let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
63
64 if indices.all_invalid() {
70 return Ok(ConstantArray::new(
71 Scalar::null(array.dtype().as_nullable()),
72 indices.len(),
73 )
74 .into_array()
75 .into());
76 }
77
78 let taken_array = take_impl(array, indices, kernels)?;
79
80 if !taken_array.is::<ConstantVTable>() {
83 let derived_stats = derive_take_stats(array);
84
85 let mut stats = taken_array.statistics().to_owned();
88 stats.combine_sets(&derived_stats, array.dtype())?;
89
90 for (stat, val) in stats {
91 taken_array.statistics().set(stat, val)
94 }
95 }
96
97 Ok(taken_array.into())
98 }
99
100 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
101 let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
102
103 if !indices.dtype().is_int() {
104 vortex_bail!(
105 "Take indices must be an integer type, got {}",
106 indices.dtype()
107 );
108 }
109
110 Ok(array
111 .dtype()
112 .union_nullability(indices.dtype().nullability()))
113 }
114
115 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
116 let TakeArgs { indices, .. } = TakeArgs::try_from(args)?;
117 Ok(indices.len())
118 }
119
120 fn is_elementwise(&self) -> bool {
121 false
122 }
123}
124
125fn derive_take_stats(arr: &dyn Array) -> StatsSet {
126 let stats = arr.statistics().to_owned();
127
128 let is_constant = arr.statistics().get_as::<bool>(Stat::IsConstant);
129
130 let mut stats = stats.keep_inexact_stats(&[
131 Stat::Min,
133 Stat::Max,
134 ]);
135
136 if is_constant == Some(Precision::Exact(true)) {
137 stats.set(Stat::IsConstant, Precision::exact(true));
139 }
140
141 stats
142}
143
144fn take_impl(
145 array: &dyn Array,
146 indices: &dyn Array,
147 kernels: &[ArcRef<dyn Kernel>],
148) -> VortexResult<ArrayRef> {
149 let args = InvocationArgs {
150 inputs: &[array.into(), indices.into()],
151 options: &(),
152 };
153
154 for kernel in TAKE_FROM_FN.kernels() {
156 if let Some(output) = kernel.invoke(&args)? {
157 return output.unwrap_array();
158 }
159 }
160 if let Some(output) = indices.invoke(&TAKE_FROM_FN, &args)? {
161 return output.unwrap_array();
162 }
163
164 for kernel in kernels {
166 if let Some(output) = kernel.invoke(&args)? {
167 return output.unwrap_array();
168 }
169 }
170 if let Some(output) = array.invoke(&TAKE_FN, &args)? {
171 return output.unwrap_array();
172 }
173
174 if !array.is_canonical() {
176 log::debug!("No take implementation found for {}", array.encoding_id());
177 let canonical = array.to_canonical();
178 return take(canonical.as_ref(), indices);
179 }
180
181 vortex_bail!("No take implementation found for {}", array.encoding_id());
182}
183
184struct TakeArgs<'a> {
185 array: &'a dyn Array,
186 indices: &'a dyn Array,
187}
188
189impl<'a> TryFrom<&InvocationArgs<'a>> for TakeArgs<'a> {
190 type Error = VortexError;
191
192 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
193 if value.inputs.len() != 2 {
194 vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
195 }
196 let array = value.inputs[0]
197 .array()
198 .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
199 let indices = value.inputs[1]
200 .array()
201 .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
202 Ok(Self { array, indices })
203 }
204}
205
206pub trait TakeKernel: VTable {
207 fn take(&self, array: &Self::Array, indices: &dyn Array) -> VortexResult<ArrayRef>;
214}
215
216pub struct TakeKernelRef(pub ArcRef<dyn Kernel>);
218inventory::collect!(TakeKernelRef);
219
220#[derive(Debug)]
221pub struct TakeKernelAdapter<V: VTable>(pub V);
222
223impl<V: VTable + TakeKernel> TakeKernelAdapter<V> {
224 pub const fn lift(&'static self) -> TakeKernelRef {
225 TakeKernelRef(ArcRef::new_ref(self))
226 }
227}
228
229impl<V: VTable + TakeKernel> Kernel for TakeKernelAdapter<V> {
230 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
231 let inputs = TakeArgs::try_from(args)?;
232 let Some(array) = inputs.array.as_opt::<V>() else {
233 return Ok(None);
234 };
235 Ok(Some(V::take(&self.0, array, inputs.indices)?.into()))
236 }
237}
238
239static TAKE_FROM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
240 let compute = ComputeFn::new("take_from".into(), ArcRef::new_ref(&TakeFrom));
241 for kernel in inventory::iter::<TakeFromKernelRef> {
242 compute.register_kernel(kernel.0.clone());
243 }
244 compute
245});
246
247pub struct TakeFrom;
248
249impl ComputeFnVTable for TakeFrom {
250 fn invoke(
251 &self,
252 _args: &InvocationArgs,
253 _kernels: &[ArcRef<dyn Kernel>],
254 ) -> VortexResult<Output> {
255 vortex_bail!(
256 "TakeFrom should not be invoked directly. Its kernels are used to accelerated the Take function"
257 )
258 }
259
260 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
261 Take.return_dtype(args)
262 }
263
264 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
265 Take.return_len(args)
266 }
267
268 fn is_elementwise(&self) -> bool {
269 Take.is_elementwise()
270 }
271}
272
273pub trait TakeFromKernel: VTable {
274 fn take_from(&self, indices: &Self::Array, array: &dyn Array)
277 -> VortexResult<Option<ArrayRef>>;
278}
279
280pub struct TakeFromKernelRef(pub ArcRef<dyn Kernel>);
281inventory::collect!(TakeFromKernelRef);
282
283#[derive(Debug)]
284pub struct TakeFromKernelAdapter<V: VTable>(pub V);
285
286impl<V: VTable + TakeFromKernel> TakeFromKernelAdapter<V> {
287 pub const fn lift(&'static self) -> TakeFromKernelRef {
288 TakeFromKernelRef(ArcRef::new_ref(self))
289 }
290}
291
292impl<V: VTable + TakeFromKernel> Kernel for TakeFromKernelAdapter<V> {
293 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
294 let inputs = TakeArgs::try_from(args)?;
295 let Some(indices) = inputs.indices.as_opt::<V>() else {
296 return Ok(None);
297 };
298 Ok(V::take_from(&self.0, indices, inputs.array)?.map(Output::from))
299 }
300}