vortex_array/expr/
optimize.rs1use std::any::Any;
5use std::cell::RefCell;
6use std::ops::Deref;
7use std::sync::Arc;
8
9use itertools::Itertools;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_utils::aliases::hash_map::HashMap;
13
14use crate::dtype::DType;
15use crate::expr::Expression;
16use crate::expr::transform::match_between::find_between;
17use crate::scalar_fn::ReduceCtx;
18use crate::scalar_fn::ReduceNode;
19use crate::scalar_fn::ReduceNodeRef;
20use crate::scalar_fn::ScalarFnRef;
21use crate::scalar_fn::SimplifyCtx;
22use crate::scalar_fn::fns::root::Root;
23
24impl Expression {
25 pub fn optimize(&self, scope: &DType) -> VortexResult<Expression> {
32 let cache = SimplifyCache {
33 scope,
34 dtype_cache: RefCell::new(HashMap::new()),
35 };
36 Ok(self
37 .clone()
38 .try_optimize(scope, &cache)?
39 .unwrap_or_else(|| self.clone()))
40 }
41
42 fn try_optimize(
44 &self,
45 scope: &DType,
46 cache: &SimplifyCache<'_>,
47 ) -> VortexResult<Option<Expression>> {
48 let reduce_ctx = ExpressionReduceCtx {
49 scope: scope.clone(),
50 };
51
52 let mut current = self.clone();
53 let mut any_optimizations = false;
54 let mut loop_counter = 0;
55
56 loop {
57 if loop_counter > 100 {
58 vortex_error::vortex_bail!(
59 "Exceeded maximum optimization iterations (possible infinite loop)"
60 );
61 }
62 loop_counter += 1;
63
64 let mut changed = false;
65
66 if let Some(simplified) = current.scalar_fn().simplify_untyped(¤t)? {
68 current = simplified;
69 changed = true;
70 any_optimizations = true;
71 }
72
73 if let Some(simplified) = current.scalar_fn().simplify(¤t, cache)? {
75 current = simplified;
76 changed = true;
77 any_optimizations = true;
78 }
79
80 let reduce_node = ExpressionReduceNode {
82 expression: current.clone(),
83 scope: scope.clone(),
84 };
85 if let Some(reduced) = current.scalar_fn().reduce(&reduce_node, &reduce_ctx)? {
86 let reduced_expr = reduced
87 .as_any()
88 .downcast_ref::<ExpressionReduceNode>()
89 .vortex_expect("ReduceNode not an ExpressionReduceNode")
90 .expression
91 .clone();
92 current = reduced_expr;
93 changed = true;
94 any_optimizations = true;
95 }
96
97 if !changed {
98 break;
99 }
100 }
101
102 if any_optimizations {
103 Ok(Some(current))
104 } else {
105 Ok(None)
106 }
107 }
108
109 pub fn optimize_recursive(&self, scope: &DType) -> VortexResult<Expression> {
113 Ok(self
114 .clone()
115 .try_optimize_recursive(scope)?
116 .unwrap_or_else(|| self.clone()))
117 }
118
119 pub fn try_optimize_recursive(&self, scope: &DType) -> VortexResult<Option<Expression>> {
121 let cache = SimplifyCache {
122 scope,
123 dtype_cache: RefCell::new(HashMap::new()),
124 };
125 let result = self.try_optimize_recursive_inner(scope, &cache)?;
126
127 Ok(Some(find_between(result.unwrap_or_else(|| self.clone()))))
131 }
132
133 fn try_optimize_recursive_inner(
134 &self,
135 scope: &DType,
136 cache: &SimplifyCache<'_>,
137 ) -> VortexResult<Option<Expression>> {
138 let mut current = self.clone();
139 let mut any_optimizations = false;
140
141 if let Some(optimized) = current.clone().try_optimize(scope, cache)? {
143 current = optimized;
144 any_optimizations = true;
145 }
146
147 let mut new_children = Vec::with_capacity(current.children().len());
149 let mut any_child_optimized = false;
150 for child in current.children().iter() {
151 if let Some(optimized) = child.try_optimize_recursive_inner(scope, cache)? {
152 new_children.push(optimized);
153 any_child_optimized = true;
154 } else {
155 new_children.push(child.clone());
156 }
157 }
158
159 if any_child_optimized {
160 current = current.with_children(new_children)?;
161 any_optimizations = true;
162
163 if let Some(optimized) = current.clone().try_optimize(scope, cache)? {
165 current = optimized;
166 }
167 }
168
169 if any_optimizations {
170 Ok(Some(current))
171 } else {
172 Ok(None)
173 }
174 }
175
176 #[deprecated(note = "Use Expression::optimize_recursive instead")]
180 pub fn simplify(&self, scope: &DType) -> VortexResult<Expression> {
181 self.optimize_recursive(scope)
182 }
183
184 #[deprecated(note = "Use Expression::optimize_recursive instead")]
188 pub fn simplify_untyped(&self) -> VortexResult<Expression> {
189 fn inner(expr: &Expression) -> VortexResult<Option<Expression>> {
191 let children: Vec<_> = expr.children().iter().map(inner).try_collect()?;
192
193 if children.iter().any(|c| c.is_some()) {
194 let new_children: Vec<_> = children
195 .into_iter()
196 .zip(expr.children().iter())
197 .map(|(new_c, old_c)| new_c.unwrap_or_else(|| old_c.clone()))
198 .collect();
199
200 let new_expr = expr.clone().with_children(new_children)?;
201 Ok(Some(
202 new_expr
203 .scalar_fn()
204 .simplify_untyped(&new_expr)?
205 .unwrap_or(new_expr),
206 ))
207 } else {
208 expr.scalar_fn().simplify_untyped(expr)
209 }
210 }
211
212 let simplified = self
213 .scalar_fn()
214 .simplify_untyped(self)?
215 .unwrap_or_else(|| self.clone());
216
217 let simplified = inner(&simplified)?.unwrap_or(simplified);
218 let simplified = find_between(simplified);
219
220 Ok(simplified)
221 }
222}
223
224struct SimplifyCache<'a> {
225 scope: &'a DType,
226 dtype_cache: RefCell<HashMap<Expression, DType>>,
227}
228
229impl SimplifyCtx for SimplifyCache<'_> {
230 fn return_dtype(&self, expr: &Expression) -> VortexResult<DType> {
231 if expr.is::<Root>() {
233 return Ok(self.scope.clone());
234 }
235
236 if let Some(dtype) = self.dtype_cache.borrow().get(expr) {
237 return Ok(dtype.clone());
238 }
239
240 let input_dtypes: Vec<_> = expr
242 .children()
243 .iter()
244 .map(|c| self.return_dtype(c))
245 .try_collect()?;
246 let dtype = expr.deref().return_dtype(&input_dtypes)?;
247 self.dtype_cache
248 .borrow_mut()
249 .insert(expr.clone(), dtype.clone());
250
251 Ok(dtype)
252 }
253}
254
255struct ExpressionReduceNode {
256 expression: Expression,
257 scope: DType,
258}
259
260impl ReduceNode for ExpressionReduceNode {
261 fn as_any(&self) -> &dyn Any {
262 self
263 }
264
265 fn node_dtype(&self) -> VortexResult<DType> {
266 self.expression.return_dtype(&self.scope)
267 }
268
269 fn scalar_fn(&self) -> Option<&ScalarFnRef> {
270 Some(self.expression.scalar_fn())
271 }
272
273 fn child(&self, idx: usize) -> ReduceNodeRef {
274 Arc::new(ExpressionReduceNode {
275 expression: self.expression.child(idx).clone(),
276 scope: self.scope.clone(),
277 })
278 }
279
280 fn child_count(&self) -> usize {
281 self.expression.children().len()
282 }
283}
284
285struct ExpressionReduceCtx {
286 scope: DType,
287}
288impl ReduceCtx for ExpressionReduceCtx {
289 fn new_node(
290 &self,
291 scalar_fn: ScalarFnRef,
292 children: &[ReduceNodeRef],
293 ) -> VortexResult<ReduceNodeRef> {
294 let expression = Expression::try_new(
295 scalar_fn,
296 children
297 .iter()
298 .map(|c| {
299 c.as_any()
300 .downcast_ref::<ExpressionReduceNode>()
301 .vortex_expect("ReduceNode not an ExpressionReduceNode")
302 .expression
303 .clone()
304 })
305 .collect::<Vec<_>>(),
306 )?;
307
308 Ok(Arc::new(ExpressionReduceNode {
309 expression,
310 scope: self.scope.clone(),
311 }))
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use vortex_error::VortexResult;
318
319 use crate::dtype::DType;
320 use crate::dtype::Nullability;
321 use crate::dtype::PType;
322 use crate::dtype::StructFields;
323 use crate::expr::eq;
324 use crate::expr::get_item;
325 use crate::expr::lit;
326 use crate::expr::or;
327 use crate::expr::root;
328
329 #[test]
330 fn optimize_or_chain_correctness() -> VortexResult<()> {
331 let expr = or(
332 eq(get_item("x", root()), lit(1i32)),
333 eq(get_item("x", root()), lit(2i32)),
334 );
335 let scope = DType::Struct(
336 StructFields::new(
337 ["x"].into(),
338 vec![DType::Primitive(PType::I32, Nullability::NonNullable)],
339 ),
340 Nullability::NonNullable,
341 );
342 let optimized = expr.optimize_recursive(&scope)?;
343
344 let s = optimized.to_string();
345 assert!(s.contains("$.x"), "expected $.x in {s}");
346 assert!(s.contains("1i32") || s.contains('1'), "expected 1 in {s}");
347 assert!(s.contains("2i32") || s.contains('2'), "expected 2 in {s}");
348 Ok(())
349 }
350}