vortex_array/arrays/dict/
take.rs1use vortex_error::VortexResult;
5
6use super::DictArray;
7use super::DictVTable;
8use crate::Array;
9use crate::ArrayRef;
10use crate::Canonical;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::arrays::ConstantArray;
14use crate::expr::stats::Precision;
15use crate::expr::stats::Stat;
16use crate::expr::stats::StatsProvider;
17use crate::expr::stats::StatsProviderExt;
18use crate::kernel::ExecuteParentKernel;
19use crate::matcher::Matcher;
20use crate::optimizer::rules::ArrayParentReduceRule;
21use crate::scalar::Scalar;
22use crate::stats::StatsSet;
23use crate::vtable::VTable;
24
25pub trait TakeReduce: VTable {
26 fn take(array: &Self::Array, indices: &dyn Array) -> VortexResult<Option<ArrayRef>>;
36}
37
38pub trait TakeExecute: VTable {
39 fn take(
48 array: &Self::Array,
49 indices: &dyn Array,
50 ctx: &mut ExecutionCtx,
51 ) -> VortexResult<Option<ArrayRef>>;
52}
53
54fn precondition<V: VTable>(array: &V::Array, indices: &dyn Array) -> Option<ArrayRef> {
59 if indices.is_empty() {
61 let result_dtype = array
62 .dtype()
63 .clone()
64 .union_nullability(indices.dtype().nullability());
65 return Some(Canonical::empty(&result_dtype).into_array());
66 }
67
68 if array.is_empty() {
70 return Some(
71 ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
72 .into_array(),
73 );
74 }
75
76 None
77}
78
79#[derive(Default, Debug)]
80pub struct TakeReduceAdaptor<V>(pub V);
81
82impl<V> ArrayParentReduceRule<V> for TakeReduceAdaptor<V>
83where
84 V: TakeReduce,
85{
86 type Parent = DictVTable;
87
88 fn reduce_parent(
89 &self,
90 array: &V::Array,
91 parent: &DictArray,
92 child_idx: usize,
93 ) -> VortexResult<Option<ArrayRef>> {
94 if child_idx != 1 {
96 return Ok(None);
97 }
98 if let Some(result) = precondition::<V>(array, parent.codes()) {
99 return Ok(Some(result));
100 }
101 let result = <V as TakeReduce>::take(array, parent.codes())?;
102 if let Some(ref taken) = result {
103 propagate_take_stats(&**array, taken.as_ref(), parent.codes())?;
104 }
105 Ok(result)
106 }
107}
108
109#[derive(Default, Debug)]
110pub struct TakeExecuteAdaptor<V>(pub V);
111
112impl<V> ExecuteParentKernel<V> for TakeExecuteAdaptor<V>
113where
114 V: TakeExecute,
115{
116 type Parent = DictVTable;
117
118 fn execute_parent(
119 &self,
120 array: &V::Array,
121 parent: <Self::Parent as Matcher>::Match<'_>,
122 child_idx: usize,
123 ctx: &mut ExecutionCtx,
124 ) -> VortexResult<Option<ArrayRef>> {
125 if child_idx != 1 {
127 return Ok(None);
128 }
129 if let Some(result) = precondition::<V>(array, parent.codes()) {
130 return Ok(Some(result));
131 }
132 let result = <V as TakeExecute>::take(array, parent.codes(), ctx)?;
133 if let Some(ref taken) = result {
134 propagate_take_stats(&**array, taken.as_ref(), parent.codes())?;
135 }
136 Ok(result)
137 }
138}
139
140pub(crate) fn propagate_take_stats(
141 source: &dyn Array,
142 target: &dyn Array,
143 indices: &dyn Array,
144) -> VortexResult<()> {
145 target.statistics().with_mut_typed_stats_set(|mut st| {
146 if indices.all_valid().unwrap_or(false) {
147 let is_constant = source.statistics().get_as::<bool>(Stat::IsConstant);
148 if is_constant == Some(Precision::Exact(true)) {
149 st.set(Stat::IsConstant, Precision::exact(true));
151 }
152 }
153 let inexact_min_max = [Stat::Min, Stat::Max]
154 .into_iter()
155 .filter_map(|stat| {
156 source
157 .statistics()
158 .get(stat)
159 .and_then(|v| v.map(|s| s.into_value()).into_inexact().transpose())
160 .map(|sv| (stat, sv))
161 })
162 .collect::<Vec<_>>();
163 st.combine_sets(
164 &(unsafe { StatsSet::new_unchecked(inexact_min_max) }).as_typed_ref(source.dtype()),
165 )
166 })
167}