Skip to main content

vortex_array/arrays/dict/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use super::DictArray;
7use super::DictVTable;
8use crate::Array;
9use crate::ArrayRef;
10use crate::Canonical;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::arrays::ConstantArray;
14use crate::expr::stats::Precision;
15use crate::expr::stats::Stat;
16use crate::expr::stats::StatsProvider;
17use crate::expr::stats::StatsProviderExt;
18use crate::kernel::ExecuteParentKernel;
19use crate::matcher::Matcher;
20use crate::optimizer::rules::ArrayParentReduceRule;
21use crate::scalar::Scalar;
22use crate::stats::StatsSet;
23use crate::vtable::VTable;
24
25pub trait TakeReduce: VTable {
26    /// Take elements from an array at the given indices without reading buffers.
27    ///
28    /// This trait is for take implementations that can operate purely on array metadata and
29    /// structure without needing to read or execute on the underlying buffers. Implementations
30    /// should return `None` if taking requires buffer access.
31    ///
32    /// # Preconditions
33    ///
34    /// The indices are guaranteed to be non-empty.
35    fn take(array: &Self::Array, indices: &dyn Array) -> VortexResult<Option<ArrayRef>>;
36}
37
38pub trait TakeExecute: VTable {
39    /// Take elements from an array at the given indices, potentially reading buffers.
40    ///
41    /// Unlike [`TakeReduce`], this trait is for take implementations that may need to read
42    /// and execute on the underlying buffers to produce the result.
43    ///
44    /// # Preconditions
45    ///
46    /// The indices are guaranteed to be non-empty.
47    fn take(
48        array: &Self::Array,
49        indices: &dyn Array,
50        ctx: &mut ExecutionCtx,
51    ) -> VortexResult<Option<ArrayRef>>;
52}
53
54/// Common preconditions for take operations that apply to all arrays.
55///
56/// Returns `Some(result)` if the precondition short-circuits the take operation,
57/// or `None` if the take should proceed normally.
58fn precondition<V: VTable>(array: &V::Array, indices: &dyn Array) -> Option<ArrayRef> {
59    // Fast-path for empty indices.
60    if indices.is_empty() {
61        let result_dtype = array
62            .dtype()
63            .clone()
64            .union_nullability(indices.dtype().nullability());
65        return Some(Canonical::empty(&result_dtype).into_array());
66    }
67
68    // Fast-path for empty arrays: all indices must be null, return all-invalid result.
69    if array.is_empty() {
70        return Some(
71            ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
72                .into_array(),
73        );
74    }
75
76    None
77}
78
79#[derive(Default, Debug)]
80pub struct TakeReduceAdaptor<V>(pub V);
81
82impl<V> ArrayParentReduceRule<V> for TakeReduceAdaptor<V>
83where
84    V: TakeReduce,
85{
86    type Parent = DictVTable;
87
88    fn reduce_parent(
89        &self,
90        array: &V::Array,
91        parent: &DictArray,
92        child_idx: usize,
93    ) -> VortexResult<Option<ArrayRef>> {
94        // Only handle the values child (index 1), not the codes child (index 0).
95        if child_idx != 1 {
96            return Ok(None);
97        }
98        if let Some(result) = precondition::<V>(array, parent.codes()) {
99            return Ok(Some(result));
100        }
101        let result = <V as TakeReduce>::take(array, parent.codes())?;
102        if let Some(ref taken) = result {
103            propagate_take_stats(&**array, taken.as_ref(), parent.codes())?;
104        }
105        Ok(result)
106    }
107}
108
109#[derive(Default, Debug)]
110pub struct TakeExecuteAdaptor<V>(pub V);
111
112impl<V> ExecuteParentKernel<V> for TakeExecuteAdaptor<V>
113where
114    V: TakeExecute,
115{
116    type Parent = DictVTable;
117
118    fn execute_parent(
119        &self,
120        array: &V::Array,
121        parent: <Self::Parent as Matcher>::Match<'_>,
122        child_idx: usize,
123        ctx: &mut ExecutionCtx,
124    ) -> VortexResult<Option<ArrayRef>> {
125        // Only handle the values child (index 1), not the codes child (index 0).
126        if child_idx != 1 {
127            return Ok(None);
128        }
129        if let Some(result) = precondition::<V>(array, parent.codes()) {
130            return Ok(Some(result));
131        }
132        let result = <V as TakeExecute>::take(array, parent.codes(), ctx)?;
133        if let Some(ref taken) = result {
134            propagate_take_stats(&**array, taken.as_ref(), parent.codes())?;
135        }
136        Ok(result)
137    }
138}
139
140pub(crate) fn propagate_take_stats(
141    source: &dyn Array,
142    target: &dyn Array,
143    indices: &dyn Array,
144) -> VortexResult<()> {
145    target.statistics().with_mut_typed_stats_set(|mut st| {
146        if indices.all_valid().unwrap_or(false) {
147            let is_constant = source.statistics().get_as::<bool>(Stat::IsConstant);
148            if is_constant == Some(Precision::Exact(true)) {
149                // Any combination of elements from a constant array is still const
150                st.set(Stat::IsConstant, Precision::exact(true));
151            }
152        }
153        let inexact_min_max = [Stat::Min, Stat::Max]
154            .into_iter()
155            .filter_map(|stat| {
156                source
157                    .statistics()
158                    .get(stat)
159                    .and_then(|v| v.map(|s| s.into_value()).into_inexact().transpose())
160                    .map(|sv| (stat, sv))
161            })
162            .collect::<Vec<_>>();
163        st.combine_sets(
164            &(unsafe { StatsSet::new_unchecked(inexact_min_max) }).as_typed_ref(source.dtype()),
165        )
166    })
167}