Skip to main content

sqlglot_rust/optimizer/
mod.rs

1//! Query optimization passes.
2//!
3//! Inspired by Python sqlglot's optimizer module. Currently implements:
4//! - Constant folding (e.g., `1 + 2` → `3`)
5//! - Boolean simplification (e.g., `TRUE AND x` → `x`)
6//! - Dead predicate elimination (e.g., `WHERE TRUE`)
7//! - Subquery unnesting / decorrelation (EXISTS, IN → JOINs)
8//! - Predicate pushdown (WHERE → derived tables / JOIN ON)
9//! - Column qualification (qualify_columns — resolve `*`, add table qualifiers)
10//! - Type annotation (annotate_types — infer SQL types for all AST nodes)
11//!
12//! Future optimizations:
13//! - Join reordering
14//! - Column pruning
15
16pub mod annotate_types;
17pub mod pushdown_predicates;
18pub mod qualify_columns;
19pub mod scope_analysis;
20pub mod unnest_subqueries;
21
22use crate::ast::*;
23use crate::errors::Result;
24
25/// Optimize a SQL statement by applying transformation passes.
26pub fn optimize(statement: Statement) -> Result<Statement> {
27    let mut stmt = statement;
28    stmt = fold_constants(stmt);
29    stmt = simplify_booleans(stmt);
30    stmt = unnest_subqueries::unnest_subqueries(stmt);
31    stmt = pushdown_predicates::pushdown_predicates(stmt);
32    Ok(stmt)
33}
34
35/// Fold constant expressions (e.g., `1 + 2` → `3`).
36fn fold_constants(statement: Statement) -> Statement {
37    match statement {
38        Statement::Select(mut sel) => {
39            if let Some(wh) = sel.where_clause {
40                sel.where_clause = Some(fold_expr(wh));
41            }
42            if let Some(having) = sel.having {
43                sel.having = Some(fold_expr(having));
44            }
45            for item in &mut sel.columns {
46                if let SelectItem::Expr { expr, .. } = item {
47                    *expr = fold_expr(expr.clone());
48                }
49            }
50            Statement::Select(sel)
51        }
52        other => other,
53    }
54}
55
56fn fold_expr(expr: Expr) -> Expr {
57    match expr {
58        Expr::BinaryOp { left, op, right } => {
59            let left = fold_expr(*left);
60            let right = fold_expr(*right);
61
62            // Try numeric folding
63            if let (Expr::Number(l), Expr::Number(r)) = (&left, &right) {
64                if let (Ok(lv), Ok(rv)) = (l.parse::<f64>(), r.parse::<f64>()) {
65                    let result = match op {
66                        BinaryOperator::Plus => Some(lv + rv),
67                        BinaryOperator::Minus => Some(lv - rv),
68                        BinaryOperator::Multiply => Some(lv * rv),
69                        BinaryOperator::Divide if rv != 0.0 => Some(lv / rv),
70                        BinaryOperator::Modulo if rv != 0.0 => Some(lv % rv),
71                        _ => None,
72                    };
73                    if let Some(val) = result {
74                        // Emit integer if it's a whole number
75                        if val.fract() == 0.0 && val.abs() < i64::MAX as f64 {
76                            return Expr::Number(format!("{}", val as i64));
77                        }
78                        return Expr::Number(format!("{val}"));
79                    }
80
81                    // Try boolean folding for comparison
82                    let cmp = match op {
83                        BinaryOperator::Eq => Some(lv == rv),
84                        BinaryOperator::Neq => Some(lv != rv),
85                        BinaryOperator::Lt => Some(lv < rv),
86                        BinaryOperator::Gt => Some(lv > rv),
87                        BinaryOperator::LtEq => Some(lv <= rv),
88                        BinaryOperator::GtEq => Some(lv >= rv),
89                        _ => None,
90                    };
91                    if let Some(val) = cmp {
92                        return Expr::Boolean(val);
93                    }
94                }
95            }
96
97            // String concatenation folding
98            if matches!(op, BinaryOperator::Concat) {
99                if let (Expr::StringLiteral(l), Expr::StringLiteral(r)) = (&left, &right) {
100                    return Expr::StringLiteral(format!("{l}{r}"));
101                }
102            }
103
104            Expr::BinaryOp {
105                left: Box::new(left),
106                op,
107                right: Box::new(right),
108            }
109        }
110        Expr::UnaryOp {
111            op: UnaryOperator::Minus,
112            expr,
113        } => {
114            let inner = fold_expr(*expr);
115            if let Expr::Number(ref n) = inner {
116                if let Ok(v) = n.parse::<f64>() {
117                    let neg = -v;
118                    if neg.fract() == 0.0 && neg.abs() < i64::MAX as f64 {
119                        return Expr::Number(format!("{}", neg as i64));
120                    }
121                    return Expr::Number(format!("{neg}"));
122                }
123            }
124            Expr::UnaryOp {
125                op: UnaryOperator::Minus,
126                expr: Box::new(inner),
127            }
128        }
129        Expr::Nested(inner) => {
130            let folded = fold_expr(*inner);
131            if folded.is_literal() {
132                folded
133            } else {
134                Expr::Nested(Box::new(folded))
135            }
136        }
137        other => other,
138    }
139}
140
141/// Simplify boolean expressions.
142fn simplify_booleans(statement: Statement) -> Statement {
143    match statement {
144        Statement::Select(mut sel) => {
145            // Simplify boolean expressions in SELECT columns
146            for item in &mut sel.columns {
147                if let SelectItem::Expr { expr, .. } = item {
148                    *expr = simplify_bool_expr(expr.clone());
149                }
150            }
151            if let Some(wh) = sel.where_clause {
152                let simplified = simplify_bool_expr(wh);
153                // WHERE TRUE → no WHERE clause
154                if simplified == Expr::Boolean(true) {
155                    sel.where_clause = None;
156                } else {
157                    sel.where_clause = Some(simplified);
158                }
159            }
160            if let Some(having) = sel.having {
161                let simplified = simplify_bool_expr(having);
162                if simplified == Expr::Boolean(true) {
163                    sel.having = None;
164                } else {
165                    sel.having = Some(simplified);
166                }
167            }
168            Statement::Select(sel)
169        }
170        other => other,
171    }
172}
173
174fn simplify_bool_expr(expr: Expr) -> Expr {
175    match expr {
176        Expr::BinaryOp {
177            left,
178            op: BinaryOperator::And,
179            right,
180        } => {
181            let left = simplify_bool_expr(*left);
182            let right = simplify_bool_expr(*right);
183            match (&left, &right) {
184                (Expr::Boolean(true), _) => right,
185                (_, Expr::Boolean(true)) => left,
186                (Expr::Boolean(false), _) | (_, Expr::Boolean(false)) => Expr::Boolean(false),
187                _ => Expr::BinaryOp {
188                    left: Box::new(left),
189                    op: BinaryOperator::And,
190                    right: Box::new(right),
191                },
192            }
193        }
194        Expr::BinaryOp {
195            left,
196            op: BinaryOperator::Or,
197            right,
198        } => {
199            let left = simplify_bool_expr(*left);
200            let right = simplify_bool_expr(*right);
201            match (&left, &right) {
202                (Expr::Boolean(true), _) | (_, Expr::Boolean(true)) => Expr::Boolean(true),
203                (Expr::Boolean(false), _) => right,
204                (_, Expr::Boolean(false)) => left,
205                _ => Expr::BinaryOp {
206                    left: Box::new(left),
207                    op: BinaryOperator::Or,
208                    right: Box::new(right),
209                },
210            }
211        }
212        Expr::UnaryOp {
213            op: UnaryOperator::Not,
214            expr,
215        } => {
216            let inner = simplify_bool_expr(*expr);
217            match inner {
218                Expr::Boolean(b) => Expr::Boolean(!b),
219                Expr::UnaryOp {
220                    op: UnaryOperator::Not,
221                    expr: inner2,
222                } => *inner2,
223                other => Expr::UnaryOp {
224                    op: UnaryOperator::Not,
225                    expr: Box::new(other),
226                },
227            }
228        }
229        Expr::Nested(inner) => {
230            let simplified = simplify_bool_expr(*inner);
231            if simplified.is_literal() {
232                simplified
233            } else {
234                Expr::Nested(Box::new(simplified))
235            }
236        }
237        other => other,
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use crate::parser::Parser;
245
246    fn optimize_sql(sql: &str) -> Statement {
247        let stmt = Parser::new(sql).unwrap().parse_statement().unwrap();
248        optimize(stmt).unwrap()
249    }
250
251    #[test]
252    fn test_constant_folding() {
253        let stmt = optimize_sql("SELECT 1 + 2 FROM t");
254        if let Statement::Select(sel) = stmt {
255            if let SelectItem::Expr { expr, .. } = &sel.columns[0] {
256                assert_eq!(*expr, Expr::Number("3".to_string()));
257            }
258        }
259    }
260
261    #[test]
262    fn test_boolean_simplification_where_true() {
263        let stmt = optimize_sql("SELECT x FROM t WHERE TRUE");
264        if let Statement::Select(sel) = stmt {
265            assert!(sel.where_clause.is_none());
266        }
267    }
268
269    #[test]
270    fn test_boolean_simplification_and_true() {
271        let stmt = optimize_sql("SELECT x FROM t WHERE TRUE AND x > 1");
272        if let Statement::Select(sel) = stmt {
273            // Should simplify to just x > 1
274            assert!(sel.where_clause.is_some());
275            assert!(!matches!(
276                &sel.where_clause,
277                Some(Expr::BinaryOp {
278                    op: BinaryOperator::And,
279                    ..
280                })
281            ));
282        }
283    }
284
285    #[test]
286    fn test_double_negation() {
287        let stmt = optimize_sql("SELECT x FROM t WHERE NOT NOT x > 1");
288        if let Statement::Select(sel) = stmt {
289            // Should simplify to x > 1 (no NOT)
290            assert!(sel.where_clause.is_some());
291        }
292    }
293}