vortex_array/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
9use vortex_scalar::Scalar;
10
11use crate::arrays::{ConstantArray, ConstantVTable};
12use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output};
13use crate::stats::{Precision, Stat, StatsProviderExt, StatsSet};
14use crate::vtable::VTable;
15use crate::{Array, ArrayRef, Canonical, IntoArray};
16
17static TAKE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
18    let compute = ComputeFn::new("take".into(), ArcRef::new_ref(&Take));
19    for kernel in inventory::iter::<TakeKernelRef> {
20        compute.register_kernel(kernel.0.clone());
21    }
22    compute
23});
24
25pub(crate) fn warm_up_vtable() -> usize {
26    TAKE_FN.kernels().len() + TAKE_FROM_FN.kernels().len()
27}
28
29/// Creates a new array using the elements from the input `array` indexed by `indices`.
30///
31/// For example, if we have an `array` `[1, 2, 3, 4, 5]` and `indices` `[4, 2]`, the resulting
32/// array would be `[5, 3]`.
33///
34/// The output array will have the same length as the `indices` array.
35pub fn take(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
36    if indices.is_empty() {
37        return Ok(Canonical::empty(
38            &array
39                .dtype()
40                .union_nullability(indices.dtype().nullability()),
41        )
42        .into_array());
43    }
44
45    TAKE_FN
46        .invoke(&InvocationArgs {
47            inputs: &[array.into(), indices.into()],
48            options: &(),
49        })?
50        .unwrap_array()
51}
52
53#[doc(hidden)]
54pub struct Take;
55
56impl ComputeFnVTable for Take {
57    fn invoke(
58        &self,
59        args: &InvocationArgs,
60        kernels: &[ArcRef<dyn Kernel>],
61    ) -> VortexResult<Output> {
62        let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
63
64        // TODO(ngates): if indices are sorted and unique (strict-sorted), then we should delegate to
65        //  the filter function since they're typically optimised for this case.
66        // TODO(ngates): if indices min is quite high, we could slice self and offset the indices
67        //  such that canonicalize does less work.
68
69        if indices.all_invalid() {
70            return Ok(ConstantArray::new(
71                Scalar::null(array.dtype().as_nullable()),
72                indices.len(),
73            )
74            .into_array()
75            .into());
76        }
77
78        let taken_array = take_impl(array, indices, kernels)?;
79
80        // We know that constant array don't need stats propagation, so we can avoid the overhead of
81        // computing derived stats and merging them in.
82        if !taken_array.is::<ConstantVTable>() {
83            let derived_stats = derive_take_stats(array);
84
85            // TODO(robert): Ideally, we want to have a `combine_sets` method available on a
86            // `StatsSetRef` so we don't have to incur a clone here in `.to_owned()`.
87            let mut stats = taken_array.statistics().to_owned();
88            stats.combine_sets(&derived_stats, array.dtype())?;
89
90            for (stat, val) in stats {
91                // Alternatively, use a monoidal pattern here to set `stat = val`, or if it already
92                // exists, combine the two stats (similar to how `combine_sets` does it).
93                taken_array.statistics().set(stat, val)
94            }
95        }
96
97        Ok(taken_array.into())
98    }
99
100    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
101        let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
102
103        if !indices.dtype().is_int() {
104            vortex_bail!(
105                "Take indices must be an integer type, got {}",
106                indices.dtype()
107            );
108        }
109
110        Ok(array
111            .dtype()
112            .union_nullability(indices.dtype().nullability()))
113    }
114
115    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
116        let TakeArgs { indices, .. } = TakeArgs::try_from(args)?;
117        Ok(indices.len())
118    }
119
120    fn is_elementwise(&self) -> bool {
121        false
122    }
123}
124
125fn derive_take_stats(arr: &dyn Array) -> StatsSet {
126    let stats = arr.statistics().to_owned();
127
128    let is_constant = arr.statistics().get_as::<bool>(Stat::IsConstant);
129
130    let mut stats = stats.keep_inexact_stats(&[
131        // Cannot create values smaller than min or larger than max
132        Stat::Min,
133        Stat::Max,
134    ]);
135
136    if is_constant == Some(Precision::Exact(true)) {
137        // Any combination of elements from a constant array is still const
138        stats.set(Stat::IsConstant, Precision::exact(true));
139    }
140
141    stats
142}
143
144fn take_impl(
145    array: &dyn Array,
146    indices: &dyn Array,
147    kernels: &[ArcRef<dyn Kernel>],
148) -> VortexResult<ArrayRef> {
149    let args = InvocationArgs {
150        inputs: &[array.into(), indices.into()],
151        options: &(),
152    };
153
154    // First look for a TakeFrom specialized on the indices.
155    for kernel in TAKE_FROM_FN.kernels() {
156        if let Some(output) = kernel.invoke(&args)? {
157            return output.unwrap_array();
158        }
159    }
160    if let Some(output) = indices.invoke(&TAKE_FROM_FN, &args)? {
161        return output.unwrap_array();
162    }
163
164    // Then look for a Take kernel
165    for kernel in kernels {
166        if let Some(output) = kernel.invoke(&args)? {
167            return output.unwrap_array();
168        }
169    }
170    if let Some(output) = array.invoke(&TAKE_FN, &args)? {
171        return output.unwrap_array();
172    }
173
174    // Otherwise, canonicalize and try again.
175    if !array.is_canonical() {
176        log::debug!("No take implementation found for {}", array.encoding_id());
177        let canonical = array.to_canonical();
178        return take(canonical.as_ref(), indices);
179    }
180
181    vortex_bail!("No take implementation found for {}", array.encoding_id());
182}
183
184struct TakeArgs<'a> {
185    array: &'a dyn Array,
186    indices: &'a dyn Array,
187}
188
189impl<'a> TryFrom<&InvocationArgs<'a>> for TakeArgs<'a> {
190    type Error = VortexError;
191
192    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
193        if value.inputs.len() != 2 {
194            vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
195        }
196        let array = value.inputs[0]
197            .array()
198            .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
199        let indices = value.inputs[1]
200            .array()
201            .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
202        Ok(Self { array, indices })
203    }
204}
205
206pub trait TakeKernel: VTable {
207    /// Create a new array by taking the values from the `array` at the
208    /// given `indices`.
209    ///
210    /// # Panics
211    ///
212    /// Using `indices` that are invalid for the given `array` will cause a panic.
213    fn take(&self, array: &Self::Array, indices: &dyn Array) -> VortexResult<ArrayRef>;
214}
215
216/// A kernel that implements the filter function.
217pub struct TakeKernelRef(pub ArcRef<dyn Kernel>);
218inventory::collect!(TakeKernelRef);
219
220#[derive(Debug)]
221pub struct TakeKernelAdapter<V: VTable>(pub V);
222
223impl<V: VTable + TakeKernel> TakeKernelAdapter<V> {
224    pub const fn lift(&'static self) -> TakeKernelRef {
225        TakeKernelRef(ArcRef::new_ref(self))
226    }
227}
228
229impl<V: VTable + TakeKernel> Kernel for TakeKernelAdapter<V> {
230    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
231        let inputs = TakeArgs::try_from(args)?;
232        let Some(array) = inputs.array.as_opt::<V>() else {
233            return Ok(None);
234        };
235        Ok(Some(V::take(&self.0, array, inputs.indices)?.into()))
236    }
237}
238
239static TAKE_FROM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
240    let compute = ComputeFn::new("take_from".into(), ArcRef::new_ref(&TakeFrom));
241    for kernel in inventory::iter::<TakeFromKernelRef> {
242        compute.register_kernel(kernel.0.clone());
243    }
244    compute
245});
246
247pub struct TakeFrom;
248
249impl ComputeFnVTable for TakeFrom {
250    fn invoke(
251        &self,
252        _args: &InvocationArgs,
253        _kernels: &[ArcRef<dyn Kernel>],
254    ) -> VortexResult<Output> {
255        vortex_bail!(
256            "TakeFrom should not be invoked directly. Its kernels are used to accelerated the Take function"
257        )
258    }
259
260    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
261        Take.return_dtype(args)
262    }
263
264    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
265        Take.return_len(args)
266    }
267
268    fn is_elementwise(&self) -> bool {
269        Take.is_elementwise()
270    }
271}
272
273pub trait TakeFromKernel: VTable {
274    /// Create a new array by taking the values from the `array` at the
275    /// given `indices`.
276    fn take_from(&self, indices: &Self::Array, array: &dyn Array)
277    -> VortexResult<Option<ArrayRef>>;
278}
279
280pub struct TakeFromKernelRef(pub ArcRef<dyn Kernel>);
281inventory::collect!(TakeFromKernelRef);
282
283#[derive(Debug)]
284pub struct TakeFromKernelAdapter<V: VTable>(pub V);
285
286impl<V: VTable + TakeFromKernel> TakeFromKernelAdapter<V> {
287    pub const fn lift(&'static self) -> TakeFromKernelRef {
288        TakeFromKernelRef(ArcRef::new_ref(self))
289    }
290}
291
292impl<V: VTable + TakeFromKernel> Kernel for TakeFromKernelAdapter<V> {
293    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
294        let inputs = TakeArgs::try_from(args)?;
295        let Some(indices) = inputs.indices.as_opt::<V>() else {
296            return Ok(None);
297        };
298        Ok(V::take_from(&self.0, indices, inputs.array)?.map(Output::from))
299    }
300}