vortex_array/compute/
take.rs

1use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
2use vortex_scalar::Scalar;
3
4use crate::arrays::ConstantArray;
5use crate::builders::ArrayBuilder;
6use crate::encoding::Encoding;
7use crate::stats::{Precision, Stat, StatsProviderExt, StatsSet};
8use crate::{Array, ArrayRef, IntoArray};
9
10pub trait TakeFn<A> {
11    /// Create a new array by taking the values from the `array` at the
12    /// given `indices`.
13    ///
14    /// # Panics
15    ///
16    /// Using `indices` that are invalid for the given `array` will cause a panic.
17    fn take(&self, array: A, indices: &dyn Array) -> VortexResult<ArrayRef>;
18
19    /// Has the same semantics as `Self::take` but materializes the result into the provided
20    /// builder.
21    fn take_into(
22        &self,
23        array: A,
24        indices: &dyn Array,
25        builder: &mut dyn ArrayBuilder,
26    ) -> VortexResult<()> {
27        builder.extend_from_array(&self.take(array, indices)?)
28    }
29}
30
31impl<E: Encoding> TakeFn<&dyn Array> for E
32where
33    E: for<'a> TakeFn<&'a E::Array>,
34{
35    fn take(&self, array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
36        let array_ref = array
37            .as_any()
38            .downcast_ref::<E::Array>()
39            .vortex_expect("Failed to downcast array");
40        TakeFn::take(self, array_ref, indices)
41    }
42
43    fn take_into(
44        &self,
45        array: &dyn Array,
46        indices: &dyn Array,
47        builder: &mut dyn ArrayBuilder,
48    ) -> VortexResult<()> {
49        let array_ref = array
50            .as_any()
51            .downcast_ref::<E::Array>()
52            .vortex_expect("Failed to downcast array");
53        TakeFn::take_into(self, array_ref, indices, builder)
54    }
55}
56
57pub fn take(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
58    // TODO(ngates): if indices are sorted and unique (strict-sorted), then we should delegate to
59    //  the filter function since they're typically optimised for this case.
60    // TODO(ngates): if indices min is quite high, we could slice self and offset the indices
61    //  such that canonicalize does less work.
62    if indices.all_invalid()? {
63        return Ok(
64            ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
65                .into_array(),
66        );
67    }
68
69    if !indices.dtype().is_int() {
70        vortex_bail!(
71            "Take indices must be an integer type, got {}",
72            indices.dtype()
73        );
74    }
75
76    // We know that constant array don't need stats propagation, so we can avoid the overhead of
77    // computing derived stats and merging them in.
78    let derived_stats = (!array.is_constant()).then(|| derive_take_stats(array));
79
80    let taken = take_impl(array, indices)?;
81
82    if let Some(derived_stats) = derived_stats {
83        let mut stats = taken.statistics().to_owned();
84        stats.combine_sets(&derived_stats, array.dtype())?;
85        for (stat, val) in stats.into_iter() {
86            taken.statistics().set(stat, val)
87        }
88    }
89
90    debug_assert_eq!(
91        taken.len(),
92        indices.len(),
93        "Take length mismatch {}",
94        array.encoding()
95    );
96    #[cfg(debug_assertions)]
97    {
98        // If either the indices or the array are nullable, the result should be nullable.
99        let expected_nullability = indices.dtype().nullability() | array.dtype().nullability();
100        assert_eq!(
101            taken.dtype(),
102            &array.dtype().with_nullability(expected_nullability),
103            "Take result ({}) should be nullable if either the indices ({}) or the array ({}) are nullable. ({})",
104            taken.dtype(),
105            indices.dtype().nullability().verbose_display(),
106            array.dtype().nullability().verbose_display(),
107            array.encoding(),
108        );
109    }
110
111    Ok(taken)
112}
113
114pub fn take_into(
115    array: &dyn Array,
116    indices: &dyn Array,
117    builder: &mut dyn ArrayBuilder,
118) -> VortexResult<()> {
119    if array.is_empty() && !indices.is_empty() {
120        vortex_bail!("Cannot take_into from an empty array");
121    }
122
123    #[cfg(debug_assertions)]
124    {
125        // If either the indices or the array are nullable, the result should be nullable.
126        let expected_nullability = indices.dtype().nullability() | array.dtype().nullability();
127        assert_eq!(
128            builder.dtype(),
129            &array.dtype().with_nullability(expected_nullability),
130            "Take_into result ({}) should be nullable if, and only if, either the indices ({}) or the array ({}) are nullable. ({})",
131            builder.dtype(),
132            indices.dtype().nullability().verbose_display(),
133            array.dtype().nullability().verbose_display(),
134            array.encoding(),
135        );
136    }
137
138    if !indices.dtype().is_int() {
139        vortex_bail!(
140            "Take indices must be an integer type, got {}",
141            indices.dtype()
142        );
143    }
144
145    let before_len = builder.len();
146
147    // We know that constant array don't need stats propagation, so we can avoid the overhead of
148    // computing derived stats and merging them in.
149    take_into_impl(array, indices, builder)?;
150
151    let after_len = builder.len();
152
153    debug_assert_eq!(
154        after_len - before_len,
155        indices.len(),
156        "Take_into length mismatch {}",
157        array.encoding()
158    );
159
160    Ok(())
161}
162
163fn derive_take_stats(arr: &dyn Array) -> StatsSet {
164    let stats = arr.statistics().to_owned();
165
166    let is_constant = stats.get_as::<bool>(Stat::IsConstant);
167
168    let mut stats = stats.keep_inexact_stats(&[
169        // Cannot create values smaller than min or larger than max
170        Stat::Min,
171        Stat::Max,
172    ]);
173
174    if is_constant == Some(Precision::Exact(true)) {
175        // Any combination of elements from a constant array is still const
176        stats.set(Stat::IsConstant, Precision::exact(true));
177    }
178
179    stats
180}
181
182fn take_impl(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
183    // First look for a TakeFrom specialized on the indices.
184    if let Some(take_from_fn) = indices.vtable().take_from_fn() {
185        if let Some(arr) = take_from_fn.take_from(indices, array)? {
186            return Ok(arr);
187        }
188    }
189
190    // If TakeFn defined for the encoding, delegate to TakeFn.
191    // If we know from stats that indices are all valid, we can avoid all bounds checks.
192    if let Some(take_fn) = array.vtable().take_fn() {
193        return take_fn.take(array, indices);
194    }
195
196    // Otherwise, flatten and try again.
197    log::debug!("No take implementation found for {}", array.encoding());
198    let canonical = array.to_canonical()?.into_array();
199    let vtable = canonical.vtable();
200    let canonical_take_fn = vtable
201        .take_fn()
202        .ok_or_else(|| vortex_err!(NotImplemented: "take", canonical.encoding()))?;
203
204    canonical_take_fn.take(&canonical, indices)
205}
206
207fn take_into_impl(
208    array: &dyn Array,
209    indices: &dyn Array,
210    builder: &mut dyn ArrayBuilder,
211) -> VortexResult<()> {
212    let result_nullability = array.dtype().nullability() | indices.dtype().nullability();
213    let result_dtype = array.dtype().with_nullability(result_nullability);
214    if &result_dtype != builder.dtype() {
215        vortex_bail!(
216            "TakeIntoFn {} had a builder with a different dtype {} to the resulting array dtype {}",
217            array.encoding(),
218            builder.dtype(),
219            result_dtype,
220        );
221    }
222    if let Some(take_fn) = array.vtable().take_fn() {
223        return take_fn.take_into(array, indices, builder);
224    }
225
226    // Otherwise, flatten and try again.
227    log::debug!("No take_into implementation found for {}", array.encoding());
228    let canonical = array.to_canonical()?.into_array();
229    let vtable = canonical.vtable();
230    let canonical_take_fn = vtable
231        .take_fn()
232        .ok_or_else(|| vortex_err!(NotImplemented: "take", canonical.encoding()))?;
233
234    canonical_take_fn.take_into(&canonical, indices, builder)
235}