sqlglot_rust/optimizer/
mod.rs1pub mod qualify_columns;
16pub mod scope_analysis;
17pub mod unnest_subqueries;
18
19use crate::ast::*;
20use crate::errors::Result;
21
22pub fn optimize(statement: Statement) -> Result<Statement> {
24 let mut stmt = statement;
25 stmt = fold_constants(stmt);
26 stmt = simplify_booleans(stmt);
27 stmt = unnest_subqueries::unnest_subqueries(stmt);
28 Ok(stmt)
29}
30
31fn fold_constants(statement: Statement) -> Statement {
33 match statement {
34 Statement::Select(mut sel) => {
35 if let Some(wh) = sel.where_clause {
36 sel.where_clause = Some(fold_expr(wh));
37 }
38 if let Some(having) = sel.having {
39 sel.having = Some(fold_expr(having));
40 }
41 for item in &mut sel.columns {
42 if let SelectItem::Expr { expr, .. } = item {
43 *expr = fold_expr(expr.clone());
44 }
45 }
46 Statement::Select(sel)
47 }
48 other => other,
49 }
50}
51
52fn fold_expr(expr: Expr) -> Expr {
53 match expr {
54 Expr::BinaryOp { left, op, right } => {
55 let left = fold_expr(*left);
56 let right = fold_expr(*right);
57
58 if let (Expr::Number(l), Expr::Number(r)) = (&left, &right) {
60 if let (Ok(lv), Ok(rv)) = (l.parse::<f64>(), r.parse::<f64>()) {
61 let result = match op {
62 BinaryOperator::Plus => Some(lv + rv),
63 BinaryOperator::Minus => Some(lv - rv),
64 BinaryOperator::Multiply => Some(lv * rv),
65 BinaryOperator::Divide if rv != 0.0 => Some(lv / rv),
66 BinaryOperator::Modulo if rv != 0.0 => Some(lv % rv),
67 _ => None,
68 };
69 if let Some(val) = result {
70 if val.fract() == 0.0 && val.abs() < i64::MAX as f64 {
72 return Expr::Number(format!("{}", val as i64));
73 }
74 return Expr::Number(format!("{val}"));
75 }
76
77 let cmp = match op {
79 BinaryOperator::Eq => Some(lv == rv),
80 BinaryOperator::Neq => Some(lv != rv),
81 BinaryOperator::Lt => Some(lv < rv),
82 BinaryOperator::Gt => Some(lv > rv),
83 BinaryOperator::LtEq => Some(lv <= rv),
84 BinaryOperator::GtEq => Some(lv >= rv),
85 _ => None,
86 };
87 if let Some(val) = cmp {
88 return Expr::Boolean(val);
89 }
90 }
91 }
92
93 if matches!(op, BinaryOperator::Concat) {
95 if let (Expr::StringLiteral(l), Expr::StringLiteral(r)) = (&left, &right) {
96 return Expr::StringLiteral(format!("{l}{r}"));
97 }
98 }
99
100 Expr::BinaryOp {
101 left: Box::new(left),
102 op,
103 right: Box::new(right),
104 }
105 }
106 Expr::UnaryOp {
107 op: UnaryOperator::Minus,
108 expr,
109 } => {
110 let inner = fold_expr(*expr);
111 if let Expr::Number(ref n) = inner {
112 if let Ok(v) = n.parse::<f64>() {
113 let neg = -v;
114 if neg.fract() == 0.0 && neg.abs() < i64::MAX as f64 {
115 return Expr::Number(format!("{}", neg as i64));
116 }
117 return Expr::Number(format!("{neg}"));
118 }
119 }
120 Expr::UnaryOp {
121 op: UnaryOperator::Minus,
122 expr: Box::new(inner),
123 }
124 }
125 Expr::Nested(inner) => {
126 let folded = fold_expr(*inner);
127 if folded.is_literal() {
128 folded
129 } else {
130 Expr::Nested(Box::new(folded))
131 }
132 }
133 other => other,
134 }
135}
136
137fn simplify_booleans(statement: Statement) -> Statement {
139 match statement {
140 Statement::Select(mut sel) => {
141 for item in &mut sel.columns {
143 if let SelectItem::Expr { expr, .. } = item {
144 *expr = simplify_bool_expr(expr.clone());
145 }
146 }
147 if let Some(wh) = sel.where_clause {
148 let simplified = simplify_bool_expr(wh);
149 if simplified == Expr::Boolean(true) {
151 sel.where_clause = None;
152 } else {
153 sel.where_clause = Some(simplified);
154 }
155 }
156 if let Some(having) = sel.having {
157 let simplified = simplify_bool_expr(having);
158 if simplified == Expr::Boolean(true) {
159 sel.having = None;
160 } else {
161 sel.having = Some(simplified);
162 }
163 }
164 Statement::Select(sel)
165 }
166 other => other,
167 }
168}
169
170fn simplify_bool_expr(expr: Expr) -> Expr {
171 match expr {
172 Expr::BinaryOp {
173 left,
174 op: BinaryOperator::And,
175 right,
176 } => {
177 let left = simplify_bool_expr(*left);
178 let right = simplify_bool_expr(*right);
179 match (&left, &right) {
180 (Expr::Boolean(true), _) => right,
181 (_, Expr::Boolean(true)) => left,
182 (Expr::Boolean(false), _) | (_, Expr::Boolean(false)) => Expr::Boolean(false),
183 _ => Expr::BinaryOp {
184 left: Box::new(left),
185 op: BinaryOperator::And,
186 right: Box::new(right),
187 },
188 }
189 }
190 Expr::BinaryOp {
191 left,
192 op: BinaryOperator::Or,
193 right,
194 } => {
195 let left = simplify_bool_expr(*left);
196 let right = simplify_bool_expr(*right);
197 match (&left, &right) {
198 (Expr::Boolean(true), _) | (_, Expr::Boolean(true)) => Expr::Boolean(true),
199 (Expr::Boolean(false), _) => right,
200 (_, Expr::Boolean(false)) => left,
201 _ => Expr::BinaryOp {
202 left: Box::new(left),
203 op: BinaryOperator::Or,
204 right: Box::new(right),
205 },
206 }
207 }
208 Expr::UnaryOp {
209 op: UnaryOperator::Not,
210 expr,
211 } => {
212 let inner = simplify_bool_expr(*expr);
213 match inner {
214 Expr::Boolean(b) => Expr::Boolean(!b),
215 Expr::UnaryOp {
216 op: UnaryOperator::Not,
217 expr: inner2,
218 } => *inner2,
219 other => Expr::UnaryOp {
220 op: UnaryOperator::Not,
221 expr: Box::new(other),
222 },
223 }
224 }
225 Expr::Nested(inner) => {
226 let simplified = simplify_bool_expr(*inner);
227 if simplified.is_literal() {
228 simplified
229 } else {
230 Expr::Nested(Box::new(simplified))
231 }
232 }
233 other => other,
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use crate::parser::Parser;
241
242 fn optimize_sql(sql: &str) -> Statement {
243 let stmt = Parser::new(sql).unwrap().parse_statement().unwrap();
244 optimize(stmt).unwrap()
245 }
246
247 #[test]
248 fn test_constant_folding() {
249 let stmt = optimize_sql("SELECT 1 + 2 FROM t");
250 if let Statement::Select(sel) = stmt {
251 if let SelectItem::Expr { expr, .. } = &sel.columns[0] {
252 assert_eq!(*expr, Expr::Number("3".to_string()));
253 }
254 }
255 }
256
257 #[test]
258 fn test_boolean_simplification_where_true() {
259 let stmt = optimize_sql("SELECT x FROM t WHERE TRUE");
260 if let Statement::Select(sel) = stmt {
261 assert!(sel.where_clause.is_none());
262 }
263 }
264
265 #[test]
266 fn test_boolean_simplification_and_true() {
267 let stmt = optimize_sql("SELECT x FROM t WHERE TRUE AND x > 1");
268 if let Statement::Select(sel) = stmt {
269 assert!(sel.where_clause.is_some());
271 assert!(!matches!(
272 &sel.where_clause,
273 Some(Expr::BinaryOp {
274 op: BinaryOperator::And,
275 ..
276 })
277 ));
278 }
279 }
280
281 #[test]
282 fn test_double_negation() {
283 let stmt = optimize_sql("SELECT x FROM t WHERE NOT NOT x > 1");
284 if let Statement::Select(sel) = stmt {
285 assert!(sel.where_clause.is_some());
287 }
288 }
289}