sqlglot_rust/optimizer/
mod.rs1pub mod annotate_types;
18pub mod lineage;
19pub mod pushdown_predicates;
20pub mod qualify_columns;
21pub mod scope_analysis;
22pub mod unnest_subqueries;
23
24use crate::ast::*;
25use crate::errors::Result;
26
27pub fn optimize(statement: Statement) -> Result<Statement> {
29 let mut stmt = statement;
30 stmt = fold_constants(stmt);
31 stmt = simplify_booleans(stmt);
32 stmt = unnest_subqueries::unnest_subqueries(stmt);
33 stmt = pushdown_predicates::pushdown_predicates(stmt);
34 Ok(stmt)
35}
36
37fn fold_constants(statement: Statement) -> Statement {
39 match statement {
40 Statement::Select(mut sel) => {
41 if let Some(wh) = sel.where_clause {
42 sel.where_clause = Some(fold_expr(wh));
43 }
44 if let Some(having) = sel.having {
45 sel.having = Some(fold_expr(having));
46 }
47 for item in &mut sel.columns {
48 if let SelectItem::Expr { expr, .. } = item {
49 *expr = fold_expr(expr.clone());
50 }
51 }
52 Statement::Select(sel)
53 }
54 other => other,
55 }
56}
57
58fn fold_expr(expr: Expr) -> Expr {
59 match expr {
60 Expr::BinaryOp { left, op, right } => {
61 let left = fold_expr(*left);
62 let right = fold_expr(*right);
63
64 if let (Expr::Number(l), Expr::Number(r)) = (&left, &right) {
66 if let (Ok(lv), Ok(rv)) = (l.parse::<f64>(), r.parse::<f64>()) {
67 let result = match op {
68 BinaryOperator::Plus => Some(lv + rv),
69 BinaryOperator::Minus => Some(lv - rv),
70 BinaryOperator::Multiply => Some(lv * rv),
71 BinaryOperator::Divide if rv != 0.0 => Some(lv / rv),
72 BinaryOperator::Modulo if rv != 0.0 => Some(lv % rv),
73 _ => None,
74 };
75 if let Some(val) = result {
76 if val.fract() == 0.0 && val.abs() < i64::MAX as f64 {
78 return Expr::Number(format!("{}", val as i64));
79 }
80 return Expr::Number(format!("{val}"));
81 }
82
83 let cmp = match op {
85 BinaryOperator::Eq => Some(lv == rv),
86 BinaryOperator::Neq => Some(lv != rv),
87 BinaryOperator::Lt => Some(lv < rv),
88 BinaryOperator::Gt => Some(lv > rv),
89 BinaryOperator::LtEq => Some(lv <= rv),
90 BinaryOperator::GtEq => Some(lv >= rv),
91 _ => None,
92 };
93 if let Some(val) = cmp {
94 return Expr::Boolean(val);
95 }
96 }
97 }
98
99 if matches!(op, BinaryOperator::Concat) {
101 match (&left, &right) {
102 (Expr::NationalStringLiteral(l), Expr::NationalStringLiteral(r)) => {
103 return Expr::NationalStringLiteral(format!("{l}{r}"));
104 }
105 (Expr::NationalStringLiteral(l), Expr::StringLiteral(r))
106 | (Expr::StringLiteral(l), Expr::NationalStringLiteral(r)) => {
107 return Expr::NationalStringLiteral(format!("{l}{r}"));
108 }
109 (Expr::StringLiteral(l), Expr::StringLiteral(r)) => {
110 return Expr::StringLiteral(format!("{l}{r}"));
111 }
112 _ => {}
113 }
114 }
115
116 Expr::BinaryOp {
117 left: Box::new(left),
118 op,
119 right: Box::new(right),
120 }
121 }
122 Expr::UnaryOp {
123 op: UnaryOperator::Minus,
124 expr,
125 } => {
126 let inner = fold_expr(*expr);
127 if let Expr::Number(ref n) = inner {
128 if let Ok(v) = n.parse::<f64>() {
129 let neg = -v;
130 if neg.fract() == 0.0 && neg.abs() < i64::MAX as f64 {
131 return Expr::Number(format!("{}", neg as i64));
132 }
133 return Expr::Number(format!("{neg}"));
134 }
135 }
136 Expr::UnaryOp {
137 op: UnaryOperator::Minus,
138 expr: Box::new(inner),
139 }
140 }
141 Expr::Nested(inner) => {
142 let folded = fold_expr(*inner);
143 if folded.is_literal() {
144 folded
145 } else {
146 Expr::Nested(Box::new(folded))
147 }
148 }
149 other => other,
150 }
151}
152
153fn simplify_booleans(statement: Statement) -> Statement {
155 match statement {
156 Statement::Select(mut sel) => {
157 for item in &mut sel.columns {
159 if let SelectItem::Expr { expr, .. } = item {
160 *expr = simplify_bool_expr(expr.clone());
161 }
162 }
163 if let Some(wh) = sel.where_clause {
164 let simplified = simplify_bool_expr(wh);
165 if simplified == Expr::Boolean(true) {
167 sel.where_clause = None;
168 } else {
169 sel.where_clause = Some(simplified);
170 }
171 }
172 if let Some(having) = sel.having {
173 let simplified = simplify_bool_expr(having);
174 if simplified == Expr::Boolean(true) {
175 sel.having = None;
176 } else {
177 sel.having = Some(simplified);
178 }
179 }
180 Statement::Select(sel)
181 }
182 other => other,
183 }
184}
185
186fn simplify_bool_expr(expr: Expr) -> Expr {
187 match expr {
188 Expr::BinaryOp {
189 left,
190 op: BinaryOperator::And,
191 right,
192 } => {
193 let left = simplify_bool_expr(*left);
194 let right = simplify_bool_expr(*right);
195 match (&left, &right) {
196 (Expr::Boolean(true), _) => right,
197 (_, Expr::Boolean(true)) => left,
198 (Expr::Boolean(false), _) | (_, Expr::Boolean(false)) => Expr::Boolean(false),
199 _ => Expr::BinaryOp {
200 left: Box::new(left),
201 op: BinaryOperator::And,
202 right: Box::new(right),
203 },
204 }
205 }
206 Expr::BinaryOp {
207 left,
208 op: BinaryOperator::Or,
209 right,
210 } => {
211 let left = simplify_bool_expr(*left);
212 let right = simplify_bool_expr(*right);
213 match (&left, &right) {
214 (Expr::Boolean(true), _) | (_, Expr::Boolean(true)) => Expr::Boolean(true),
215 (Expr::Boolean(false), _) => right,
216 (_, Expr::Boolean(false)) => left,
217 _ => Expr::BinaryOp {
218 left: Box::new(left),
219 op: BinaryOperator::Or,
220 right: Box::new(right),
221 },
222 }
223 }
224 Expr::UnaryOp {
225 op: UnaryOperator::Not,
226 expr,
227 } => {
228 let inner = simplify_bool_expr(*expr);
229 match inner {
230 Expr::Boolean(b) => Expr::Boolean(!b),
231 Expr::UnaryOp {
232 op: UnaryOperator::Not,
233 expr: inner2,
234 } => *inner2,
235 other => Expr::UnaryOp {
236 op: UnaryOperator::Not,
237 expr: Box::new(other),
238 },
239 }
240 }
241 Expr::Nested(inner) => {
242 let simplified = simplify_bool_expr(*inner);
243 if simplified.is_literal() {
244 simplified
245 } else {
246 Expr::Nested(Box::new(simplified))
247 }
248 }
249 other => other,
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use crate::parser::Parser;
257
258 fn optimize_sql(sql: &str) -> Statement {
259 let stmt = Parser::new(sql).unwrap().parse_statement().unwrap();
260 optimize(stmt).unwrap()
261 }
262
263 #[test]
264 fn test_constant_folding() {
265 let stmt = optimize_sql("SELECT 1 + 2 FROM t");
266 if let Statement::Select(sel) = stmt {
267 if let SelectItem::Expr { expr, .. } = &sel.columns[0] {
268 assert_eq!(*expr, Expr::Number("3".to_string()));
269 }
270 }
271 }
272
273 #[test]
274 fn test_boolean_simplification_where_true() {
275 let stmt = optimize_sql("SELECT x FROM t WHERE TRUE");
276 if let Statement::Select(sel) = stmt {
277 assert!(sel.where_clause.is_none());
278 }
279 }
280
281 #[test]
282 fn test_boolean_simplification_and_true() {
283 let stmt = optimize_sql("SELECT x FROM t WHERE TRUE AND x > 1");
284 if let Statement::Select(sel) = stmt {
285 assert!(sel.where_clause.is_some());
287 assert!(!matches!(
288 &sel.where_clause,
289 Some(Expr::BinaryOp {
290 op: BinaryOperator::And,
291 ..
292 })
293 ));
294 }
295 }
296
297 #[test]
298 fn test_double_negation() {
299 let stmt = optimize_sql("SELECT x FROM t WHERE NOT NOT x > 1");
300 if let Statement::Select(sel) = stmt {
301 assert!(sel.where_clause.is_some());
303 }
304 }
305}