Skip to main content

vortex_array/arrays/scalar_fn/
rules.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::sync::Arc;
6
7use itertools::Itertools;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10
11use crate::ArrayRef;
12use crate::Canonical;
13use crate::IntoArray;
14use crate::LEGACY_SESSION;
15use crate::VortexSessionExecute;
16use crate::array::ArrayView;
17use crate::arrays::Constant;
18use crate::arrays::ConstantArray;
19use crate::arrays::Filter;
20use crate::arrays::ScalarFnArray;
21use crate::arrays::ScalarFnVTable;
22use crate::arrays::Slice;
23use crate::arrays::StructArray;
24use crate::arrays::scalar_fn::ScalarFnArrayExt;
25use crate::dtype::DType;
26use crate::optimizer::rules::ArrayParentReduceRule;
27use crate::optimizer::rules::ArrayReduceRule;
28use crate::optimizer::rules::ParentRuleSet;
29use crate::optimizer::rules::ReduceRuleSet;
30use crate::scalar_fn::ReduceCtx;
31use crate::scalar_fn::ReduceNode;
32use crate::scalar_fn::ReduceNodeRef;
33use crate::scalar_fn::ScalarFnRef;
34use crate::scalar_fn::fns::pack::Pack;
35use crate::validity::Validity;
36
37pub(super) const RULES: ReduceRuleSet<ScalarFnVTable> = ReduceRuleSet::new(&[
38    &ScalarFnPackToStructRule,
39    &ScalarFnConstantRule,
40    &ScalarFnAbstractReduceRule,
41]);
42
43pub(super) const PARENT_RULES: ParentRuleSet<ScalarFnVTable> = ParentRuleSet::new(&[
44    ParentRuleSet::lift(&ScalarFnUnaryFilterPushDownRule),
45    ParentRuleSet::lift(&ScalarFnSliceReduceRule),
46]);
47
48/// Converts a ScalarFnArray with Pack into a StructArray directly.
49#[derive(Debug)]
50struct ScalarFnPackToStructRule;
51impl ArrayReduceRule<ScalarFnVTable> for ScalarFnPackToStructRule {
52    fn reduce(&self, array: ArrayView<'_, ScalarFnVTable>) -> VortexResult<Option<ArrayRef>> {
53        let Some(pack_options) = array.scalar_fn().as_opt::<Pack>() else {
54            return Ok(None);
55        };
56
57        let validity = match pack_options.nullability {
58            crate::dtype::Nullability::NonNullable => Validity::NonNullable,
59            crate::dtype::Nullability::Nullable => Validity::AllValid,
60        };
61
62        Ok(Some(
63            StructArray::try_new(
64                pack_options.names.clone(),
65                array.children(),
66                array.len(),
67                validity,
68            )?
69            .into_array(),
70        ))
71    }
72}
73
74#[derive(Debug)]
75struct ScalarFnConstantRule;
76impl ArrayReduceRule<ScalarFnVTable> for ScalarFnConstantRule {
77    fn reduce(&self, array: ArrayView<'_, ScalarFnVTable>) -> VortexResult<Option<ArrayRef>> {
78        if !array.children().iter().all(|c| c.is::<Constant>()) {
79            return Ok(None);
80        }
81        if array.is_empty() {
82            Ok(Some(Canonical::empty(array.dtype()).into_array()))
83        } else {
84            let result = array
85                .array()
86                .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())?;
87            Ok(Some(ConstantArray::new(result, array.len()).into_array()))
88        }
89    }
90}
91
92#[derive(Debug)]
93struct ScalarFnSliceReduceRule;
94impl ArrayParentReduceRule<ScalarFnVTable> for ScalarFnSliceReduceRule {
95    type Parent = Slice;
96
97    fn reduce_parent(
98        &self,
99        array: ArrayView<'_, ScalarFnVTable>,
100        parent: ArrayView<'_, Slice>,
101        _child_idx: usize,
102    ) -> VortexResult<Option<ArrayRef>> {
103        let range = parent.slice_range();
104
105        let children: Vec<_> = array
106            .iter_children()
107            .map(|c| c.slice(range.clone()))
108            .collect::<VortexResult<_>>()?;
109
110        Ok(Some(
111            ScalarFnArray::try_new(array.scalar_fn().clone(), children, range.len())?.into_array(),
112        ))
113    }
114}
115
116#[derive(Debug)]
117struct ScalarFnAbstractReduceRule;
118impl ArrayReduceRule<ScalarFnVTable> for ScalarFnAbstractReduceRule {
119    fn reduce(&self, array: ArrayView<'_, ScalarFnVTable>) -> VortexResult<Option<ArrayRef>> {
120        if let Some(reduced) = array
121            .scalar_fn()
122            .reduce(array.as_ref(), &ArrayReduceCtx { len: array.len() })?
123        {
124            return Ok(Some(
125                reduced
126                    .as_any()
127                    .downcast_ref::<ArrayRef>()
128                    .vortex_expect("ReduceNode is not an ArrayRef")
129                    .clone(),
130            ));
131        }
132        Ok(None)
133    }
134}
135
136impl ReduceNode for ArrayRef {
137    fn as_any(&self) -> &dyn Any {
138        self
139    }
140
141    fn node_dtype(&self) -> VortexResult<DType> {
142        Ok(self.dtype().clone())
143    }
144
145    fn scalar_fn(&self) -> Option<&ScalarFnRef> {
146        self.as_opt::<ScalarFnVTable>()
147            .map(|a| a.data().scalar_fn())
148    }
149
150    fn child(&self, idx: usize) -> ReduceNodeRef {
151        Arc::new(self.nth_child(idx).vortex_expect("child idx out of bounds"))
152    }
153
154    fn child_count(&self) -> usize {
155        self.nchildren()
156    }
157}
158
159struct ArrayReduceCtx {
160    // The length of the array being reduced
161    len: usize,
162}
163impl ReduceCtx for ArrayReduceCtx {
164    fn new_node(
165        &self,
166        scalar_fn: ScalarFnRef,
167        children: &[ReduceNodeRef],
168    ) -> VortexResult<ReduceNodeRef> {
169        Ok(Arc::new(
170            ScalarFnArray::try_new(
171                scalar_fn,
172                children
173                    .iter()
174                    .map(|c| {
175                        c.as_any()
176                            .downcast_ref::<ArrayRef>()
177                            .vortex_expect("ReduceNode is not an ArrayRef")
178                            .clone()
179                    })
180                    .collect(),
181                self.len,
182            )?
183            .into_array(),
184        ))
185    }
186}
187
188#[derive(Debug)]
189struct ScalarFnUnaryFilterPushDownRule;
190
191impl ArrayParentReduceRule<ScalarFnVTable> for ScalarFnUnaryFilterPushDownRule {
192    type Parent = Filter;
193
194    fn reduce_parent(
195        &self,
196        child: ArrayView<'_, ScalarFnVTable>,
197        parent: ArrayView<'_, Filter>,
198        _child_idx: usize,
199    ) -> VortexResult<Option<ArrayRef>> {
200        // If we only have one non-constant child, then it is _always_ cheaper to push down the
201        // filter over the children of the scalar function array.
202        if child
203            .iter_children()
204            .filter(|c| !c.is::<Constant>())
205            .count()
206            == 1
207        {
208            let new_children: Vec<_> = child
209                .iter_children()
210                .map(|c| match c.as_opt::<Constant>() {
211                    Some(array) => {
212                        Ok(ConstantArray::new(array.scalar().clone(), parent.len()).into_array())
213                    }
214                    None => c.filter(parent.filter_mask().clone()),
215                })
216                .try_collect()?;
217
218            let new_array =
219                ScalarFnArray::try_new(child.scalar_fn().clone(), new_children, parent.len())?
220                    .into_array();
221
222            return Ok(Some(new_array));
223        }
224
225        Ok(None)
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use vortex_error::VortexExpect;
232
233    use crate::array::IntoArray;
234    use crate::arrays::ChunkedArray;
235    use crate::arrays::PrimitiveArray;
236    use crate::arrays::scalar_fn::rules::ConstantArray;
237    use crate::dtype::DType;
238    use crate::dtype::Nullability;
239    use crate::dtype::PType;
240    use crate::expr::cast;
241    use crate::expr::is_null;
242    use crate::expr::root;
243
244    #[test]
245    fn test_empty_constants() {
246        let array = ChunkedArray::try_new(
247            vec![
248                ConstantArray::new(Some(1u64), 0).into_array(),
249                PrimitiveArray::from_iter(vec![2u64])
250                    .into_array()
251                    .apply(&cast(
252                        root(),
253                        DType::Primitive(PType::U64, Nullability::Nullable),
254                    ))
255                    .vortex_expect("casted"),
256            ],
257            DType::Primitive(PType::U64, Nullability::Nullable),
258        )
259        .vortex_expect("construction")
260        .into_array();
261
262        let expr = is_null(root());
263        array.apply(&expr).vortex_expect("expr evaluation");
264    }
265}