Skip to main content

vortex_array/arrays/scalar_fn/
rules.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::sync::Arc;
6
7use itertools::Itertools;
8use vortex_dtype::DType;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11
12use crate::Array;
13use crate::ArrayRef;
14use crate::Canonical;
15use crate::IntoArray;
16use crate::arrays::ConstantArray;
17use crate::arrays::ConstantVTable;
18use crate::arrays::FilterArray;
19use crate::arrays::FilterVTable;
20use crate::arrays::ScalarFnArray;
21use crate::arrays::ScalarFnVTable;
22use crate::arrays::SliceReduceAdaptor;
23use crate::arrays::StructArray;
24use crate::expr::Pack;
25use crate::expr::ReduceCtx;
26use crate::expr::ReduceNode;
27use crate::expr::ReduceNodeRef;
28use crate::expr::ScalarFn;
29use crate::optimizer::rules::ArrayParentReduceRule;
30use crate::optimizer::rules::ArrayReduceRule;
31use crate::optimizer::rules::ParentRuleSet;
32use crate::optimizer::rules::ReduceRuleSet;
33use crate::validity::Validity;
34
35pub(super) const RULES: ReduceRuleSet<ScalarFnVTable> = ReduceRuleSet::new(&[
36    &ScalarFnPackToStructRule,
37    &ScalarFnConstantRule,
38    &ScalarFnAbstractReduceRule,
39]);
40
41pub(super) const PARENT_RULES: ParentRuleSet<ScalarFnVTable> = ParentRuleSet::new(&[
42    ParentRuleSet::lift(&ScalarFnUnaryFilterPushDownRule),
43    ParentRuleSet::lift(&SliceReduceAdaptor(ScalarFnVTable)),
44]);
45
46/// Converts a ScalarFnArray with Pack into a StructArray directly.
47#[derive(Debug)]
48struct ScalarFnPackToStructRule;
49impl ArrayReduceRule<ScalarFnVTable> for ScalarFnPackToStructRule {
50    fn reduce(&self, array: &ScalarFnArray) -> VortexResult<Option<ArrayRef>> {
51        let Some(pack_options) = array.scalar_fn.as_opt::<Pack>() else {
52            return Ok(None);
53        };
54
55        let validity = match pack_options.nullability {
56            vortex_dtype::Nullability::NonNullable => Validity::NonNullable,
57            vortex_dtype::Nullability::Nullable => Validity::AllValid,
58        };
59
60        Ok(Some(
61            StructArray::try_new(
62                pack_options.names.clone(),
63                array.children.clone(),
64                array.len,
65                validity,
66            )?
67            .into_array(),
68        ))
69    }
70}
71
72#[derive(Debug)]
73struct ScalarFnConstantRule;
74impl ArrayReduceRule<ScalarFnVTable> for ScalarFnConstantRule {
75    fn reduce(&self, array: &ScalarFnArray) -> VortexResult<Option<ArrayRef>> {
76        if !array.children.iter().all(|c| c.is::<ConstantVTable>()) {
77            return Ok(None);
78        }
79        if array.is_empty() {
80            Ok(Some(Canonical::empty(array.dtype()).into_array()))
81        } else {
82            let result = array.scalar_at(0)?;
83            Ok(Some(ConstantArray::new(result, array.len).into_array()))
84        }
85    }
86}
87
88#[derive(Debug)]
89struct ScalarFnAbstractReduceRule;
90impl ArrayReduceRule<ScalarFnVTable> for ScalarFnAbstractReduceRule {
91    fn reduce(&self, array: &ScalarFnArray) -> VortexResult<Option<ArrayRef>> {
92        if let Some(reduced) = array
93            .scalar_fn
94            .reduce(array, &ArrayReduceCtx { len: array.len })?
95        {
96            return Ok(Some(
97                reduced
98                    .as_any()
99                    .downcast_ref::<ArrayRef>()
100                    .vortex_expect("ReduceNode is not an ArrayRef")
101                    .clone(),
102            ));
103        }
104        Ok(None)
105    }
106}
107
108impl ReduceNode for ScalarFnArray {
109    fn as_any(&self) -> &dyn Any {
110        self
111    }
112
113    fn node_dtype(&self) -> VortexResult<DType> {
114        Ok(self.dtype().clone())
115    }
116
117    #[allow(clippy::same_name_method)]
118    fn scalar_fn(&self) -> Option<&ScalarFn> {
119        Some(ScalarFnArray::scalar_fn(self))
120    }
121
122    fn child(&self, idx: usize) -> ReduceNodeRef {
123        Arc::new(self.children()[idx].clone())
124    }
125
126    fn child_count(&self) -> usize {
127        self.children.len()
128    }
129}
130
131impl ReduceNode for ArrayRef {
132    fn as_any(&self) -> &dyn Any {
133        self
134    }
135
136    fn node_dtype(&self) -> VortexResult<DType> {
137        self.as_ref().node_dtype()
138    }
139
140    fn scalar_fn(&self) -> Option<&ScalarFn> {
141        self.as_ref().scalar_fn()
142    }
143
144    fn child(&self, idx: usize) -> ReduceNodeRef {
145        self.as_ref().child(idx)
146    }
147
148    fn child_count(&self) -> usize {
149        self.as_ref().child_count()
150    }
151}
152
153struct ArrayReduceCtx {
154    // The length of the array being reduced
155    len: usize,
156}
157impl ReduceCtx for ArrayReduceCtx {
158    fn new_node(
159        &self,
160        scalar_fn: ScalarFn,
161        children: &[ReduceNodeRef],
162    ) -> VortexResult<ReduceNodeRef> {
163        Ok(Arc::new(
164            ScalarFnArray::try_new(
165                scalar_fn,
166                children
167                    .iter()
168                    .map(|c| {
169                        c.as_any()
170                            .downcast_ref::<ArrayRef>()
171                            .vortex_expect("ReduceNode is not an ArrayRef")
172                            .clone()
173                    })
174                    .collect(),
175                self.len,
176            )?
177            .into_array(),
178        ))
179    }
180}
181
182#[derive(Debug)]
183struct ScalarFnUnaryFilterPushDownRule;
184
185impl ArrayParentReduceRule<ScalarFnVTable> for ScalarFnUnaryFilterPushDownRule {
186    type Parent = FilterVTable;
187
188    fn reduce_parent(
189        &self,
190        child: &ScalarFnArray,
191        parent: &FilterArray,
192        _child_idx: usize,
193    ) -> VortexResult<Option<ArrayRef>> {
194        // If we only have one non-constant child, then it is _always_ cheaper to push down the
195        // filter over the children of the scalar function array.
196        if child
197            .children
198            .iter()
199            .filter(|c| !c.is::<ConstantVTable>())
200            .count()
201            == 1
202        {
203            let new_children: Vec<_> = child
204                .children
205                .iter()
206                .map(|c| match c.as_opt::<ConstantVTable>() {
207                    Some(array) => {
208                        Ok(ConstantArray::new(array.scalar().clone(), parent.len()).into_array())
209                    }
210                    None => c.filter(parent.filter_mask().clone()),
211                })
212                .try_collect()?;
213
214            let new_array =
215                ScalarFnArray::try_new(child.scalar_fn.clone(), new_children, parent.len())?
216                    .into_array();
217
218            return Ok(Some(new_array));
219        }
220
221        Ok(None)
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use vortex_dtype::DType;
228    use vortex_dtype::Nullability;
229    use vortex_dtype::PType;
230    use vortex_error::VortexExpect;
231
232    use crate::array::IntoArray;
233    use crate::arrays::ChunkedArray;
234    use crate::arrays::ConstantArray;
235    use crate::arrays::PrimitiveArray;
236    use crate::expr::cast;
237    use crate::expr::is_null;
238    use crate::expr::root;
239
240    #[test]
241    fn test_empty_constants() {
242        let array = ChunkedArray::try_new(
243            vec![
244                ConstantArray::new(Some(1u64), 0).into_array(),
245                PrimitiveArray::from_iter(vec![2u64])
246                    .into_array()
247                    .apply(&cast(
248                        root(),
249                        DType::Primitive(PType::U64, Nullability::Nullable),
250                    ))
251                    .vortex_expect("casted"),
252            ],
253            DType::Primitive(PType::U64, Nullability::Nullable),
254        )
255        .vortex_expect("construction")
256        .to_array();
257
258        let expr = is_null(root());
259        array.apply(&expr).vortex_expect("expr evaluation");
260    }
261}