slac/
optimizer.rs

1//! Transformation routines to optimize an [`Expression`] AST.
2
3use crate::environment::{Environment, FunctionResult};
4use crate::{Expression, Operator, Result, execute};
5
6use crate::stdlib::common::TERNARY_IF_THEN;
7
8/// Recursivly transforms ternary function calls into [`Expression::Ternary`].
9/// Three parameter [`crate::stdlib::common::if_then`] calls are transformed
10/// into a [`Operator::TernaryCondition`];
11///
12/// # Remarks
13///
14/// While the [`crate::stdlib::common::if_then`] is eagerly evaluated, the
15/// [`Expression::Ternary`] supports short-circuit evaluation in the `TreeWalkingInterpreter`.
16pub fn transform_ternary(expression: &mut Expression, found_const: &mut bool) {
17    match expression {
18        Expression::Unary { right, operator: _ } => {
19            transform_ternary(right, found_const);
20        }
21        Expression::Binary {
22            left,
23            right,
24            operator: _,
25        } => {
26            transform_ternary(left, found_const);
27            transform_ternary(right, found_const);
28        }
29        Expression::Ternary {
30            left,
31            middle,
32            right,
33            operator: _,
34        } => {
35            transform_ternary(left, found_const);
36            transform_ternary(middle, found_const);
37            transform_ternary(right, found_const);
38        }
39        Expression::Array { expressions } => {
40            for expr in expressions {
41                transform_ternary(expr, found_const);
42            }
43        }
44        Expression::Call { name, params } if (name == TERNARY_IF_THEN) => {
45            if let [left, middle, right] = params.as_slice() {
46                *found_const = true;
47                *expression = Expression::Ternary {
48                    left: Box::new(left.clone()),
49                    middle: Box::new(middle.clone()),
50                    right: Box::new(right.clone()),
51                    operator: Operator::TernaryCondition,
52                }
53            } else {
54                for expr in params {
55                    transform_ternary(expr, found_const);
56                }
57            }
58        }
59        Expression::Call { name: _, params } => {
60            for expr in params {
61                transform_ternary(expr, found_const);
62            }
63        }
64        _ => (),
65    }
66}
67
68fn expressions_are_const(expressions: &[Expression]) -> bool {
69    expressions
70        .iter()
71        .all(|e| matches!(e, Expression::Literal { value: _ }))
72}
73
74/// Evaluates [`Expression::Unary`], [`Expression::Binary`] [`Expression::Array`] into a single
75/// [`Expression::Literal`] if all arguments are also an [`Expression::Literal`].
76///
77/// Evaluates [`Operator::TernaryCondition`] [`Expression::Ternary`] into either
78/// the second or third argument, if the first argument is a [`Expression::Literal`].
79///
80/// Evaluates [`Expression::Call`] into a single [`Expression::Literal`] if all parameters
81/// are [`Expression::Literal`] and the function is a pure function.
82///
83/// # Errors
84///
85/// Will return [`crate::Error`] if constant evaluation is not possible.
86pub fn fold_constants(
87    env: &impl Environment,
88    expression: &mut Expression,
89    found_const: &mut bool,
90) -> Result<()> {
91    match expression {
92        Expression::Unary { right, operator: _ } => match right.as_ref() {
93            Expression::Literal { value: _ } => {
94                *found_const = true;
95                *expression = Expression::Literal {
96                    value: execute(env, expression)?,
97                }
98            }
99            _ => fold_constants(env, right, found_const)?,
100        },
101        Expression::Binary {
102            left,
103            right,
104            operator: _,
105        } => {
106            if let (Expression::Literal { value: _ }, Expression::Literal { value: _ }) =
107                (left.as_ref(), right.as_ref())
108            {
109                *found_const = true;
110                *expression = Expression::Literal {
111                    value: execute(env, expression)?,
112                };
113            } else {
114                fold_constants(env, left, found_const)?;
115                fold_constants(env, right, found_const)?;
116            }
117        }
118        Expression::Ternary {
119            left,
120            middle,
121            right,
122            operator,
123        } => {
124            if let (Expression::Literal { value: left }, Operator::TernaryCondition) =
125                (left.as_ref(), operator)
126            {
127                *found_const = true;
128                if left.as_bool() {
129                    *expression = *middle.clone();
130                } else {
131                    *expression = *right.clone();
132                }
133            } else {
134                fold_constants(env, left, found_const)?;
135                fold_constants(env, middle, found_const)?;
136                fold_constants(env, right, found_const)?;
137            }
138        }
139        Expression::Array { expressions } if expressions_are_const(expressions) => {
140            *found_const = true;
141            *expression = Expression::Literal {
142                value: execute(env, expression)?,
143            };
144        }
145        Expression::Array { expressions } => {
146            for expr in expressions {
147                fold_constants(env, expr, found_const)?;
148            }
149        }
150
151        Expression::Call { name, params } if expressions_are_const(params) => {
152            match env.function_exists(name, params.len()) {
153                // only inline pure functions
154                FunctionResult::Exists { pure } if pure => {
155                    *found_const = true;
156                    *expression = Expression::Literal {
157                        value: execute(env, expression)?,
158                    };
159                }
160                _ => (),
161            }
162        }
163        Expression::Call { name: _, params } => {
164            for expr in params {
165                fold_constants(env, expr, found_const)?;
166            }
167        }
168        _ => (),
169    };
170
171    Ok(())
172}
173
174/// Transforms an [`Expression`] tree by applying [`transform_ternary`] and
175/// [`fold_constants`] in a loop until no further optimization is possible.
176///
177/// # Errors
178///
179/// Will return [`crate::Error`] if constant evaluation is not possible.
180pub fn optimize(env: &impl Environment, expression: &mut Expression) -> Result<()> {
181    let mut found_const = false;
182
183    loop {
184        transform_ternary(expression, &mut found_const);
185        fold_constants(env, expression, &mut found_const)?;
186
187        if found_const {
188            found_const = false; // repeat until no further folding is possible
189        } else {
190            return Ok(());
191        }
192    }
193}
194
195#[cfg(test)]
196mod test {
197
198    use super::{optimize, transform_ternary};
199    use crate::stdlib::common::TERNARY_IF_THEN;
200    use crate::stdlib::extend_environment;
201    use crate::{Expression, Operator, StaticEnvironment, Value};
202
203    #[test]
204    fn ternary_flat() {
205        let mut expr = Expression::Call {
206            name: String::from(TERNARY_IF_THEN),
207            params: vec![
208                Expression::Literal {
209                    value: Value::Boolean(true),
210                },
211                Expression::Literal {
212                    value: Value::Number(1.0),
213                },
214                Expression::Literal {
215                    value: Value::Number(2.0),
216                },
217            ],
218        };
219
220        let ternary = Expression::Ternary {
221            left: Box::new(Expression::Literal {
222                value: Value::Boolean(true),
223            }),
224            middle: Box::new(Expression::Literal {
225                value: Value::Number(1.0),
226            }),
227            right: Box::new(Expression::Literal {
228                value: Value::Number(2.0),
229            }),
230            operator: Operator::TernaryCondition,
231        };
232
233        transform_ternary(&mut expr, &mut false);
234
235        assert_eq!(ternary, expr);
236    }
237
238    #[test]
239    fn ternary_nested() {
240        let mut expr = Expression::Unary {
241            right: Box::new(Expression::Call {
242                name: String::from(TERNARY_IF_THEN),
243                params: vec![
244                    Expression::Literal {
245                        value: Value::Boolean(true),
246                    },
247                    Expression::Literal {
248                        value: Value::Number(1.0),
249                    },
250                    Expression::Literal {
251                        value: Value::Number(2.0),
252                    },
253                ],
254            }),
255            operator: Operator::Minus,
256        };
257
258        let ternary = Expression::Unary {
259            right: Box::new(Expression::Ternary {
260                left: Box::new(Expression::Literal {
261                    value: Value::Boolean(true),
262                }),
263                middle: Box::new(Expression::Literal {
264                    value: Value::Number(1.0),
265                }),
266                right: Box::new(Expression::Literal {
267                    value: Value::Number(2.0),
268                }),
269                operator: Operator::TernaryCondition,
270            }),
271            operator: Operator::Minus,
272        };
273
274        transform_ternary(&mut expr, &mut false);
275
276        assert_eq!(ternary, expr);
277    }
278
279    #[test]
280    fn fold_const_flat_binary() {
281        let mut expr = Expression::Binary {
282            left: Box::new(Expression::Literal {
283                value: Value::Number(10.0),
284            }),
285            right: Box::new(Expression::Literal {
286                value: Value::Number(5.0),
287            }),
288            operator: Operator::Plus,
289        };
290
291        let value = Expression::Literal {
292            value: Value::Number(15.0),
293        };
294
295        optimize(&StaticEnvironment::default(), &mut expr).unwrap();
296        assert_eq!(value, expr);
297    }
298
299    #[test]
300    fn fold_const_flat_unary() {
301        let mut expr = Expression::Unary {
302            right: Box::new(Expression::Literal {
303                value: Value::Number(5.0),
304            }),
305            operator: Operator::Minus,
306        };
307
308        let value = Expression::Literal {
309            value: Value::Number(-5.0),
310        };
311        optimize(&StaticEnvironment::default(), &mut expr).unwrap();
312        assert_eq!(value, expr);
313
314        let mut expr = Expression::Unary {
315            right: Box::new(Expression::Unary {
316                right: Box::new(Expression::Literal {
317                    value: Value::Number(5.0),
318                }),
319                operator: Operator::Minus,
320            }),
321            operator: Operator::Minus,
322        };
323
324        let value = Expression::Literal {
325            value: Value::Number(5.0),
326        };
327
328        optimize(&StaticEnvironment::default(), &mut expr).unwrap();
329        assert_eq!(value, expr);
330    }
331
332    #[test]
333    fn fold_const_ternary() {
334        let mut expr = Expression::Ternary {
335            left: Box::new(Expression::Literal {
336                value: Value::Boolean(true),
337            }),
338            middle: Box::new(Expression::Literal {
339                value: Value::Number(1.0),
340            }),
341            right: Box::new(Expression::Literal {
342                value: Value::Number(2.0),
343            }),
344            operator: Operator::TernaryCondition,
345        };
346
347        let value = Expression::Literal {
348            value: Value::Number(1.0),
349        };
350
351        optimize(&StaticEnvironment::default(), &mut expr).unwrap();
352        assert_eq!(value, expr);
353    }
354
355    #[test]
356    fn fold_vectors() {
357        let mut expr = Expression::Array {
358            expressions: vec![Expression::Unary {
359                right: Box::new(Expression::Call {
360                    name: String::from(TERNARY_IF_THEN),
361                    params: vec![
362                        Expression::Literal {
363                            value: Value::Boolean(true),
364                        },
365                        Expression::Call {
366                            name: String::from(TERNARY_IF_THEN),
367                            params: vec![
368                                Expression::Literal {
369                                    value: Value::Boolean(true),
370                                },
371                                Expression::Literal {
372                                    value: Value::Number(3.0),
373                                },
374                            ],
375                        },
376                        Expression::Literal {
377                            value: Value::Number(2.0),
378                        },
379                    ],
380                }),
381                operator: Operator::Minus,
382            }],
383        };
384
385        let value = Expression::Array {
386            expressions: vec![Expression::Unary {
387                right: Box::new(Expression::Call {
388                    name: String::from(TERNARY_IF_THEN),
389                    params: vec![
390                        Expression::Literal {
391                            value: Value::Boolean(true),
392                        },
393                        Expression::Literal {
394                            value: Value::Number(3.0),
395                        },
396                    ],
397                }),
398                operator: Operator::Minus,
399            }],
400        };
401        optimize(&StaticEnvironment::default(), &mut expr).unwrap();
402
403        assert_eq!(value, expr);
404    }
405
406    #[test]
407    fn fold_array() {
408        let mut expr = Expression::Array {
409            expressions: vec![
410                Expression::Literal {
411                    value: Value::Boolean(true),
412                },
413                Expression::Literal {
414                    value: Value::Boolean(false),
415                },
416            ],
417        };
418
419        let value = Expression::Literal {
420            value: Value::Array(vec![Value::Boolean(true), Value::Boolean(false)]),
421        };
422
423        optimize(&StaticEnvironment::default(), &mut expr).unwrap();
424
425        assert_eq!(value, expr);
426    }
427
428    #[test]
429    fn fold_pure_function() {
430        let mut expr = Expression::Call {
431            name: String::from("max"),
432            params: vec![
433                Expression::Literal {
434                    value: Value::Number(10.0),
435                },
436                Expression::Literal {
437                    value: Value::Number(20.0),
438                },
439            ],
440        };
441
442        let value = Expression::Literal {
443            value: Value::Number(20.0),
444        };
445
446        let mut env = StaticEnvironment::default();
447        extend_environment(&mut env);
448
449        optimize(&env, &mut expr).unwrap();
450
451        assert_eq!(value, expr);
452    }
453}