vortex_array/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
9use vortex_scalar::Scalar;
10
11use crate::arrays::ConstantArray;
12use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output};
13use crate::stats::{Precision, Stat, StatsProvider, StatsProviderExt, StatsSet};
14use crate::vtable::VTable;
15use crate::{Array, ArrayRef, Canonical, IntoArray};
16
17static TAKE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
18    let compute = ComputeFn::new("take".into(), ArcRef::new_ref(&Take));
19    for kernel in inventory::iter::<TakeKernelRef> {
20        compute.register_kernel(kernel.0.clone());
21    }
22    compute
23});
24
25pub(crate) fn warm_up_vtable() -> usize {
26    TAKE_FN.kernels().len() + TAKE_FROM_FN.kernels().len()
27}
28
29/// Creates a new array using the elements from the input `array` indexed by `indices`.
30///
31/// For example, if we have an `array` `[1, 2, 3, 4, 5]` and `indices` `[4, 2]`, the resulting
32/// array would be `[5, 3]`.
33///
34/// The output array will have the same length as the `indices` array.
35pub fn take(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
36    if indices.is_empty() {
37        return Ok(Canonical::empty(
38            &array
39                .dtype()
40                .union_nullability(indices.dtype().nullability()),
41        )
42        .into_array());
43    }
44
45    TAKE_FN
46        .invoke(&InvocationArgs {
47            inputs: &[array.into(), indices.into()],
48            options: &(),
49        })?
50        .unwrap_array()
51}
52
53#[doc(hidden)]
54pub struct Take;
55
56impl ComputeFnVTable for Take {
57    fn invoke(
58        &self,
59        args: &InvocationArgs,
60        kernels: &[ArcRef<dyn Kernel>],
61    ) -> VortexResult<Output> {
62        let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
63
64        // TODO(ngates): if indices are sorted and unique (strict-sorted), then we should delegate to
65        //  the filter function since they're typically optimised for this case.
66        // TODO(ngates): if indices min is quite high, we could slice self and offset the indices
67        //  such that canonicalize does less work.
68
69        if indices.all_invalid() {
70            return Ok(ConstantArray::new(
71                Scalar::null(array.dtype().as_nullable()),
72                indices.len(),
73            )
74            .into_array()
75            .into());
76        }
77
78        let taken_array = take_impl(array, indices, kernels)?;
79
80        // We know that constant array don't need stats propagation, so we can avoid the overhead of
81        // computing derived stats and merging them in.
82        if !taken_array.is_constant() {
83            propagate_take_stats(array, &taken_array, indices)?;
84        }
85
86        Ok(taken_array.into())
87    }
88
89    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
90        let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
91
92        if !indices.dtype().is_int() {
93            vortex_bail!(
94                "Take indices must be an integer type, got {}",
95                indices.dtype()
96            );
97        }
98
99        Ok(array
100            .dtype()
101            .union_nullability(indices.dtype().nullability()))
102    }
103
104    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
105        let TakeArgs { indices, .. } = TakeArgs::try_from(args)?;
106        Ok(indices.len())
107    }
108
109    fn is_elementwise(&self) -> bool {
110        false
111    }
112}
113
114fn propagate_take_stats(
115    source: &dyn Array,
116    target: &dyn Array,
117    indices: &dyn Array,
118) -> VortexResult<()> {
119    target.statistics().with_mut_typed_stats_set(|mut st| {
120        if indices.all_valid() {
121            let is_constant = source.statistics().get_as::<bool>(Stat::IsConstant);
122            if is_constant == Some(Precision::Exact(true)) {
123                // Any combination of elements from a constant array is still const
124                st.set(Stat::IsConstant, Precision::exact(true));
125            }
126        }
127        let inexact_min_max = [Stat::Min, Stat::Max]
128            .into_iter()
129            .filter_map(|stat| {
130                source
131                    .statistics()
132                    .get(stat)
133                    .map(|v| (stat, v.map(|s| s.into_value()).into_inexact()))
134            })
135            .collect::<Vec<_>>();
136        st.combine_sets(
137            &(unsafe { StatsSet::new_unchecked(inexact_min_max) }).as_typed_ref(source.dtype()),
138        )
139    })
140}
141
142fn take_impl(
143    array: &dyn Array,
144    indices: &dyn Array,
145    kernels: &[ArcRef<dyn Kernel>],
146) -> VortexResult<ArrayRef> {
147    let args = InvocationArgs {
148        inputs: &[array.into(), indices.into()],
149        options: &(),
150    };
151
152    // First look for a TakeFrom specialized on the indices.
153    for kernel in TAKE_FROM_FN.kernels() {
154        if let Some(output) = kernel.invoke(&args)? {
155            return output.unwrap_array();
156        }
157    }
158    if let Some(output) = indices.invoke(&TAKE_FROM_FN, &args)? {
159        return output.unwrap_array();
160    }
161
162    // Then look for a Take kernel
163    for kernel in kernels {
164        if let Some(output) = kernel.invoke(&args)? {
165            return output.unwrap_array();
166        }
167    }
168    if let Some(output) = array.invoke(&TAKE_FN, &args)? {
169        return output.unwrap_array();
170    }
171
172    // Otherwise, canonicalize and try again.
173    if !array.is_canonical() {
174        log::debug!("No take implementation found for {}", array.encoding_id());
175        let canonical = array.to_canonical();
176        return take(canonical.as_ref(), indices);
177    }
178
179    vortex_bail!("No take implementation found for {}", array.encoding_id());
180}
181
182struct TakeArgs<'a> {
183    array: &'a dyn Array,
184    indices: &'a dyn Array,
185}
186
187impl<'a> TryFrom<&InvocationArgs<'a>> for TakeArgs<'a> {
188    type Error = VortexError;
189
190    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
191        if value.inputs.len() != 2 {
192            vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
193        }
194        let array = value.inputs[0]
195            .array()
196            .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
197        let indices = value.inputs[1]
198            .array()
199            .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
200        Ok(Self { array, indices })
201    }
202}
203
204pub trait TakeKernel: VTable {
205    /// Create a new array by taking the values from the `array` at the
206    /// given `indices`.
207    ///
208    /// # Panics
209    ///
210    /// Using `indices` that are invalid for the given `array` will cause a panic.
211    fn take(&self, array: &Self::Array, indices: &dyn Array) -> VortexResult<ArrayRef>;
212}
213
214/// A kernel that implements the filter function.
215pub struct TakeKernelRef(pub ArcRef<dyn Kernel>);
216inventory::collect!(TakeKernelRef);
217
218#[derive(Debug)]
219pub struct TakeKernelAdapter<V: VTable>(pub V);
220
221impl<V: VTable + TakeKernel> TakeKernelAdapter<V> {
222    pub const fn lift(&'static self) -> TakeKernelRef {
223        TakeKernelRef(ArcRef::new_ref(self))
224    }
225}
226
227impl<V: VTable + TakeKernel> Kernel for TakeKernelAdapter<V> {
228    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
229        let inputs = TakeArgs::try_from(args)?;
230        let Some(array) = inputs.array.as_opt::<V>() else {
231            return Ok(None);
232        };
233        Ok(Some(V::take(&self.0, array, inputs.indices)?.into()))
234    }
235}
236
237static TAKE_FROM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
238    let compute = ComputeFn::new("take_from".into(), ArcRef::new_ref(&TakeFrom));
239    for kernel in inventory::iter::<TakeFromKernelRef> {
240        compute.register_kernel(kernel.0.clone());
241    }
242    compute
243});
244
245pub struct TakeFrom;
246
247impl ComputeFnVTable for TakeFrom {
248    fn invoke(
249        &self,
250        _args: &InvocationArgs,
251        _kernels: &[ArcRef<dyn Kernel>],
252    ) -> VortexResult<Output> {
253        vortex_bail!(
254            "TakeFrom should not be invoked directly. Its kernels are used to accelerated the Take function"
255        )
256    }
257
258    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
259        Take.return_dtype(args)
260    }
261
262    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
263        Take.return_len(args)
264    }
265
266    fn is_elementwise(&self) -> bool {
267        Take.is_elementwise()
268    }
269}
270
271pub trait TakeFromKernel: VTable {
272    /// Create a new array by taking the values from the `array` at the
273    /// given `indices`.
274    fn take_from(&self, indices: &Self::Array, array: &dyn Array)
275    -> VortexResult<Option<ArrayRef>>;
276}
277
278pub struct TakeFromKernelRef(pub ArcRef<dyn Kernel>);
279inventory::collect!(TakeFromKernelRef);
280
281#[derive(Debug)]
282pub struct TakeFromKernelAdapter<V: VTable>(pub V);
283
284impl<V: VTable + TakeFromKernel> TakeFromKernelAdapter<V> {
285    pub const fn lift(&'static self) -> TakeFromKernelRef {
286        TakeFromKernelRef(ArcRef::new_ref(self))
287    }
288}
289
290impl<V: VTable + TakeFromKernel> Kernel for TakeFromKernelAdapter<V> {
291    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
292        let inputs = TakeArgs::try_from(args)?;
293        let Some(indices) = inputs.indices.as_opt::<V>() else {
294            return Ok(None);
295        };
296        Ok(V::take_from(&self.0, indices, inputs.array)?.map(Output::from))
297    }
298}