sqlglot_rust/optimizer/
mod.rs1pub mod annotate_types;
17pub mod pushdown_predicates;
18pub mod qualify_columns;
19pub mod scope_analysis;
20pub mod unnest_subqueries;
21
22use crate::ast::*;
23use crate::errors::Result;
24
25pub fn optimize(statement: Statement) -> Result<Statement> {
27 let mut stmt = statement;
28 stmt = fold_constants(stmt);
29 stmt = simplify_booleans(stmt);
30 stmt = unnest_subqueries::unnest_subqueries(stmt);
31 stmt = pushdown_predicates::pushdown_predicates(stmt);
32 Ok(stmt)
33}
34
35fn fold_constants(statement: Statement) -> Statement {
37 match statement {
38 Statement::Select(mut sel) => {
39 if let Some(wh) = sel.where_clause {
40 sel.where_clause = Some(fold_expr(wh));
41 }
42 if let Some(having) = sel.having {
43 sel.having = Some(fold_expr(having));
44 }
45 for item in &mut sel.columns {
46 if let SelectItem::Expr { expr, .. } = item {
47 *expr = fold_expr(expr.clone());
48 }
49 }
50 Statement::Select(sel)
51 }
52 other => other,
53 }
54}
55
56fn fold_expr(expr: Expr) -> Expr {
57 match expr {
58 Expr::BinaryOp { left, op, right } => {
59 let left = fold_expr(*left);
60 let right = fold_expr(*right);
61
62 if let (Expr::Number(l), Expr::Number(r)) = (&left, &right) {
64 if let (Ok(lv), Ok(rv)) = (l.parse::<f64>(), r.parse::<f64>()) {
65 let result = match op {
66 BinaryOperator::Plus => Some(lv + rv),
67 BinaryOperator::Minus => Some(lv - rv),
68 BinaryOperator::Multiply => Some(lv * rv),
69 BinaryOperator::Divide if rv != 0.0 => Some(lv / rv),
70 BinaryOperator::Modulo if rv != 0.0 => Some(lv % rv),
71 _ => None,
72 };
73 if let Some(val) = result {
74 if val.fract() == 0.0 && val.abs() < i64::MAX as f64 {
76 return Expr::Number(format!("{}", val as i64));
77 }
78 return Expr::Number(format!("{val}"));
79 }
80
81 let cmp = match op {
83 BinaryOperator::Eq => Some(lv == rv),
84 BinaryOperator::Neq => Some(lv != rv),
85 BinaryOperator::Lt => Some(lv < rv),
86 BinaryOperator::Gt => Some(lv > rv),
87 BinaryOperator::LtEq => Some(lv <= rv),
88 BinaryOperator::GtEq => Some(lv >= rv),
89 _ => None,
90 };
91 if let Some(val) = cmp {
92 return Expr::Boolean(val);
93 }
94 }
95 }
96
97 if matches!(op, BinaryOperator::Concat) {
99 if let (Expr::StringLiteral(l), Expr::StringLiteral(r)) = (&left, &right) {
100 return Expr::StringLiteral(format!("{l}{r}"));
101 }
102 }
103
104 Expr::BinaryOp {
105 left: Box::new(left),
106 op,
107 right: Box::new(right),
108 }
109 }
110 Expr::UnaryOp {
111 op: UnaryOperator::Minus,
112 expr,
113 } => {
114 let inner = fold_expr(*expr);
115 if let Expr::Number(ref n) = inner {
116 if let Ok(v) = n.parse::<f64>() {
117 let neg = -v;
118 if neg.fract() == 0.0 && neg.abs() < i64::MAX as f64 {
119 return Expr::Number(format!("{}", neg as i64));
120 }
121 return Expr::Number(format!("{neg}"));
122 }
123 }
124 Expr::UnaryOp {
125 op: UnaryOperator::Minus,
126 expr: Box::new(inner),
127 }
128 }
129 Expr::Nested(inner) => {
130 let folded = fold_expr(*inner);
131 if folded.is_literal() {
132 folded
133 } else {
134 Expr::Nested(Box::new(folded))
135 }
136 }
137 other => other,
138 }
139}
140
141fn simplify_booleans(statement: Statement) -> Statement {
143 match statement {
144 Statement::Select(mut sel) => {
145 for item in &mut sel.columns {
147 if let SelectItem::Expr { expr, .. } = item {
148 *expr = simplify_bool_expr(expr.clone());
149 }
150 }
151 if let Some(wh) = sel.where_clause {
152 let simplified = simplify_bool_expr(wh);
153 if simplified == Expr::Boolean(true) {
155 sel.where_clause = None;
156 } else {
157 sel.where_clause = Some(simplified);
158 }
159 }
160 if let Some(having) = sel.having {
161 let simplified = simplify_bool_expr(having);
162 if simplified == Expr::Boolean(true) {
163 sel.having = None;
164 } else {
165 sel.having = Some(simplified);
166 }
167 }
168 Statement::Select(sel)
169 }
170 other => other,
171 }
172}
173
174fn simplify_bool_expr(expr: Expr) -> Expr {
175 match expr {
176 Expr::BinaryOp {
177 left,
178 op: BinaryOperator::And,
179 right,
180 } => {
181 let left = simplify_bool_expr(*left);
182 let right = simplify_bool_expr(*right);
183 match (&left, &right) {
184 (Expr::Boolean(true), _) => right,
185 (_, Expr::Boolean(true)) => left,
186 (Expr::Boolean(false), _) | (_, Expr::Boolean(false)) => Expr::Boolean(false),
187 _ => Expr::BinaryOp {
188 left: Box::new(left),
189 op: BinaryOperator::And,
190 right: Box::new(right),
191 },
192 }
193 }
194 Expr::BinaryOp {
195 left,
196 op: BinaryOperator::Or,
197 right,
198 } => {
199 let left = simplify_bool_expr(*left);
200 let right = simplify_bool_expr(*right);
201 match (&left, &right) {
202 (Expr::Boolean(true), _) | (_, Expr::Boolean(true)) => Expr::Boolean(true),
203 (Expr::Boolean(false), _) => right,
204 (_, Expr::Boolean(false)) => left,
205 _ => Expr::BinaryOp {
206 left: Box::new(left),
207 op: BinaryOperator::Or,
208 right: Box::new(right),
209 },
210 }
211 }
212 Expr::UnaryOp {
213 op: UnaryOperator::Not,
214 expr,
215 } => {
216 let inner = simplify_bool_expr(*expr);
217 match inner {
218 Expr::Boolean(b) => Expr::Boolean(!b),
219 Expr::UnaryOp {
220 op: UnaryOperator::Not,
221 expr: inner2,
222 } => *inner2,
223 other => Expr::UnaryOp {
224 op: UnaryOperator::Not,
225 expr: Box::new(other),
226 },
227 }
228 }
229 Expr::Nested(inner) => {
230 let simplified = simplify_bool_expr(*inner);
231 if simplified.is_literal() {
232 simplified
233 } else {
234 Expr::Nested(Box::new(simplified))
235 }
236 }
237 other => other,
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use crate::parser::Parser;
245
246 fn optimize_sql(sql: &str) -> Statement {
247 let stmt = Parser::new(sql).unwrap().parse_statement().unwrap();
248 optimize(stmt).unwrap()
249 }
250
251 #[test]
252 fn test_constant_folding() {
253 let stmt = optimize_sql("SELECT 1 + 2 FROM t");
254 if let Statement::Select(sel) = stmt {
255 if let SelectItem::Expr { expr, .. } = &sel.columns[0] {
256 assert_eq!(*expr, Expr::Number("3".to_string()));
257 }
258 }
259 }
260
261 #[test]
262 fn test_boolean_simplification_where_true() {
263 let stmt = optimize_sql("SELECT x FROM t WHERE TRUE");
264 if let Statement::Select(sel) = stmt {
265 assert!(sel.where_clause.is_none());
266 }
267 }
268
269 #[test]
270 fn test_boolean_simplification_and_true() {
271 let stmt = optimize_sql("SELECT x FROM t WHERE TRUE AND x > 1");
272 if let Statement::Select(sel) = stmt {
273 assert!(sel.where_clause.is_some());
275 assert!(!matches!(
276 &sel.where_clause,
277 Some(Expr::BinaryOp {
278 op: BinaryOperator::And,
279 ..
280 })
281 ));
282 }
283 }
284
285 #[test]
286 fn test_double_negation() {
287 let stmt = optimize_sql("SELECT x FROM t WHERE NOT NOT x > 1");
288 if let Statement::Select(sel) = stmt {
289 assert!(sel.where_clause.is_some());
291 }
292 }
293}