Skip to main content

sqlglot_rust/optimizer/
mod.rs

1//! Query optimization passes.
2//!
3//! Inspired by Python sqlglot's optimizer module. Currently implements:
4//! - Constant folding (e.g., `1 + 2` → `3`)
5//! - Boolean simplification (e.g., `TRUE AND x` → `x`)
6//! - Dead predicate elimination (e.g., `WHERE TRUE`)
7//! - Subquery unnesting / decorrelation (EXISTS, IN → JOINs)
8//! - Column qualification (qualify_columns — resolve `*`, add table qualifiers)
9//!
10//! Future optimizations:
11//! - Predicate pushdown
12//! - Join reordering
13//! - Column pruning
14
15pub mod qualify_columns;
16pub mod scope_analysis;
17pub mod unnest_subqueries;
18
19use crate::ast::*;
20use crate::errors::Result;
21
22/// Optimize a SQL statement by applying transformation passes.
23pub fn optimize(statement: Statement) -> Result<Statement> {
24    let mut stmt = statement;
25    stmt = fold_constants(stmt);
26    stmt = simplify_booleans(stmt);
27    stmt = unnest_subqueries::unnest_subqueries(stmt);
28    Ok(stmt)
29}
30
31/// Fold constant expressions (e.g., `1 + 2` → `3`).
32fn fold_constants(statement: Statement) -> Statement {
33    match statement {
34        Statement::Select(mut sel) => {
35            if let Some(wh) = sel.where_clause {
36                sel.where_clause = Some(fold_expr(wh));
37            }
38            if let Some(having) = sel.having {
39                sel.having = Some(fold_expr(having));
40            }
41            for item in &mut sel.columns {
42                if let SelectItem::Expr { expr, .. } = item {
43                    *expr = fold_expr(expr.clone());
44                }
45            }
46            Statement::Select(sel)
47        }
48        other => other,
49    }
50}
51
52fn fold_expr(expr: Expr) -> Expr {
53    match expr {
54        Expr::BinaryOp { left, op, right } => {
55            let left = fold_expr(*left);
56            let right = fold_expr(*right);
57
58            // Try numeric folding
59            if let (Expr::Number(l), Expr::Number(r)) = (&left, &right) {
60                if let (Ok(lv), Ok(rv)) = (l.parse::<f64>(), r.parse::<f64>()) {
61                    let result = match op {
62                        BinaryOperator::Plus => Some(lv + rv),
63                        BinaryOperator::Minus => Some(lv - rv),
64                        BinaryOperator::Multiply => Some(lv * rv),
65                        BinaryOperator::Divide if rv != 0.0 => Some(lv / rv),
66                        BinaryOperator::Modulo if rv != 0.0 => Some(lv % rv),
67                        _ => None,
68                    };
69                    if let Some(val) = result {
70                        // Emit integer if it's a whole number
71                        if val.fract() == 0.0 && val.abs() < i64::MAX as f64 {
72                            return Expr::Number(format!("{}", val as i64));
73                        }
74                        return Expr::Number(format!("{val}"));
75                    }
76
77                    // Try boolean folding for comparison
78                    let cmp = match op {
79                        BinaryOperator::Eq => Some(lv == rv),
80                        BinaryOperator::Neq => Some(lv != rv),
81                        BinaryOperator::Lt => Some(lv < rv),
82                        BinaryOperator::Gt => Some(lv > rv),
83                        BinaryOperator::LtEq => Some(lv <= rv),
84                        BinaryOperator::GtEq => Some(lv >= rv),
85                        _ => None,
86                    };
87                    if let Some(val) = cmp {
88                        return Expr::Boolean(val);
89                    }
90                }
91            }
92
93            // String concatenation folding
94            if matches!(op, BinaryOperator::Concat) {
95                if let (Expr::StringLiteral(l), Expr::StringLiteral(r)) = (&left, &right) {
96                    return Expr::StringLiteral(format!("{l}{r}"));
97                }
98            }
99
100            Expr::BinaryOp {
101                left: Box::new(left),
102                op,
103                right: Box::new(right),
104            }
105        }
106        Expr::UnaryOp {
107            op: UnaryOperator::Minus,
108            expr,
109        } => {
110            let inner = fold_expr(*expr);
111            if let Expr::Number(ref n) = inner {
112                if let Ok(v) = n.parse::<f64>() {
113                    let neg = -v;
114                    if neg.fract() == 0.0 && neg.abs() < i64::MAX as f64 {
115                        return Expr::Number(format!("{}", neg as i64));
116                    }
117                    return Expr::Number(format!("{neg}"));
118                }
119            }
120            Expr::UnaryOp {
121                op: UnaryOperator::Minus,
122                expr: Box::new(inner),
123            }
124        }
125        Expr::Nested(inner) => {
126            let folded = fold_expr(*inner);
127            if folded.is_literal() {
128                folded
129            } else {
130                Expr::Nested(Box::new(folded))
131            }
132        }
133        other => other,
134    }
135}
136
137/// Simplify boolean expressions.
138fn simplify_booleans(statement: Statement) -> Statement {
139    match statement {
140        Statement::Select(mut sel) => {
141            // Simplify boolean expressions in SELECT columns
142            for item in &mut sel.columns {
143                if let SelectItem::Expr { expr, .. } = item {
144                    *expr = simplify_bool_expr(expr.clone());
145                }
146            }
147            if let Some(wh) = sel.where_clause {
148                let simplified = simplify_bool_expr(wh);
149                // WHERE TRUE → no WHERE clause
150                if simplified == Expr::Boolean(true) {
151                    sel.where_clause = None;
152                } else {
153                    sel.where_clause = Some(simplified);
154                }
155            }
156            if let Some(having) = sel.having {
157                let simplified = simplify_bool_expr(having);
158                if simplified == Expr::Boolean(true) {
159                    sel.having = None;
160                } else {
161                    sel.having = Some(simplified);
162                }
163            }
164            Statement::Select(sel)
165        }
166        other => other,
167    }
168}
169
170fn simplify_bool_expr(expr: Expr) -> Expr {
171    match expr {
172        Expr::BinaryOp {
173            left,
174            op: BinaryOperator::And,
175            right,
176        } => {
177            let left = simplify_bool_expr(*left);
178            let right = simplify_bool_expr(*right);
179            match (&left, &right) {
180                (Expr::Boolean(true), _) => right,
181                (_, Expr::Boolean(true)) => left,
182                (Expr::Boolean(false), _) | (_, Expr::Boolean(false)) => Expr::Boolean(false),
183                _ => Expr::BinaryOp {
184                    left: Box::new(left),
185                    op: BinaryOperator::And,
186                    right: Box::new(right),
187                },
188            }
189        }
190        Expr::BinaryOp {
191            left,
192            op: BinaryOperator::Or,
193            right,
194        } => {
195            let left = simplify_bool_expr(*left);
196            let right = simplify_bool_expr(*right);
197            match (&left, &right) {
198                (Expr::Boolean(true), _) | (_, Expr::Boolean(true)) => Expr::Boolean(true),
199                (Expr::Boolean(false), _) => right,
200                (_, Expr::Boolean(false)) => left,
201                _ => Expr::BinaryOp {
202                    left: Box::new(left),
203                    op: BinaryOperator::Or,
204                    right: Box::new(right),
205                },
206            }
207        }
208        Expr::UnaryOp {
209            op: UnaryOperator::Not,
210            expr,
211        } => {
212            let inner = simplify_bool_expr(*expr);
213            match inner {
214                Expr::Boolean(b) => Expr::Boolean(!b),
215                Expr::UnaryOp {
216                    op: UnaryOperator::Not,
217                    expr: inner2,
218                } => *inner2,
219                other => Expr::UnaryOp {
220                    op: UnaryOperator::Not,
221                    expr: Box::new(other),
222                },
223            }
224        }
225        Expr::Nested(inner) => {
226            let simplified = simplify_bool_expr(*inner);
227            if simplified.is_literal() {
228                simplified
229            } else {
230                Expr::Nested(Box::new(simplified))
231            }
232        }
233        other => other,
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use crate::parser::Parser;
241
242    fn optimize_sql(sql: &str) -> Statement {
243        let stmt = Parser::new(sql).unwrap().parse_statement().unwrap();
244        optimize(stmt).unwrap()
245    }
246
247    #[test]
248    fn test_constant_folding() {
249        let stmt = optimize_sql("SELECT 1 + 2 FROM t");
250        if let Statement::Select(sel) = stmt {
251            if let SelectItem::Expr { expr, .. } = &sel.columns[0] {
252                assert_eq!(*expr, Expr::Number("3".to_string()));
253            }
254        }
255    }
256
257    #[test]
258    fn test_boolean_simplification_where_true() {
259        let stmt = optimize_sql("SELECT x FROM t WHERE TRUE");
260        if let Statement::Select(sel) = stmt {
261            assert!(sel.where_clause.is_none());
262        }
263    }
264
265    #[test]
266    fn test_boolean_simplification_and_true() {
267        let stmt = optimize_sql("SELECT x FROM t WHERE TRUE AND x > 1");
268        if let Statement::Select(sel) = stmt {
269            // Should simplify to just x > 1
270            assert!(sel.where_clause.is_some());
271            assert!(!matches!(
272                &sel.where_clause,
273                Some(Expr::BinaryOp {
274                    op: BinaryOperator::And,
275                    ..
276                })
277            ));
278        }
279    }
280
281    #[test]
282    fn test_double_negation() {
283        let stmt = optimize_sql("SELECT x FROM t WHERE NOT NOT x > 1");
284        if let Statement::Select(sel) = stmt {
285            // Should simplify to x > 1 (no NOT)
286            assert!(sel.where_clause.is_some());
287        }
288    }
289}