Skip to main content

sqlglot_rust/optimizer/
mod.rs

1//! Query optimization passes.
2//!
3//! Inspired by Python sqlglot's optimizer module. Currently implements:
4//! - Constant folding (e.g., `1 + 2` → `3`)
5//! - Boolean simplification (e.g., `TRUE AND x` → `x`)
6//! - Dead predicate elimination (e.g., `WHERE TRUE`)
7//! - Subquery unnesting / decorrelation (EXISTS, IN → JOINs)
8//! - Predicate pushdown (WHERE → derived tables / JOIN ON)
9//! - Column qualification (qualify_columns — resolve `*`, add table qualifiers)
10//! - Type annotation (annotate_types — infer SQL types for all AST nodes)
11//! - Column lineage (lineage — trace data flow from source to output columns)
12//!
13//! Future optimizations:
14//! - Join reordering
15//! - Column pruning
16
17pub mod annotate_types;
18pub mod lineage;
19pub mod pushdown_predicates;
20pub mod qualify_columns;
21pub mod scope_analysis;
22pub mod unnest_subqueries;
23
24use crate::ast::*;
25use crate::errors::Result;
26
27/// Optimize a SQL statement by applying transformation passes.
28pub fn optimize(statement: Statement) -> Result<Statement> {
29    let mut stmt = statement;
30    stmt = fold_constants(stmt);
31    stmt = simplify_booleans(stmt);
32    stmt = unnest_subqueries::unnest_subqueries(stmt);
33    stmt = pushdown_predicates::pushdown_predicates(stmt);
34    Ok(stmt)
35}
36
37/// Fold constant expressions (e.g., `1 + 2` → `3`).
38fn fold_constants(statement: Statement) -> Statement {
39    match statement {
40        Statement::Select(mut sel) => {
41            if let Some(wh) = sel.where_clause {
42                sel.where_clause = Some(fold_expr(wh));
43            }
44            if let Some(having) = sel.having {
45                sel.having = Some(fold_expr(having));
46            }
47            for item in &mut sel.columns {
48                if let SelectItem::Expr { expr, .. } = item {
49                    *expr = fold_expr(expr.clone());
50                }
51            }
52            Statement::Select(sel)
53        }
54        other => other,
55    }
56}
57
58fn fold_expr(expr: Expr) -> Expr {
59    match expr {
60        Expr::BinaryOp { left, op, right } => {
61            let left = fold_expr(*left);
62            let right = fold_expr(*right);
63
64            // Try numeric folding
65            if let (Expr::Number(l), Expr::Number(r)) = (&left, &right) {
66                if let (Ok(lv), Ok(rv)) = (l.parse::<f64>(), r.parse::<f64>()) {
67                    let result = match op {
68                        BinaryOperator::Plus => Some(lv + rv),
69                        BinaryOperator::Minus => Some(lv - rv),
70                        BinaryOperator::Multiply => Some(lv * rv),
71                        BinaryOperator::Divide if rv != 0.0 => Some(lv / rv),
72                        BinaryOperator::Modulo if rv != 0.0 => Some(lv % rv),
73                        _ => None,
74                    };
75                    if let Some(val) = result {
76                        // Emit integer if it's a whole number
77                        if val.fract() == 0.0 && val.abs() < i64::MAX as f64 {
78                            return Expr::Number(format!("{}", val as i64));
79                        }
80                        return Expr::Number(format!("{val}"));
81                    }
82
83                    // Try boolean folding for comparison
84                    let cmp = match op {
85                        BinaryOperator::Eq => Some(lv == rv),
86                        BinaryOperator::Neq => Some(lv != rv),
87                        BinaryOperator::Lt => Some(lv < rv),
88                        BinaryOperator::Gt => Some(lv > rv),
89                        BinaryOperator::LtEq => Some(lv <= rv),
90                        BinaryOperator::GtEq => Some(lv >= rv),
91                        _ => None,
92                    };
93                    if let Some(val) = cmp {
94                        return Expr::Boolean(val);
95                    }
96                }
97            }
98
99            // String concatenation folding
100            if matches!(op, BinaryOperator::Concat) {
101                match (&left, &right) {
102                    (Expr::NationalStringLiteral(l), Expr::NationalStringLiteral(r)) => {
103                        return Expr::NationalStringLiteral(format!("{l}{r}"));
104                    }
105                    (Expr::NationalStringLiteral(l), Expr::StringLiteral(r))
106                    | (Expr::StringLiteral(l), Expr::NationalStringLiteral(r)) => {
107                        return Expr::NationalStringLiteral(format!("{l}{r}"));
108                    }
109                    (Expr::StringLiteral(l), Expr::StringLiteral(r)) => {
110                        return Expr::StringLiteral(format!("{l}{r}"));
111                    }
112                    _ => {}
113                }
114            }
115
116            Expr::BinaryOp {
117                left: Box::new(left),
118                op,
119                right: Box::new(right),
120            }
121        }
122        Expr::UnaryOp {
123            op: UnaryOperator::Minus,
124            expr,
125        } => {
126            let inner = fold_expr(*expr);
127            if let Expr::Number(ref n) = inner {
128                if let Ok(v) = n.parse::<f64>() {
129                    let neg = -v;
130                    if neg.fract() == 0.0 && neg.abs() < i64::MAX as f64 {
131                        return Expr::Number(format!("{}", neg as i64));
132                    }
133                    return Expr::Number(format!("{neg}"));
134                }
135            }
136            Expr::UnaryOp {
137                op: UnaryOperator::Minus,
138                expr: Box::new(inner),
139            }
140        }
141        Expr::Nested(inner) => {
142            let folded = fold_expr(*inner);
143            if folded.is_literal() {
144                folded
145            } else {
146                Expr::Nested(Box::new(folded))
147            }
148        }
149        other => other,
150    }
151}
152
153/// Simplify boolean expressions.
154fn simplify_booleans(statement: Statement) -> Statement {
155    match statement {
156        Statement::Select(mut sel) => {
157            // Simplify boolean expressions in SELECT columns
158            for item in &mut sel.columns {
159                if let SelectItem::Expr { expr, .. } = item {
160                    *expr = simplify_bool_expr(expr.clone());
161                }
162            }
163            if let Some(wh) = sel.where_clause {
164                let simplified = simplify_bool_expr(wh);
165                // WHERE TRUE → no WHERE clause
166                if simplified == Expr::Boolean(true) {
167                    sel.where_clause = None;
168                } else {
169                    sel.where_clause = Some(simplified);
170                }
171            }
172            if let Some(having) = sel.having {
173                let simplified = simplify_bool_expr(having);
174                if simplified == Expr::Boolean(true) {
175                    sel.having = None;
176                } else {
177                    sel.having = Some(simplified);
178                }
179            }
180            Statement::Select(sel)
181        }
182        other => other,
183    }
184}
185
186fn simplify_bool_expr(expr: Expr) -> Expr {
187    match expr {
188        Expr::BinaryOp {
189            left,
190            op: BinaryOperator::And,
191            right,
192        } => {
193            let left = simplify_bool_expr(*left);
194            let right = simplify_bool_expr(*right);
195            match (&left, &right) {
196                (Expr::Boolean(true), _) => right,
197                (_, Expr::Boolean(true)) => left,
198                (Expr::Boolean(false), _) | (_, Expr::Boolean(false)) => Expr::Boolean(false),
199                _ => Expr::BinaryOp {
200                    left: Box::new(left),
201                    op: BinaryOperator::And,
202                    right: Box::new(right),
203                },
204            }
205        }
206        Expr::BinaryOp {
207            left,
208            op: BinaryOperator::Or,
209            right,
210        } => {
211            let left = simplify_bool_expr(*left);
212            let right = simplify_bool_expr(*right);
213            match (&left, &right) {
214                (Expr::Boolean(true), _) | (_, Expr::Boolean(true)) => Expr::Boolean(true),
215                (Expr::Boolean(false), _) => right,
216                (_, Expr::Boolean(false)) => left,
217                _ => Expr::BinaryOp {
218                    left: Box::new(left),
219                    op: BinaryOperator::Or,
220                    right: Box::new(right),
221                },
222            }
223        }
224        Expr::UnaryOp {
225            op: UnaryOperator::Not,
226            expr,
227        } => {
228            let inner = simplify_bool_expr(*expr);
229            match inner {
230                Expr::Boolean(b) => Expr::Boolean(!b),
231                Expr::UnaryOp {
232                    op: UnaryOperator::Not,
233                    expr: inner2,
234                } => *inner2,
235                other => Expr::UnaryOp {
236                    op: UnaryOperator::Not,
237                    expr: Box::new(other),
238                },
239            }
240        }
241        Expr::Nested(inner) => {
242            let simplified = simplify_bool_expr(*inner);
243            if simplified.is_literal() {
244                simplified
245            } else {
246                Expr::Nested(Box::new(simplified))
247            }
248        }
249        other => other,
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use crate::parser::Parser;
257
258    fn optimize_sql(sql: &str) -> Statement {
259        let stmt = Parser::new(sql).unwrap().parse_statement().unwrap();
260        optimize(stmt).unwrap()
261    }
262
263    #[test]
264    fn test_constant_folding() {
265        let stmt = optimize_sql("SELECT 1 + 2 FROM t");
266        if let Statement::Select(sel) = stmt {
267            if let SelectItem::Expr { expr, .. } = &sel.columns[0] {
268                assert_eq!(*expr, Expr::Number("3".to_string()));
269            }
270        }
271    }
272
273    #[test]
274    fn test_boolean_simplification_where_true() {
275        let stmt = optimize_sql("SELECT x FROM t WHERE TRUE");
276        if let Statement::Select(sel) = stmt {
277            assert!(sel.where_clause.is_none());
278        }
279    }
280
281    #[test]
282    fn test_boolean_simplification_and_true() {
283        let stmt = optimize_sql("SELECT x FROM t WHERE TRUE AND x > 1");
284        if let Statement::Select(sel) = stmt {
285            // Should simplify to just x > 1
286            assert!(sel.where_clause.is_some());
287            assert!(!matches!(
288                &sel.where_clause,
289                Some(Expr::BinaryOp {
290                    op: BinaryOperator::And,
291                    ..
292                })
293            ));
294        }
295    }
296
297    #[test]
298    fn test_double_negation() {
299        let stmt = optimize_sql("SELECT x FROM t WHERE NOT NOT x > 1");
300        if let Statement::Select(sel) = stmt {
301            // Should simplify to x > 1 (no NOT)
302            assert!(sel.where_clause.is_some());
303        }
304    }
305}