vortex_array/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
9use vortex_scalar::Scalar;
10
11use crate::arrays::{ConstantArray, ConstantVTable};
12use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output};
13use crate::stats::{Precision, Stat, StatsProviderExt, StatsSet};
14use crate::vtable::VTable;
15use crate::{Array, ArrayRef, Canonical, IntoArray};
16
17static TAKE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
18    let compute = ComputeFn::new("take".into(), ArcRef::new_ref(&Take));
19    for kernel in inventory::iter::<TakeKernelRef> {
20        compute.register_kernel(kernel.0.clone());
21    }
22    compute
23});
24
25/// Creates a new array using the elements from the input `array` indexed by `indices`.
26///
27/// For example, if we have an `array` `[1, 2, 3, 4, 5]` and `indices` `[4, 2]`, the resulting
28/// array would be `[5, 3]`.
29///
30/// The output array will have the same length as the `indices` array.
31pub fn take(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
32    if indices.is_empty() {
33        return Ok(Canonical::empty(
34            &array
35                .dtype()
36                .union_nullability(indices.dtype().nullability()),
37        )
38        .into_array());
39    }
40
41    TAKE_FN
42        .invoke(&InvocationArgs {
43            inputs: &[array.into(), indices.into()],
44            options: &(),
45        })?
46        .unwrap_array()
47}
48
49#[doc(hidden)]
50pub struct Take;
51
52impl ComputeFnVTable for Take {
53    fn invoke(
54        &self,
55        args: &InvocationArgs,
56        kernels: &[ArcRef<dyn Kernel>],
57    ) -> VortexResult<Output> {
58        let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
59
60        // TODO(ngates): if indices are sorted and unique (strict-sorted), then we should delegate to
61        //  the filter function since they're typically optimised for this case.
62        // TODO(ngates): if indices min is quite high, we could slice self and offset the indices
63        //  such that canonicalize does less work.
64
65        if indices.all_invalid()? {
66            return Ok(ConstantArray::new(
67                Scalar::null(array.dtype().as_nullable()),
68                indices.len(),
69            )
70            .into_array()
71            .into());
72        }
73
74        let taken_array = take_impl(array, indices, kernels)?;
75
76        // We know that constant array don't need stats propagation, so we can avoid the overhead of
77        // computing derived stats and merging them in.
78        if !taken_array.is::<ConstantVTable>() {
79            let derived_stats = derive_take_stats(array);
80
81            // TODO(robert): Ideally, we want to have a `combine_sets` method available on a
82            // `StatsSetRef` so we don't have to incur a clone here in `.to_owned()`.
83            let mut stats = taken_array.statistics().to_owned();
84            stats.combine_sets(&derived_stats, array.dtype())?;
85
86            for (stat, val) in stats {
87                // Alternatively, use a monoidal pattern here to set `stat = val`, or if it already
88                // exists, combine the two stats (similar to how `combine_sets` does it).
89                taken_array.statistics().set(stat, val)
90            }
91        }
92
93        Ok(taken_array.into())
94    }
95
96    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
97        let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
98
99        if !indices.dtype().is_int() {
100            vortex_bail!(
101                "Take indices must be an integer type, got {}",
102                indices.dtype()
103            );
104        }
105
106        Ok(array
107            .dtype()
108            .union_nullability(indices.dtype().nullability()))
109    }
110
111    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
112        let TakeArgs { indices, .. } = TakeArgs::try_from(args)?;
113        Ok(indices.len())
114    }
115
116    fn is_elementwise(&self) -> bool {
117        false
118    }
119}
120
121fn derive_take_stats(arr: &dyn Array) -> StatsSet {
122    let stats = arr.statistics().to_owned();
123
124    let is_constant = arr.statistics().get_as::<bool>(Stat::IsConstant);
125
126    let mut stats = stats.keep_inexact_stats(&[
127        // Cannot create values smaller than min or larger than max
128        Stat::Min,
129        Stat::Max,
130    ]);
131
132    if is_constant == Some(Precision::Exact(true)) {
133        // Any combination of elements from a constant array is still const
134        stats.set(Stat::IsConstant, Precision::exact(true));
135    }
136
137    stats
138}
139
140fn take_impl(
141    array: &dyn Array,
142    indices: &dyn Array,
143    kernels: &[ArcRef<dyn Kernel>],
144) -> VortexResult<ArrayRef> {
145    let args = InvocationArgs {
146        inputs: &[array.into(), indices.into()],
147        options: &(),
148    };
149
150    // First look for a TakeFrom specialized on the indices.
151    for kernel in TAKE_FROM_FN.kernels() {
152        if let Some(output) = kernel.invoke(&args)? {
153            return output.unwrap_array();
154        }
155    }
156    if let Some(output) = indices.invoke(&TAKE_FROM_FN, &args)? {
157        return output.unwrap_array();
158    }
159
160    // Then look for a Take kernel
161    for kernel in kernels {
162        if let Some(output) = kernel.invoke(&args)? {
163            return output.unwrap_array();
164        }
165    }
166    if let Some(output) = array.invoke(&TAKE_FN, &args)? {
167        return output.unwrap_array();
168    }
169
170    // Otherwise, canonicalize and try again.
171    if !array.is_canonical() {
172        log::debug!("No take implementation found for {}", array.encoding_id());
173        let canonical = array.to_canonical()?;
174        return take(canonical.as_ref(), indices);
175    }
176
177    vortex_bail!("No take implementation found for {}", array.encoding_id());
178}
179
180struct TakeArgs<'a> {
181    array: &'a dyn Array,
182    indices: &'a dyn Array,
183}
184
185impl<'a> TryFrom<&InvocationArgs<'a>> for TakeArgs<'a> {
186    type Error = VortexError;
187
188    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
189        if value.inputs.len() != 2 {
190            vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
191        }
192        let array = value.inputs[0]
193            .array()
194            .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
195        let indices = value.inputs[1]
196            .array()
197            .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
198        Ok(Self { array, indices })
199    }
200}
201
202pub trait TakeKernel: VTable {
203    /// Create a new array by taking the values from the `array` at the
204    /// given `indices`.
205    ///
206    /// # Panics
207    ///
208    /// Using `indices` that are invalid for the given `array` will cause a panic.
209    fn take(&self, array: &Self::Array, indices: &dyn Array) -> VortexResult<ArrayRef>;
210}
211
212/// A kernel that implements the filter function.
213pub struct TakeKernelRef(pub ArcRef<dyn Kernel>);
214inventory::collect!(TakeKernelRef);
215
216#[derive(Debug)]
217pub struct TakeKernelAdapter<V: VTable>(pub V);
218
219impl<V: VTable + TakeKernel> TakeKernelAdapter<V> {
220    pub const fn lift(&'static self) -> TakeKernelRef {
221        TakeKernelRef(ArcRef::new_ref(self))
222    }
223}
224
225impl<V: VTable + TakeKernel> Kernel for TakeKernelAdapter<V> {
226    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
227        let inputs = TakeArgs::try_from(args)?;
228        let Some(array) = inputs.array.as_opt::<V>() else {
229            return Ok(None);
230        };
231        Ok(Some(V::take(&self.0, array, inputs.indices)?.into()))
232    }
233}
234
235static TAKE_FROM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
236    let compute = ComputeFn::new("take_from".into(), ArcRef::new_ref(&TakeFrom));
237    for kernel in inventory::iter::<TakeFromKernelRef> {
238        compute.register_kernel(kernel.0.clone());
239    }
240    compute
241});
242
243pub struct TakeFrom;
244
245impl ComputeFnVTable for TakeFrom {
246    fn invoke(
247        &self,
248        _args: &InvocationArgs,
249        _kernels: &[ArcRef<dyn Kernel>],
250    ) -> VortexResult<Output> {
251        vortex_bail!(
252            "TakeFrom should not be invoked directly. Its kernels are used to accelerated the Take function"
253        )
254    }
255
256    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
257        Take.return_dtype(args)
258    }
259
260    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
261        Take.return_len(args)
262    }
263
264    fn is_elementwise(&self) -> bool {
265        Take.is_elementwise()
266    }
267}
268
269pub trait TakeFromKernel: VTable {
270    /// Create a new array by taking the values from the `array` at the
271    /// given `indices`.
272    fn take_from(&self, indices: &Self::Array, array: &dyn Array)
273    -> VortexResult<Option<ArrayRef>>;
274}
275
276pub struct TakeFromKernelRef(pub ArcRef<dyn Kernel>);
277inventory::collect!(TakeFromKernelRef);
278
279#[derive(Debug)]
280pub struct TakeFromKernelAdapter<V: VTable>(pub V);
281
282impl<V: VTable + TakeFromKernel> TakeFromKernelAdapter<V> {
283    pub const fn lift(&'static self) -> TakeFromKernelRef {
284        TakeFromKernelRef(ArcRef::new_ref(self))
285    }
286}
287
288impl<V: VTable + TakeFromKernel> Kernel for TakeFromKernelAdapter<V> {
289    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
290        let inputs = TakeArgs::try_from(args)?;
291        let Some(indices) = inputs.indices.as_opt::<V>() else {
292            return Ok(None);
293        };
294        Ok(V::take_from(&self.0, indices, inputs.array)?.map(Output::from))
295    }
296}