vortex_array/arrays/scalar_fn/
rules.rs1use std::any::Any;
5use std::sync::Arc;
6
7use itertools::Itertools;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10
11use crate::ArrayRef;
12use crate::Canonical;
13use crate::IntoArray;
14use crate::array::ArrayView;
15use crate::arrays::Constant;
16use crate::arrays::ConstantArray;
17use crate::arrays::Filter;
18use crate::arrays::ScalarFnArray;
19use crate::arrays::ScalarFnVTable;
20use crate::arrays::Slice;
21use crate::arrays::StructArray;
22use crate::arrays::scalar_fn::ScalarFnArrayExt;
23use crate::dtype::DType;
24use crate::optimizer::rules::ArrayParentReduceRule;
25use crate::optimizer::rules::ArrayReduceRule;
26use crate::optimizer::rules::ParentRuleSet;
27use crate::optimizer::rules::ReduceRuleSet;
28use crate::scalar_fn::ReduceCtx;
29use crate::scalar_fn::ReduceNode;
30use crate::scalar_fn::ReduceNodeRef;
31use crate::scalar_fn::ScalarFnRef;
32use crate::scalar_fn::fns::pack::Pack;
33use crate::validity::Validity;
34
35pub(super) const RULES: ReduceRuleSet<ScalarFnVTable> = ReduceRuleSet::new(&[
36 &ScalarFnPackToStructRule,
37 &ScalarFnConstantRule,
38 &ScalarFnAbstractReduceRule,
39]);
40
41pub(super) const PARENT_RULES: ParentRuleSet<ScalarFnVTable> = ParentRuleSet::new(&[
42 ParentRuleSet::lift(&ScalarFnUnaryFilterPushDownRule),
43 ParentRuleSet::lift(&ScalarFnSliceReduceRule),
44]);
45
46#[derive(Debug)]
48struct ScalarFnPackToStructRule;
49impl ArrayReduceRule<ScalarFnVTable> for ScalarFnPackToStructRule {
50 fn reduce(&self, array: ArrayView<'_, ScalarFnVTable>) -> VortexResult<Option<ArrayRef>> {
51 let Some(pack_options) = array.scalar_fn().as_opt::<Pack>() else {
52 return Ok(None);
53 };
54
55 let validity = match pack_options.nullability {
56 crate::dtype::Nullability::NonNullable => Validity::NonNullable,
57 crate::dtype::Nullability::Nullable => Validity::AllValid,
58 };
59
60 Ok(Some(
61 StructArray::try_new(
62 pack_options.names.clone(),
63 array.children(),
64 array.len(),
65 validity,
66 )?
67 .into_array(),
68 ))
69 }
70}
71
72#[derive(Debug)]
73struct ScalarFnConstantRule;
74impl ArrayReduceRule<ScalarFnVTable> for ScalarFnConstantRule {
75 fn reduce(&self, array: ArrayView<'_, ScalarFnVTable>) -> VortexResult<Option<ArrayRef>> {
76 if !array.children().iter().all(|c| c.is::<Constant>()) {
77 return Ok(None);
78 }
79 if array.is_empty() {
80 Ok(Some(Canonical::empty(array.dtype()).into_array()))
81 } else {
82 let result = array.array().scalar_at(0)?;
83 Ok(Some(ConstantArray::new(result, array.len()).into_array()))
84 }
85 }
86}
87
88#[derive(Debug)]
89struct ScalarFnSliceReduceRule;
90impl ArrayParentReduceRule<ScalarFnVTable> for ScalarFnSliceReduceRule {
91 type Parent = Slice;
92
93 fn reduce_parent(
94 &self,
95 array: ArrayView<'_, ScalarFnVTable>,
96 parent: ArrayView<'_, Slice>,
97 _child_idx: usize,
98 ) -> VortexResult<Option<ArrayRef>> {
99 let range = parent.slice_range();
100
101 let children: Vec<_> = array
102 .iter_children()
103 .map(|c| c.slice(range.clone()))
104 .collect::<VortexResult<_>>()?;
105
106 Ok(Some(
107 ScalarFnArray::try_new(array.scalar_fn().clone(), children, range.len())?.into_array(),
108 ))
109 }
110}
111
112#[derive(Debug)]
113struct ScalarFnAbstractReduceRule;
114impl ArrayReduceRule<ScalarFnVTable> for ScalarFnAbstractReduceRule {
115 fn reduce(&self, array: ArrayView<'_, ScalarFnVTable>) -> VortexResult<Option<ArrayRef>> {
116 if let Some(reduced) = array
117 .scalar_fn()
118 .reduce(array.as_ref(), &ArrayReduceCtx { len: array.len() })?
119 {
120 return Ok(Some(
121 reduced
122 .as_any()
123 .downcast_ref::<ArrayRef>()
124 .vortex_expect("ReduceNode is not an ArrayRef")
125 .clone(),
126 ));
127 }
128 Ok(None)
129 }
130}
131
132impl ReduceNode for ArrayRef {
133 fn as_any(&self) -> &dyn Any {
134 self
135 }
136
137 fn node_dtype(&self) -> VortexResult<DType> {
138 Ok(self.dtype().clone())
139 }
140
141 fn scalar_fn(&self) -> Option<&ScalarFnRef> {
142 self.as_opt::<ScalarFnVTable>()
143 .map(|a| a.data().scalar_fn())
144 }
145
146 fn child(&self, idx: usize) -> ReduceNodeRef {
147 Arc::new(self.nth_child(idx).vortex_expect("child idx out of bounds"))
148 }
149
150 fn child_count(&self) -> usize {
151 self.nchildren()
152 }
153}
154
155struct ArrayReduceCtx {
156 len: usize,
158}
159impl ReduceCtx for ArrayReduceCtx {
160 fn new_node(
161 &self,
162 scalar_fn: ScalarFnRef,
163 children: &[ReduceNodeRef],
164 ) -> VortexResult<ReduceNodeRef> {
165 Ok(Arc::new(
166 ScalarFnArray::try_new(
167 scalar_fn,
168 children
169 .iter()
170 .map(|c| {
171 c.as_any()
172 .downcast_ref::<ArrayRef>()
173 .vortex_expect("ReduceNode is not an ArrayRef")
174 .clone()
175 })
176 .collect(),
177 self.len,
178 )?
179 .into_array(),
180 ))
181 }
182}
183
184#[derive(Debug)]
185struct ScalarFnUnaryFilterPushDownRule;
186
187impl ArrayParentReduceRule<ScalarFnVTable> for ScalarFnUnaryFilterPushDownRule {
188 type Parent = Filter;
189
190 fn reduce_parent(
191 &self,
192 child: ArrayView<'_, ScalarFnVTable>,
193 parent: ArrayView<'_, Filter>,
194 _child_idx: usize,
195 ) -> VortexResult<Option<ArrayRef>> {
196 if child
199 .iter_children()
200 .filter(|c| !c.is::<Constant>())
201 .count()
202 == 1
203 {
204 let new_children: Vec<_> = child
205 .iter_children()
206 .map(|c| match c.as_opt::<Constant>() {
207 Some(array) => {
208 Ok(ConstantArray::new(array.scalar().clone(), parent.len()).into_array())
209 }
210 None => c.filter(parent.filter_mask().clone()),
211 })
212 .try_collect()?;
213
214 let new_array =
215 ScalarFnArray::try_new(child.scalar_fn().clone(), new_children, parent.len())?
216 .into_array();
217
218 return Ok(Some(new_array));
219 }
220
221 Ok(None)
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use vortex_error::VortexExpect;
228
229 use crate::array::IntoArray;
230 use crate::arrays::ChunkedArray;
231 use crate::arrays::PrimitiveArray;
232 use crate::arrays::scalar_fn::rules::ConstantArray;
233 use crate::dtype::DType;
234 use crate::dtype::Nullability;
235 use crate::dtype::PType;
236 use crate::expr::cast;
237 use crate::expr::is_null;
238 use crate::expr::root;
239
240 #[test]
241 fn test_empty_constants() {
242 let array = ChunkedArray::try_new(
243 vec![
244 ConstantArray::new(Some(1u64), 0).into_array(),
245 PrimitiveArray::from_iter(vec![2u64])
246 .into_array()
247 .apply(&cast(
248 root(),
249 DType::Primitive(PType::U64, Nullability::Nullable),
250 ))
251 .vortex_expect("casted"),
252 ],
253 DType::Primitive(PType::U64, Nullability::Nullable),
254 )
255 .vortex_expect("construction")
256 .into_array();
257
258 let expr = is_null(root());
259 array.apply(&expr).vortex_expect("expr evaluation");
260 }
261}