Skip to main content

sqlglot_rust/optimizer/
mod.rs

1//! Query optimization passes.
2//!
3//! Inspired by Python sqlglot's optimizer module. Currently implements:
4//! - Constant folding (e.g., `1 + 2` → `3`)
5//! - Boolean simplification (e.g., `TRUE AND x` → `x`)
6//! - Dead predicate elimination (e.g., `WHERE TRUE`)
7//! - Subquery unnesting / decorrelation (EXISTS, IN → JOINs)
8//!
9//! Future optimizations:
10//! - Predicate pushdown
11//! - Join reordering
12//! - Column pruning
13
14pub mod scope_analysis;
15pub mod unnest_subqueries;
16
17use crate::ast::*;
18use crate::errors::Result;
19
20/// Optimize a SQL statement by applying transformation passes.
21pub fn optimize(statement: Statement) -> Result<Statement> {
22    let mut stmt = statement;
23    stmt = fold_constants(stmt);
24    stmt = simplify_booleans(stmt);
25    stmt = unnest_subqueries::unnest_subqueries(stmt);
26    Ok(stmt)
27}
28
29/// Fold constant expressions (e.g., `1 + 2` → `3`).
30fn fold_constants(statement: Statement) -> Statement {
31    match statement {
32        Statement::Select(mut sel) => {
33            if let Some(wh) = sel.where_clause {
34                sel.where_clause = Some(fold_expr(wh));
35            }
36            if let Some(having) = sel.having {
37                sel.having = Some(fold_expr(having));
38            }
39            for item in &mut sel.columns {
40                if let SelectItem::Expr { expr, .. } = item {
41                    *expr = fold_expr(expr.clone());
42                }
43            }
44            Statement::Select(sel)
45        }
46        other => other,
47    }
48}
49
50fn fold_expr(expr: Expr) -> Expr {
51    match expr {
52        Expr::BinaryOp { left, op, right } => {
53            let left = fold_expr(*left);
54            let right = fold_expr(*right);
55
56            // Try numeric folding
57            if let (Expr::Number(l), Expr::Number(r)) = (&left, &right) {
58                if let (Ok(lv), Ok(rv)) = (l.parse::<f64>(), r.parse::<f64>()) {
59                    let result = match op {
60                        BinaryOperator::Plus => Some(lv + rv),
61                        BinaryOperator::Minus => Some(lv - rv),
62                        BinaryOperator::Multiply => Some(lv * rv),
63                        BinaryOperator::Divide if rv != 0.0 => Some(lv / rv),
64                        BinaryOperator::Modulo if rv != 0.0 => Some(lv % rv),
65                        _ => None,
66                    };
67                    if let Some(val) = result {
68                        // Emit integer if it's a whole number
69                        if val.fract() == 0.0 && val.abs() < i64::MAX as f64 {
70                            return Expr::Number(format!("{}", val as i64));
71                        }
72                        return Expr::Number(format!("{val}"));
73                    }
74
75                    // Try boolean folding for comparison
76                    let cmp = match op {
77                        BinaryOperator::Eq => Some(lv == rv),
78                        BinaryOperator::Neq => Some(lv != rv),
79                        BinaryOperator::Lt => Some(lv < rv),
80                        BinaryOperator::Gt => Some(lv > rv),
81                        BinaryOperator::LtEq => Some(lv <= rv),
82                        BinaryOperator::GtEq => Some(lv >= rv),
83                        _ => None,
84                    };
85                    if let Some(val) = cmp {
86                        return Expr::Boolean(val);
87                    }
88                }
89            }
90
91            // String concatenation folding
92            if matches!(op, BinaryOperator::Concat) {
93                if let (Expr::StringLiteral(l), Expr::StringLiteral(r)) = (&left, &right) {
94                    return Expr::StringLiteral(format!("{l}{r}"));
95                }
96            }
97
98            Expr::BinaryOp {
99                left: Box::new(left),
100                op,
101                right: Box::new(right),
102            }
103        }
104        Expr::UnaryOp {
105            op: UnaryOperator::Minus,
106            expr,
107        } => {
108            let inner = fold_expr(*expr);
109            if let Expr::Number(ref n) = inner {
110                if let Ok(v) = n.parse::<f64>() {
111                    let neg = -v;
112                    if neg.fract() == 0.0 && neg.abs() < i64::MAX as f64 {
113                        return Expr::Number(format!("{}", neg as i64));
114                    }
115                    return Expr::Number(format!("{neg}"));
116                }
117            }
118            Expr::UnaryOp {
119                op: UnaryOperator::Minus,
120                expr: Box::new(inner),
121            }
122        }
123        Expr::Nested(inner) => {
124            let folded = fold_expr(*inner);
125            if folded.is_literal() {
126                folded
127            } else {
128                Expr::Nested(Box::new(folded))
129            }
130        }
131        other => other,
132    }
133}
134
135/// Simplify boolean expressions.
136fn simplify_booleans(statement: Statement) -> Statement {
137    match statement {
138        Statement::Select(mut sel) => {
139            // Simplify boolean expressions in SELECT columns
140            for item in &mut sel.columns {
141                if let SelectItem::Expr { expr, .. } = item {
142                    *expr = simplify_bool_expr(expr.clone());
143                }
144            }
145            if let Some(wh) = sel.where_clause {
146                let simplified = simplify_bool_expr(wh);
147                // WHERE TRUE → no WHERE clause
148                if simplified == Expr::Boolean(true) {
149                    sel.where_clause = None;
150                } else {
151                    sel.where_clause = Some(simplified);
152                }
153            }
154            if let Some(having) = sel.having {
155                let simplified = simplify_bool_expr(having);
156                if simplified == Expr::Boolean(true) {
157                    sel.having = None;
158                } else {
159                    sel.having = Some(simplified);
160                }
161            }
162            Statement::Select(sel)
163        }
164        other => other,
165    }
166}
167
168fn simplify_bool_expr(expr: Expr) -> Expr {
169    match expr {
170        Expr::BinaryOp {
171            left,
172            op: BinaryOperator::And,
173            right,
174        } => {
175            let left = simplify_bool_expr(*left);
176            let right = simplify_bool_expr(*right);
177            match (&left, &right) {
178                (Expr::Boolean(true), _) => right,
179                (_, Expr::Boolean(true)) => left,
180                (Expr::Boolean(false), _) | (_, Expr::Boolean(false)) => Expr::Boolean(false),
181                _ => Expr::BinaryOp {
182                    left: Box::new(left),
183                    op: BinaryOperator::And,
184                    right: Box::new(right),
185                },
186            }
187        }
188        Expr::BinaryOp {
189            left,
190            op: BinaryOperator::Or,
191            right,
192        } => {
193            let left = simplify_bool_expr(*left);
194            let right = simplify_bool_expr(*right);
195            match (&left, &right) {
196                (Expr::Boolean(true), _) | (_, Expr::Boolean(true)) => Expr::Boolean(true),
197                (Expr::Boolean(false), _) => right,
198                (_, Expr::Boolean(false)) => left,
199                _ => Expr::BinaryOp {
200                    left: Box::new(left),
201                    op: BinaryOperator::Or,
202                    right: Box::new(right),
203                },
204            }
205        }
206        Expr::UnaryOp {
207            op: UnaryOperator::Not,
208            expr,
209        } => {
210            let inner = simplify_bool_expr(*expr);
211            match inner {
212                Expr::Boolean(b) => Expr::Boolean(!b),
213                Expr::UnaryOp {
214                    op: UnaryOperator::Not,
215                    expr: inner2,
216                } => *inner2,
217                other => Expr::UnaryOp {
218                    op: UnaryOperator::Not,
219                    expr: Box::new(other),
220                },
221            }
222        }
223        Expr::Nested(inner) => {
224            let simplified = simplify_bool_expr(*inner);
225            if simplified.is_literal() {
226                simplified
227            } else {
228                Expr::Nested(Box::new(simplified))
229            }
230        }
231        other => other,
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use crate::parser::Parser;
239
240    fn optimize_sql(sql: &str) -> Statement {
241        let stmt = Parser::new(sql).unwrap().parse_statement().unwrap();
242        optimize(stmt).unwrap()
243    }
244
245    #[test]
246    fn test_constant_folding() {
247        let stmt = optimize_sql("SELECT 1 + 2 FROM t");
248        if let Statement::Select(sel) = stmt {
249            if let SelectItem::Expr { expr, .. } = &sel.columns[0] {
250                assert_eq!(*expr, Expr::Number("3".to_string()));
251            }
252        }
253    }
254
255    #[test]
256    fn test_boolean_simplification_where_true() {
257        let stmt = optimize_sql("SELECT x FROM t WHERE TRUE");
258        if let Statement::Select(sel) = stmt {
259            assert!(sel.where_clause.is_none());
260        }
261    }
262
263    #[test]
264    fn test_boolean_simplification_and_true() {
265        let stmt = optimize_sql("SELECT x FROM t WHERE TRUE AND x > 1");
266        if let Statement::Select(sel) = stmt {
267            // Should simplify to just x > 1
268            assert!(sel.where_clause.is_some());
269            assert!(!matches!(
270                &sel.where_clause,
271                Some(Expr::BinaryOp {
272                    op: BinaryOperator::And,
273                    ..
274                })
275            ));
276        }
277    }
278
279    #[test]
280    fn test_double_negation() {
281        let stmt = optimize_sql("SELECT x FROM t WHERE NOT NOT x > 1");
282        if let Statement::Select(sel) = stmt {
283            // Should simplify to x > 1 (no NOT)
284            assert!(sel.where_clause.is_some());
285        }
286    }
287}