Skip to main content

vortex_array/expr/
optimize.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::cell::RefCell;
6use std::ops::Deref;
7use std::sync::Arc;
8
9use itertools::Itertools;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_utils::aliases::hash_map::HashMap;
13
14use crate::dtype::DType;
15use crate::expr::Expression;
16use crate::expr::transform::match_between::find_between;
17use crate::scalar_fn::ReduceCtx;
18use crate::scalar_fn::ReduceNode;
19use crate::scalar_fn::ReduceNodeRef;
20use crate::scalar_fn::ScalarFnRef;
21use crate::scalar_fn::SimplifyCtx;
22use crate::scalar_fn::fns::root::Root;
23
24impl Expression {
25    /// Optimize the root expression node only, iterating to convergence.
26    ///
27    /// This applies optimization rules repeatedly until no more changes occur:
28    /// 1. `simplify_untyped` - type-independent simplifications
29    /// 2. `simplify` - type-aware simplifications
30    /// 3. `reduce` - abstract reduction rules via `ReduceNode`/`ReduceCtx`
31    pub fn optimize(&self, scope: &DType) -> VortexResult<Expression> {
32        Ok(self
33            .clone()
34            .try_optimize(scope)?
35            .unwrap_or_else(|| self.clone()))
36    }
37
38    /// Try to optimize the root expression node only, returning None if no optimizations applied.
39    pub fn try_optimize(&self, scope: &DType) -> VortexResult<Option<Expression>> {
40        let cache = SimplifyCache {
41            scope,
42            dtype_cache: RefCell::new(HashMap::new()),
43        };
44        let reduce_ctx = ExpressionReduceCtx {
45            scope: scope.clone(),
46        };
47
48        let mut current = self.clone();
49        let mut any_optimizations = false;
50        let mut loop_counter = 0;
51
52        loop {
53            if loop_counter > 100 {
54                vortex_error::vortex_bail!(
55                    "Exceeded maximum optimization iterations (possible infinite loop)"
56                );
57            }
58            loop_counter += 1;
59
60            let mut changed = false;
61
62            // Try simplify_untyped
63            if let Some(simplified) = current.scalar_fn().simplify_untyped(&current)? {
64                current = simplified;
65                changed = true;
66                any_optimizations = true;
67            }
68
69            // Try simplify (typed)
70            if let Some(simplified) = current.scalar_fn().simplify(&current, &cache)? {
71                current = simplified;
72                changed = true;
73                any_optimizations = true;
74            }
75
76            // Try reduce via ReduceNode/ReduceCtx
77            let reduce_node = ExpressionReduceNode {
78                expression: current.clone(),
79                scope: scope.clone(),
80            };
81            if let Some(reduced) = current.scalar_fn().reduce(&reduce_node, &reduce_ctx)? {
82                let reduced_expr = reduced
83                    .as_any()
84                    .downcast_ref::<ExpressionReduceNode>()
85                    .vortex_expect("ReduceNode not an ExpressionReduceNode")
86                    .expression
87                    .clone();
88                current = reduced_expr;
89                changed = true;
90                any_optimizations = true;
91            }
92
93            if !changed {
94                break;
95            }
96        }
97
98        if any_optimizations {
99            Ok(Some(current))
100        } else {
101            Ok(None)
102        }
103    }
104
105    /// Optimize the entire expression tree recursively.
106    ///
107    /// Optimizes children first (bottom-up), then optimizes the root.
108    pub fn optimize_recursive(&self, scope: &DType) -> VortexResult<Expression> {
109        Ok(self
110            .clone()
111            .try_optimize_recursive(scope)?
112            .unwrap_or_else(|| self.clone()))
113    }
114
115    /// Try to optimize the entire expression tree recursively.
116    pub fn try_optimize_recursive(&self, scope: &DType) -> VortexResult<Option<Expression>> {
117        let mut current = self.clone();
118        let mut any_optimizations = false;
119
120        // First optimize the root
121        if let Some(optimized) = current.clone().try_optimize(scope)? {
122            current = optimized;
123            any_optimizations = true;
124        }
125
126        // Then recursively optimize children
127        let mut new_children = Vec::with_capacity(current.children().len());
128        let mut any_child_optimized = false;
129        for child in current.children().iter() {
130            if let Some(optimized) = child.try_optimize_recursive(scope)? {
131                new_children.push(optimized);
132                any_child_optimized = true;
133            } else {
134                new_children.push(child.clone());
135            }
136        }
137
138        if any_child_optimized {
139            current = current.with_children(new_children)?;
140            any_optimizations = true;
141
142            // After updating children, try to optimize root again
143            if let Some(optimized) = current.clone().try_optimize(scope)? {
144                current = optimized;
145            }
146        }
147
148        // TODO(ngates): remove the "between" optimization, or rewrite it to not always convert
149        //  to CNF?
150        let current = find_between(current);
151
152        if any_optimizations {
153            Ok(Some(current))
154        } else {
155            Ok(None)
156        }
157    }
158
159    /// Simplify the expression, returning a potentially new expression.
160    ///
161    /// Deprecated: Use [`Expression::optimize_recursive`] instead, which iterates to convergence.
162    #[deprecated(note = "Use Expression::optimize_recursive instead")]
163    pub fn simplify(&self, scope: &DType) -> VortexResult<Expression> {
164        self.optimize_recursive(scope)
165    }
166
167    /// Simplify the expression without type information.
168    ///
169    /// Deprecated: Use [`Expression::optimize_recursive`] instead.
170    #[deprecated(note = "Use Expression::optimize_recursive instead")]
171    pub fn simplify_untyped(&self) -> VortexResult<Expression> {
172        // For backwards compat, do a single bottom-up pass of untyped simplification
173        fn inner(expr: &Expression) -> VortexResult<Option<Expression>> {
174            let children: Vec<_> = expr.children().iter().map(inner).try_collect()?;
175
176            if children.iter().any(|c| c.is_some()) {
177                let new_children: Vec<_> = children
178                    .into_iter()
179                    .zip(expr.children().iter())
180                    .map(|(new_c, old_c)| new_c.unwrap_or_else(|| old_c.clone()))
181                    .collect();
182
183                let new_expr = expr.clone().with_children(new_children)?;
184                Ok(Some(
185                    new_expr
186                        .scalar_fn()
187                        .simplify_untyped(&new_expr)?
188                        .unwrap_or(new_expr),
189                ))
190            } else {
191                expr.scalar_fn().simplify_untyped(expr)
192            }
193        }
194
195        let simplified = self
196            .scalar_fn()
197            .simplify_untyped(self)?
198            .unwrap_or_else(|| self.clone());
199
200        let simplified = inner(&simplified)?.unwrap_or(simplified);
201        let simplified = find_between(simplified);
202
203        Ok(simplified)
204    }
205}
206
207struct SimplifyCache<'a> {
208    scope: &'a DType,
209    dtype_cache: RefCell<HashMap<Expression, DType>>,
210}
211
212impl SimplifyCtx for SimplifyCache<'_> {
213    fn return_dtype(&self, expr: &Expression) -> VortexResult<DType> {
214        // If the expression is "root", return the scope dtype
215        if expr.is::<Root>() {
216            return Ok(self.scope.clone());
217        }
218
219        if let Some(dtype) = self.dtype_cache.borrow().get(expr) {
220            return Ok(dtype.clone());
221        }
222
223        // Otherwise, compute dtype from children
224        let input_dtypes: Vec<_> = expr
225            .children()
226            .iter()
227            .map(|c| self.return_dtype(c))
228            .try_collect()?;
229        let dtype = expr.deref().return_dtype(&input_dtypes)?;
230        self.dtype_cache
231            .borrow_mut()
232            .insert(expr.clone(), dtype.clone());
233
234        Ok(dtype)
235    }
236}
237
238struct ExpressionReduceNode {
239    expression: Expression,
240    scope: DType,
241}
242
243impl ReduceNode for ExpressionReduceNode {
244    fn as_any(&self) -> &dyn Any {
245        self
246    }
247
248    fn node_dtype(&self) -> VortexResult<DType> {
249        self.expression.return_dtype(&self.scope)
250    }
251
252    fn scalar_fn(&self) -> Option<&ScalarFnRef> {
253        Some(self.expression.scalar_fn())
254    }
255
256    fn child(&self, idx: usize) -> ReduceNodeRef {
257        Arc::new(ExpressionReduceNode {
258            expression: self.expression.child(idx).clone(),
259            scope: self.scope.clone(),
260        })
261    }
262
263    fn child_count(&self) -> usize {
264        self.expression.children().len()
265    }
266}
267
268struct ExpressionReduceCtx {
269    scope: DType,
270}
271impl ReduceCtx for ExpressionReduceCtx {
272    fn new_node(
273        &self,
274        scalar_fn: ScalarFnRef,
275        children: &[ReduceNodeRef],
276    ) -> VortexResult<ReduceNodeRef> {
277        let expression = Expression::try_new(
278            scalar_fn,
279            children
280                .iter()
281                .map(|c| {
282                    c.as_any()
283                        .downcast_ref::<ExpressionReduceNode>()
284                        .vortex_expect("ReduceNode not an ExpressionReduceNode")
285                        .expression
286                        .clone()
287                })
288                .collect::<Vec<_>>(),
289        )?;
290
291        Ok(Arc::new(ExpressionReduceNode {
292            expression,
293            scope: self.scope.clone(),
294        }))
295    }
296}