vortex_array/arrays/scalar_fn/
rules.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::sync::Arc;
6
7use itertools::Itertools;
8use vortex_dtype::DType;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_scalar::Scalar;
13use vortex_vector::Datum;
14use vortex_vector::VectorOps;
15use vortex_vector::datum_matches_dtype;
16
17use crate::Array;
18use crate::ArrayRef;
19use crate::ArrayVisitor;
20use crate::IntoArray;
21use crate::arrays::ConstantArray;
22use crate::arrays::ConstantVTable;
23use crate::arrays::FilterArray;
24use crate::arrays::FilterVTable;
25use crate::arrays::ScalarFnArray;
26use crate::arrays::ScalarFnVTable;
27use crate::arrays::StructArray;
28use crate::expr::ExecutionArgs;
29use crate::expr::Pack;
30use crate::expr::ReduceCtx;
31use crate::expr::ReduceNode;
32use crate::expr::ReduceNodeRef;
33use crate::expr::ScalarFn;
34use crate::matchers::Exact;
35use crate::optimizer::rules::ArrayParentReduceRule;
36use crate::optimizer::rules::ArrayReduceRule;
37use crate::optimizer::rules::ParentRuleSet;
38use crate::optimizer::rules::ReduceRuleSet;
39use crate::validity::Validity;
40
41pub(super) const RULES: ReduceRuleSet<ScalarFnVTable> = ReduceRuleSet::new(&[
42    &ScalarFnPackToStructRule,
43    &ScalarFnConstantRule,
44    &ScalarFnAbstractReduceRule,
45]);
46
47pub(super) const PARENT_RULES: ParentRuleSet<ScalarFnVTable> =
48    ParentRuleSet::new(&[ParentRuleSet::lift(&ScalarFnUnaryFilterPushDownRule)]);
49
50/// Converts a ScalarFnArray with Pack into a StructArray directly.
51#[derive(Debug)]
52struct ScalarFnPackToStructRule;
53impl ArrayReduceRule<ScalarFnVTable> for ScalarFnPackToStructRule {
54    fn reduce(&self, array: &ScalarFnArray) -> VortexResult<Option<ArrayRef>> {
55        let Some(pack_options) = array.scalar_fn.as_opt::<Pack>() else {
56            return Ok(None);
57        };
58
59        let validity = match pack_options.nullability {
60            vortex_dtype::Nullability::NonNullable => Validity::NonNullable,
61            vortex_dtype::Nullability::Nullable => Validity::AllValid,
62        };
63
64        Ok(Some(
65            StructArray::try_new(
66                pack_options.names.clone(),
67                array.children.clone(),
68                array.len,
69                validity,
70            )?
71            .into_array(),
72        ))
73    }
74}
75
76#[derive(Debug)]
77struct ScalarFnConstantRule;
78impl ArrayReduceRule<ScalarFnVTable> for ScalarFnConstantRule {
79    fn reduce(&self, array: &ScalarFnArray) -> VortexResult<Option<ArrayRef>> {
80        if !array.children.iter().all(|c| c.is::<ConstantVTable>()) {
81            return Ok(None);
82        }
83
84        let input_datums: Vec<_> = array
85            .children
86            .iter()
87            .map(|c| c.as_::<ConstantVTable>().scalar().to_vector_scalar())
88            .map(Datum::Scalar)
89            .collect();
90        let input_dtypes = array.children.iter().map(|c| c.dtype().clone()).collect();
91
92        let result = array.scalar_fn.execute(ExecutionArgs {
93            datums: input_datums,
94            dtypes: input_dtypes,
95            row_count: array.len,
96            return_dtype: array.dtype.clone(),
97        })?;
98        vortex_ensure!(
99            datum_matches_dtype(&result, &array.dtype),
100            "Scalar function {} result does not match expected dtype",
101            array.scalar_fn
102        );
103
104        let result = match result {
105            Datum::Scalar(s) => s,
106            Datum::Vector(v) => {
107                tracing::info!(
108                    "Scalar function {} returned vector from execution over all scalar inputs",
109                    array.scalar_fn,
110                );
111                v.scalar_at(0)
112            }
113        };
114
115        Ok(Some(
116            ConstantArray::new(Scalar::from_vector_scalar(result, &array.dtype)?, array.len)
117                .into_array(),
118        ))
119    }
120}
121
122#[derive(Debug)]
123struct ScalarFnAbstractReduceRule;
124impl ArrayReduceRule<ScalarFnVTable> for ScalarFnAbstractReduceRule {
125    fn reduce(&self, array: &ScalarFnArray) -> VortexResult<Option<ArrayRef>> {
126        if let Some(reduced) = array.scalar_fn.reduce(
127            // Blergh, re-boxing
128            &array.to_array(),
129            &ArrayReduceCtx { len: array.len },
130        )? {
131            return Ok(Some(
132                reduced
133                    .as_any()
134                    .downcast_ref::<ArrayRef>()
135                    .vortex_expect("ReduceNode is not an ArrayRef")
136                    .clone(),
137            ));
138        }
139        Ok(None)
140    }
141}
142
143impl ReduceNode for ArrayRef {
144    fn as_any(&self) -> &dyn Any {
145        self
146    }
147
148    fn node_dtype(&self) -> VortexResult<DType> {
149        Ok(self.as_ref().dtype().clone())
150    }
151
152    fn scalar_fn(&self) -> Option<&ScalarFn> {
153        self.as_opt::<ScalarFnVTable>().map(|a| a.scalar_fn())
154    }
155
156    fn child(&self, idx: usize) -> ReduceNodeRef {
157        Arc::new(<dyn Array>::children(self)[idx].clone())
158    }
159
160    fn child_count(&self) -> usize {
161        self.nchildren()
162    }
163}
164
165struct ArrayReduceCtx {
166    // The length of the array being reduced
167    len: usize,
168}
169impl ReduceCtx for ArrayReduceCtx {
170    fn new_node(
171        &self,
172        scalar_fn: ScalarFn,
173        children: &[ReduceNodeRef],
174    ) -> VortexResult<ReduceNodeRef> {
175        Ok(Arc::new(
176            ScalarFnArray::try_new(
177                scalar_fn,
178                children
179                    .iter()
180                    .map(|c| {
181                        c.as_any()
182                            .downcast_ref::<ArrayRef>()
183                            .vortex_expect("ReduceNode is not an ArrayRef")
184                            .clone()
185                    })
186                    .collect(),
187                self.len,
188            )?
189            .into_array(),
190        ))
191    }
192}
193
194#[derive(Debug)]
195struct ScalarFnUnaryFilterPushDownRule;
196
197impl ArrayParentReduceRule<ScalarFnVTable> for ScalarFnUnaryFilterPushDownRule {
198    type Parent = Exact<FilterVTable>;
199
200    fn parent(&self) -> Self::Parent {
201        Exact::from(&FilterVTable)
202    }
203
204    fn reduce_parent(
205        &self,
206        child: &ScalarFnArray,
207        parent: &FilterArray,
208        _child_idx: usize,
209    ) -> VortexResult<Option<ArrayRef>> {
210        // If we only have one non-constant child, then it is _always_ cheaper to push down the
211        // filter over the children of the scalar function array.
212        if child
213            .children
214            .iter()
215            .filter(|c| !c.is::<ConstantVTable>())
216            .count()
217            == 1
218        {
219            let new_children: Vec<_> = child
220                .children
221                .iter()
222                .map(|c| match c.as_opt::<ConstantVTable>() {
223                    Some(array) => {
224                        Ok(ConstantArray::new(array.scalar().clone(), parent.len()).into_array())
225                    }
226                    None => c.filter(parent.filter_mask().clone()),
227                })
228                .try_collect()?;
229
230            let new_array =
231                ScalarFnArray::try_new(child.scalar_fn.clone(), new_children, parent.len())?
232                    .into_array();
233
234            return Ok(Some(new_array));
235        }
236
237        Ok(None)
238    }
239}