vortex_array/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
9use vortex_scalar::Scalar;
10
11use crate::arrays::ConstantArray;
12use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output};
13use crate::stats::{Precision, Stat, StatsProviderExt, StatsSet};
14use crate::vtable::VTable;
15use crate::{Array, ArrayRef, Canonical, IntoArray};
16
17pub fn take(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
18    if indices.is_empty() {
19        return Ok(Canonical::empty(
20            &array
21                .dtype()
22                .union_nullability(indices.dtype().nullability()),
23        )
24        .into_array());
25    }
26
27    TAKE_FN
28        .invoke(&InvocationArgs {
29            inputs: &[array.into(), indices.into()],
30            options: &(),
31        })?
32        .unwrap_array()
33}
34
35pub static TAKE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
36    let compute = ComputeFn::new("take".into(), ArcRef::new_ref(&Take));
37    for kernel in inventory::iter::<TakeKernelRef> {
38        compute.register_kernel(kernel.0.clone());
39    }
40    compute
41});
42
43pub struct Take;
44
45impl ComputeFnVTable for Take {
46    fn invoke(
47        &self,
48        args: &InvocationArgs,
49        kernels: &[ArcRef<dyn Kernel>],
50    ) -> VortexResult<Output> {
51        let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
52
53        // TODO(ngates): if indices are sorted and unique (strict-sorted), then we should delegate to
54        //  the filter function since they're typically optimised for this case.
55        // TODO(ngates): if indices min is quite high, we could slice self and offset the indices
56        //  such that canonicalize does less work.
57
58        if indices.all_invalid()? {
59            return Ok(ConstantArray::new(
60                Scalar::null(array.dtype().as_nullable()),
61                indices.len(),
62            )
63            .into_array()
64            .into());
65        }
66
67        // We know that constant array don't need stats propagation, so we can avoid the overhead of
68        // computing derived stats and merging them in.
69        let derived_stats = (!array.is_constant()).then(|| derive_take_stats(array));
70
71        let taken = take_impl(array, indices, kernels)?;
72
73        if let Some(derived_stats) = derived_stats {
74            let mut stats = taken.statistics().to_owned();
75            stats.combine_sets(&derived_stats, array.dtype())?;
76            for (stat, val) in stats.into_iter() {
77                taken.statistics().set(stat, val)
78            }
79        }
80
81        Ok(taken.into())
82    }
83
84    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
85        let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
86
87        if !indices.dtype().is_int() {
88            vortex_bail!(
89                "Take indices must be an integer type, got {}",
90                indices.dtype()
91            );
92        }
93
94        Ok(array
95            .dtype()
96            .union_nullability(indices.dtype().nullability()))
97    }
98
99    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
100        let TakeArgs { indices, .. } = TakeArgs::try_from(args)?;
101        Ok(indices.len())
102    }
103
104    fn is_elementwise(&self) -> bool {
105        false
106    }
107}
108
109fn derive_take_stats(arr: &dyn Array) -> StatsSet {
110    let stats = arr.statistics().to_owned();
111
112    let is_constant = stats.get_as::<bool>(Stat::IsConstant);
113
114    let mut stats = stats.keep_inexact_stats(&[
115        // Cannot create values smaller than min or larger than max
116        Stat::Min,
117        Stat::Max,
118    ]);
119
120    if is_constant == Some(Precision::Exact(true)) {
121        // Any combination of elements from a constant array is still const
122        stats.set(Stat::IsConstant, Precision::exact(true));
123    }
124
125    stats
126}
127
128fn take_impl(
129    array: &dyn Array,
130    indices: &dyn Array,
131    kernels: &[ArcRef<dyn Kernel>],
132) -> VortexResult<ArrayRef> {
133    let args = InvocationArgs {
134        inputs: &[array.into(), indices.into()],
135        options: &(),
136    };
137
138    // First look for a TakeFrom specialized on the indices.
139    for kernel in TAKE_FROM_FN.kernels() {
140        if let Some(output) = kernel.invoke(&args)? {
141            return output.unwrap_array();
142        }
143    }
144    if let Some(output) = indices.invoke(&TAKE_FROM_FN, &args)? {
145        return output.unwrap_array();
146    }
147
148    // Then look for a Take kernel
149    for kernel in kernels {
150        if let Some(output) = kernel.invoke(&args)? {
151            return output.unwrap_array();
152        }
153    }
154    if let Some(output) = array.invoke(&TAKE_FN, &args)? {
155        return output.unwrap_array();
156    }
157
158    // Otherwise, canonicalize and try again.
159    if !array.is_canonical() {
160        log::debug!("No take implementation found for {}", array.encoding_id());
161        let canonical = array.to_canonical()?;
162        return take(canonical.as_ref(), indices);
163    }
164
165    vortex_bail!("No take implementation found for {}", array.encoding_id());
166}
167
168struct TakeArgs<'a> {
169    array: &'a dyn Array,
170    indices: &'a dyn Array,
171}
172
173impl<'a> TryFrom<&InvocationArgs<'a>> for TakeArgs<'a> {
174    type Error = VortexError;
175
176    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
177        if value.inputs.len() != 2 {
178            vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
179        }
180        let array = value.inputs[0]
181            .array()
182            .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
183        let indices = value.inputs[1]
184            .array()
185            .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
186        Ok(Self { array, indices })
187    }
188}
189
190pub trait TakeKernel: VTable {
191    /// Create a new array by taking the values from the `array` at the
192    /// given `indices`.
193    ///
194    /// # Panics
195    ///
196    /// Using `indices` that are invalid for the given `array` will cause a panic.
197    fn take(&self, array: &Self::Array, indices: &dyn Array) -> VortexResult<ArrayRef>;
198}
199
200/// A kernel that implements the filter function.
201pub struct TakeKernelRef(pub ArcRef<dyn Kernel>);
202inventory::collect!(TakeKernelRef);
203
204#[derive(Debug)]
205pub struct TakeKernelAdapter<V: VTable>(pub V);
206
207impl<V: VTable + TakeKernel> TakeKernelAdapter<V> {
208    pub const fn lift(&'static self) -> TakeKernelRef {
209        TakeKernelRef(ArcRef::new_ref(self))
210    }
211}
212
213impl<V: VTable + TakeKernel> Kernel for TakeKernelAdapter<V> {
214    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
215        let inputs = TakeArgs::try_from(args)?;
216        let Some(array) = inputs.array.as_opt::<V>() else {
217            return Ok(None);
218        };
219        Ok(Some(V::take(&self.0, array, inputs.indices)?.into()))
220    }
221}
222
223pub static TAKE_FROM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
224    let compute = ComputeFn::new("take_from".into(), ArcRef::new_ref(&TakeFrom));
225    for kernel in inventory::iter::<TakeFromKernelRef> {
226        compute.register_kernel(kernel.0.clone());
227    }
228    compute
229});
230
231pub struct TakeFrom;
232
233impl ComputeFnVTable for TakeFrom {
234    fn invoke(
235        &self,
236        _args: &InvocationArgs,
237        _kernels: &[ArcRef<dyn Kernel>],
238    ) -> VortexResult<Output> {
239        vortex_bail!(
240            "TakeFrom should not be invoked directly. Its kernels are used to accelerated the Take function"
241        )
242    }
243
244    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
245        Take.return_dtype(args)
246    }
247
248    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
249        Take.return_len(args)
250    }
251
252    fn is_elementwise(&self) -> bool {
253        Take.is_elementwise()
254    }
255}
256
257pub trait TakeFromKernel: VTable {
258    /// Create a new array by taking the values from the `array` at the
259    /// given `indices`.
260    fn take_from(&self, indices: &Self::Array, array: &dyn Array)
261    -> VortexResult<Option<ArrayRef>>;
262}
263
264pub struct TakeFromKernelRef(pub ArcRef<dyn Kernel>);
265inventory::collect!(TakeFromKernelRef);
266
267#[derive(Debug)]
268pub struct TakeFromKernelAdapter<V: VTable>(pub V);
269
270impl<V: VTable + TakeFromKernel> TakeFromKernelAdapter<V> {
271    pub const fn lift(&'static self) -> TakeFromKernelRef {
272        TakeFromKernelRef(ArcRef::new_ref(self))
273    }
274}
275
276impl<V: VTable + TakeFromKernel> Kernel for TakeFromKernelAdapter<V> {
277    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
278        let inputs = TakeArgs::try_from(args)?;
279        let Some(indices) = inputs.indices.as_opt::<V>() else {
280            return Ok(None);
281        };
282        Ok(V::take_from(&self.0, indices, inputs.array)?.map(Output::from))
283    }
284}