Skip to main content

vortex_array/expr/
optimize.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::cell::RefCell;
6use std::ops::Deref;
7use std::sync::Arc;
8
9use itertools::Itertools;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_utils::aliases::hash_map::HashMap;
13
14use crate::dtype::DType;
15use crate::expr::Expression;
16use crate::expr::transform::match_between::find_between;
17use crate::scalar_fn::ReduceCtx;
18use crate::scalar_fn::ReduceNode;
19use crate::scalar_fn::ReduceNodeRef;
20use crate::scalar_fn::ScalarFnRef;
21use crate::scalar_fn::SimplifyCtx;
22use crate::scalar_fn::fns::root::Root;
23
24impl Expression {
25    /// Optimize the root expression node only, iterating to convergence.
26    ///
27    /// This applies optimization rules repeatedly until no more changes occur:
28    /// 1. `simplify_untyped` - type-independent simplifications
29    /// 2. `simplify` - type-aware simplifications
30    /// 3. `reduce` - abstract reduction rules via `ReduceNode`/`ReduceCtx`
31    pub fn optimize(&self, scope: &DType) -> VortexResult<Expression> {
32        let cache = SimplifyCache {
33            scope,
34            dtype_cache: RefCell::new(HashMap::new()),
35        };
36        Ok(self
37            .clone()
38            .try_optimize(scope, &cache)?
39            .unwrap_or_else(|| self.clone()))
40    }
41
42    /// Try to optimize the root expression node only, returning None if no optimizations applied.
43    fn try_optimize(
44        &self,
45        scope: &DType,
46        cache: &SimplifyCache<'_>,
47    ) -> VortexResult<Option<Expression>> {
48        let reduce_ctx = ExpressionReduceCtx {
49            scope: scope.clone(),
50        };
51
52        let mut current = self.clone();
53        let mut any_optimizations = false;
54        let mut loop_counter = 0;
55
56        loop {
57            if loop_counter > 100 {
58                vortex_error::vortex_bail!(
59                    "Exceeded maximum optimization iterations (possible infinite loop)"
60                );
61            }
62            loop_counter += 1;
63
64            let mut changed = false;
65
66            // Try simplify_untyped
67            if let Some(simplified) = current.scalar_fn().simplify_untyped(&current)? {
68                current = simplified;
69                changed = true;
70                any_optimizations = true;
71            }
72
73            // Try simplify (typed)
74            if let Some(simplified) = current.scalar_fn().simplify(&current, cache)? {
75                current = simplified;
76                changed = true;
77                any_optimizations = true;
78            }
79
80            // Try reduce via ReduceNode/ReduceCtx
81            let reduce_node = ExpressionReduceNode {
82                expression: current.clone(),
83                scope: scope.clone(),
84            };
85            if let Some(reduced) = current.scalar_fn().reduce(&reduce_node, &reduce_ctx)? {
86                let reduced_expr = reduced
87                    .as_any()
88                    .downcast_ref::<ExpressionReduceNode>()
89                    .vortex_expect("ReduceNode not an ExpressionReduceNode")
90                    .expression
91                    .clone();
92                current = reduced_expr;
93                changed = true;
94                any_optimizations = true;
95            }
96
97            if !changed {
98                break;
99            }
100        }
101
102        if any_optimizations {
103            Ok(Some(current))
104        } else {
105            Ok(None)
106        }
107    }
108
109    /// Optimize the entire expression tree recursively.
110    ///
111    /// Optimizes children first (bottom-up), then optimizes the root.
112    pub fn optimize_recursive(&self, scope: &DType) -> VortexResult<Expression> {
113        Ok(self
114            .clone()
115            .try_optimize_recursive(scope)?
116            .unwrap_or_else(|| self.clone()))
117    }
118
119    /// Try to optimize the entire expression tree recursively.
120    pub fn try_optimize_recursive(&self, scope: &DType) -> VortexResult<Option<Expression>> {
121        let cache = SimplifyCache {
122            scope,
123            dtype_cache: RefCell::new(HashMap::new()),
124        };
125        let result = self.try_optimize_recursive_inner(scope, &cache)?;
126
127        // Apply the between optimization once at the top level only.
128        // TODO(ngates): remove the "between" optimization, or rewrite it to not always convert
129        //  to CNF?
130        Ok(Some(find_between(result.unwrap_or_else(|| self.clone()))))
131    }
132
133    fn try_optimize_recursive_inner(
134        &self,
135        scope: &DType,
136        cache: &SimplifyCache<'_>,
137    ) -> VortexResult<Option<Expression>> {
138        let mut current = self.clone();
139        let mut any_optimizations = false;
140
141        // First optimize the root
142        if let Some(optimized) = current.clone().try_optimize(scope, cache)? {
143            current = optimized;
144            any_optimizations = true;
145        }
146
147        // Then recursively optimize children
148        let mut new_children = Vec::with_capacity(current.children().len());
149        let mut any_child_optimized = false;
150        for child in current.children().iter() {
151            if let Some(optimized) = child.try_optimize_recursive_inner(scope, cache)? {
152                new_children.push(optimized);
153                any_child_optimized = true;
154            } else {
155                new_children.push(child.clone());
156            }
157        }
158
159        if any_child_optimized {
160            current = current.with_children(new_children)?;
161            any_optimizations = true;
162
163            // After updating children, try to optimize root again
164            if let Some(optimized) = current.clone().try_optimize(scope, cache)? {
165                current = optimized;
166            }
167        }
168
169        if any_optimizations {
170            Ok(Some(current))
171        } else {
172            Ok(None)
173        }
174    }
175
176    /// Simplify the expression, returning a potentially new expression.
177    ///
178    /// Deprecated: Use [`Expression::optimize_recursive`] instead, which iterates to convergence.
179    #[deprecated(note = "Use Expression::optimize_recursive instead")]
180    pub fn simplify(&self, scope: &DType) -> VortexResult<Expression> {
181        self.optimize_recursive(scope)
182    }
183
184    /// Simplify the expression without type information.
185    ///
186    /// Deprecated: Use [`Expression::optimize_recursive`] instead.
187    #[deprecated(note = "Use Expression::optimize_recursive instead")]
188    pub fn simplify_untyped(&self) -> VortexResult<Expression> {
189        // For backwards compat, do a single bottom-up pass of untyped simplification
190        fn inner(expr: &Expression) -> VortexResult<Option<Expression>> {
191            let children: Vec<_> = expr.children().iter().map(inner).try_collect()?;
192
193            if children.iter().any(|c| c.is_some()) {
194                let new_children: Vec<_> = children
195                    .into_iter()
196                    .zip(expr.children().iter())
197                    .map(|(new_c, old_c)| new_c.unwrap_or_else(|| old_c.clone()))
198                    .collect();
199
200                let new_expr = expr.clone().with_children(new_children)?;
201                Ok(Some(
202                    new_expr
203                        .scalar_fn()
204                        .simplify_untyped(&new_expr)?
205                        .unwrap_or(new_expr),
206                ))
207            } else {
208                expr.scalar_fn().simplify_untyped(expr)
209            }
210        }
211
212        let simplified = self
213            .scalar_fn()
214            .simplify_untyped(self)?
215            .unwrap_or_else(|| self.clone());
216
217        let simplified = inner(&simplified)?.unwrap_or(simplified);
218        let simplified = find_between(simplified);
219
220        Ok(simplified)
221    }
222}
223
224struct SimplifyCache<'a> {
225    scope: &'a DType,
226    dtype_cache: RefCell<HashMap<Expression, DType>>,
227}
228
229impl SimplifyCtx for SimplifyCache<'_> {
230    fn return_dtype(&self, expr: &Expression) -> VortexResult<DType> {
231        // If the expression is "root", return the scope dtype
232        if expr.is::<Root>() {
233            return Ok(self.scope.clone());
234        }
235
236        if let Some(dtype) = self.dtype_cache.borrow().get(expr) {
237            return Ok(dtype.clone());
238        }
239
240        // Otherwise, compute dtype from children
241        let input_dtypes: Vec<_> = expr
242            .children()
243            .iter()
244            .map(|c| self.return_dtype(c))
245            .try_collect()?;
246        let dtype = expr.deref().return_dtype(&input_dtypes)?;
247        self.dtype_cache
248            .borrow_mut()
249            .insert(expr.clone(), dtype.clone());
250
251        Ok(dtype)
252    }
253}
254
255struct ExpressionReduceNode {
256    expression: Expression,
257    scope: DType,
258}
259
260impl ReduceNode for ExpressionReduceNode {
261    fn as_any(&self) -> &dyn Any {
262        self
263    }
264
265    fn node_dtype(&self) -> VortexResult<DType> {
266        self.expression.return_dtype(&self.scope)
267    }
268
269    fn scalar_fn(&self) -> Option<&ScalarFnRef> {
270        Some(self.expression.scalar_fn())
271    }
272
273    fn child(&self, idx: usize) -> ReduceNodeRef {
274        Arc::new(ExpressionReduceNode {
275            expression: self.expression.child(idx).clone(),
276            scope: self.scope.clone(),
277        })
278    }
279
280    fn child_count(&self) -> usize {
281        self.expression.children().len()
282    }
283}
284
285struct ExpressionReduceCtx {
286    scope: DType,
287}
288impl ReduceCtx for ExpressionReduceCtx {
289    fn new_node(
290        &self,
291        scalar_fn: ScalarFnRef,
292        children: &[ReduceNodeRef],
293    ) -> VortexResult<ReduceNodeRef> {
294        let expression = Expression::try_new(
295            scalar_fn,
296            children
297                .iter()
298                .map(|c| {
299                    c.as_any()
300                        .downcast_ref::<ExpressionReduceNode>()
301                        .vortex_expect("ReduceNode not an ExpressionReduceNode")
302                        .expression
303                        .clone()
304                })
305                .collect::<Vec<_>>(),
306        )?;
307
308        Ok(Arc::new(ExpressionReduceNode {
309            expression,
310            scope: self.scope.clone(),
311        }))
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use vortex_error::VortexResult;
318
319    use crate::dtype::DType;
320    use crate::dtype::Nullability;
321    use crate::dtype::PType;
322    use crate::dtype::StructFields;
323    use crate::expr::eq;
324    use crate::expr::get_item;
325    use crate::expr::lit;
326    use crate::expr::or;
327    use crate::expr::root;
328
329    #[test]
330    fn optimize_or_chain_correctness() -> VortexResult<()> {
331        let expr = or(
332            eq(get_item("x", root()), lit(1i32)),
333            eq(get_item("x", root()), lit(2i32)),
334        );
335        let scope = DType::Struct(
336            StructFields::new(
337                ["x"].into(),
338                vec![DType::Primitive(PType::I32, Nullability::NonNullable)],
339            ),
340            Nullability::NonNullable,
341        );
342        let optimized = expr.optimize_recursive(&scope)?;
343
344        let s = optimized.to_string();
345        assert!(s.contains("$.x"), "expected $.x in {s}");
346        assert!(s.contains("1i32") || s.contains('1'), "expected 1 in {s}");
347        assert!(s.contains("2i32") || s.contains('2'), "expected 2 in {s}");
348        Ok(())
349    }
350}