Skip to main content

vortex_array/arrays/scalar_fn/
rules.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::sync::Arc;
6
7use itertools::Itertools;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10
11use crate::ArrayRef;
12use crate::IntoArray;
13use crate::array::ArrayView;
14use crate::arrays::Constant;
15use crate::arrays::ConstantArray;
16use crate::arrays::Filter;
17use crate::arrays::ScalarFn;
18use crate::arrays::ScalarFnArray;
19use crate::arrays::Slice;
20use crate::arrays::StructArray;
21use crate::arrays::scalar_fn::ScalarFnArrayExt;
22use crate::dtype::DType;
23use crate::optimizer::rules::ArrayParentReduceRule;
24use crate::optimizer::rules::ArrayReduceRule;
25use crate::optimizer::rules::ParentRuleSet;
26use crate::optimizer::rules::ReduceRuleSet;
27use crate::scalar_fn::ReduceCtx;
28use crate::scalar_fn::ReduceNode;
29use crate::scalar_fn::ReduceNodeRef;
30use crate::scalar_fn::ScalarFnRef;
31use crate::scalar_fn::fns::pack::Pack;
32use crate::validity::Validity;
33
34pub(super) const RULES: ReduceRuleSet<ScalarFn> =
35    ReduceRuleSet::new(&[&ScalarFnPackToStructRule, &ScalarFnAbstractReduceRule]);
36
37pub(super) const PARENT_RULES: ParentRuleSet<ScalarFn> = ParentRuleSet::new(&[
38    ParentRuleSet::lift(&ScalarFnUnaryFilterPushDownRule),
39    ParentRuleSet::lift(&ScalarFnSliceReduceRule),
40]);
41
42/// Converts a ScalarFnArray with Pack into a StructArray directly.
43#[derive(Debug)]
44struct ScalarFnPackToStructRule;
45impl ArrayReduceRule<ScalarFn> for ScalarFnPackToStructRule {
46    fn reduce(&self, array: ArrayView<'_, ScalarFn>) -> VortexResult<Option<ArrayRef>> {
47        let Some(pack_options) = array.scalar_fn().as_opt::<Pack>() else {
48            return Ok(None);
49        };
50
51        let validity = match pack_options.nullability {
52            crate::dtype::Nullability::NonNullable => Validity::NonNullable,
53            crate::dtype::Nullability::Nullable => Validity::AllValid,
54        };
55
56        Ok(Some(
57            StructArray::try_new(
58                pack_options.names.clone(),
59                array.children(),
60                array.len(),
61                validity,
62            )?
63            .into_array(),
64        ))
65    }
66}
67
68#[derive(Debug)]
69struct ScalarFnSliceReduceRule;
70impl ArrayParentReduceRule<ScalarFn> for ScalarFnSliceReduceRule {
71    type Parent = Slice;
72
73    fn reduce_parent(
74        &self,
75        array: ArrayView<'_, ScalarFn>,
76        parent: ArrayView<'_, Slice>,
77        _child_idx: usize,
78    ) -> VortexResult<Option<ArrayRef>> {
79        let range = parent.slice_range();
80
81        let children: Vec<_> = array
82            .iter_children()
83            .map(|c| c.slice(range.clone()))
84            .collect::<VortexResult<_>>()?;
85
86        Ok(Some(
87            ScalarFnArray::try_new(array.scalar_fn().clone(), children, range.len())?.into_array(),
88        ))
89    }
90}
91
92#[derive(Debug)]
93struct ScalarFnAbstractReduceRule;
94impl ArrayReduceRule<ScalarFn> for ScalarFnAbstractReduceRule {
95    fn reduce(&self, array: ArrayView<'_, ScalarFn>) -> VortexResult<Option<ArrayRef>> {
96        if let Some(reduced) = array
97            .scalar_fn()
98            .reduce(array.as_ref(), &ArrayReduceCtx { len: array.len() })?
99        {
100            return Ok(Some(
101                reduced
102                    .as_any()
103                    .downcast_ref::<ArrayRef>()
104                    .vortex_expect("ReduceNode is not an ArrayRef")
105                    .clone(),
106            ));
107        }
108        Ok(None)
109    }
110}
111
112impl ReduceNode for ArrayRef {
113    fn as_any(&self) -> &dyn Any {
114        self
115    }
116
117    fn node_dtype(&self) -> VortexResult<DType> {
118        Ok(self.dtype().clone())
119    }
120
121    fn scalar_fn(&self) -> Option<&ScalarFnRef> {
122        self.as_opt::<ScalarFn>().map(|a| a.data().scalar_fn())
123    }
124
125    fn child(&self, idx: usize) -> ReduceNodeRef {
126        Arc::new(self.nth_child(idx).vortex_expect("child idx out of bounds"))
127    }
128
129    fn child_count(&self) -> usize {
130        self.nchildren()
131    }
132}
133
134struct ArrayReduceCtx {
135    // The length of the array being reduced
136    len: usize,
137}
138impl ReduceCtx for ArrayReduceCtx {
139    fn new_node(
140        &self,
141        scalar_fn: ScalarFnRef,
142        children: &[ReduceNodeRef],
143    ) -> VortexResult<ReduceNodeRef> {
144        Ok(Arc::new(
145            ScalarFnArray::try_new(
146                scalar_fn,
147                children
148                    .iter()
149                    .map(|c| {
150                        c.as_any()
151                            .downcast_ref::<ArrayRef>()
152                            .vortex_expect("ReduceNode is not an ArrayRef")
153                            .clone()
154                    })
155                    .collect(),
156                self.len,
157            )?
158            .into_array(),
159        ))
160    }
161}
162
163#[derive(Debug)]
164struct ScalarFnUnaryFilterPushDownRule;
165
166impl ArrayParentReduceRule<ScalarFn> for ScalarFnUnaryFilterPushDownRule {
167    type Parent = Filter;
168
169    fn reduce_parent(
170        &self,
171        child: ArrayView<'_, ScalarFn>,
172        parent: ArrayView<'_, Filter>,
173        _child_idx: usize,
174    ) -> VortexResult<Option<ArrayRef>> {
175        // If we only have one non-constant child, then it is _always_ cheaper to push down the
176        // filter over the children of the scalar function array.
177        if child
178            .iter_children()
179            .filter(|c| !c.is::<Constant>())
180            .count()
181            == 1
182        {
183            let new_children: Vec<_> = child
184                .iter_children()
185                .map(|c| match c.as_opt::<Constant>() {
186                    Some(array) => {
187                        Ok(ConstantArray::new(array.scalar().clone(), parent.len()).into_array())
188                    }
189                    None => c.filter(parent.filter_mask().clone()),
190                })
191                .try_collect()?;
192
193            let new_array =
194                ScalarFnArray::try_new(child.scalar_fn().clone(), new_children, parent.len())?
195                    .into_array();
196
197            return Ok(Some(new_array));
198        }
199
200        Ok(None)
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use vortex_error::VortexExpect;
207
208    use crate::array::IntoArray;
209    use crate::arrays::ChunkedArray;
210    use crate::arrays::PrimitiveArray;
211    use crate::arrays::scalar_fn::rules::ConstantArray;
212    use crate::dtype::DType;
213    use crate::dtype::Nullability;
214    use crate::dtype::PType;
215    use crate::expr::cast;
216    use crate::expr::is_null;
217    use crate::expr::root;
218
219    #[test]
220    fn test_empty_constants() {
221        let array = ChunkedArray::try_new(
222            vec![
223                ConstantArray::new(Some(1u64), 0).into_array(),
224                PrimitiveArray::from_iter(vec![2u64])
225                    .into_array()
226                    .apply(&cast(
227                        root(),
228                        DType::Primitive(PType::U64, Nullability::Nullable),
229                    ))
230                    .vortex_expect("casted"),
231            ],
232            DType::Primitive(PType::U64, Nullability::Nullable),
233        )
234        .vortex_expect("construction")
235        .into_array();
236
237        let expr = is_null(root());
238        array.apply(&expr).vortex_expect("expr evaluation");
239    }
240}