vortex_array/compute/
take.rs

1use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
2use vortex_scalar::Scalar;
3
4use crate::arrays::ConstantArray;
5use crate::builders::ArrayBuilder;
6use crate::encoding::Encoding;
7use crate::stats::{Precision, Stat, StatsProviderExt, StatsSet};
8use crate::{Array, ArrayRef, IntoArray};
9
10pub trait TakeFn<A> {
11    /// Create a new array by taking the values from the `array` at the
12    /// given `indices`.
13    ///
14    /// # Panics
15    ///
16    /// Using `indices` that are invalid for the given `array` will cause a panic.
17    fn take(&self, array: A, indices: &dyn Array) -> VortexResult<ArrayRef>;
18
19    /// Has the same semantics as `Self::take` but materializes the result into the provided
20    /// builder.
21    fn take_into(
22        &self,
23        array: A,
24        indices: &dyn Array,
25        builder: &mut dyn ArrayBuilder,
26    ) -> VortexResult<()> {
27        builder.extend_from_array(&self.take(array, indices)?)
28    }
29}
30
31impl<E: Encoding> TakeFn<&dyn Array> for E
32where
33    E: for<'a> TakeFn<&'a E::Array>,
34{
35    fn take(&self, array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
36        let array_ref = array
37            .as_any()
38            .downcast_ref::<E::Array>()
39            .vortex_expect("Failed to downcast array");
40        TakeFn::take(self, array_ref, indices)
41    }
42
43    fn take_into(
44        &self,
45        array: &dyn Array,
46        indices: &dyn Array,
47        builder: &mut dyn ArrayBuilder,
48    ) -> VortexResult<()> {
49        let array_ref = array
50            .as_any()
51            .downcast_ref::<E::Array>()
52            .vortex_expect("Failed to downcast array");
53        TakeFn::take_into(self, array_ref, indices, builder)
54    }
55}
56
57pub fn take(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
58    // TODO(ngates): if indices are sorted and unique (strict-sorted), then we should delegate to
59    //  the filter function since they're typically optimised for this case.
60    // TODO(ngates): if indices min is quite high, we could slice self and offset the indices
61    //  such that canonicalize does less work.
62    if indices.all_invalid()? {
63        return Ok(
64            ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
65                .into_array(),
66        );
67    }
68
69    if !indices.dtype().is_int() {
70        vortex_bail!(
71            "Take indices must be an integer type, got {}",
72            indices.dtype()
73        );
74    }
75
76    // We know that constant array don't need stats propagation, so we can avoid the overhead of
77    // computing derived stats and merging them in.
78    let derived_stats = (!array.is_constant()).then(|| derive_take_stats(array));
79
80    let taken = take_impl(array, indices)?;
81
82    if let Some(derived_stats) = derived_stats {
83        let mut stats = taken.statistics().to_owned();
84        stats.combine_sets(&derived_stats, array.dtype())?;
85        for (stat, val) in stats.into_iter() {
86            taken.statistics().set(stat, val)
87        }
88    }
89
90    debug_assert_eq!(
91        taken.len(),
92        indices.len(),
93        "Take length mismatch {}",
94        array.encoding()
95    );
96    #[cfg(debug_assertions)]
97    {
98        // If either the indices or the array are nullable, the result should be nullable.
99        let expected_nullability = indices.dtype().nullability() | array.dtype().nullability();
100        assert_eq!(
101            taken.dtype(),
102            &array.dtype().with_nullability(expected_nullability),
103            "Take result ({}) should be nullable if either the indices ({}) or the array ({}) are nullable. ({})",
104            taken.dtype(),
105            indices.dtype().nullability().verbose_display(),
106            array.dtype().nullability().verbose_display(),
107            array.encoding(),
108        );
109    }
110
111    Ok(taken)
112}
113
114pub fn take_into(
115    array: &dyn Array,
116    indices: &dyn Array,
117    builder: &mut dyn ArrayBuilder,
118) -> VortexResult<()> {
119    if indices.all_invalid()? {
120        builder.append_nulls(indices.len());
121        return Ok(());
122    }
123
124    if array.is_empty() && !indices.is_empty() {
125        vortex_bail!("Cannot take_into from an empty array");
126    }
127
128    #[cfg(debug_assertions)]
129    {
130        // If either the indices or the array are nullable, the result should be nullable.
131        let expected_nullability = indices.dtype().nullability() | array.dtype().nullability();
132        assert_eq!(
133            builder.dtype(),
134            &array.dtype().with_nullability(expected_nullability),
135            "Take_into result ({}) should be nullable if, and only if, either the indices ({}) or the array ({}) are nullable. ({})",
136            builder.dtype(),
137            indices.dtype().nullability().verbose_display(),
138            array.dtype().nullability().verbose_display(),
139            array.encoding(),
140        );
141    }
142
143    if !indices.dtype().is_int() {
144        vortex_bail!(
145            "Take indices must be an integer type, got {}",
146            indices.dtype()
147        );
148    }
149
150    let before_len = builder.len();
151
152    // We know that constant array don't need stats propagation, so we can avoid the overhead of
153    // computing derived stats and merging them in.
154    take_into_impl(array, indices, builder)?;
155
156    let after_len = builder.len();
157
158    debug_assert_eq!(
159        after_len - before_len,
160        indices.len(),
161        "Take_into length mismatch {}",
162        array.encoding()
163    );
164
165    Ok(())
166}
167
168fn derive_take_stats(arr: &dyn Array) -> StatsSet {
169    let stats = arr.statistics().to_owned();
170
171    let is_constant = stats.get_as::<bool>(Stat::IsConstant);
172
173    let mut stats = stats.keep_inexact_stats(&[
174        // Cannot create values smaller than min or larger than max
175        Stat::Min,
176        Stat::Max,
177    ]);
178
179    if is_constant == Some(Precision::Exact(true)) {
180        // Any combination of elements from a constant array is still const
181        stats.set(Stat::IsConstant, Precision::exact(true));
182    }
183
184    stats
185}
186
187fn take_impl(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
188    // First look for a TakeFrom specialized on the indices.
189    if let Some(take_from_fn) = indices.vtable().take_from_fn() {
190        if let Some(arr) = take_from_fn.take_from(indices, array)? {
191            return Ok(arr);
192        }
193    }
194
195    // If TakeFn defined for the encoding, delegate to TakeFn.
196    // If we know from stats that indices are all valid, we can avoid all bounds checks.
197    if let Some(take_fn) = array.vtable().take_fn() {
198        return take_fn.take(array, indices);
199    }
200
201    // Otherwise, flatten and try again.
202    log::debug!("No take implementation found for {}", array.encoding());
203    let canonical = array.to_canonical()?.into_array();
204    let vtable = canonical.vtable();
205    let canonical_take_fn = vtable
206        .take_fn()
207        .ok_or_else(|| vortex_err!(NotImplemented: "take", canonical.encoding()))?;
208
209    canonical_take_fn.take(&canonical, indices)
210}
211
212fn take_into_impl(
213    array: &dyn Array,
214    indices: &dyn Array,
215    builder: &mut dyn ArrayBuilder,
216) -> VortexResult<()> {
217    let result_nullability = array.dtype().nullability() | indices.dtype().nullability();
218    let result_dtype = array.dtype().with_nullability(result_nullability);
219    if &result_dtype != builder.dtype() {
220        vortex_bail!(
221            "TakeIntoFn {} had a builder with a different dtype {} to the resulting array dtype {}",
222            array.encoding(),
223            builder.dtype(),
224            result_dtype,
225        );
226    }
227    if let Some(take_fn) = array.vtable().take_fn() {
228        return take_fn.take_into(array, indices, builder);
229    }
230
231    // Otherwise, flatten and try again.
232    log::debug!("No take_into implementation found for {}", array.encoding());
233    let canonical = array.to_canonical()?.into_array();
234    let vtable = canonical.vtable();
235    let canonical_take_fn = vtable
236        .take_fn()
237        .ok_or_else(|| vortex_err!(NotImplemented: "take", canonical.encoding()))?;
238
239    canonical_take_fn.take_into(&canonical, indices, builder)
240}