vortex_array/arrays/dict/
take.rs1use vortex_error::VortexResult;
5
6use super::Dict;
7use crate::ArrayRef;
8use crate::Canonical;
9use crate::ExecutionCtx;
10use crate::IntoArray;
11use crate::array::ArrayView;
12use crate::array::VTable;
13use crate::arrays::ConstantArray;
14use crate::arrays::dict::DictArraySlotsExt;
15use crate::expr::stats::Precision;
16use crate::expr::stats::Stat;
17use crate::expr::stats::StatsProvider;
18use crate::expr::stats::StatsProviderExt;
19use crate::kernel::ExecuteParentKernel;
20use crate::matcher::Matcher;
21use crate::optimizer::rules::ArrayParentReduceRule;
22use crate::scalar::Scalar;
23use crate::stats::StatsSet;
24use crate::validity::Validity;
25
26pub trait TakeReduce: VTable {
27 fn take(array: ArrayView<'_, Self>, indices: &ArrayRef) -> VortexResult<Option<ArrayRef>>;
37}
38
39pub trait TakeExecute: VTable {
40 fn take(
49 array: ArrayView<'_, Self>,
50 indices: &ArrayRef,
51 ctx: &mut ExecutionCtx,
52 ) -> VortexResult<Option<ArrayRef>>;
53}
54
55fn precondition<V: VTable>(array: ArrayView<'_, V>, indices: &ArrayRef) -> Option<ArrayRef> {
60 if indices.is_empty() {
62 let result_dtype = array
63 .dtype()
64 .clone()
65 .union_nullability(indices.dtype().nullability());
66 return Some(Canonical::empty(&result_dtype).into_array());
67 }
68
69 if array.is_empty() {
71 return Some(
72 ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
73 .into_array(),
74 );
75 }
76
77 None
78}
79
80#[derive(Default, Debug)]
81pub struct TakeReduceAdaptor<V>(pub V);
82
83impl<V> ArrayParentReduceRule<V> for TakeReduceAdaptor<V>
84where
85 V: TakeReduce,
86{
87 type Parent = Dict;
88
89 fn reduce_parent(
90 &self,
91 array: ArrayView<'_, V>,
92 parent: ArrayView<'_, Dict>,
93 child_idx: usize,
94 ) -> VortexResult<Option<ArrayRef>> {
95 if child_idx != 1 {
97 return Ok(None);
98 }
99 if let Some(result) = precondition::<V>(array, parent.codes()) {
100 return Ok(Some(result));
101 }
102 let result = <V as TakeReduce>::take(array, parent.codes())?;
103 if let Some(ref taken) = result {
104 propagate_take_stats(array.array(), taken, parent.codes())?;
105 }
106 Ok(result)
107 }
108}
109
110#[derive(Default, Debug)]
111pub struct TakeExecuteAdaptor<V>(pub V);
112
113impl<V> ExecuteParentKernel<V> for TakeExecuteAdaptor<V>
114where
115 V: TakeExecute,
116{
117 type Parent = Dict;
118
119 fn execute_parent(
120 &self,
121 array: ArrayView<'_, V>,
122 parent: <Self::Parent as Matcher>::Match<'_>,
123 child_idx: usize,
124 ctx: &mut ExecutionCtx,
125 ) -> VortexResult<Option<ArrayRef>> {
126 if child_idx != 1 {
128 return Ok(None);
129 }
130 if let Some(result) = precondition::<V>(array, parent.codes()) {
131 return Ok(Some(result));
132 }
133 let result = <V as TakeExecute>::take(array, parent.codes(), ctx)?;
134 if let Some(ref taken) = result {
135 propagate_take_stats(array.array(), taken, parent.codes())?;
136 }
137 Ok(result)
138 }
139}
140
141pub(crate) fn propagate_take_stats(
142 source: &ArrayRef,
143 target: &ArrayRef,
144 indices: &ArrayRef,
145) -> VortexResult<()> {
146 let indices_all_valid = matches!(
147 indices.validity()?,
148 Validity::NonNullable | Validity::AllValid
149 );
150 target.statistics().with_mut_typed_stats_set(|mut st| {
151 if indices_all_valid {
152 let is_constant = source.statistics().get_as::<bool>(Stat::IsConstant);
153 if is_constant == Some(Precision::Exact(true)) {
154 st.set(Stat::IsConstant, Precision::exact(true));
156 }
157 }
158 let inexact_min_max = [Stat::Min, Stat::Max]
159 .into_iter()
160 .filter_map(|stat| {
161 source
162 .statistics()
163 .get(stat)
164 .and_then(|v| v.map(|s| s.into_value()).into_inexact().transpose())
165 .map(|sv| (stat, sv))
166 })
167 .collect::<Vec<_>>();
168 st.combine_sets(
169 &(unsafe { StatsSet::new_unchecked(inexact_min_max) }).as_typed_ref(source.dtype()),
170 )
171 })
172}