vortex_array/compute/
take.rs

1use std::sync::LazyLock;
2
3use arcref::ArcRef;
4use vortex_dtype::DType;
5use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
6use vortex_scalar::Scalar;
7
8use crate::arrays::ConstantArray;
9use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output};
10use crate::stats::{Precision, Stat, StatsProviderExt, StatsSet};
11use crate::vtable::VTable;
12use crate::{Array, ArrayRef, IntoArray};
13
14pub fn take(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
15    TAKE_FN
16        .invoke(&InvocationArgs {
17            inputs: &[array.into(), indices.into()],
18            options: &(),
19        })?
20        .unwrap_array()
21}
22
23pub static TAKE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
24    let compute = ComputeFn::new("take".into(), ArcRef::new_ref(&Take));
25    for kernel in inventory::iter::<TakeKernelRef> {
26        compute.register_kernel(kernel.0.clone());
27    }
28    compute
29});
30
31pub struct Take;
32
33impl ComputeFnVTable for Take {
34    fn invoke(
35        &self,
36        args: &InvocationArgs,
37        kernels: &[ArcRef<dyn Kernel>],
38    ) -> VortexResult<Output> {
39        let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
40
41        // TODO(ngates): if indices are sorted and unique (strict-sorted), then we should delegate to
42        //  the filter function since they're typically optimised for this case.
43        // TODO(ngates): if indices min is quite high, we could slice self and offset the indices
44        //  such that canonicalize does less work.
45
46        if indices.all_invalid()? {
47            return Ok(ConstantArray::new(
48                Scalar::null(array.dtype().as_nullable()),
49                indices.len(),
50            )
51            .into_array()
52            .into());
53        }
54
55        // We know that constant array don't need stats propagation, so we can avoid the overhead of
56        // computing derived stats and merging them in.
57        let derived_stats = (!array.is_constant()).then(|| derive_take_stats(array));
58
59        let taken = take_impl(array, indices, kernels)?;
60
61        if let Some(derived_stats) = derived_stats {
62            let mut stats = taken.statistics().to_owned();
63            stats.combine_sets(&derived_stats, array.dtype())?;
64            for (stat, val) in stats.into_iter() {
65                taken.statistics().set(stat, val)
66            }
67        }
68
69        Ok(taken.into())
70    }
71
72    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
73        let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
74
75        if !indices.dtype().is_int() {
76            vortex_bail!(
77                "Take indices must be an integer type, got {}",
78                indices.dtype()
79            );
80        }
81
82        Ok(array
83            .dtype()
84            .union_nullability(indices.dtype().nullability()))
85    }
86
87    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
88        let TakeArgs { indices, .. } = TakeArgs::try_from(args)?;
89        Ok(indices.len())
90    }
91
92    fn is_elementwise(&self) -> bool {
93        false
94    }
95}
96
97fn derive_take_stats(arr: &dyn Array) -> StatsSet {
98    let stats = arr.statistics().to_owned();
99
100    let is_constant = stats.get_as::<bool>(Stat::IsConstant);
101
102    let mut stats = stats.keep_inexact_stats(&[
103        // Cannot create values smaller than min or larger than max
104        Stat::Min,
105        Stat::Max,
106    ]);
107
108    if is_constant == Some(Precision::Exact(true)) {
109        // Any combination of elements from a constant array is still const
110        stats.set(Stat::IsConstant, Precision::exact(true));
111    }
112
113    stats
114}
115
116fn take_impl(
117    array: &dyn Array,
118    indices: &dyn Array,
119    kernels: &[ArcRef<dyn Kernel>],
120) -> VortexResult<ArrayRef> {
121    let args = InvocationArgs {
122        inputs: &[array.into(), indices.into()],
123        options: &(),
124    };
125
126    // First look for a TakeFrom specialized on the indices.
127    for kernel in TAKE_FROM_FN.kernels() {
128        if let Some(output) = kernel.invoke(&args)? {
129            return output.unwrap_array();
130        }
131    }
132    if let Some(output) = indices.invoke(&TAKE_FROM_FN, &args)? {
133        return output.unwrap_array();
134    }
135
136    // Then look for a Take kernel
137    for kernel in kernels {
138        if let Some(output) = kernel.invoke(&args)? {
139            return output.unwrap_array();
140        }
141    }
142    if let Some(output) = array.invoke(&TAKE_FN, &args)? {
143        return output.unwrap_array();
144    }
145
146    // Otherwise, canonicalize and try again.
147    if !array.is_canonical() {
148        log::debug!("No take implementation found for {}", array.encoding_id());
149        let canonical = array.to_canonical()?;
150        return take(canonical.as_ref(), indices);
151    }
152
153    vortex_bail!("No take implementation found for {}", array.encoding_id());
154}
155
156struct TakeArgs<'a> {
157    array: &'a dyn Array,
158    indices: &'a dyn Array,
159}
160
161impl<'a> TryFrom<&InvocationArgs<'a>> for TakeArgs<'a> {
162    type Error = VortexError;
163
164    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
165        if value.inputs.len() != 2 {
166            vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
167        }
168        let array = value.inputs[0]
169            .array()
170            .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
171        let indices = value.inputs[1]
172            .array()
173            .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
174        Ok(Self { array, indices })
175    }
176}
177
178pub trait TakeKernel: VTable {
179    /// Create a new array by taking the values from the `array` at the
180    /// given `indices`.
181    ///
182    /// # Panics
183    ///
184    /// Using `indices` that are invalid for the given `array` will cause a panic.
185    fn take(&self, array: &Self::Array, indices: &dyn Array) -> VortexResult<ArrayRef>;
186}
187
188/// A kernel that implements the filter function.
189pub struct TakeKernelRef(pub ArcRef<dyn Kernel>);
190inventory::collect!(TakeKernelRef);
191
192#[derive(Debug)]
193pub struct TakeKernelAdapter<V: VTable>(pub V);
194
195impl<V: VTable + TakeKernel> TakeKernelAdapter<V> {
196    pub const fn lift(&'static self) -> TakeKernelRef {
197        TakeKernelRef(ArcRef::new_ref(self))
198    }
199}
200
201impl<V: VTable + TakeKernel> Kernel for TakeKernelAdapter<V> {
202    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
203        let inputs = TakeArgs::try_from(args)?;
204        let Some(array) = inputs.array.as_opt::<V>() else {
205            return Ok(None);
206        };
207        Ok(Some(V::take(&self.0, array, inputs.indices)?.into()))
208    }
209}
210
211pub static TAKE_FROM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
212    let compute = ComputeFn::new("take_from".into(), ArcRef::new_ref(&TakeFrom));
213    for kernel in inventory::iter::<TakeFromKernelRef> {
214        compute.register_kernel(kernel.0.clone());
215    }
216    compute
217});
218
219pub struct TakeFrom;
220
221impl ComputeFnVTable for TakeFrom {
222    fn invoke(
223        &self,
224        _args: &InvocationArgs,
225        _kernels: &[ArcRef<dyn Kernel>],
226    ) -> VortexResult<Output> {
227        vortex_bail!(
228            "TakeFrom should not be invoked directly. Its kernels are used to accelerated the Take function"
229        )
230    }
231
232    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
233        Take.return_dtype(args)
234    }
235
236    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
237        Take.return_len(args)
238    }
239
240    fn is_elementwise(&self) -> bool {
241        Take.is_elementwise()
242    }
243}
244
245pub trait TakeFromKernel: VTable {
246    /// Create a new array by taking the values from the `array` at the
247    /// given `indices`.
248    fn take_from(&self, indices: &Self::Array, array: &dyn Array)
249    -> VortexResult<Option<ArrayRef>>;
250}
251
252pub struct TakeFromKernelRef(pub ArcRef<dyn Kernel>);
253inventory::collect!(TakeFromKernelRef);
254
255#[derive(Debug)]
256pub struct TakeFromKernelAdapter<V: VTable>(pub V);
257
258impl<V: VTable + TakeFromKernel> TakeFromKernelAdapter<V> {
259    pub const fn lift(&'static self) -> TakeFromKernelRef {
260        TakeFromKernelRef(ArcRef::new_ref(self))
261    }
262}
263
264impl<V: VTable + TakeFromKernel> Kernel for TakeFromKernelAdapter<V> {
265    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
266        let inputs = TakeArgs::try_from(args)?;
267        let Some(indices) = inputs.indices.as_opt::<V>() else {
268            return Ok(None);
269        };
270        Ok(V::take_from(&self.0, indices, inputs.array)?.map(Output::from))
271    }
272}