Skip to main content

vortex_array/arrays/scalar_fn/
rules.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::sync::Arc;
6
7use itertools::Itertools;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10
11use crate::ArrayRef;
12use crate::Canonical;
13use crate::DynArray;
14use crate::IntoArray;
15use crate::arrays::Constant;
16use crate::arrays::ConstantArray;
17use crate::arrays::Filter;
18use crate::arrays::FilterArray;
19use crate::arrays::ScalarFnArray;
20use crate::arrays::ScalarFnVTable;
21use crate::arrays::Slice;
22use crate::arrays::SliceArray;
23use crate::arrays::StructArray;
24use crate::dtype::DType;
25use crate::optimizer::rules::ArrayParentReduceRule;
26use crate::optimizer::rules::ArrayReduceRule;
27use crate::optimizer::rules::ParentRuleSet;
28use crate::optimizer::rules::ReduceRuleSet;
29use crate::scalar_fn::ReduceCtx;
30use crate::scalar_fn::ReduceNode;
31use crate::scalar_fn::ReduceNodeRef;
32use crate::scalar_fn::ScalarFnRef;
33use crate::scalar_fn::fns::pack::Pack;
34use crate::validity::Validity;
35
36pub(super) const RULES: ReduceRuleSet<ScalarFnVTable> = ReduceRuleSet::new(&[
37    &ScalarFnPackToStructRule,
38    &ScalarFnConstantRule,
39    &ScalarFnAbstractReduceRule,
40]);
41
42pub(super) const PARENT_RULES: ParentRuleSet<ScalarFnVTable> = ParentRuleSet::new(&[
43    ParentRuleSet::lift(&ScalarFnUnaryFilterPushDownRule),
44    ParentRuleSet::lift(&ScalarFnSliceReduceRule),
45]);
46
47/// Converts a ScalarFnArray with Pack into a StructArray directly.
48#[derive(Debug)]
49struct ScalarFnPackToStructRule;
50impl ArrayReduceRule<ScalarFnVTable> for ScalarFnPackToStructRule {
51    fn reduce(&self, array: &ScalarFnArray) -> VortexResult<Option<ArrayRef>> {
52        let Some(pack_options) = array.scalar_fn().as_opt::<Pack>() else {
53            return Ok(None);
54        };
55
56        let validity = match pack_options.nullability {
57            crate::dtype::Nullability::NonNullable => Validity::NonNullable,
58            crate::dtype::Nullability::Nullable => Validity::AllValid,
59        };
60
61        Ok(Some(
62            StructArray::try_new(
63                pack_options.names.clone(),
64                array.children.clone(),
65                array.len,
66                validity,
67            )?
68            .into_array(),
69        ))
70    }
71}
72
73#[derive(Debug)]
74struct ScalarFnConstantRule;
75impl ArrayReduceRule<ScalarFnVTable> for ScalarFnConstantRule {
76    fn reduce(&self, array: &ScalarFnArray) -> VortexResult<Option<ArrayRef>> {
77        if !array.children.iter().all(|c| c.is::<Constant>()) {
78            return Ok(None);
79        }
80        if array.is_empty() {
81            Ok(Some(Canonical::empty(array.dtype()).into_array()))
82        } else {
83            let result = array.scalar_at(0)?;
84            Ok(Some(ConstantArray::new(result, array.len).into_array()))
85        }
86    }
87}
88
89#[derive(Debug)]
90struct ScalarFnSliceReduceRule;
91impl ArrayParentReduceRule<ScalarFnVTable> for ScalarFnSliceReduceRule {
92    type Parent = Slice;
93
94    fn reduce_parent(
95        &self,
96        array: &ScalarFnArray,
97        parent: &SliceArray,
98        _child_idx: usize,
99    ) -> VortexResult<Option<ArrayRef>> {
100        let range = parent.slice_range();
101
102        let children: Vec<_> = array
103            .children()
104            .iter()
105            .map(|c| c.slice(range.clone()))
106            .collect::<VortexResult<_>>()?;
107
108        Ok(Some(
109            ScalarFnArray {
110                vtable: array.vtable.clone(),
111                dtype: array.dtype.clone(),
112                len: range.len(),
113                children,
114                stats: Default::default(),
115            }
116            .into_array(),
117        ))
118    }
119}
120
121#[derive(Debug)]
122struct ScalarFnAbstractReduceRule;
123impl ArrayReduceRule<ScalarFnVTable> for ScalarFnAbstractReduceRule {
124    fn reduce(&self, array: &ScalarFnArray) -> VortexResult<Option<ArrayRef>> {
125        if let Some(reduced) = array
126            .scalar_fn()
127            .reduce(array, &ArrayReduceCtx { len: array.len })?
128        {
129            return Ok(Some(
130                reduced
131                    .as_any()
132                    .downcast_ref::<ArrayRef>()
133                    .vortex_expect("ReduceNode is not an ArrayRef")
134                    .clone(),
135            ));
136        }
137        Ok(None)
138    }
139}
140
141impl ReduceNode for ScalarFnArray {
142    fn as_any(&self) -> &dyn Any {
143        self
144    }
145
146    fn node_dtype(&self) -> VortexResult<DType> {
147        Ok(self.dtype().clone())
148    }
149
150    #[allow(clippy::same_name_method)]
151    fn scalar_fn(&self) -> Option<&ScalarFnRef> {
152        Some(ScalarFnArray::scalar_fn(self))
153    }
154
155    fn child(&self, idx: usize) -> ReduceNodeRef {
156        Arc::new(self.children()[idx].clone())
157    }
158
159    fn child_count(&self) -> usize {
160        self.children.len()
161    }
162}
163
164impl ReduceNode for ArrayRef {
165    fn as_any(&self) -> &dyn Any {
166        self
167    }
168
169    fn node_dtype(&self) -> VortexResult<DType> {
170        self.as_ref().node_dtype()
171    }
172
173    fn scalar_fn(&self) -> Option<&ScalarFnRef> {
174        self.as_ref().scalar_fn()
175    }
176
177    fn child(&self, idx: usize) -> ReduceNodeRef {
178        self.as_ref().child(idx)
179    }
180
181    fn child_count(&self) -> usize {
182        self.as_ref().child_count()
183    }
184}
185
186struct ArrayReduceCtx {
187    // The length of the array being reduced
188    len: usize,
189}
190impl ReduceCtx for ArrayReduceCtx {
191    fn new_node(
192        &self,
193        scalar_fn: ScalarFnRef,
194        children: &[ReduceNodeRef],
195    ) -> VortexResult<ReduceNodeRef> {
196        Ok(Arc::new(
197            ScalarFnArray::try_new(
198                scalar_fn,
199                children
200                    .iter()
201                    .map(|c| {
202                        c.as_any()
203                            .downcast_ref::<ArrayRef>()
204                            .vortex_expect("ReduceNode is not an ArrayRef")
205                            .clone()
206                    })
207                    .collect(),
208                self.len,
209            )?
210            .into_array(),
211        ))
212    }
213}
214
215#[derive(Debug)]
216struct ScalarFnUnaryFilterPushDownRule;
217
218impl ArrayParentReduceRule<ScalarFnVTable> for ScalarFnUnaryFilterPushDownRule {
219    type Parent = Filter;
220
221    fn reduce_parent(
222        &self,
223        child: &ScalarFnArray,
224        parent: &FilterArray,
225        _child_idx: usize,
226    ) -> VortexResult<Option<ArrayRef>> {
227        // If we only have one non-constant child, then it is _always_ cheaper to push down the
228        // filter over the children of the scalar function array.
229        if child
230            .children
231            .iter()
232            .filter(|c| !c.is::<Constant>())
233            .count()
234            == 1
235        {
236            let new_children: Vec<_> = child
237                .children
238                .iter()
239                .map(|c| match c.as_opt::<Constant>() {
240                    Some(array) => {
241                        Ok(ConstantArray::new(array.scalar().clone(), parent.len()).into_array())
242                    }
243                    None => c.filter(parent.filter_mask().clone()),
244                })
245                .try_collect()?;
246
247            let new_array =
248                ScalarFnArray::try_new(child.scalar_fn().clone(), new_children, parent.len())?
249                    .into_array();
250
251            return Ok(Some(new_array));
252        }
253
254        Ok(None)
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use vortex_error::VortexExpect;
261
262    use crate::array::IntoArray;
263    use crate::arrays::ChunkedArray;
264    use crate::arrays::PrimitiveArray;
265    use crate::arrays::scalar_fn::rules::ConstantArray;
266    use crate::dtype::DType;
267    use crate::dtype::Nullability;
268    use crate::dtype::PType;
269    use crate::expr::cast;
270    use crate::expr::is_null;
271    use crate::expr::root;
272
273    #[test]
274    fn test_empty_constants() {
275        let array = ChunkedArray::try_new(
276            vec![
277                ConstantArray::new(Some(1u64), 0).into_array(),
278                PrimitiveArray::from_iter(vec![2u64])
279                    .into_array()
280                    .apply(&cast(
281                        root(),
282                        DType::Primitive(PType::U64, Nullability::Nullable),
283                    ))
284                    .vortex_expect("casted"),
285            ],
286            DType::Primitive(PType::U64, Nullability::Nullable),
287        )
288        .vortex_expect("construction")
289        .into_array();
290
291        let expr = is_null(root());
292        array.apply(&expr).vortex_expect("expr evaluation");
293    }
294}