Skip to main content

sqlglot_rust/optimizer/
mod.rs

1//! Query optimization passes.
2//!
3//! Inspired by Python sqlglot's optimizer module. Currently implements:
4//! - Constant folding (e.g., `1 + 2` → `3`)
5//! - Boolean simplification (e.g., `TRUE AND x` → `x`)
6//! - Dead predicate elimination (e.g., `WHERE TRUE`)
7//! - Subquery unnesting / decorrelation (EXISTS, IN → JOINs)
8//!
9//! Future optimizations:
10//! - Predicate pushdown
11//! - Join reordering
12//! - Column pruning
13
14pub mod unnest_subqueries;
15
16use crate::ast::*;
17use crate::errors::Result;
18
19/// Optimize a SQL statement by applying transformation passes.
20pub fn optimize(statement: Statement) -> Result<Statement> {
21    let mut stmt = statement;
22    stmt = fold_constants(stmt);
23    stmt = simplify_booleans(stmt);
24    stmt = unnest_subqueries::unnest_subqueries(stmt);
25    Ok(stmt)
26}
27
28/// Fold constant expressions (e.g., `1 + 2` → `3`).
29fn fold_constants(statement: Statement) -> Statement {
30    match statement {
31        Statement::Select(mut sel) => {
32            if let Some(wh) = sel.where_clause {
33                sel.where_clause = Some(fold_expr(wh));
34            }
35            if let Some(having) = sel.having {
36                sel.having = Some(fold_expr(having));
37            }
38            for item in &mut sel.columns {
39                if let SelectItem::Expr { expr, .. } = item {
40                    *expr = fold_expr(expr.clone());
41                }
42            }
43            Statement::Select(sel)
44        }
45        other => other,
46    }
47}
48
49fn fold_expr(expr: Expr) -> Expr {
50    match expr {
51        Expr::BinaryOp { left, op, right } => {
52            let left = fold_expr(*left);
53            let right = fold_expr(*right);
54
55            // Try numeric folding
56            if let (Expr::Number(l), Expr::Number(r)) = (&left, &right) {
57                if let (Ok(lv), Ok(rv)) = (l.parse::<f64>(), r.parse::<f64>()) {
58                    let result = match op {
59                        BinaryOperator::Plus => Some(lv + rv),
60                        BinaryOperator::Minus => Some(lv - rv),
61                        BinaryOperator::Multiply => Some(lv * rv),
62                        BinaryOperator::Divide if rv != 0.0 => Some(lv / rv),
63                        BinaryOperator::Modulo if rv != 0.0 => Some(lv % rv),
64                        _ => None,
65                    };
66                    if let Some(val) = result {
67                        // Emit integer if it's a whole number
68                        if val.fract() == 0.0 && val.abs() < i64::MAX as f64 {
69                            return Expr::Number(format!("{}", val as i64));
70                        }
71                        return Expr::Number(format!("{val}"));
72                    }
73
74                    // Try boolean folding for comparison
75                    let cmp = match op {
76                        BinaryOperator::Eq => Some(lv == rv),
77                        BinaryOperator::Neq => Some(lv != rv),
78                        BinaryOperator::Lt => Some(lv < rv),
79                        BinaryOperator::Gt => Some(lv > rv),
80                        BinaryOperator::LtEq => Some(lv <= rv),
81                        BinaryOperator::GtEq => Some(lv >= rv),
82                        _ => None,
83                    };
84                    if let Some(val) = cmp {
85                        return Expr::Boolean(val);
86                    }
87                }
88            }
89
90            // String concatenation folding
91            if matches!(op, BinaryOperator::Concat) {
92                if let (Expr::StringLiteral(l), Expr::StringLiteral(r)) = (&left, &right) {
93                    return Expr::StringLiteral(format!("{l}{r}"));
94                }
95            }
96
97            Expr::BinaryOp {
98                left: Box::new(left),
99                op,
100                right: Box::new(right),
101            }
102        }
103        Expr::UnaryOp { op: UnaryOperator::Minus, expr } => {
104            let inner = fold_expr(*expr);
105            if let Expr::Number(ref n) = inner {
106                if let Ok(v) = n.parse::<f64>() {
107                    let neg = -v;
108                    if neg.fract() == 0.0 && neg.abs() < i64::MAX as f64 {
109                        return Expr::Number(format!("{}", neg as i64));
110                    }
111                    return Expr::Number(format!("{neg}"));
112                }
113            }
114            Expr::UnaryOp {
115                op: UnaryOperator::Minus,
116                expr: Box::new(inner),
117            }
118        }
119        Expr::Nested(inner) => {
120            let folded = fold_expr(*inner);
121            if folded.is_literal() {
122                folded
123            } else {
124                Expr::Nested(Box::new(folded))
125            }
126        }
127        other => other,
128    }
129}
130
131/// Simplify boolean expressions.
132fn simplify_booleans(statement: Statement) -> Statement {
133    match statement {
134        Statement::Select(mut sel) => {
135            // Simplify boolean expressions in SELECT columns
136            for item in &mut sel.columns {
137                if let SelectItem::Expr { expr, .. } = item {
138                    *expr = simplify_bool_expr(expr.clone());
139                }
140            }
141            if let Some(wh) = sel.where_clause {
142                let simplified = simplify_bool_expr(wh);
143                // WHERE TRUE → no WHERE clause
144                if simplified == Expr::Boolean(true) {
145                    sel.where_clause = None;
146                } else {
147                    sel.where_clause = Some(simplified);
148                }
149            }
150            if let Some(having) = sel.having {
151                let simplified = simplify_bool_expr(having);
152                if simplified == Expr::Boolean(true) {
153                    sel.having = None;
154                } else {
155                    sel.having = Some(simplified);
156                }
157            }
158            Statement::Select(sel)
159        }
160        other => other,
161    }
162}
163
164fn simplify_bool_expr(expr: Expr) -> Expr {
165    match expr {
166        Expr::BinaryOp { left, op: BinaryOperator::And, right } => {
167            let left = simplify_bool_expr(*left);
168            let right = simplify_bool_expr(*right);
169            match (&left, &right) {
170                (Expr::Boolean(true), _) => right,
171                (_, Expr::Boolean(true)) => left,
172                (Expr::Boolean(false), _) | (_, Expr::Boolean(false)) => Expr::Boolean(false),
173                _ => Expr::BinaryOp {
174                    left: Box::new(left),
175                    op: BinaryOperator::And,
176                    right: Box::new(right),
177                },
178            }
179        }
180        Expr::BinaryOp { left, op: BinaryOperator::Or, right } => {
181            let left = simplify_bool_expr(*left);
182            let right = simplify_bool_expr(*right);
183            match (&left, &right) {
184                (Expr::Boolean(true), _) | (_, Expr::Boolean(true)) => Expr::Boolean(true),
185                (Expr::Boolean(false), _) => right,
186                (_, Expr::Boolean(false)) => left,
187                _ => Expr::BinaryOp {
188                    left: Box::new(left),
189                    op: BinaryOperator::Or,
190                    right: Box::new(right),
191                },
192            }
193        }
194        Expr::UnaryOp { op: UnaryOperator::Not, expr } => {
195            let inner = simplify_bool_expr(*expr);
196            match inner {
197                Expr::Boolean(b) => Expr::Boolean(!b),
198                Expr::UnaryOp { op: UnaryOperator::Not, expr: inner2 } => *inner2,
199                other => Expr::UnaryOp {
200                    op: UnaryOperator::Not,
201                    expr: Box::new(other),
202                },
203            }
204        }
205        Expr::Nested(inner) => {
206            let simplified = simplify_bool_expr(*inner);
207            if simplified.is_literal() {
208                simplified
209            } else {
210                Expr::Nested(Box::new(simplified))
211            }
212        }
213        other => other,
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use crate::parser::Parser;
221
222    fn optimize_sql(sql: &str) -> Statement {
223        let stmt = Parser::new(sql).unwrap().parse_statement().unwrap();
224        optimize(stmt).unwrap()
225    }
226
227    #[test]
228    fn test_constant_folding() {
229        let stmt = optimize_sql("SELECT 1 + 2 FROM t");
230        if let Statement::Select(sel) = stmt {
231            if let SelectItem::Expr { expr, .. } = &sel.columns[0] {
232                assert_eq!(*expr, Expr::Number("3".to_string()));
233            }
234        }
235    }
236
237    #[test]
238    fn test_boolean_simplification_where_true() {
239        let stmt = optimize_sql("SELECT x FROM t WHERE TRUE");
240        if let Statement::Select(sel) = stmt {
241            assert!(sel.where_clause.is_none());
242        }
243    }
244
245    #[test]
246    fn test_boolean_simplification_and_true() {
247        let stmt = optimize_sql("SELECT x FROM t WHERE TRUE AND x > 1");
248        if let Statement::Select(sel) = stmt {
249            // Should simplify to just x > 1
250            assert!(sel.where_clause.is_some());
251            assert!(!matches!(&sel.where_clause, Some(Expr::BinaryOp { op: BinaryOperator::And, .. })));
252        }
253    }
254
255    #[test]
256    fn test_double_negation() {
257        let stmt = optimize_sql("SELECT x FROM t WHERE NOT NOT x > 1");
258        if let Statement::Select(sel) = stmt {
259            // Should simplify to x > 1 (no NOT)
260            assert!(sel.where_clause.is_some());
261        }
262    }
263}