Skip to main content

vortex_array/arrays/dict/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use smallvec::SmallVec;
5use vortex_error::VortexResult;
6
7use super::Dict;
8use crate::ArrayRef;
9use crate::Canonical;
10use crate::ExecutionCtx;
11use crate::IntoArray;
12use crate::array::ArrayView;
13use crate::array::VTable;
14use crate::arrays::ConstantArray;
15use crate::arrays::dict::DictArraySlotsExt;
16use crate::expr::stats::Precision;
17use crate::expr::stats::Stat;
18use crate::expr::stats::StatsProvider;
19use crate::expr::stats::StatsProviderExt;
20use crate::kernel::ExecuteParentKernel;
21use crate::matcher::Matcher;
22use crate::optimizer::rules::ArrayParentReduceRule;
23use crate::scalar::Scalar;
24use crate::stats::StatsSet;
25use crate::validity::Validity;
26
27pub trait TakeReduce: VTable {
28    /// Take elements from an array at the given indices without reading buffers.
29    ///
30    /// This trait is for take implementations that can operate purely on array metadata and
31    /// structure without needing to read or execute on the underlying buffers. Implementations
32    /// should return `None` if taking requires buffer access.
33    ///
34    /// # Preconditions
35    ///
36    /// The indices are guaranteed to be non-empty.
37    fn take(array: ArrayView<'_, Self>, indices: &ArrayRef) -> VortexResult<Option<ArrayRef>>;
38}
39
40pub trait TakeExecute: VTable {
41    /// Take elements from an array at the given indices, potentially reading buffers.
42    ///
43    /// Unlike [`TakeReduce`], this trait is for take implementations that may need to read
44    /// and execute on the underlying buffers to produce the result.
45    ///
46    /// # Preconditions
47    ///
48    /// The indices are guaranteed to be non-empty.
49    fn take(
50        array: ArrayView<'_, Self>,
51        indices: &ArrayRef,
52        ctx: &mut ExecutionCtx,
53    ) -> VortexResult<Option<ArrayRef>>;
54}
55
56/// Common preconditions for take operations that apply to all arrays.
57///
58/// Returns `Some(result)` if the precondition short-circuits the take operation,
59/// or `None` if the take should proceed normally.
60fn precondition<V: VTable>(array: ArrayView<'_, V>, indices: &ArrayRef) -> Option<ArrayRef> {
61    // Fast-path for empty indices.
62    if indices.is_empty() {
63        let result_dtype = array
64            .dtype()
65            .clone()
66            .union_nullability(indices.dtype().nullability());
67        return Some(Canonical::empty(&result_dtype).into_array());
68    }
69
70    // Fast-path for empty arrays: all indices must be null, return all-invalid result.
71    if array.is_empty() {
72        return Some(
73            ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
74                .into_array(),
75        );
76    }
77
78    None
79}
80
81#[derive(Default, Debug)]
82pub struct TakeReduceAdaptor<V>(pub V);
83
84impl<V> ArrayParentReduceRule<V> for TakeReduceAdaptor<V>
85where
86    V: TakeReduce,
87{
88    type Parent = Dict;
89
90    fn reduce_parent(
91        &self,
92        array: ArrayView<'_, V>,
93        parent: ArrayView<'_, Dict>,
94        child_idx: usize,
95    ) -> VortexResult<Option<ArrayRef>> {
96        // Only handle the values child (index 1), not the codes child (index 0).
97        if child_idx != 1 {
98            return Ok(None);
99        }
100        if let Some(result) = precondition::<V>(array, parent.codes()) {
101            return Ok(Some(result));
102        }
103        let result = <V as TakeReduce>::take(array, parent.codes())?;
104        if let Some(taken) = &result {
105            propagate_take_stats(array.array(), taken, parent.codes())?;
106        }
107        Ok(result)
108    }
109}
110
111#[derive(Default, Debug)]
112pub struct TakeExecuteAdaptor<V>(pub V);
113
114impl<V> ExecuteParentKernel<V> for TakeExecuteAdaptor<V>
115where
116    V: TakeExecute,
117{
118    type Parent = Dict;
119
120    fn execute_parent(
121        &self,
122        array: ArrayView<'_, V>,
123        parent: <Self::Parent as Matcher>::Match<'_>,
124        child_idx: usize,
125        ctx: &mut ExecutionCtx,
126    ) -> VortexResult<Option<ArrayRef>> {
127        // Only handle the values child (index 1), not the codes child (index 0).
128        if child_idx != 1 {
129            return Ok(None);
130        }
131        if let Some(result) = precondition::<V>(array, parent.codes()) {
132            return Ok(Some(result));
133        }
134        let result = <V as TakeExecute>::take(array, parent.codes(), ctx)?;
135        if let Some(taken) = &result {
136            propagate_take_stats(array.array(), taken, parent.codes())?;
137        }
138        Ok(result)
139    }
140}
141
142pub(crate) fn propagate_take_stats(
143    source: &ArrayRef,
144    target: &ArrayRef,
145    indices: &ArrayRef,
146) -> VortexResult<()> {
147    let indices_all_valid = matches!(
148        indices.validity()?,
149        Validity::NonNullable | Validity::AllValid
150    );
151    target.statistics().with_mut_typed_stats_set(|mut st| {
152        if indices_all_valid {
153            let is_constant = source.statistics().get_as::<bool>(Stat::IsConstant);
154            if is_constant == Some(Precision::Exact(true)) {
155                // Any combination of elements from a constant array is still const
156                st.set(Stat::IsConstant, Precision::exact(true));
157            }
158        }
159        let inexact_min_max = [Stat::Min, Stat::Max]
160            .into_iter()
161            .filter_map(|stat| {
162                source
163                    .statistics()
164                    .get(stat)
165                    .and_then(|v| v.map(|s| s.into_value()).into_inexact().transpose())
166                    .map(|sv| (stat, sv))
167            })
168            .collect::<SmallVec<_>>();
169        st.combine_sets(
170            &(unsafe { StatsSet::new_unchecked(inexact_min_max) }).as_typed_ref(source.dtype()),
171        )
172    })
173}