Skip to main content

sqlglot_rust/optimizer/
mod.rs

1//! Query optimization passes.
2//!
3//! Inspired by Python sqlglot's optimizer module. Currently implements:
4//! - Constant folding (e.g., `1 + 2` → `3`)
5//! - Boolean simplification (e.g., `TRUE AND x` → `x`)
6//! - Dead predicate elimination (e.g., `WHERE TRUE`)
7//! - Subquery unnesting / decorrelation (EXISTS, IN → JOINs)
8//! - Column qualification (qualify_columns — resolve `*`, add table qualifiers)
9//! - Type annotation (annotate_types — infer SQL types for all AST nodes)
10//!
11//! Future optimizations:
12//! - Predicate pushdown
13//! - Join reordering
14//! - Column pruning
15
16pub mod annotate_types;
17pub mod qualify_columns;
18pub mod scope_analysis;
19pub mod unnest_subqueries;
20
21use crate::ast::*;
22use crate::errors::Result;
23
24/// Optimize a SQL statement by applying transformation passes.
25pub fn optimize(statement: Statement) -> Result<Statement> {
26    let mut stmt = statement;
27    stmt = fold_constants(stmt);
28    stmt = simplify_booleans(stmt);
29    stmt = unnest_subqueries::unnest_subqueries(stmt);
30    Ok(stmt)
31}
32
33/// Fold constant expressions (e.g., `1 + 2` → `3`).
34fn fold_constants(statement: Statement) -> Statement {
35    match statement {
36        Statement::Select(mut sel) => {
37            if let Some(wh) = sel.where_clause {
38                sel.where_clause = Some(fold_expr(wh));
39            }
40            if let Some(having) = sel.having {
41                sel.having = Some(fold_expr(having));
42            }
43            for item in &mut sel.columns {
44                if let SelectItem::Expr { expr, .. } = item {
45                    *expr = fold_expr(expr.clone());
46                }
47            }
48            Statement::Select(sel)
49        }
50        other => other,
51    }
52}
53
54fn fold_expr(expr: Expr) -> Expr {
55    match expr {
56        Expr::BinaryOp { left, op, right } => {
57            let left = fold_expr(*left);
58            let right = fold_expr(*right);
59
60            // Try numeric folding
61            if let (Expr::Number(l), Expr::Number(r)) = (&left, &right) {
62                if let (Ok(lv), Ok(rv)) = (l.parse::<f64>(), r.parse::<f64>()) {
63                    let result = match op {
64                        BinaryOperator::Plus => Some(lv + rv),
65                        BinaryOperator::Minus => Some(lv - rv),
66                        BinaryOperator::Multiply => Some(lv * rv),
67                        BinaryOperator::Divide if rv != 0.0 => Some(lv / rv),
68                        BinaryOperator::Modulo if rv != 0.0 => Some(lv % rv),
69                        _ => None,
70                    };
71                    if let Some(val) = result {
72                        // Emit integer if it's a whole number
73                        if val.fract() == 0.0 && val.abs() < i64::MAX as f64 {
74                            return Expr::Number(format!("{}", val as i64));
75                        }
76                        return Expr::Number(format!("{val}"));
77                    }
78
79                    // Try boolean folding for comparison
80                    let cmp = match op {
81                        BinaryOperator::Eq => Some(lv == rv),
82                        BinaryOperator::Neq => Some(lv != rv),
83                        BinaryOperator::Lt => Some(lv < rv),
84                        BinaryOperator::Gt => Some(lv > rv),
85                        BinaryOperator::LtEq => Some(lv <= rv),
86                        BinaryOperator::GtEq => Some(lv >= rv),
87                        _ => None,
88                    };
89                    if let Some(val) = cmp {
90                        return Expr::Boolean(val);
91                    }
92                }
93            }
94
95            // String concatenation folding
96            if matches!(op, BinaryOperator::Concat) {
97                if let (Expr::StringLiteral(l), Expr::StringLiteral(r)) = (&left, &right) {
98                    return Expr::StringLiteral(format!("{l}{r}"));
99                }
100            }
101
102            Expr::BinaryOp {
103                left: Box::new(left),
104                op,
105                right: Box::new(right),
106            }
107        }
108        Expr::UnaryOp {
109            op: UnaryOperator::Minus,
110            expr,
111        } => {
112            let inner = fold_expr(*expr);
113            if let Expr::Number(ref n) = inner {
114                if let Ok(v) = n.parse::<f64>() {
115                    let neg = -v;
116                    if neg.fract() == 0.0 && neg.abs() < i64::MAX as f64 {
117                        return Expr::Number(format!("{}", neg as i64));
118                    }
119                    return Expr::Number(format!("{neg}"));
120                }
121            }
122            Expr::UnaryOp {
123                op: UnaryOperator::Minus,
124                expr: Box::new(inner),
125            }
126        }
127        Expr::Nested(inner) => {
128            let folded = fold_expr(*inner);
129            if folded.is_literal() {
130                folded
131            } else {
132                Expr::Nested(Box::new(folded))
133            }
134        }
135        other => other,
136    }
137}
138
139/// Simplify boolean expressions.
140fn simplify_booleans(statement: Statement) -> Statement {
141    match statement {
142        Statement::Select(mut sel) => {
143            // Simplify boolean expressions in SELECT columns
144            for item in &mut sel.columns {
145                if let SelectItem::Expr { expr, .. } = item {
146                    *expr = simplify_bool_expr(expr.clone());
147                }
148            }
149            if let Some(wh) = sel.where_clause {
150                let simplified = simplify_bool_expr(wh);
151                // WHERE TRUE → no WHERE clause
152                if simplified == Expr::Boolean(true) {
153                    sel.where_clause = None;
154                } else {
155                    sel.where_clause = Some(simplified);
156                }
157            }
158            if let Some(having) = sel.having {
159                let simplified = simplify_bool_expr(having);
160                if simplified == Expr::Boolean(true) {
161                    sel.having = None;
162                } else {
163                    sel.having = Some(simplified);
164                }
165            }
166            Statement::Select(sel)
167        }
168        other => other,
169    }
170}
171
172fn simplify_bool_expr(expr: Expr) -> Expr {
173    match expr {
174        Expr::BinaryOp {
175            left,
176            op: BinaryOperator::And,
177            right,
178        } => {
179            let left = simplify_bool_expr(*left);
180            let right = simplify_bool_expr(*right);
181            match (&left, &right) {
182                (Expr::Boolean(true), _) => right,
183                (_, Expr::Boolean(true)) => left,
184                (Expr::Boolean(false), _) | (_, Expr::Boolean(false)) => Expr::Boolean(false),
185                _ => Expr::BinaryOp {
186                    left: Box::new(left),
187                    op: BinaryOperator::And,
188                    right: Box::new(right),
189                },
190            }
191        }
192        Expr::BinaryOp {
193            left,
194            op: BinaryOperator::Or,
195            right,
196        } => {
197            let left = simplify_bool_expr(*left);
198            let right = simplify_bool_expr(*right);
199            match (&left, &right) {
200                (Expr::Boolean(true), _) | (_, Expr::Boolean(true)) => Expr::Boolean(true),
201                (Expr::Boolean(false), _) => right,
202                (_, Expr::Boolean(false)) => left,
203                _ => Expr::BinaryOp {
204                    left: Box::new(left),
205                    op: BinaryOperator::Or,
206                    right: Box::new(right),
207                },
208            }
209        }
210        Expr::UnaryOp {
211            op: UnaryOperator::Not,
212            expr,
213        } => {
214            let inner = simplify_bool_expr(*expr);
215            match inner {
216                Expr::Boolean(b) => Expr::Boolean(!b),
217                Expr::UnaryOp {
218                    op: UnaryOperator::Not,
219                    expr: inner2,
220                } => *inner2,
221                other => Expr::UnaryOp {
222                    op: UnaryOperator::Not,
223                    expr: Box::new(other),
224                },
225            }
226        }
227        Expr::Nested(inner) => {
228            let simplified = simplify_bool_expr(*inner);
229            if simplified.is_literal() {
230                simplified
231            } else {
232                Expr::Nested(Box::new(simplified))
233            }
234        }
235        other => other,
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use crate::parser::Parser;
243
244    fn optimize_sql(sql: &str) -> Statement {
245        let stmt = Parser::new(sql).unwrap().parse_statement().unwrap();
246        optimize(stmt).unwrap()
247    }
248
249    #[test]
250    fn test_constant_folding() {
251        let stmt = optimize_sql("SELECT 1 + 2 FROM t");
252        if let Statement::Select(sel) = stmt {
253            if let SelectItem::Expr { expr, .. } = &sel.columns[0] {
254                assert_eq!(*expr, Expr::Number("3".to_string()));
255            }
256        }
257    }
258
259    #[test]
260    fn test_boolean_simplification_where_true() {
261        let stmt = optimize_sql("SELECT x FROM t WHERE TRUE");
262        if let Statement::Select(sel) = stmt {
263            assert!(sel.where_clause.is_none());
264        }
265    }
266
267    #[test]
268    fn test_boolean_simplification_and_true() {
269        let stmt = optimize_sql("SELECT x FROM t WHERE TRUE AND x > 1");
270        if let Statement::Select(sel) = stmt {
271            // Should simplify to just x > 1
272            assert!(sel.where_clause.is_some());
273            assert!(!matches!(
274                &sel.where_clause,
275                Some(Expr::BinaryOp {
276                    op: BinaryOperator::And,
277                    ..
278                })
279            ));
280        }
281    }
282
283    #[test]
284    fn test_double_negation() {
285        let stmt = optimize_sql("SELECT x FROM t WHERE NOT NOT x > 1");
286        if let Statement::Select(sel) = stmt {
287            // Should simplify to x > 1 (no NOT)
288            assert!(sel.where_clause.is_some());
289        }
290    }
291}