vortex_array/arrays/scalar_fn/
rules.rs1use std::any::Any;
5use std::sync::Arc;
6
7use itertools::Itertools;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10
11use crate::ArrayRef;
12use crate::IntoArray;
13use crate::array::ArrayView;
14use crate::arrays::Constant;
15use crate::arrays::ConstantArray;
16use crate::arrays::Filter;
17use crate::arrays::ScalarFn;
18use crate::arrays::ScalarFnArray;
19use crate::arrays::Slice;
20use crate::arrays::StructArray;
21use crate::arrays::scalar_fn::ScalarFnArrayExt;
22use crate::dtype::DType;
23use crate::optimizer::rules::ArrayParentReduceRule;
24use crate::optimizer::rules::ArrayReduceRule;
25use crate::optimizer::rules::ParentRuleSet;
26use crate::optimizer::rules::ReduceRuleSet;
27use crate::scalar_fn::ReduceCtx;
28use crate::scalar_fn::ReduceNode;
29use crate::scalar_fn::ReduceNodeRef;
30use crate::scalar_fn::ScalarFnRef;
31use crate::scalar_fn::fns::pack::Pack;
32use crate::validity::Validity;
33
34pub(super) const RULES: ReduceRuleSet<ScalarFn> =
35 ReduceRuleSet::new(&[&ScalarFnPackToStructRule, &ScalarFnAbstractReduceRule]);
36
37pub(super) const PARENT_RULES: ParentRuleSet<ScalarFn> = ParentRuleSet::new(&[
38 ParentRuleSet::lift(&ScalarFnUnaryFilterPushDownRule),
39 ParentRuleSet::lift(&ScalarFnSliceReduceRule),
40]);
41
42#[derive(Debug)]
44struct ScalarFnPackToStructRule;
45impl ArrayReduceRule<ScalarFn> for ScalarFnPackToStructRule {
46 fn reduce(&self, array: ArrayView<'_, ScalarFn>) -> VortexResult<Option<ArrayRef>> {
47 let Some(pack_options) = array.scalar_fn().as_opt::<Pack>() else {
48 return Ok(None);
49 };
50
51 let validity = match pack_options.nullability {
52 crate::dtype::Nullability::NonNullable => Validity::NonNullable,
53 crate::dtype::Nullability::Nullable => Validity::AllValid,
54 };
55
56 Ok(Some(
57 StructArray::try_new(
58 pack_options.names.clone(),
59 array.children(),
60 array.len(),
61 validity,
62 )?
63 .into_array(),
64 ))
65 }
66}
67
68#[derive(Debug)]
69struct ScalarFnSliceReduceRule;
70impl ArrayParentReduceRule<ScalarFn> for ScalarFnSliceReduceRule {
71 type Parent = Slice;
72
73 fn reduce_parent(
74 &self,
75 array: ArrayView<'_, ScalarFn>,
76 parent: ArrayView<'_, Slice>,
77 _child_idx: usize,
78 ) -> VortexResult<Option<ArrayRef>> {
79 let range = parent.slice_range();
80
81 let children: Vec<_> = array
82 .iter_children()
83 .map(|c| c.slice(range.clone()))
84 .collect::<VortexResult<_>>()?;
85
86 Ok(Some(
87 ScalarFnArray::try_new(array.scalar_fn().clone(), children, range.len())?.into_array(),
88 ))
89 }
90}
91
92#[derive(Debug)]
93struct ScalarFnAbstractReduceRule;
94impl ArrayReduceRule<ScalarFn> for ScalarFnAbstractReduceRule {
95 fn reduce(&self, array: ArrayView<'_, ScalarFn>) -> VortexResult<Option<ArrayRef>> {
96 if let Some(reduced) = array
97 .scalar_fn()
98 .reduce(array.as_ref(), &ArrayReduceCtx { len: array.len() })?
99 {
100 return Ok(Some(
101 reduced
102 .as_any()
103 .downcast_ref::<ArrayRef>()
104 .vortex_expect("ReduceNode is not an ArrayRef")
105 .clone(),
106 ));
107 }
108 Ok(None)
109 }
110}
111
112impl ReduceNode for ArrayRef {
113 fn as_any(&self) -> &dyn Any {
114 self
115 }
116
117 fn node_dtype(&self) -> VortexResult<DType> {
118 Ok(self.dtype().clone())
119 }
120
121 fn scalar_fn(&self) -> Option<&ScalarFnRef> {
122 self.as_opt::<ScalarFn>().map(|a| a.data().scalar_fn())
123 }
124
125 fn child(&self, idx: usize) -> ReduceNodeRef {
126 Arc::new(self.nth_child(idx).vortex_expect("child idx out of bounds"))
127 }
128
129 fn child_count(&self) -> usize {
130 self.nchildren()
131 }
132}
133
134struct ArrayReduceCtx {
135 len: usize,
137}
138impl ReduceCtx for ArrayReduceCtx {
139 fn new_node(
140 &self,
141 scalar_fn: ScalarFnRef,
142 children: &[ReduceNodeRef],
143 ) -> VortexResult<ReduceNodeRef> {
144 Ok(Arc::new(
145 ScalarFnArray::try_new(
146 scalar_fn,
147 children
148 .iter()
149 .map(|c| {
150 c.as_any()
151 .downcast_ref::<ArrayRef>()
152 .vortex_expect("ReduceNode is not an ArrayRef")
153 .clone()
154 })
155 .collect(),
156 self.len,
157 )?
158 .into_array(),
159 ))
160 }
161}
162
163#[derive(Debug)]
164struct ScalarFnUnaryFilterPushDownRule;
165
166impl ArrayParentReduceRule<ScalarFn> for ScalarFnUnaryFilterPushDownRule {
167 type Parent = Filter;
168
169 fn reduce_parent(
170 &self,
171 child: ArrayView<'_, ScalarFn>,
172 parent: ArrayView<'_, Filter>,
173 _child_idx: usize,
174 ) -> VortexResult<Option<ArrayRef>> {
175 if child
178 .iter_children()
179 .filter(|c| !c.is::<Constant>())
180 .count()
181 == 1
182 {
183 let new_children: Vec<_> = child
184 .iter_children()
185 .map(|c| match c.as_opt::<Constant>() {
186 Some(array) => {
187 Ok(ConstantArray::new(array.scalar().clone(), parent.len()).into_array())
188 }
189 None => c.filter(parent.filter_mask().clone()),
190 })
191 .try_collect()?;
192
193 let new_array =
194 ScalarFnArray::try_new(child.scalar_fn().clone(), new_children, parent.len())?
195 .into_array();
196
197 return Ok(Some(new_array));
198 }
199
200 Ok(None)
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use vortex_error::VortexExpect;
207
208 use crate::array::IntoArray;
209 use crate::arrays::ChunkedArray;
210 use crate::arrays::PrimitiveArray;
211 use crate::arrays::scalar_fn::rules::ConstantArray;
212 use crate::dtype::DType;
213 use crate::dtype::Nullability;
214 use crate::dtype::PType;
215 use crate::expr::cast;
216 use crate::expr::is_null;
217 use crate::expr::root;
218
219 #[test]
220 fn test_empty_constants() {
221 let array = ChunkedArray::try_new(
222 vec![
223 ConstantArray::new(Some(1u64), 0).into_array(),
224 PrimitiveArray::from_iter(vec![2u64])
225 .into_array()
226 .apply(&cast(
227 root(),
228 DType::Primitive(PType::U64, Nullability::Nullable),
229 ))
230 .vortex_expect("casted"),
231 ],
232 DType::Primitive(PType::U64, Nullability::Nullable),
233 )
234 .vortex_expect("construction")
235 .into_array();
236
237 let expr = is_null(root());
238 array.apply(&expr).vortex_expect("expr evaluation");
239 }
240}