Skip to main content

sqlglot_rust/optimizer/
mod.rs

1/// Query optimization passes.
2///
3/// Inspired by Python sqlglot's optimizer module. Currently implements:
4/// - Constant folding (e.g., `1 + 2` → `3`)
5/// - Boolean simplification (e.g., `TRUE AND x` → `x`)
6/// - Dead predicate elimination (e.g., `WHERE TRUE`)
7///
8/// Future optimizations:
9/// - Predicate pushdown
10/// - Join reordering
11/// - Column pruning
12/// - Subquery unnesting
13
14use crate::ast::*;
15use crate::errors::Result;
16
17/// Optimize a SQL statement by applying transformation passes.
18pub fn optimize(statement: Statement) -> Result<Statement> {
19    let mut stmt = statement;
20    stmt = fold_constants(stmt);
21    stmt = simplify_booleans(stmt);
22    Ok(stmt)
23}
24
25/// Fold constant expressions (e.g., `1 + 2` → `3`).
26fn fold_constants(statement: Statement) -> Statement {
27    match statement {
28        Statement::Select(mut sel) => {
29            if let Some(wh) = sel.where_clause {
30                sel.where_clause = Some(fold_expr(wh));
31            }
32            if let Some(having) = sel.having {
33                sel.having = Some(fold_expr(having));
34            }
35            for item in &mut sel.columns {
36                if let SelectItem::Expr { expr, .. } = item {
37                    *expr = fold_expr(expr.clone());
38                }
39            }
40            Statement::Select(sel)
41        }
42        other => other,
43    }
44}
45
46fn fold_expr(expr: Expr) -> Expr {
47    match expr {
48        Expr::BinaryOp { left, op, right } => {
49            let left = fold_expr(*left);
50            let right = fold_expr(*right);
51
52            // Try numeric folding
53            if let (Expr::Number(l), Expr::Number(r)) = (&left, &right) {
54                if let (Ok(lv), Ok(rv)) = (l.parse::<f64>(), r.parse::<f64>()) {
55                    let result = match op {
56                        BinaryOperator::Plus => Some(lv + rv),
57                        BinaryOperator::Minus => Some(lv - rv),
58                        BinaryOperator::Multiply => Some(lv * rv),
59                        BinaryOperator::Divide if rv != 0.0 => Some(lv / rv),
60                        BinaryOperator::Modulo if rv != 0.0 => Some(lv % rv),
61                        _ => None,
62                    };
63                    if let Some(val) = result {
64                        // Emit integer if it's a whole number
65                        if val.fract() == 0.0 && val.abs() < i64::MAX as f64 {
66                            return Expr::Number(format!("{}", val as i64));
67                        }
68                        return Expr::Number(format!("{val}"));
69                    }
70
71                    // Try boolean folding for comparison
72                    let cmp = match op {
73                        BinaryOperator::Eq => Some(lv == rv),
74                        BinaryOperator::Neq => Some(lv != rv),
75                        BinaryOperator::Lt => Some(lv < rv),
76                        BinaryOperator::Gt => Some(lv > rv),
77                        BinaryOperator::LtEq => Some(lv <= rv),
78                        BinaryOperator::GtEq => Some(lv >= rv),
79                        _ => None,
80                    };
81                    if let Some(val) = cmp {
82                        return Expr::Boolean(val);
83                    }
84                }
85            }
86
87            // String concatenation folding
88            if matches!(op, BinaryOperator::Concat) {
89                if let (Expr::StringLiteral(l), Expr::StringLiteral(r)) = (&left, &right) {
90                    return Expr::StringLiteral(format!("{l}{r}"));
91                }
92            }
93
94            Expr::BinaryOp {
95                left: Box::new(left),
96                op,
97                right: Box::new(right),
98            }
99        }
100        Expr::UnaryOp { op: UnaryOperator::Minus, expr } => {
101            let inner = fold_expr(*expr);
102            if let Expr::Number(ref n) = inner {
103                if let Ok(v) = n.parse::<f64>() {
104                    let neg = -v;
105                    if neg.fract() == 0.0 && neg.abs() < i64::MAX as f64 {
106                        return Expr::Number(format!("{}", neg as i64));
107                    }
108                    return Expr::Number(format!("{neg}"));
109                }
110            }
111            Expr::UnaryOp {
112                op: UnaryOperator::Minus,
113                expr: Box::new(inner),
114            }
115        }
116        Expr::Nested(inner) => {
117            let folded = fold_expr(*inner);
118            if folded.is_literal() {
119                folded
120            } else {
121                Expr::Nested(Box::new(folded))
122            }
123        }
124        other => other,
125    }
126}
127
128/// Simplify boolean expressions.
129fn simplify_booleans(statement: Statement) -> Statement {
130    match statement {
131        Statement::Select(mut sel) => {
132            // Simplify boolean expressions in SELECT columns
133            for item in &mut sel.columns {
134                if let SelectItem::Expr { expr, .. } = item {
135                    *expr = simplify_bool_expr(expr.clone());
136                }
137            }
138            if let Some(wh) = sel.where_clause {
139                let simplified = simplify_bool_expr(wh);
140                // WHERE TRUE → no WHERE clause
141                if simplified == Expr::Boolean(true) {
142                    sel.where_clause = None;
143                } else {
144                    sel.where_clause = Some(simplified);
145                }
146            }
147            if let Some(having) = sel.having {
148                let simplified = simplify_bool_expr(having);
149                if simplified == Expr::Boolean(true) {
150                    sel.having = None;
151                } else {
152                    sel.having = Some(simplified);
153                }
154            }
155            Statement::Select(sel)
156        }
157        other => other,
158    }
159}
160
161fn simplify_bool_expr(expr: Expr) -> Expr {
162    match expr {
163        Expr::BinaryOp { left, op: BinaryOperator::And, right } => {
164            let left = simplify_bool_expr(*left);
165            let right = simplify_bool_expr(*right);
166            match (&left, &right) {
167                (Expr::Boolean(true), _) => right,
168                (_, Expr::Boolean(true)) => left,
169                (Expr::Boolean(false), _) | (_, Expr::Boolean(false)) => Expr::Boolean(false),
170                _ => Expr::BinaryOp {
171                    left: Box::new(left),
172                    op: BinaryOperator::And,
173                    right: Box::new(right),
174                },
175            }
176        }
177        Expr::BinaryOp { left, op: BinaryOperator::Or, right } => {
178            let left = simplify_bool_expr(*left);
179            let right = simplify_bool_expr(*right);
180            match (&left, &right) {
181                (Expr::Boolean(true), _) | (_, Expr::Boolean(true)) => Expr::Boolean(true),
182                (Expr::Boolean(false), _) => right,
183                (_, Expr::Boolean(false)) => left,
184                _ => Expr::BinaryOp {
185                    left: Box::new(left),
186                    op: BinaryOperator::Or,
187                    right: Box::new(right),
188                },
189            }
190        }
191        Expr::UnaryOp { op: UnaryOperator::Not, expr } => {
192            let inner = simplify_bool_expr(*expr);
193            match inner {
194                Expr::Boolean(b) => Expr::Boolean(!b),
195                Expr::UnaryOp { op: UnaryOperator::Not, expr: inner2 } => *inner2,
196                other => Expr::UnaryOp {
197                    op: UnaryOperator::Not,
198                    expr: Box::new(other),
199                },
200            }
201        }
202        Expr::Nested(inner) => {
203            let simplified = simplify_bool_expr(*inner);
204            if simplified.is_literal() {
205                simplified
206            } else {
207                Expr::Nested(Box::new(simplified))
208            }
209        }
210        other => other,
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use crate::parser::Parser;
218
219    fn optimize_sql(sql: &str) -> Statement {
220        let stmt = Parser::new(sql).unwrap().parse_statement().unwrap();
221        optimize(stmt).unwrap()
222    }
223
224    #[test]
225    fn test_constant_folding() {
226        let stmt = optimize_sql("SELECT 1 + 2 FROM t");
227        if let Statement::Select(sel) = stmt {
228            if let SelectItem::Expr { expr, .. } = &sel.columns[0] {
229                assert_eq!(*expr, Expr::Number("3".to_string()));
230            }
231        }
232    }
233
234    #[test]
235    fn test_boolean_simplification_where_true() {
236        let stmt = optimize_sql("SELECT x FROM t WHERE TRUE");
237        if let Statement::Select(sel) = stmt {
238            assert!(sel.where_clause.is_none());
239        }
240    }
241
242    #[test]
243    fn test_boolean_simplification_and_true() {
244        let stmt = optimize_sql("SELECT x FROM t WHERE TRUE AND x > 1");
245        if let Statement::Select(sel) = stmt {
246            // Should simplify to just x > 1
247            assert!(sel.where_clause.is_some());
248            assert!(!matches!(&sel.where_clause, Some(Expr::BinaryOp { op: BinaryOperator::And, .. })));
249        }
250    }
251
252    #[test]
253    fn test_double_negation() {
254        let stmt = optimize_sql("SELECT x FROM t WHERE NOT NOT x > 1");
255        if let Statement::Select(sel) = stmt {
256            // Should simplify to x > 1 (no NOT)
257            assert!(sel.where_clause.is_some());
258        }
259    }
260}