vortex_array/expr/
optimize.rs1use std::any::Any;
5use std::cell::RefCell;
6use std::ops::Deref;
7use std::sync::Arc;
8
9use itertools::Itertools;
10use vortex_dtype::DType;
11use vortex_error::VortexExpect;
12use vortex_error::VortexResult;
13use vortex_utils::aliases::hash_map::HashMap;
14
15use crate::expr::Expression;
16use crate::expr::ReduceCtx;
17use crate::expr::ReduceNode;
18use crate::expr::ReduceNodeRef;
19use crate::expr::Root;
20use crate::expr::ScalarFn;
21use crate::expr::SimplifyCtx;
22use crate::expr::transform::match_between::find_between;
23
24impl Expression {
25 pub fn optimize(&self, scope: &DType) -> VortexResult<Expression> {
32 Ok(self
33 .clone()
34 .try_optimize(scope)?
35 .unwrap_or_else(|| self.clone()))
36 }
37
38 pub fn try_optimize(&self, scope: &DType) -> VortexResult<Option<Expression>> {
40 let cache = SimplifyCache {
41 scope,
42 dtype_cache: RefCell::new(HashMap::new()),
43 };
44 let reduce_ctx = ExpressionReduceCtx {
45 scope: scope.clone(),
46 };
47
48 let mut current = self.clone();
49 let mut any_optimizations = false;
50 let mut loop_counter = 0;
51
52 loop {
53 if loop_counter > 100 {
54 vortex_error::vortex_bail!(
55 "Exceeded maximum optimization iterations (possible infinite loop)"
56 );
57 }
58 loop_counter += 1;
59
60 let mut changed = false;
61
62 if let Some(simplified) = current.vtable().as_dyn().simplify_untyped(¤t)? {
64 current = simplified;
65 changed = true;
66 any_optimizations = true;
67 }
68
69 if let Some(simplified) = current.vtable().as_dyn().simplify(¤t, &cache)? {
71 current = simplified;
72 changed = true;
73 any_optimizations = true;
74 }
75
76 let reduce_node = ExpressionReduceNode {
78 expression: current.clone(),
79 scope: scope.clone(),
80 };
81 if let Some(reduced) = current.scalar_fn().reduce(&reduce_node, &reduce_ctx)? {
82 let reduced_expr = reduced
83 .as_any()
84 .downcast_ref::<ExpressionReduceNode>()
85 .vortex_expect("ReduceNode not an ExpressionReduceNode")
86 .expression
87 .clone();
88 current = reduced_expr;
89 changed = true;
90 any_optimizations = true;
91 }
92
93 if !changed {
94 break;
95 }
96 }
97
98 if any_optimizations {
99 Ok(Some(current))
100 } else {
101 Ok(None)
102 }
103 }
104
105 pub fn optimize_recursive(&self, scope: &DType) -> VortexResult<Expression> {
109 Ok(self
110 .clone()
111 .try_optimize_recursive(scope)?
112 .unwrap_or_else(|| self.clone()))
113 }
114
115 pub fn try_optimize_recursive(&self, scope: &DType) -> VortexResult<Option<Expression>> {
117 let mut current = self.clone();
118 let mut any_optimizations = false;
119
120 if let Some(optimized) = current.clone().try_optimize(scope)? {
122 current = optimized;
123 any_optimizations = true;
124 }
125
126 let mut new_children = Vec::with_capacity(current.children().len());
128 let mut any_child_optimized = false;
129 for child in current.children().iter() {
130 if let Some(optimized) = child.try_optimize_recursive(scope)? {
131 new_children.push(optimized);
132 any_child_optimized = true;
133 } else {
134 new_children.push(child.clone());
135 }
136 }
137
138 if any_child_optimized {
139 current = current.with_children(new_children)?;
140 any_optimizations = true;
141
142 if let Some(optimized) = current.clone().try_optimize(scope)? {
144 current = optimized;
145 }
146 }
147
148 let current = find_between(current);
151
152 if any_optimizations {
153 Ok(Some(current))
154 } else {
155 Ok(None)
156 }
157 }
158
159 #[deprecated(note = "Use Expression::optimize_recursive instead")]
163 pub fn simplify(&self, scope: &DType) -> VortexResult<Expression> {
164 self.optimize_recursive(scope)
165 }
166
167 #[deprecated(note = "Use Expression::optimize_recursive instead")]
171 pub fn simplify_untyped(&self) -> VortexResult<Expression> {
172 fn inner(expr: &Expression) -> VortexResult<Option<Expression>> {
174 let children: Vec<_> = expr.children().iter().map(inner).try_collect()?;
175
176 if children.iter().any(|c| c.is_some()) {
177 let new_children: Vec<_> = children
178 .into_iter()
179 .zip(expr.children().iter())
180 .map(|(new_c, old_c)| new_c.unwrap_or_else(|| old_c.clone()))
181 .collect();
182
183 let new_expr = expr.clone().with_children(new_children)?;
184 Ok(Some(
185 new_expr
186 .vtable()
187 .as_dyn()
188 .simplify_untyped(&new_expr)?
189 .unwrap_or(new_expr),
190 ))
191 } else {
192 expr.vtable().as_dyn().simplify_untyped(expr)
193 }
194 }
195
196 let simplified = self
197 .vtable()
198 .as_dyn()
199 .simplify_untyped(self)?
200 .unwrap_or_else(|| self.clone());
201
202 let simplified = inner(&simplified)?.unwrap_or(simplified);
203 let simplified = find_between(simplified);
204
205 Ok(simplified)
206 }
207}
208
209struct SimplifyCache<'a> {
210 scope: &'a DType,
211 dtype_cache: RefCell<HashMap<Expression, DType>>,
212}
213
214impl SimplifyCtx for SimplifyCache<'_> {
215 fn return_dtype(&self, expr: &Expression) -> VortexResult<DType> {
216 if expr.is::<Root>() {
218 return Ok(self.scope.clone());
219 }
220
221 if let Some(dtype) = self.dtype_cache.borrow().get(expr) {
222 return Ok(dtype.clone());
223 }
224
225 let input_dtypes: Vec<_> = expr
227 .children()
228 .iter()
229 .map(|c| self.return_dtype(c))
230 .try_collect()?;
231 let dtype = expr.deref().return_dtype(&input_dtypes)?;
232 self.dtype_cache
233 .borrow_mut()
234 .insert(expr.clone(), dtype.clone());
235
236 Ok(dtype)
237 }
238}
239
240struct ExpressionReduceNode {
241 expression: Expression,
242 scope: DType,
243}
244
245impl ReduceNode for ExpressionReduceNode {
246 fn as_any(&self) -> &dyn Any {
247 self
248 }
249
250 fn node_dtype(&self) -> VortexResult<DType> {
251 self.expression.return_dtype(&self.scope)
252 }
253
254 fn scalar_fn(&self) -> Option<&ScalarFn> {
255 Some(self.expression.scalar_fn())
256 }
257
258 fn child(&self, idx: usize) -> ReduceNodeRef {
259 Arc::new(ExpressionReduceNode {
260 expression: self.expression.child(idx).clone(),
261 scope: self.scope.clone(),
262 })
263 }
264
265 fn child_count(&self) -> usize {
266 self.expression.children().len()
267 }
268}
269
270struct ExpressionReduceCtx {
271 scope: DType,
272}
273impl ReduceCtx for ExpressionReduceCtx {
274 fn new_node(
275 &self,
276 scalar_fn: ScalarFn,
277 children: &[ReduceNodeRef],
278 ) -> VortexResult<ReduceNodeRef> {
279 let expression = Expression::try_new(
280 scalar_fn,
281 children
282 .iter()
283 .map(|c| {
284 c.as_any()
285 .downcast_ref::<ExpressionReduceNode>()
286 .vortex_expect("ReduceNode not an ExpressionReduceNode")
287 .expression
288 .clone()
289 })
290 .collect::<Vec<_>>(),
291 )?;
292
293 Ok(Arc::new(ExpressionReduceNode {
294 expression,
295 scope: self.scope.clone(),
296 }))
297 }
298}