vortex_array/expr/
optimize.rs1use std::any::Any;
5use std::cell::RefCell;
6use std::ops::Deref;
7use std::sync::Arc;
8
9use itertools::Itertools;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_utils::aliases::hash_map::HashMap;
13
14use crate::dtype::DType;
15use crate::expr::Expression;
16use crate::expr::transform::match_between::find_between;
17use crate::scalar_fn::ReduceCtx;
18use crate::scalar_fn::ReduceNode;
19use crate::scalar_fn::ReduceNodeRef;
20use crate::scalar_fn::ScalarFnRef;
21use crate::scalar_fn::SimplifyCtx;
22use crate::scalar_fn::fns::root::Root;
23
24impl Expression {
25 pub fn optimize(&self, scope: &DType) -> VortexResult<Expression> {
32 Ok(self
33 .clone()
34 .try_optimize(scope)?
35 .unwrap_or_else(|| self.clone()))
36 }
37
38 pub fn try_optimize(&self, scope: &DType) -> VortexResult<Option<Expression>> {
40 let cache = SimplifyCache {
41 scope,
42 dtype_cache: RefCell::new(HashMap::new()),
43 };
44 let reduce_ctx = ExpressionReduceCtx {
45 scope: scope.clone(),
46 };
47
48 let mut current = self.clone();
49 let mut any_optimizations = false;
50 let mut loop_counter = 0;
51
52 loop {
53 if loop_counter > 100 {
54 vortex_error::vortex_bail!(
55 "Exceeded maximum optimization iterations (possible infinite loop)"
56 );
57 }
58 loop_counter += 1;
59
60 let mut changed = false;
61
62 if let Some(simplified) = current.scalar_fn().simplify_untyped(¤t)? {
64 current = simplified;
65 changed = true;
66 any_optimizations = true;
67 }
68
69 if let Some(simplified) = current.scalar_fn().simplify(¤t, &cache)? {
71 current = simplified;
72 changed = true;
73 any_optimizations = true;
74 }
75
76 let reduce_node = ExpressionReduceNode {
78 expression: current.clone(),
79 scope: scope.clone(),
80 };
81 if let Some(reduced) = current.scalar_fn().reduce(&reduce_node, &reduce_ctx)? {
82 let reduced_expr = reduced
83 .as_any()
84 .downcast_ref::<ExpressionReduceNode>()
85 .vortex_expect("ReduceNode not an ExpressionReduceNode")
86 .expression
87 .clone();
88 current = reduced_expr;
89 changed = true;
90 any_optimizations = true;
91 }
92
93 if !changed {
94 break;
95 }
96 }
97
98 if any_optimizations {
99 Ok(Some(current))
100 } else {
101 Ok(None)
102 }
103 }
104
105 pub fn optimize_recursive(&self, scope: &DType) -> VortexResult<Expression> {
109 Ok(self
110 .clone()
111 .try_optimize_recursive(scope)?
112 .unwrap_or_else(|| self.clone()))
113 }
114
115 pub fn try_optimize_recursive(&self, scope: &DType) -> VortexResult<Option<Expression>> {
117 let mut current = self.clone();
118 let mut any_optimizations = false;
119
120 if let Some(optimized) = current.clone().try_optimize(scope)? {
122 current = optimized;
123 any_optimizations = true;
124 }
125
126 let mut new_children = Vec::with_capacity(current.children().len());
128 let mut any_child_optimized = false;
129 for child in current.children().iter() {
130 if let Some(optimized) = child.try_optimize_recursive(scope)? {
131 new_children.push(optimized);
132 any_child_optimized = true;
133 } else {
134 new_children.push(child.clone());
135 }
136 }
137
138 if any_child_optimized {
139 current = current.with_children(new_children)?;
140 any_optimizations = true;
141
142 if let Some(optimized) = current.clone().try_optimize(scope)? {
144 current = optimized;
145 }
146 }
147
148 let current = find_between(current);
151
152 if any_optimizations {
153 Ok(Some(current))
154 } else {
155 Ok(None)
156 }
157 }
158
159 #[deprecated(note = "Use Expression::optimize_recursive instead")]
163 pub fn simplify(&self, scope: &DType) -> VortexResult<Expression> {
164 self.optimize_recursive(scope)
165 }
166
167 #[deprecated(note = "Use Expression::optimize_recursive instead")]
171 pub fn simplify_untyped(&self) -> VortexResult<Expression> {
172 fn inner(expr: &Expression) -> VortexResult<Option<Expression>> {
174 let children: Vec<_> = expr.children().iter().map(inner).try_collect()?;
175
176 if children.iter().any(|c| c.is_some()) {
177 let new_children: Vec<_> = children
178 .into_iter()
179 .zip(expr.children().iter())
180 .map(|(new_c, old_c)| new_c.unwrap_or_else(|| old_c.clone()))
181 .collect();
182
183 let new_expr = expr.clone().with_children(new_children)?;
184 Ok(Some(
185 new_expr
186 .scalar_fn()
187 .simplify_untyped(&new_expr)?
188 .unwrap_or(new_expr),
189 ))
190 } else {
191 expr.scalar_fn().simplify_untyped(expr)
192 }
193 }
194
195 let simplified = self
196 .scalar_fn()
197 .simplify_untyped(self)?
198 .unwrap_or_else(|| self.clone());
199
200 let simplified = inner(&simplified)?.unwrap_or(simplified);
201 let simplified = find_between(simplified);
202
203 Ok(simplified)
204 }
205}
206
207struct SimplifyCache<'a> {
208 scope: &'a DType,
209 dtype_cache: RefCell<HashMap<Expression, DType>>,
210}
211
212impl SimplifyCtx for SimplifyCache<'_> {
213 fn return_dtype(&self, expr: &Expression) -> VortexResult<DType> {
214 if expr.is::<Root>() {
216 return Ok(self.scope.clone());
217 }
218
219 if let Some(dtype) = self.dtype_cache.borrow().get(expr) {
220 return Ok(dtype.clone());
221 }
222
223 let input_dtypes: Vec<_> = expr
225 .children()
226 .iter()
227 .map(|c| self.return_dtype(c))
228 .try_collect()?;
229 let dtype = expr.deref().return_dtype(&input_dtypes)?;
230 self.dtype_cache
231 .borrow_mut()
232 .insert(expr.clone(), dtype.clone());
233
234 Ok(dtype)
235 }
236}
237
238struct ExpressionReduceNode {
239 expression: Expression,
240 scope: DType,
241}
242
243impl ReduceNode for ExpressionReduceNode {
244 fn as_any(&self) -> &dyn Any {
245 self
246 }
247
248 fn node_dtype(&self) -> VortexResult<DType> {
249 self.expression.return_dtype(&self.scope)
250 }
251
252 fn scalar_fn(&self) -> Option<&ScalarFnRef> {
253 Some(self.expression.scalar_fn())
254 }
255
256 fn child(&self, idx: usize) -> ReduceNodeRef {
257 Arc::new(ExpressionReduceNode {
258 expression: self.expression.child(idx).clone(),
259 scope: self.scope.clone(),
260 })
261 }
262
263 fn child_count(&self) -> usize {
264 self.expression.children().len()
265 }
266}
267
268struct ExpressionReduceCtx {
269 scope: DType,
270}
271impl ReduceCtx for ExpressionReduceCtx {
272 fn new_node(
273 &self,
274 scalar_fn: ScalarFnRef,
275 children: &[ReduceNodeRef],
276 ) -> VortexResult<ReduceNodeRef> {
277 let expression = Expression::try_new(
278 scalar_fn,
279 children
280 .iter()
281 .map(|c| {
282 c.as_any()
283 .downcast_ref::<ExpressionReduceNode>()
284 .vortex_expect("ReduceNode not an ExpressionReduceNode")
285 .expression
286 .clone()
287 })
288 .collect::<Vec<_>>(),
289 )?;
290
291 Ok(Arc::new(ExpressionReduceNode {
292 expression,
293 scope: self.scope.clone(),
294 }))
295 }
296}