vortex_array/expr/
optimize.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::cell::RefCell;
6use std::ops::Deref;
7use std::sync::Arc;
8
9use itertools::Itertools;
10use vortex_dtype::DType;
11use vortex_error::VortexExpect;
12use vortex_error::VortexResult;
13use vortex_utils::aliases::hash_map::HashMap;
14
15use crate::expr::Expression;
16use crate::expr::ReduceCtx;
17use crate::expr::ReduceNode;
18use crate::expr::ReduceNodeRef;
19use crate::expr::Root;
20use crate::expr::ScalarFn;
21use crate::expr::SimplifyCtx;
22use crate::expr::transform::match_between::find_between;
23
24impl Expression {
25    /// Optimize the root expression node only, iterating to convergence.
26    ///
27    /// This applies optimization rules repeatedly until no more changes occur:
28    /// 1. `simplify_untyped` - type-independent simplifications
29    /// 2. `simplify` - type-aware simplifications
30    /// 3. `reduce` - abstract reduction rules via `ReduceNode`/`ReduceCtx`
31    pub fn optimize(&self, scope: &DType) -> VortexResult<Expression> {
32        Ok(self
33            .clone()
34            .try_optimize(scope)?
35            .unwrap_or_else(|| self.clone()))
36    }
37
38    /// Try to optimize the root expression node only, returning None if no optimizations applied.
39    pub fn try_optimize(&self, scope: &DType) -> VortexResult<Option<Expression>> {
40        let cache = SimplifyCache {
41            scope,
42            dtype_cache: RefCell::new(HashMap::new()),
43        };
44        let reduce_ctx = ExpressionReduceCtx {
45            scope: scope.clone(),
46        };
47
48        let mut current = self.clone();
49        let mut any_optimizations = false;
50        let mut loop_counter = 0;
51
52        loop {
53            if loop_counter > 100 {
54                vortex_error::vortex_bail!(
55                    "Exceeded maximum optimization iterations (possible infinite loop)"
56                );
57            }
58            loop_counter += 1;
59
60            let mut changed = false;
61
62            // Try simplify_untyped
63            if let Some(simplified) = current.vtable().as_dyn().simplify_untyped(&current)? {
64                current = simplified;
65                changed = true;
66                any_optimizations = true;
67            }
68
69            // Try simplify (typed)
70            if let Some(simplified) = current.vtable().as_dyn().simplify(&current, &cache)? {
71                current = simplified;
72                changed = true;
73                any_optimizations = true;
74            }
75
76            // Try reduce via ReduceNode/ReduceCtx
77            let reduce_node = ExpressionReduceNode {
78                expression: current.clone(),
79                scope: scope.clone(),
80            };
81            if let Some(reduced) = current.scalar_fn().reduce(&reduce_node, &reduce_ctx)? {
82                let reduced_expr = reduced
83                    .as_any()
84                    .downcast_ref::<ExpressionReduceNode>()
85                    .vortex_expect("ReduceNode not an ExpressionReduceNode")
86                    .expression
87                    .clone();
88                current = reduced_expr;
89                changed = true;
90                any_optimizations = true;
91            }
92
93            if !changed {
94                break;
95            }
96        }
97
98        if any_optimizations {
99            Ok(Some(current))
100        } else {
101            Ok(None)
102        }
103    }
104
105    /// Optimize the entire expression tree recursively.
106    ///
107    /// Optimizes children first (bottom-up), then optimizes the root.
108    pub fn optimize_recursive(&self, scope: &DType) -> VortexResult<Expression> {
109        Ok(self
110            .clone()
111            .try_optimize_recursive(scope)?
112            .unwrap_or_else(|| self.clone()))
113    }
114
115    /// Try to optimize the entire expression tree recursively.
116    pub fn try_optimize_recursive(&self, scope: &DType) -> VortexResult<Option<Expression>> {
117        let mut current = self.clone();
118        let mut any_optimizations = false;
119
120        // First optimize the root
121        if let Some(optimized) = current.clone().try_optimize(scope)? {
122            current = optimized;
123            any_optimizations = true;
124        }
125
126        // Then recursively optimize children
127        let mut new_children = Vec::with_capacity(current.children().len());
128        let mut any_child_optimized = false;
129        for child in current.children().iter() {
130            if let Some(optimized) = child.try_optimize_recursive(scope)? {
131                new_children.push(optimized);
132                any_child_optimized = true;
133            } else {
134                new_children.push(child.clone());
135            }
136        }
137
138        if any_child_optimized {
139            current = current.with_children(new_children)?;
140            any_optimizations = true;
141
142            // After updating children, try to optimize root again
143            if let Some(optimized) = current.clone().try_optimize(scope)? {
144                current = optimized;
145            }
146        }
147
148        // TODO(ngates): remove the "between" optimization, or rewrite it to not always convert
149        //  to CNF?
150        let current = find_between(current);
151
152        if any_optimizations {
153            Ok(Some(current))
154        } else {
155            Ok(None)
156        }
157    }
158
159    /// Simplify the expression, returning a potentially new expression.
160    ///
161    /// Deprecated: Use [`Expression::optimize_recursive`] instead, which iterates to convergence.
162    #[deprecated(note = "Use Expression::optimize_recursive instead")]
163    pub fn simplify(&self, scope: &DType) -> VortexResult<Expression> {
164        self.optimize_recursive(scope)
165    }
166
167    /// Simplify the expression without type information.
168    ///
169    /// Deprecated: Use [`Expression::optimize_recursive`] instead.
170    #[deprecated(note = "Use Expression::optimize_recursive instead")]
171    pub fn simplify_untyped(&self) -> VortexResult<Expression> {
172        // For backwards compat, do a single bottom-up pass of untyped simplification
173        fn inner(expr: &Expression) -> VortexResult<Option<Expression>> {
174            let children: Vec<_> = expr.children().iter().map(inner).try_collect()?;
175
176            if children.iter().any(|c| c.is_some()) {
177                let new_children: Vec<_> = children
178                    .into_iter()
179                    .zip(expr.children().iter())
180                    .map(|(new_c, old_c)| new_c.unwrap_or_else(|| old_c.clone()))
181                    .collect();
182
183                let new_expr = expr.clone().with_children(new_children)?;
184                Ok(Some(
185                    new_expr
186                        .vtable()
187                        .as_dyn()
188                        .simplify_untyped(&new_expr)?
189                        .unwrap_or(new_expr),
190                ))
191            } else {
192                expr.vtable().as_dyn().simplify_untyped(expr)
193            }
194        }
195
196        let simplified = self
197            .vtable()
198            .as_dyn()
199            .simplify_untyped(self)?
200            .unwrap_or_else(|| self.clone());
201
202        let simplified = inner(&simplified)?.unwrap_or(simplified);
203        let simplified = find_between(simplified);
204
205        Ok(simplified)
206    }
207}
208
209struct SimplifyCache<'a> {
210    scope: &'a DType,
211    dtype_cache: RefCell<HashMap<Expression, DType>>,
212}
213
214impl SimplifyCtx for SimplifyCache<'_> {
215    fn return_dtype(&self, expr: &Expression) -> VortexResult<DType> {
216        // If the expression is "root", return the scope dtype
217        if expr.is::<Root>() {
218            return Ok(self.scope.clone());
219        }
220
221        if let Some(dtype) = self.dtype_cache.borrow().get(expr) {
222            return Ok(dtype.clone());
223        }
224
225        // Otherwise, compute dtype from children
226        let input_dtypes: Vec<_> = expr
227            .children()
228            .iter()
229            .map(|c| self.return_dtype(c))
230            .try_collect()?;
231        let dtype = expr.deref().return_dtype(&input_dtypes)?;
232        self.dtype_cache
233            .borrow_mut()
234            .insert(expr.clone(), dtype.clone());
235
236        Ok(dtype)
237    }
238}
239
240struct ExpressionReduceNode {
241    expression: Expression,
242    scope: DType,
243}
244
245impl ReduceNode for ExpressionReduceNode {
246    fn as_any(&self) -> &dyn Any {
247        self
248    }
249
250    fn node_dtype(&self) -> VortexResult<DType> {
251        self.expression.return_dtype(&self.scope)
252    }
253
254    fn scalar_fn(&self) -> Option<&ScalarFn> {
255        Some(self.expression.scalar_fn())
256    }
257
258    fn child(&self, idx: usize) -> ReduceNodeRef {
259        Arc::new(ExpressionReduceNode {
260            expression: self.expression.child(idx).clone(),
261            scope: self.scope.clone(),
262        })
263    }
264
265    fn child_count(&self) -> usize {
266        self.expression.children().len()
267    }
268}
269
270struct ExpressionReduceCtx {
271    scope: DType,
272}
273impl ReduceCtx for ExpressionReduceCtx {
274    fn new_node(
275        &self,
276        scalar_fn: ScalarFn,
277        children: &[ReduceNodeRef],
278    ) -> VortexResult<ReduceNodeRef> {
279        let expression = Expression::try_new(
280            scalar_fn,
281            children
282                .iter()
283                .map(|c| {
284                    c.as_any()
285                        .downcast_ref::<ExpressionReduceNode>()
286                        .vortex_expect("ReduceNode not an ExpressionReduceNode")
287                        .expression
288                        .clone()
289                })
290                .collect::<Vec<_>>(),
291        )?;
292
293        Ok(Arc::new(ExpressionReduceNode {
294            expression,
295            scope: self.scope.clone(),
296        }))
297    }
298}