vortex_array/compute/
take.rs

1use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
2use vortex_scalar::Scalar;
3
4use crate::arrays::ConstantArray;
5use crate::builders::ArrayBuilder;
6use crate::encoding::Encoding;
7use crate::stats::{Precision, Stat, StatsProviderExt, StatsSet};
8use crate::{Array, ArrayRef, IntoArray};
9
10pub trait TakeFn<A> {
11    /// Create a new array by taking the values from the `array` at the
12    /// given `indices`.
13    ///
14    /// # Panics
15    ///
16    /// Using `indices` that are invalid for the given `array` will cause a panic.
17    fn take(&self, array: A, indices: &dyn Array) -> VortexResult<ArrayRef>;
18
19    /// Has the same semantics as `Self::take` but materializes the result into the provided
20    /// builder.
21    fn take_into(
22        &self,
23        array: A,
24        indices: &dyn Array,
25        builder: &mut dyn ArrayBuilder,
26    ) -> VortexResult<()> {
27        builder.extend_from_array(&self.take(array, indices)?)
28    }
29}
30
31impl<E: Encoding> TakeFn<&dyn Array> for E
32where
33    E: for<'a> TakeFn<&'a E::Array>,
34{
35    fn take(&self, array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
36        let array_ref = array
37            .as_any()
38            .downcast_ref::<E::Array>()
39            .vortex_expect("Failed to downcast array");
40        TakeFn::take(self, array_ref, indices)
41    }
42
43    fn take_into(
44        &self,
45        array: &dyn Array,
46        indices: &dyn Array,
47        builder: &mut dyn ArrayBuilder,
48    ) -> VortexResult<()> {
49        let array_ref = array
50            .as_any()
51            .downcast_ref::<E::Array>()
52            .vortex_expect("Failed to downcast array");
53        TakeFn::take_into(self, array_ref, indices, builder)
54    }
55}
56
57pub fn take(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
58    // TODO(ngates): if indices are sorted and unique (strict-sorted), then we should delegate to
59    //  the filter function since they're typically optimised for this case.
60    // TODO(ngates): if indices min is quite high, we could slice self and offset the indices
61    //  such that canonicalize does less work.
62    if indices.all_invalid()? {
63        return Ok(
64            ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
65                .into_array(),
66        );
67    }
68
69    if !indices.dtype().is_int() {
70        vortex_bail!(
71            "Take indices must be an integer type, got {}",
72            indices.dtype()
73        );
74    }
75
76    // We know that constant array don't need stats propagation, so we can avoid the overhead of
77    // computing derived stats and merging them in.
78    let derived_stats = (!array.is_constant()).then(|| derive_take_stats(array));
79
80    let taken = take_impl(array, indices)?;
81
82    if let Some(derived_stats) = derived_stats {
83        let mut stats = taken.statistics().to_owned();
84        stats.combine_sets(&derived_stats, array.dtype())?;
85        for (stat, val) in stats.into_iter() {
86            taken.statistics().set(stat, val)
87        }
88    }
89
90    assert_eq!(
91        taken.len(),
92        indices.len(),
93        "Take length mismatch {}",
94        array.encoding()
95    );
96    // If either the indices or the array are nullable, the result should be nullable.
97    let expected_nullability = indices.dtype().nullability() | array.dtype().nullability();
98    assert_eq!(
99        taken.dtype(),
100        &array.dtype().with_nullability(expected_nullability),
101        "Take result ({}) should be nullable if either the indices ({}) or the array ({}) are nullable. ({})",
102        taken.dtype(),
103        indices.dtype().nullability().verbose_display(),
104        array.dtype().nullability().verbose_display(),
105        array.encoding(),
106    );
107
108    Ok(taken)
109}
110
111pub fn take_into(
112    array: &dyn Array,
113    indices: &dyn Array,
114    builder: &mut dyn ArrayBuilder,
115) -> VortexResult<()> {
116    if indices.all_invalid()? {
117        builder.append_nulls(indices.len());
118        return Ok(());
119    }
120
121    if array.is_empty() && !indices.is_empty() {
122        vortex_bail!("Cannot take_into from an empty array");
123    }
124
125    // If either the indices or the array are nullable, the result should be nullable.
126    let expected_nullability = indices.dtype().nullability() | array.dtype().nullability();
127    assert_eq!(
128        builder.dtype(),
129        &array.dtype().with_nullability(expected_nullability),
130        "Take_into result ({}) should be nullable if, and only if, either the indices ({}) or the array ({}) are nullable. ({})",
131        builder.dtype(),
132        indices.dtype().nullability().verbose_display(),
133        array.dtype().nullability().verbose_display(),
134        array.encoding(),
135    );
136
137    if !indices.dtype().is_int() {
138        vortex_bail!(
139            "Take indices must be an integer type, got {}",
140            indices.dtype()
141        );
142    }
143
144    let before_len = builder.len();
145
146    // We know that constant array don't need stats propagation, so we can avoid the overhead of
147    // computing derived stats and merging them in.
148    take_into_impl(array, indices, builder)?;
149
150    let after_len = builder.len();
151
152    assert_eq!(
153        after_len - before_len,
154        indices.len(),
155        "Take_into length mismatch {}",
156        array.encoding()
157    );
158
159    Ok(())
160}
161
162fn derive_take_stats(arr: &dyn Array) -> StatsSet {
163    let stats = arr.statistics().to_owned();
164
165    let is_constant = stats.get_as::<bool>(Stat::IsConstant);
166
167    let mut stats = stats.keep_inexact_stats(&[
168        // Cannot create values smaller than min or larger than max
169        Stat::Min,
170        Stat::Max,
171    ]);
172
173    if is_constant == Some(Precision::Exact(true)) {
174        // Any combination of elements from a constant array is still const
175        stats.set(Stat::IsConstant, Precision::exact(true));
176    }
177
178    stats
179}
180
181fn take_impl(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
182    // First look for a TakeFrom specialized on the indices.
183    if let Some(take_from_fn) = indices.vtable().take_from_fn() {
184        if let Some(arr) = take_from_fn.take_from(indices, array)? {
185            return Ok(arr);
186        }
187    }
188
189    // If TakeFn defined for the encoding, delegate to TakeFn.
190    // If we know from stats that indices are all valid, we can avoid all bounds checks.
191    if let Some(take_fn) = array.vtable().take_fn() {
192        return take_fn.take(array, indices);
193    }
194
195    // Otherwise, flatten and try again.
196    log::debug!("No take implementation found for {}", array.encoding());
197    let canonical = array.to_canonical()?.into_array();
198    let vtable = canonical.vtable();
199    let canonical_take_fn = vtable
200        .take_fn()
201        .ok_or_else(|| vortex_err!(NotImplemented: "take", canonical.encoding()))?;
202
203    canonical_take_fn.take(&canonical, indices)
204}
205
206fn take_into_impl(
207    array: &dyn Array,
208    indices: &dyn Array,
209    builder: &mut dyn ArrayBuilder,
210) -> VortexResult<()> {
211    let result_nullability = array.dtype().nullability() | indices.dtype().nullability();
212    let result_dtype = array.dtype().with_nullability(result_nullability);
213    if &result_dtype != builder.dtype() {
214        vortex_bail!(
215            "TakeIntoFn {} had a builder with a different dtype {} to the resulting array dtype {}",
216            array.encoding(),
217            builder.dtype(),
218            result_dtype,
219        );
220    }
221    if let Some(take_fn) = array.vtable().take_fn() {
222        return take_fn.take_into(array, indices, builder);
223    }
224
225    // Otherwise, flatten and try again.
226    log::debug!("No take_into implementation found for {}", array.encoding());
227    let canonical = array.to_canonical()?.into_array();
228    let vtable = canonical.vtable();
229    let canonical_take_fn = vtable
230        .take_fn()
231        .ok_or_else(|| vortex_err!(NotImplemented: "take", canonical.encoding()))?;
232
233    canonical_take_fn.take_into(&canonical, indices, builder)
234}