vortex_array/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::VortexError;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_err;
12use vortex_scalar::Scalar;
13
14use crate::Array;
15use crate::ArrayRef;
16use crate::Canonical;
17use crate::IntoArray;
18use crate::arrays::ConstantArray;
19use crate::compute::ComputeFn;
20use crate::compute::ComputeFnVTable;
21use crate::compute::InvocationArgs;
22use crate::compute::Kernel;
23use crate::compute::Output;
24use crate::expr::stats::Precision;
25use crate::expr::stats::Stat;
26use crate::expr::stats::StatsProvider;
27use crate::expr::stats::StatsProviderExt;
28use crate::stats::StatsSet;
29use crate::vtable::VTable;
30
31static TAKE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
32    let compute = ComputeFn::new("take".into(), ArcRef::new_ref(&Take));
33    for kernel in inventory::iter::<TakeKernelRef> {
34        compute.register_kernel(kernel.0.clone());
35    }
36    compute
37});
38
39pub(crate) fn warm_up_vtable() -> usize {
40    TAKE_FN.kernels().len() + TAKE_FROM_FN.kernels().len()
41}
42
43/// Creates a new array using the elements from the input `array` indexed by `indices`.
44///
45/// For example, if we have an `array` `[1, 2, 3, 4, 5]` and `indices` `[4, 2]`, the resulting
46/// array would be `[5, 3]`.
47///
48/// The output array will have the same length as the `indices` array.
49pub fn take(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
50    if indices.is_empty() {
51        return Ok(Canonical::empty(
52            &array
53                .dtype()
54                .union_nullability(indices.dtype().nullability()),
55        )
56        .into_array());
57    }
58
59    TAKE_FN
60        .invoke(&InvocationArgs {
61            inputs: &[array.into(), indices.into()],
62            options: &(),
63        })?
64        .unwrap_array()
65}
66
67#[doc(hidden)]
68pub struct Take;
69
70impl ComputeFnVTable for Take {
71    fn invoke(
72        &self,
73        args: &InvocationArgs,
74        kernels: &[ArcRef<dyn Kernel>],
75    ) -> VortexResult<Output> {
76        let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
77
78        // TODO(ngates): if indices are sorted and unique (strict-sorted), then we should delegate to
79        //  the filter function since they're typically optimised for this case.
80        // TODO(ngates): if indices min is quite high, we could slice self and offset the indices
81        //  such that canonicalize does less work.
82
83        if indices.all_invalid() {
84            return Ok(ConstantArray::new(
85                Scalar::null(array.dtype().as_nullable()),
86                indices.len(),
87            )
88            .into_array()
89            .into());
90        }
91
92        let taken_array = take_impl(array, indices, kernels)?;
93
94        // We know that constant array don't need stats propagation, so we can avoid the overhead of
95        // computing derived stats and merging them in.
96        if !taken_array.is_constant() {
97            propagate_take_stats(array, &taken_array, indices)?;
98        }
99
100        Ok(taken_array.into())
101    }
102
103    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
104        let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
105
106        if !indices.dtype().is_int() {
107            vortex_bail!(
108                "Take indices must be an integer type, got {}",
109                indices.dtype()
110            );
111        }
112
113        Ok(array
114            .dtype()
115            .union_nullability(indices.dtype().nullability()))
116    }
117
118    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
119        let TakeArgs { indices, .. } = TakeArgs::try_from(args)?;
120        Ok(indices.len())
121    }
122
123    fn is_elementwise(&self) -> bool {
124        false
125    }
126}
127
128fn propagate_take_stats(
129    source: &dyn Array,
130    target: &dyn Array,
131    indices: &dyn Array,
132) -> VortexResult<()> {
133    target.statistics().with_mut_typed_stats_set(|mut st| {
134        if indices.all_valid() {
135            let is_constant = source.statistics().get_as::<bool>(Stat::IsConstant);
136            if is_constant == Some(Precision::Exact(true)) {
137                // Any combination of elements from a constant array is still const
138                st.set(Stat::IsConstant, Precision::exact(true));
139            }
140        }
141        let inexact_min_max = [Stat::Min, Stat::Max]
142            .into_iter()
143            .filter_map(|stat| {
144                source
145                    .statistics()
146                    .get(stat)
147                    .map(|v| (stat, v.map(|s| s.into_value()).into_inexact()))
148            })
149            .collect::<Vec<_>>();
150        st.combine_sets(
151            &(unsafe { StatsSet::new_unchecked(inexact_min_max) }).as_typed_ref(source.dtype()),
152        )
153    })
154}
155
156fn take_impl(
157    array: &dyn Array,
158    indices: &dyn Array,
159    kernels: &[ArcRef<dyn Kernel>],
160) -> VortexResult<ArrayRef> {
161    let args = InvocationArgs {
162        inputs: &[array.into(), indices.into()],
163        options: &(),
164    };
165
166    // First look for a TakeFrom specialized on the indices.
167    for kernel in TAKE_FROM_FN.kernels() {
168        if let Some(output) = kernel.invoke(&args)? {
169            return output.unwrap_array();
170        }
171    }
172    if let Some(output) = indices.invoke(&TAKE_FROM_FN, &args)? {
173        return output.unwrap_array();
174    }
175
176    // Then look for a Take kernel
177    for kernel in kernels {
178        if let Some(output) = kernel.invoke(&args)? {
179            return output.unwrap_array();
180        }
181    }
182    if let Some(output) = array.invoke(&TAKE_FN, &args)? {
183        return output.unwrap_array();
184    }
185
186    // Otherwise, canonicalize and try again.
187    if !array.is_canonical() {
188        log::debug!("No take implementation found for {}", array.encoding_id());
189        let canonical = array.to_canonical();
190        return take(canonical.as_ref(), indices);
191    }
192
193    vortex_bail!("No take implementation found for {}", array.encoding_id());
194}
195
196struct TakeArgs<'a> {
197    array: &'a dyn Array,
198    indices: &'a dyn Array,
199}
200
201impl<'a> TryFrom<&InvocationArgs<'a>> for TakeArgs<'a> {
202    type Error = VortexError;
203
204    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
205        if value.inputs.len() != 2 {
206            vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
207        }
208        let array = value.inputs[0]
209            .array()
210            .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
211        let indices = value.inputs[1]
212            .array()
213            .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
214        Ok(Self { array, indices })
215    }
216}
217
218pub trait TakeKernel: VTable {
219    /// Create a new array by taking the values from the `array` at the
220    /// given `indices`.
221    ///
222    /// # Panics
223    ///
224    /// Using `indices` that are invalid for the given `array` will cause a panic.
225    fn take(&self, array: &Self::Array, indices: &dyn Array) -> VortexResult<ArrayRef>;
226}
227
228/// A kernel that implements the filter function.
229pub struct TakeKernelRef(pub ArcRef<dyn Kernel>);
230inventory::collect!(TakeKernelRef);
231
232#[derive(Debug)]
233pub struct TakeKernelAdapter<V: VTable>(pub V);
234
235impl<V: VTable + TakeKernel> TakeKernelAdapter<V> {
236    pub const fn lift(&'static self) -> TakeKernelRef {
237        TakeKernelRef(ArcRef::new_ref(self))
238    }
239}
240
241impl<V: VTable + TakeKernel> Kernel for TakeKernelAdapter<V> {
242    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
243        let inputs = TakeArgs::try_from(args)?;
244        let Some(array) = inputs.array.as_opt::<V>() else {
245            return Ok(None);
246        };
247        Ok(Some(V::take(&self.0, array, inputs.indices)?.into()))
248    }
249}
250
251static TAKE_FROM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
252    let compute = ComputeFn::new("take_from".into(), ArcRef::new_ref(&TakeFrom));
253    for kernel in inventory::iter::<TakeFromKernelRef> {
254        compute.register_kernel(kernel.0.clone());
255    }
256    compute
257});
258
259pub struct TakeFrom;
260
261impl ComputeFnVTable for TakeFrom {
262    fn invoke(
263        &self,
264        _args: &InvocationArgs,
265        _kernels: &[ArcRef<dyn Kernel>],
266    ) -> VortexResult<Output> {
267        vortex_bail!(
268            "TakeFrom should not be invoked directly. Its kernels are used to accelerated the Take function"
269        )
270    }
271
272    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
273        Take.return_dtype(args)
274    }
275
276    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
277        Take.return_len(args)
278    }
279
280    fn is_elementwise(&self) -> bool {
281        Take.is_elementwise()
282    }
283}
284
285pub trait TakeFromKernel: VTable {
286    /// Create a new array by taking the values from the `array` at the
287    /// given `indices`.
288    fn take_from(&self, indices: &Self::Array, array: &dyn Array)
289    -> VortexResult<Option<ArrayRef>>;
290}
291
292pub struct TakeFromKernelRef(pub ArcRef<dyn Kernel>);
293inventory::collect!(TakeFromKernelRef);
294
295#[derive(Debug)]
296pub struct TakeFromKernelAdapter<V: VTable>(pub V);
297
298impl<V: VTable + TakeFromKernel> TakeFromKernelAdapter<V> {
299    pub const fn lift(&'static self) -> TakeFromKernelRef {
300        TakeFromKernelRef(ArcRef::new_ref(self))
301    }
302}
303
304impl<V: VTable + TakeFromKernel> Kernel for TakeFromKernelAdapter<V> {
305    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
306        let inputs = TakeArgs::try_from(args)?;
307        let Some(indices) = inputs.indices.as_opt::<V>() else {
308            return Ok(None);
309        };
310        Ok(V::take_from(&self.0, indices, inputs.array)?.map(Output::from))
311    }
312}