vortex_array/arrays/scalar_fn/
rules.rs1use std::any::Any;
5use std::sync::Arc;
6
7use itertools::Itertools;
8use vortex_dtype::DType;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_scalar::Scalar;
13use vortex_vector::Datum;
14use vortex_vector::VectorOps;
15use vortex_vector::datum_matches_dtype;
16
17use crate::Array;
18use crate::ArrayRef;
19use crate::ArrayVisitor;
20use crate::IntoArray;
21use crate::arrays::ConstantArray;
22use crate::arrays::ConstantVTable;
23use crate::arrays::FilterArray;
24use crate::arrays::FilterVTable;
25use crate::arrays::ScalarFnArray;
26use crate::arrays::ScalarFnVTable;
27use crate::arrays::StructArray;
28use crate::expr::ExecutionArgs;
29use crate::expr::Pack;
30use crate::expr::ReduceCtx;
31use crate::expr::ReduceNode;
32use crate::expr::ReduceNodeRef;
33use crate::expr::ScalarFn;
34use crate::matchers::Exact;
35use crate::optimizer::rules::ArrayParentReduceRule;
36use crate::optimizer::rules::ArrayReduceRule;
37use crate::optimizer::rules::ParentRuleSet;
38use crate::optimizer::rules::ReduceRuleSet;
39use crate::validity::Validity;
40
41pub(super) const RULES: ReduceRuleSet<ScalarFnVTable> = ReduceRuleSet::new(&[
42 &ScalarFnPackToStructRule,
43 &ScalarFnConstantRule,
44 &ScalarFnAbstractReduceRule,
45]);
46
47pub(super) const PARENT_RULES: ParentRuleSet<ScalarFnVTable> =
48 ParentRuleSet::new(&[ParentRuleSet::lift(&ScalarFnUnaryFilterPushDownRule)]);
49
50#[derive(Debug)]
52struct ScalarFnPackToStructRule;
53impl ArrayReduceRule<ScalarFnVTable> for ScalarFnPackToStructRule {
54 fn reduce(&self, array: &ScalarFnArray) -> VortexResult<Option<ArrayRef>> {
55 let Some(pack_options) = array.scalar_fn.as_opt::<Pack>() else {
56 return Ok(None);
57 };
58
59 let validity = match pack_options.nullability {
60 vortex_dtype::Nullability::NonNullable => Validity::NonNullable,
61 vortex_dtype::Nullability::Nullable => Validity::AllValid,
62 };
63
64 Ok(Some(
65 StructArray::try_new(
66 pack_options.names.clone(),
67 array.children.clone(),
68 array.len,
69 validity,
70 )?
71 .into_array(),
72 ))
73 }
74}
75
76#[derive(Debug)]
77struct ScalarFnConstantRule;
78impl ArrayReduceRule<ScalarFnVTable> for ScalarFnConstantRule {
79 fn reduce(&self, array: &ScalarFnArray) -> VortexResult<Option<ArrayRef>> {
80 if !array.children.iter().all(|c| c.is::<ConstantVTable>()) {
81 return Ok(None);
82 }
83
84 let input_datums: Vec<_> = array
85 .children
86 .iter()
87 .map(|c| c.as_::<ConstantVTable>().scalar().to_vector_scalar())
88 .map(Datum::Scalar)
89 .collect();
90 let input_dtypes = array.children.iter().map(|c| c.dtype().clone()).collect();
91
92 let result = array.scalar_fn.execute(ExecutionArgs {
93 datums: input_datums,
94 dtypes: input_dtypes,
95 row_count: array.len,
96 return_dtype: array.dtype.clone(),
97 })?;
98 vortex_ensure!(
99 datum_matches_dtype(&result, &array.dtype),
100 "Scalar function {} result does not match expected dtype",
101 array.scalar_fn
102 );
103
104 let result = match result {
105 Datum::Scalar(s) => s,
106 Datum::Vector(v) => {
107 tracing::info!(
108 "Scalar function {} returned vector from execution over all scalar inputs",
109 array.scalar_fn,
110 );
111 v.scalar_at(0)
112 }
113 };
114
115 Ok(Some(
116 ConstantArray::new(Scalar::from_vector_scalar(result, &array.dtype)?, array.len)
117 .into_array(),
118 ))
119 }
120}
121
122#[derive(Debug)]
123struct ScalarFnAbstractReduceRule;
124impl ArrayReduceRule<ScalarFnVTable> for ScalarFnAbstractReduceRule {
125 fn reduce(&self, array: &ScalarFnArray) -> VortexResult<Option<ArrayRef>> {
126 if let Some(reduced) = array.scalar_fn.reduce(
127 &array.to_array(),
129 &ArrayReduceCtx { len: array.len },
130 )? {
131 return Ok(Some(
132 reduced
133 .as_any()
134 .downcast_ref::<ArrayRef>()
135 .vortex_expect("ReduceNode is not an ArrayRef")
136 .clone(),
137 ));
138 }
139 Ok(None)
140 }
141}
142
143impl ReduceNode for ArrayRef {
144 fn as_any(&self) -> &dyn Any {
145 self
146 }
147
148 fn node_dtype(&self) -> VortexResult<DType> {
149 Ok(self.as_ref().dtype().clone())
150 }
151
152 fn scalar_fn(&self) -> Option<&ScalarFn> {
153 self.as_opt::<ScalarFnVTable>().map(|a| a.scalar_fn())
154 }
155
156 fn child(&self, idx: usize) -> ReduceNodeRef {
157 Arc::new(<dyn Array>::children(self)[idx].clone())
158 }
159
160 fn child_count(&self) -> usize {
161 self.nchildren()
162 }
163}
164
165struct ArrayReduceCtx {
166 len: usize,
168}
169impl ReduceCtx for ArrayReduceCtx {
170 fn new_node(
171 &self,
172 scalar_fn: ScalarFn,
173 children: &[ReduceNodeRef],
174 ) -> VortexResult<ReduceNodeRef> {
175 Ok(Arc::new(
176 ScalarFnArray::try_new(
177 scalar_fn,
178 children
179 .iter()
180 .map(|c| {
181 c.as_any()
182 .downcast_ref::<ArrayRef>()
183 .vortex_expect("ReduceNode is not an ArrayRef")
184 .clone()
185 })
186 .collect(),
187 self.len,
188 )?
189 .into_array(),
190 ))
191 }
192}
193
194#[derive(Debug)]
195struct ScalarFnUnaryFilterPushDownRule;
196
197impl ArrayParentReduceRule<ScalarFnVTable> for ScalarFnUnaryFilterPushDownRule {
198 type Parent = Exact<FilterVTable>;
199
200 fn parent(&self) -> Self::Parent {
201 Exact::from(&FilterVTable)
202 }
203
204 fn reduce_parent(
205 &self,
206 child: &ScalarFnArray,
207 parent: &FilterArray,
208 _child_idx: usize,
209 ) -> VortexResult<Option<ArrayRef>> {
210 if child
213 .children
214 .iter()
215 .filter(|c| !c.is::<ConstantVTable>())
216 .count()
217 == 1
218 {
219 let new_children: Vec<_> = child
220 .children
221 .iter()
222 .map(|c| match c.as_opt::<ConstantVTable>() {
223 Some(array) => {
224 Ok(ConstantArray::new(array.scalar().clone(), parent.len()).into_array())
225 }
226 None => c.filter(parent.filter_mask().clone()),
227 })
228 .try_collect()?;
229
230 let new_array =
231 ScalarFnArray::try_new(child.scalar_fn.clone(), new_children, parent.len())?
232 .into_array();
233
234 return Ok(Some(new_array));
235 }
236
237 Ok(None)
238 }
239}