vortex_array/arrays/scalar_fn/
rules.rs1use std::any::Any;
5use std::sync::Arc;
6
7use itertools::Itertools;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10
11use crate::ArrayRef;
12use crate::Canonical;
13use crate::IntoArray;
14use crate::LEGACY_SESSION;
15use crate::VortexSessionExecute;
16use crate::array::ArrayView;
17use crate::arrays::Constant;
18use crate::arrays::ConstantArray;
19use crate::arrays::Filter;
20use crate::arrays::ScalarFnArray;
21use crate::arrays::ScalarFnVTable;
22use crate::arrays::Slice;
23use crate::arrays::StructArray;
24use crate::arrays::scalar_fn::ScalarFnArrayExt;
25use crate::dtype::DType;
26use crate::optimizer::rules::ArrayParentReduceRule;
27use crate::optimizer::rules::ArrayReduceRule;
28use crate::optimizer::rules::ParentRuleSet;
29use crate::optimizer::rules::ReduceRuleSet;
30use crate::scalar_fn::ReduceCtx;
31use crate::scalar_fn::ReduceNode;
32use crate::scalar_fn::ReduceNodeRef;
33use crate::scalar_fn::ScalarFnRef;
34use crate::scalar_fn::fns::pack::Pack;
35use crate::validity::Validity;
36
37pub(super) const RULES: ReduceRuleSet<ScalarFnVTable> = ReduceRuleSet::new(&[
38 &ScalarFnPackToStructRule,
39 &ScalarFnConstantRule,
40 &ScalarFnAbstractReduceRule,
41]);
42
43pub(super) const PARENT_RULES: ParentRuleSet<ScalarFnVTable> = ParentRuleSet::new(&[
44 ParentRuleSet::lift(&ScalarFnUnaryFilterPushDownRule),
45 ParentRuleSet::lift(&ScalarFnSliceReduceRule),
46]);
47
48#[derive(Debug)]
50struct ScalarFnPackToStructRule;
51impl ArrayReduceRule<ScalarFnVTable> for ScalarFnPackToStructRule {
52 fn reduce(&self, array: ArrayView<'_, ScalarFnVTable>) -> VortexResult<Option<ArrayRef>> {
53 let Some(pack_options) = array.scalar_fn().as_opt::<Pack>() else {
54 return Ok(None);
55 };
56
57 let validity = match pack_options.nullability {
58 crate::dtype::Nullability::NonNullable => Validity::NonNullable,
59 crate::dtype::Nullability::Nullable => Validity::AllValid,
60 };
61
62 Ok(Some(
63 StructArray::try_new(
64 pack_options.names.clone(),
65 array.children(),
66 array.len(),
67 validity,
68 )?
69 .into_array(),
70 ))
71 }
72}
73
74#[derive(Debug)]
75struct ScalarFnConstantRule;
76impl ArrayReduceRule<ScalarFnVTable> for ScalarFnConstantRule {
77 fn reduce(&self, array: ArrayView<'_, ScalarFnVTable>) -> VortexResult<Option<ArrayRef>> {
78 if !array.children().iter().all(|c| c.is::<Constant>()) {
79 return Ok(None);
80 }
81 if array.is_empty() {
82 Ok(Some(Canonical::empty(array.dtype()).into_array()))
83 } else {
84 let result = array
85 .array()
86 .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())?;
87 Ok(Some(ConstantArray::new(result, array.len()).into_array()))
88 }
89 }
90}
91
92#[derive(Debug)]
93struct ScalarFnSliceReduceRule;
94impl ArrayParentReduceRule<ScalarFnVTable> for ScalarFnSliceReduceRule {
95 type Parent = Slice;
96
97 fn reduce_parent(
98 &self,
99 array: ArrayView<'_, ScalarFnVTable>,
100 parent: ArrayView<'_, Slice>,
101 _child_idx: usize,
102 ) -> VortexResult<Option<ArrayRef>> {
103 let range = parent.slice_range();
104
105 let children: Vec<_> = array
106 .iter_children()
107 .map(|c| c.slice(range.clone()))
108 .collect::<VortexResult<_>>()?;
109
110 Ok(Some(
111 ScalarFnArray::try_new(array.scalar_fn().clone(), children, range.len())?.into_array(),
112 ))
113 }
114}
115
116#[derive(Debug)]
117struct ScalarFnAbstractReduceRule;
118impl ArrayReduceRule<ScalarFnVTable> for ScalarFnAbstractReduceRule {
119 fn reduce(&self, array: ArrayView<'_, ScalarFnVTable>) -> VortexResult<Option<ArrayRef>> {
120 if let Some(reduced) = array
121 .scalar_fn()
122 .reduce(array.as_ref(), &ArrayReduceCtx { len: array.len() })?
123 {
124 return Ok(Some(
125 reduced
126 .as_any()
127 .downcast_ref::<ArrayRef>()
128 .vortex_expect("ReduceNode is not an ArrayRef")
129 .clone(),
130 ));
131 }
132 Ok(None)
133 }
134}
135
136impl ReduceNode for ArrayRef {
137 fn as_any(&self) -> &dyn Any {
138 self
139 }
140
141 fn node_dtype(&self) -> VortexResult<DType> {
142 Ok(self.dtype().clone())
143 }
144
145 fn scalar_fn(&self) -> Option<&ScalarFnRef> {
146 self.as_opt::<ScalarFnVTable>()
147 .map(|a| a.data().scalar_fn())
148 }
149
150 fn child(&self, idx: usize) -> ReduceNodeRef {
151 Arc::new(self.nth_child(idx).vortex_expect("child idx out of bounds"))
152 }
153
154 fn child_count(&self) -> usize {
155 self.nchildren()
156 }
157}
158
159struct ArrayReduceCtx {
160 len: usize,
162}
163impl ReduceCtx for ArrayReduceCtx {
164 fn new_node(
165 &self,
166 scalar_fn: ScalarFnRef,
167 children: &[ReduceNodeRef],
168 ) -> VortexResult<ReduceNodeRef> {
169 Ok(Arc::new(
170 ScalarFnArray::try_new(
171 scalar_fn,
172 children
173 .iter()
174 .map(|c| {
175 c.as_any()
176 .downcast_ref::<ArrayRef>()
177 .vortex_expect("ReduceNode is not an ArrayRef")
178 .clone()
179 })
180 .collect(),
181 self.len,
182 )?
183 .into_array(),
184 ))
185 }
186}
187
188#[derive(Debug)]
189struct ScalarFnUnaryFilterPushDownRule;
190
191impl ArrayParentReduceRule<ScalarFnVTable> for ScalarFnUnaryFilterPushDownRule {
192 type Parent = Filter;
193
194 fn reduce_parent(
195 &self,
196 child: ArrayView<'_, ScalarFnVTable>,
197 parent: ArrayView<'_, Filter>,
198 _child_idx: usize,
199 ) -> VortexResult<Option<ArrayRef>> {
200 if child
203 .iter_children()
204 .filter(|c| !c.is::<Constant>())
205 .count()
206 == 1
207 {
208 let new_children: Vec<_> = child
209 .iter_children()
210 .map(|c| match c.as_opt::<Constant>() {
211 Some(array) => {
212 Ok(ConstantArray::new(array.scalar().clone(), parent.len()).into_array())
213 }
214 None => c.filter(parent.filter_mask().clone()),
215 })
216 .try_collect()?;
217
218 let new_array =
219 ScalarFnArray::try_new(child.scalar_fn().clone(), new_children, parent.len())?
220 .into_array();
221
222 return Ok(Some(new_array));
223 }
224
225 Ok(None)
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use vortex_error::VortexExpect;
232
233 use crate::array::IntoArray;
234 use crate::arrays::ChunkedArray;
235 use crate::arrays::PrimitiveArray;
236 use crate::arrays::scalar_fn::rules::ConstantArray;
237 use crate::dtype::DType;
238 use crate::dtype::Nullability;
239 use crate::dtype::PType;
240 use crate::expr::cast;
241 use crate::expr::is_null;
242 use crate::expr::root;
243
244 #[test]
245 fn test_empty_constants() {
246 let array = ChunkedArray::try_new(
247 vec![
248 ConstantArray::new(Some(1u64), 0).into_array(),
249 PrimitiveArray::from_iter(vec![2u64])
250 .into_array()
251 .apply(&cast(
252 root(),
253 DType::Primitive(PType::U64, Nullability::Nullable),
254 ))
255 .vortex_expect("casted"),
256 ],
257 DType::Primitive(PType::U64, Nullability::Nullable),
258 )
259 .vortex_expect("construction")
260 .into_array();
261
262 let expr = is_null(root());
263 array.apply(&expr).vortex_expect("expr evaluation");
264 }
265}