Skip to main content

vortex_array/arrays/dict/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use super::Dict;
7use crate::ArrayRef;
8use crate::Canonical;
9use crate::ExecutionCtx;
10use crate::IntoArray;
11use crate::array::ArrayView;
12use crate::array::VTable;
13use crate::arrays::ConstantArray;
14use crate::arrays::dict::DictArraySlotsExt;
15use crate::expr::stats::Precision;
16use crate::expr::stats::Stat;
17use crate::expr::stats::StatsProvider;
18use crate::expr::stats::StatsProviderExt;
19use crate::kernel::ExecuteParentKernel;
20use crate::matcher::Matcher;
21use crate::optimizer::rules::ArrayParentReduceRule;
22use crate::scalar::Scalar;
23use crate::stats::StatsSet;
24use crate::validity::Validity;
25
26pub trait TakeReduce: VTable {
27    /// Take elements from an array at the given indices without reading buffers.
28    ///
29    /// This trait is for take implementations that can operate purely on array metadata and
30    /// structure without needing to read or execute on the underlying buffers. Implementations
31    /// should return `None` if taking requires buffer access.
32    ///
33    /// # Preconditions
34    ///
35    /// The indices are guaranteed to be non-empty.
36    fn take(array: ArrayView<'_, Self>, indices: &ArrayRef) -> VortexResult<Option<ArrayRef>>;
37}
38
39pub trait TakeExecute: VTable {
40    /// Take elements from an array at the given indices, potentially reading buffers.
41    ///
42    /// Unlike [`TakeReduce`], this trait is for take implementations that may need to read
43    /// and execute on the underlying buffers to produce the result.
44    ///
45    /// # Preconditions
46    ///
47    /// The indices are guaranteed to be non-empty.
48    fn take(
49        array: ArrayView<'_, Self>,
50        indices: &ArrayRef,
51        ctx: &mut ExecutionCtx,
52    ) -> VortexResult<Option<ArrayRef>>;
53}
54
55/// Common preconditions for take operations that apply to all arrays.
56///
57/// Returns `Some(result)` if the precondition short-circuits the take operation,
58/// or `None` if the take should proceed normally.
59fn precondition<V: VTable>(array: ArrayView<'_, V>, indices: &ArrayRef) -> Option<ArrayRef> {
60    // Fast-path for empty indices.
61    if indices.is_empty() {
62        let result_dtype = array
63            .dtype()
64            .clone()
65            .union_nullability(indices.dtype().nullability());
66        return Some(Canonical::empty(&result_dtype).into_array());
67    }
68
69    // Fast-path for empty arrays: all indices must be null, return all-invalid result.
70    if array.is_empty() {
71        return Some(
72            ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
73                .into_array(),
74        );
75    }
76
77    None
78}
79
80#[derive(Default, Debug)]
81pub struct TakeReduceAdaptor<V>(pub V);
82
83impl<V> ArrayParentReduceRule<V> for TakeReduceAdaptor<V>
84where
85    V: TakeReduce,
86{
87    type Parent = Dict;
88
89    fn reduce_parent(
90        &self,
91        array: ArrayView<'_, V>,
92        parent: ArrayView<'_, Dict>,
93        child_idx: usize,
94    ) -> VortexResult<Option<ArrayRef>> {
95        // Only handle the values child (index 1), not the codes child (index 0).
96        if child_idx != 1 {
97            return Ok(None);
98        }
99        if let Some(result) = precondition::<V>(array, parent.codes()) {
100            return Ok(Some(result));
101        }
102        let result = <V as TakeReduce>::take(array, parent.codes())?;
103        if let Some(ref taken) = result {
104            propagate_take_stats(array.array(), taken, parent.codes())?;
105        }
106        Ok(result)
107    }
108}
109
110#[derive(Default, Debug)]
111pub struct TakeExecuteAdaptor<V>(pub V);
112
113impl<V> ExecuteParentKernel<V> for TakeExecuteAdaptor<V>
114where
115    V: TakeExecute,
116{
117    type Parent = Dict;
118
119    fn execute_parent(
120        &self,
121        array: ArrayView<'_, V>,
122        parent: <Self::Parent as Matcher>::Match<'_>,
123        child_idx: usize,
124        ctx: &mut ExecutionCtx,
125    ) -> VortexResult<Option<ArrayRef>> {
126        // Only handle the values child (index 1), not the codes child (index 0).
127        if child_idx != 1 {
128            return Ok(None);
129        }
130        if let Some(result) = precondition::<V>(array, parent.codes()) {
131            return Ok(Some(result));
132        }
133        let result = <V as TakeExecute>::take(array, parent.codes(), ctx)?;
134        if let Some(ref taken) = result {
135            propagate_take_stats(array.array(), taken, parent.codes())?;
136        }
137        Ok(result)
138    }
139}
140
141pub(crate) fn propagate_take_stats(
142    source: &ArrayRef,
143    target: &ArrayRef,
144    indices: &ArrayRef,
145) -> VortexResult<()> {
146    let indices_all_valid = matches!(
147        indices.validity()?,
148        Validity::NonNullable | Validity::AllValid
149    );
150    target.statistics().with_mut_typed_stats_set(|mut st| {
151        if indices_all_valid {
152            let is_constant = source.statistics().get_as::<bool>(Stat::IsConstant);
153            if is_constant == Some(Precision::Exact(true)) {
154                // Any combination of elements from a constant array is still const
155                st.set(Stat::IsConstant, Precision::exact(true));
156            }
157        }
158        let inexact_min_max = [Stat::Min, Stat::Max]
159            .into_iter()
160            .filter_map(|stat| {
161                source
162                    .statistics()
163                    .get(stat)
164                    .and_then(|v| v.map(|s| s.into_value()).into_inexact().transpose())
165                    .map(|sv| (stat, sv))
166            })
167            .collect::<Vec<_>>();
168        st.combine_sets(
169            &(unsafe { StatsSet::new_unchecked(inexact_min_max) }).as_typed_ref(source.dtype()),
170        )
171    })
172}