vortex_array/arrays/scalar_fn/
rules.rs1use std::any::Any;
5use std::sync::Arc;
6
7use itertools::Itertools;
8use vortex_dtype::DType;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11
12use crate::Array;
13use crate::ArrayRef;
14use crate::Canonical;
15use crate::IntoArray;
16use crate::arrays::ConstantArray;
17use crate::arrays::ConstantVTable;
18use crate::arrays::FilterArray;
19use crate::arrays::FilterVTable;
20use crate::arrays::ScalarFnArray;
21use crate::arrays::ScalarFnVTable;
22use crate::arrays::SliceReduceAdaptor;
23use crate::arrays::StructArray;
24use crate::expr::Pack;
25use crate::expr::ReduceCtx;
26use crate::expr::ReduceNode;
27use crate::expr::ReduceNodeRef;
28use crate::expr::ScalarFn;
29use crate::optimizer::rules::ArrayParentReduceRule;
30use crate::optimizer::rules::ArrayReduceRule;
31use crate::optimizer::rules::ParentRuleSet;
32use crate::optimizer::rules::ReduceRuleSet;
33use crate::validity::Validity;
34
35pub(super) const RULES: ReduceRuleSet<ScalarFnVTable> = ReduceRuleSet::new(&[
36 &ScalarFnPackToStructRule,
37 &ScalarFnConstantRule,
38 &ScalarFnAbstractReduceRule,
39]);
40
41pub(super) const PARENT_RULES: ParentRuleSet<ScalarFnVTable> = ParentRuleSet::new(&[
42 ParentRuleSet::lift(&ScalarFnUnaryFilterPushDownRule),
43 ParentRuleSet::lift(&SliceReduceAdaptor(ScalarFnVTable)),
44]);
45
46#[derive(Debug)]
48struct ScalarFnPackToStructRule;
49impl ArrayReduceRule<ScalarFnVTable> for ScalarFnPackToStructRule {
50 fn reduce(&self, array: &ScalarFnArray) -> VortexResult<Option<ArrayRef>> {
51 let Some(pack_options) = array.scalar_fn.as_opt::<Pack>() else {
52 return Ok(None);
53 };
54
55 let validity = match pack_options.nullability {
56 vortex_dtype::Nullability::NonNullable => Validity::NonNullable,
57 vortex_dtype::Nullability::Nullable => Validity::AllValid,
58 };
59
60 Ok(Some(
61 StructArray::try_new(
62 pack_options.names.clone(),
63 array.children.clone(),
64 array.len,
65 validity,
66 )?
67 .into_array(),
68 ))
69 }
70}
71
72#[derive(Debug)]
73struct ScalarFnConstantRule;
74impl ArrayReduceRule<ScalarFnVTable> for ScalarFnConstantRule {
75 fn reduce(&self, array: &ScalarFnArray) -> VortexResult<Option<ArrayRef>> {
76 if !array.children.iter().all(|c| c.is::<ConstantVTable>()) {
77 return Ok(None);
78 }
79 if array.is_empty() {
80 Ok(Some(Canonical::empty(array.dtype()).into_array()))
81 } else {
82 let result = array.scalar_at(0)?;
83 Ok(Some(ConstantArray::new(result, array.len).into_array()))
84 }
85 }
86}
87
88#[derive(Debug)]
89struct ScalarFnAbstractReduceRule;
90impl ArrayReduceRule<ScalarFnVTable> for ScalarFnAbstractReduceRule {
91 fn reduce(&self, array: &ScalarFnArray) -> VortexResult<Option<ArrayRef>> {
92 if let Some(reduced) = array
93 .scalar_fn
94 .reduce(array, &ArrayReduceCtx { len: array.len })?
95 {
96 return Ok(Some(
97 reduced
98 .as_any()
99 .downcast_ref::<ArrayRef>()
100 .vortex_expect("ReduceNode is not an ArrayRef")
101 .clone(),
102 ));
103 }
104 Ok(None)
105 }
106}
107
108impl ReduceNode for ScalarFnArray {
109 fn as_any(&self) -> &dyn Any {
110 self
111 }
112
113 fn node_dtype(&self) -> VortexResult<DType> {
114 Ok(self.dtype().clone())
115 }
116
117 #[allow(clippy::same_name_method)]
118 fn scalar_fn(&self) -> Option<&ScalarFn> {
119 Some(ScalarFnArray::scalar_fn(self))
120 }
121
122 fn child(&self, idx: usize) -> ReduceNodeRef {
123 Arc::new(self.children()[idx].clone())
124 }
125
126 fn child_count(&self) -> usize {
127 self.children.len()
128 }
129}
130
131impl ReduceNode for ArrayRef {
132 fn as_any(&self) -> &dyn Any {
133 self
134 }
135
136 fn node_dtype(&self) -> VortexResult<DType> {
137 self.as_ref().node_dtype()
138 }
139
140 fn scalar_fn(&self) -> Option<&ScalarFn> {
141 self.as_ref().scalar_fn()
142 }
143
144 fn child(&self, idx: usize) -> ReduceNodeRef {
145 self.as_ref().child(idx)
146 }
147
148 fn child_count(&self) -> usize {
149 self.as_ref().child_count()
150 }
151}
152
153struct ArrayReduceCtx {
154 len: usize,
156}
157impl ReduceCtx for ArrayReduceCtx {
158 fn new_node(
159 &self,
160 scalar_fn: ScalarFn,
161 children: &[ReduceNodeRef],
162 ) -> VortexResult<ReduceNodeRef> {
163 Ok(Arc::new(
164 ScalarFnArray::try_new(
165 scalar_fn,
166 children
167 .iter()
168 .map(|c| {
169 c.as_any()
170 .downcast_ref::<ArrayRef>()
171 .vortex_expect("ReduceNode is not an ArrayRef")
172 .clone()
173 })
174 .collect(),
175 self.len,
176 )?
177 .into_array(),
178 ))
179 }
180}
181
182#[derive(Debug)]
183struct ScalarFnUnaryFilterPushDownRule;
184
185impl ArrayParentReduceRule<ScalarFnVTable> for ScalarFnUnaryFilterPushDownRule {
186 type Parent = FilterVTable;
187
188 fn reduce_parent(
189 &self,
190 child: &ScalarFnArray,
191 parent: &FilterArray,
192 _child_idx: usize,
193 ) -> VortexResult<Option<ArrayRef>> {
194 if child
197 .children
198 .iter()
199 .filter(|c| !c.is::<ConstantVTable>())
200 .count()
201 == 1
202 {
203 let new_children: Vec<_> = child
204 .children
205 .iter()
206 .map(|c| match c.as_opt::<ConstantVTable>() {
207 Some(array) => {
208 Ok(ConstantArray::new(array.scalar().clone(), parent.len()).into_array())
209 }
210 None => c.filter(parent.filter_mask().clone()),
211 })
212 .try_collect()?;
213
214 let new_array =
215 ScalarFnArray::try_new(child.scalar_fn.clone(), new_children, parent.len())?
216 .into_array();
217
218 return Ok(Some(new_array));
219 }
220
221 Ok(None)
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use vortex_dtype::DType;
228 use vortex_dtype::Nullability;
229 use vortex_dtype::PType;
230 use vortex_error::VortexExpect;
231
232 use crate::array::IntoArray;
233 use crate::arrays::ChunkedArray;
234 use crate::arrays::ConstantArray;
235 use crate::arrays::PrimitiveArray;
236 use crate::expr::cast;
237 use crate::expr::is_null;
238 use crate::expr::root;
239
240 #[test]
241 fn test_empty_constants() {
242 let array = ChunkedArray::try_new(
243 vec![
244 ConstantArray::new(Some(1u64), 0).into_array(),
245 PrimitiveArray::from_iter(vec![2u64])
246 .into_array()
247 .apply(&cast(
248 root(),
249 DType::Primitive(PType::U64, Nullability::Nullable),
250 ))
251 .vortex_expect("casted"),
252 ],
253 DType::Primitive(PType::U64, Nullability::Nullable),
254 )
255 .vortex_expect("construction")
256 .to_array();
257
258 let expr = is_null(root());
259 array.apply(&expr).vortex_expect("expr evaluation");
260 }
261}