vortex_array/arrays/dict/
take.rs1use smallvec::SmallVec;
5use vortex_error::VortexResult;
6
7use super::Dict;
8use crate::ArrayRef;
9use crate::Canonical;
10use crate::ExecutionCtx;
11use crate::IntoArray;
12use crate::array::ArrayView;
13use crate::array::VTable;
14use crate::arrays::ConstantArray;
15use crate::arrays::dict::DictArraySlotsExt;
16use crate::expr::stats::Precision;
17use crate::expr::stats::Stat;
18use crate::expr::stats::StatsProvider;
19use crate::expr::stats::StatsProviderExt;
20use crate::kernel::ExecuteParentKernel;
21use crate::matcher::Matcher;
22use crate::optimizer::rules::ArrayParentReduceRule;
23use crate::scalar::Scalar;
24use crate::stats::StatsSet;
25use crate::validity::Validity;
26
27pub trait TakeReduce: VTable {
28 fn take(array: ArrayView<'_, Self>, indices: &ArrayRef) -> VortexResult<Option<ArrayRef>>;
38}
39
40pub trait TakeExecute: VTable {
41 fn take(
50 array: ArrayView<'_, Self>,
51 indices: &ArrayRef,
52 ctx: &mut ExecutionCtx,
53 ) -> VortexResult<Option<ArrayRef>>;
54}
55
56fn precondition<V: VTable>(array: ArrayView<'_, V>, indices: &ArrayRef) -> Option<ArrayRef> {
61 if indices.is_empty() {
63 let result_dtype = array
64 .dtype()
65 .clone()
66 .union_nullability(indices.dtype().nullability());
67 return Some(Canonical::empty(&result_dtype).into_array());
68 }
69
70 if array.is_empty() {
72 return Some(
73 ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
74 .into_array(),
75 );
76 }
77
78 None
79}
80
81#[derive(Default, Debug)]
82pub struct TakeReduceAdaptor<V>(pub V);
83
84impl<V> ArrayParentReduceRule<V> for TakeReduceAdaptor<V>
85where
86 V: TakeReduce,
87{
88 type Parent = Dict;
89
90 fn reduce_parent(
91 &self,
92 array: ArrayView<'_, V>,
93 parent: ArrayView<'_, Dict>,
94 child_idx: usize,
95 ) -> VortexResult<Option<ArrayRef>> {
96 if child_idx != 1 {
98 return Ok(None);
99 }
100 if let Some(result) = precondition::<V>(array, parent.codes()) {
101 return Ok(Some(result));
102 }
103 let result = <V as TakeReduce>::take(array, parent.codes())?;
104 if let Some(taken) = &result {
105 propagate_take_stats(array.array(), taken, parent.codes())?;
106 }
107 Ok(result)
108 }
109}
110
111#[derive(Default, Debug)]
112pub struct TakeExecuteAdaptor<V>(pub V);
113
114impl<V> ExecuteParentKernel<V> for TakeExecuteAdaptor<V>
115where
116 V: TakeExecute,
117{
118 type Parent = Dict;
119
120 fn execute_parent(
121 &self,
122 array: ArrayView<'_, V>,
123 parent: <Self::Parent as Matcher>::Match<'_>,
124 child_idx: usize,
125 ctx: &mut ExecutionCtx,
126 ) -> VortexResult<Option<ArrayRef>> {
127 if child_idx != 1 {
129 return Ok(None);
130 }
131 if let Some(result) = precondition::<V>(array, parent.codes()) {
132 return Ok(Some(result));
133 }
134 let result = <V as TakeExecute>::take(array, parent.codes(), ctx)?;
135 if let Some(taken) = &result {
136 propagate_take_stats(array.array(), taken, parent.codes())?;
137 }
138 Ok(result)
139 }
140}
141
142pub(crate) fn propagate_take_stats(
143 source: &ArrayRef,
144 target: &ArrayRef,
145 indices: &ArrayRef,
146) -> VortexResult<()> {
147 let indices_all_valid = matches!(
148 indices.validity()?,
149 Validity::NonNullable | Validity::AllValid
150 );
151 target.statistics().with_mut_typed_stats_set(|mut st| {
152 if indices_all_valid {
153 let is_constant = source.statistics().get_as::<bool>(Stat::IsConstant);
154 if is_constant == Some(Precision::Exact(true)) {
155 st.set(Stat::IsConstant, Precision::exact(true));
157 }
158 }
159 let inexact_min_max = [Stat::Min, Stat::Max]
160 .into_iter()
161 .filter_map(|stat| {
162 source
163 .statistics()
164 .get(stat)
165 .and_then(|v| v.map(|s| s.into_value()).into_inexact().transpose())
166 .map(|sv| (stat, sv))
167 })
168 .collect::<SmallVec<_>>();
169 st.combine_sets(
170 &(unsafe { StatsSet::new_unchecked(inexact_min_max) }).as_typed_ref(source.dtype()),
171 )
172 })
173}