vortex_array/arrays/scalar_fn/
rules.rs1use std::any::Any;
5use std::sync::Arc;
6
7use itertools::Itertools;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10
11use crate::ArrayRef;
12use crate::Canonical;
13use crate::DynArray;
14use crate::IntoArray;
15use crate::arrays::Constant;
16use crate::arrays::ConstantArray;
17use crate::arrays::Filter;
18use crate::arrays::FilterArray;
19use crate::arrays::ScalarFnArray;
20use crate::arrays::ScalarFnVTable;
21use crate::arrays::Slice;
22use crate::arrays::SliceArray;
23use crate::arrays::StructArray;
24use crate::dtype::DType;
25use crate::optimizer::rules::ArrayParentReduceRule;
26use crate::optimizer::rules::ArrayReduceRule;
27use crate::optimizer::rules::ParentRuleSet;
28use crate::optimizer::rules::ReduceRuleSet;
29use crate::scalar_fn::ReduceCtx;
30use crate::scalar_fn::ReduceNode;
31use crate::scalar_fn::ReduceNodeRef;
32use crate::scalar_fn::ScalarFnRef;
33use crate::scalar_fn::fns::pack::Pack;
34use crate::validity::Validity;
35
36pub(super) const RULES: ReduceRuleSet<ScalarFnVTable> = ReduceRuleSet::new(&[
37 &ScalarFnPackToStructRule,
38 &ScalarFnConstantRule,
39 &ScalarFnAbstractReduceRule,
40]);
41
42pub(super) const PARENT_RULES: ParentRuleSet<ScalarFnVTable> = ParentRuleSet::new(&[
43 ParentRuleSet::lift(&ScalarFnUnaryFilterPushDownRule),
44 ParentRuleSet::lift(&ScalarFnSliceReduceRule),
45]);
46
47#[derive(Debug)]
49struct ScalarFnPackToStructRule;
50impl ArrayReduceRule<ScalarFnVTable> for ScalarFnPackToStructRule {
51 fn reduce(&self, array: &ScalarFnArray) -> VortexResult<Option<ArrayRef>> {
52 let Some(pack_options) = array.scalar_fn().as_opt::<Pack>() else {
53 return Ok(None);
54 };
55
56 let validity = match pack_options.nullability {
57 crate::dtype::Nullability::NonNullable => Validity::NonNullable,
58 crate::dtype::Nullability::Nullable => Validity::AllValid,
59 };
60
61 Ok(Some(
62 StructArray::try_new(
63 pack_options.names.clone(),
64 array.children.clone(),
65 array.len,
66 validity,
67 )?
68 .into_array(),
69 ))
70 }
71}
72
73#[derive(Debug)]
74struct ScalarFnConstantRule;
75impl ArrayReduceRule<ScalarFnVTable> for ScalarFnConstantRule {
76 fn reduce(&self, array: &ScalarFnArray) -> VortexResult<Option<ArrayRef>> {
77 if !array.children.iter().all(|c| c.is::<Constant>()) {
78 return Ok(None);
79 }
80 if array.is_empty() {
81 Ok(Some(Canonical::empty(array.dtype()).into_array()))
82 } else {
83 let result = array.scalar_at(0)?;
84 Ok(Some(ConstantArray::new(result, array.len).into_array()))
85 }
86 }
87}
88
89#[derive(Debug)]
90struct ScalarFnSliceReduceRule;
91impl ArrayParentReduceRule<ScalarFnVTable> for ScalarFnSliceReduceRule {
92 type Parent = Slice;
93
94 fn reduce_parent(
95 &self,
96 array: &ScalarFnArray,
97 parent: &SliceArray,
98 _child_idx: usize,
99 ) -> VortexResult<Option<ArrayRef>> {
100 let range = parent.slice_range();
101
102 let children: Vec<_> = array
103 .children()
104 .iter()
105 .map(|c| c.slice(range.clone()))
106 .collect::<VortexResult<_>>()?;
107
108 Ok(Some(
109 ScalarFnArray {
110 vtable: array.vtable.clone(),
111 dtype: array.dtype.clone(),
112 len: range.len(),
113 children,
114 stats: Default::default(),
115 }
116 .into_array(),
117 ))
118 }
119}
120
121#[derive(Debug)]
122struct ScalarFnAbstractReduceRule;
123impl ArrayReduceRule<ScalarFnVTable> for ScalarFnAbstractReduceRule {
124 fn reduce(&self, array: &ScalarFnArray) -> VortexResult<Option<ArrayRef>> {
125 if let Some(reduced) = array
126 .scalar_fn()
127 .reduce(array, &ArrayReduceCtx { len: array.len })?
128 {
129 return Ok(Some(
130 reduced
131 .as_any()
132 .downcast_ref::<ArrayRef>()
133 .vortex_expect("ReduceNode is not an ArrayRef")
134 .clone(),
135 ));
136 }
137 Ok(None)
138 }
139}
140
141impl ReduceNode for ScalarFnArray {
142 fn as_any(&self) -> &dyn Any {
143 self
144 }
145
146 fn node_dtype(&self) -> VortexResult<DType> {
147 Ok(self.dtype().clone())
148 }
149
150 #[allow(clippy::same_name_method)]
151 fn scalar_fn(&self) -> Option<&ScalarFnRef> {
152 Some(ScalarFnArray::scalar_fn(self))
153 }
154
155 fn child(&self, idx: usize) -> ReduceNodeRef {
156 Arc::new(self.children()[idx].clone())
157 }
158
159 fn child_count(&self) -> usize {
160 self.children.len()
161 }
162}
163
164impl ReduceNode for ArrayRef {
165 fn as_any(&self) -> &dyn Any {
166 self
167 }
168
169 fn node_dtype(&self) -> VortexResult<DType> {
170 self.as_ref().node_dtype()
171 }
172
173 fn scalar_fn(&self) -> Option<&ScalarFnRef> {
174 self.as_ref().scalar_fn()
175 }
176
177 fn child(&self, idx: usize) -> ReduceNodeRef {
178 self.as_ref().child(idx)
179 }
180
181 fn child_count(&self) -> usize {
182 self.as_ref().child_count()
183 }
184}
185
186struct ArrayReduceCtx {
187 len: usize,
189}
190impl ReduceCtx for ArrayReduceCtx {
191 fn new_node(
192 &self,
193 scalar_fn: ScalarFnRef,
194 children: &[ReduceNodeRef],
195 ) -> VortexResult<ReduceNodeRef> {
196 Ok(Arc::new(
197 ScalarFnArray::try_new(
198 scalar_fn,
199 children
200 .iter()
201 .map(|c| {
202 c.as_any()
203 .downcast_ref::<ArrayRef>()
204 .vortex_expect("ReduceNode is not an ArrayRef")
205 .clone()
206 })
207 .collect(),
208 self.len,
209 )?
210 .into_array(),
211 ))
212 }
213}
214
215#[derive(Debug)]
216struct ScalarFnUnaryFilterPushDownRule;
217
218impl ArrayParentReduceRule<ScalarFnVTable> for ScalarFnUnaryFilterPushDownRule {
219 type Parent = Filter;
220
221 fn reduce_parent(
222 &self,
223 child: &ScalarFnArray,
224 parent: &FilterArray,
225 _child_idx: usize,
226 ) -> VortexResult<Option<ArrayRef>> {
227 if child
230 .children
231 .iter()
232 .filter(|c| !c.is::<Constant>())
233 .count()
234 == 1
235 {
236 let new_children: Vec<_> = child
237 .children
238 .iter()
239 .map(|c| match c.as_opt::<Constant>() {
240 Some(array) => {
241 Ok(ConstantArray::new(array.scalar().clone(), parent.len()).into_array())
242 }
243 None => c.filter(parent.filter_mask().clone()),
244 })
245 .try_collect()?;
246
247 let new_array =
248 ScalarFnArray::try_new(child.scalar_fn().clone(), new_children, parent.len())?
249 .into_array();
250
251 return Ok(Some(new_array));
252 }
253
254 Ok(None)
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use vortex_error::VortexExpect;
261
262 use crate::array::IntoArray;
263 use crate::arrays::ChunkedArray;
264 use crate::arrays::PrimitiveArray;
265 use crate::arrays::scalar_fn::rules::ConstantArray;
266 use crate::dtype::DType;
267 use crate::dtype::Nullability;
268 use crate::dtype::PType;
269 use crate::expr::cast;
270 use crate::expr::is_null;
271 use crate::expr::root;
272
273 #[test]
274 fn test_empty_constants() {
275 let array = ChunkedArray::try_new(
276 vec![
277 ConstantArray::new(Some(1u64), 0).into_array(),
278 PrimitiveArray::from_iter(vec![2u64])
279 .into_array()
280 .apply(&cast(
281 root(),
282 DType::Primitive(PType::U64, Nullability::Nullable),
283 ))
284 .vortex_expect("casted"),
285 ],
286 DType::Primitive(PType::U64, Nullability::Nullable),
287 )
288 .vortex_expect("construction")
289 .into_array();
290
291 let expr = is_null(root());
292 array.apply(&expr).vortex_expect("expr evaluation");
293 }
294}