sqlglot_rust/optimizer/
mod.rs1pub mod annotate_types;
17pub mod qualify_columns;
18pub mod scope_analysis;
19pub mod unnest_subqueries;
20
21use crate::ast::*;
22use crate::errors::Result;
23
24pub fn optimize(statement: Statement) -> Result<Statement> {
26 let mut stmt = statement;
27 stmt = fold_constants(stmt);
28 stmt = simplify_booleans(stmt);
29 stmt = unnest_subqueries::unnest_subqueries(stmt);
30 Ok(stmt)
31}
32
33fn fold_constants(statement: Statement) -> Statement {
35 match statement {
36 Statement::Select(mut sel) => {
37 if let Some(wh) = sel.where_clause {
38 sel.where_clause = Some(fold_expr(wh));
39 }
40 if let Some(having) = sel.having {
41 sel.having = Some(fold_expr(having));
42 }
43 for item in &mut sel.columns {
44 if let SelectItem::Expr { expr, .. } = item {
45 *expr = fold_expr(expr.clone());
46 }
47 }
48 Statement::Select(sel)
49 }
50 other => other,
51 }
52}
53
54fn fold_expr(expr: Expr) -> Expr {
55 match expr {
56 Expr::BinaryOp { left, op, right } => {
57 let left = fold_expr(*left);
58 let right = fold_expr(*right);
59
60 if let (Expr::Number(l), Expr::Number(r)) = (&left, &right) {
62 if let (Ok(lv), Ok(rv)) = (l.parse::<f64>(), r.parse::<f64>()) {
63 let result = match op {
64 BinaryOperator::Plus => Some(lv + rv),
65 BinaryOperator::Minus => Some(lv - rv),
66 BinaryOperator::Multiply => Some(lv * rv),
67 BinaryOperator::Divide if rv != 0.0 => Some(lv / rv),
68 BinaryOperator::Modulo if rv != 0.0 => Some(lv % rv),
69 _ => None,
70 };
71 if let Some(val) = result {
72 if val.fract() == 0.0 && val.abs() < i64::MAX as f64 {
74 return Expr::Number(format!("{}", val as i64));
75 }
76 return Expr::Number(format!("{val}"));
77 }
78
79 let cmp = match op {
81 BinaryOperator::Eq => Some(lv == rv),
82 BinaryOperator::Neq => Some(lv != rv),
83 BinaryOperator::Lt => Some(lv < rv),
84 BinaryOperator::Gt => Some(lv > rv),
85 BinaryOperator::LtEq => Some(lv <= rv),
86 BinaryOperator::GtEq => Some(lv >= rv),
87 _ => None,
88 };
89 if let Some(val) = cmp {
90 return Expr::Boolean(val);
91 }
92 }
93 }
94
95 if matches!(op, BinaryOperator::Concat) {
97 if let (Expr::StringLiteral(l), Expr::StringLiteral(r)) = (&left, &right) {
98 return Expr::StringLiteral(format!("{l}{r}"));
99 }
100 }
101
102 Expr::BinaryOp {
103 left: Box::new(left),
104 op,
105 right: Box::new(right),
106 }
107 }
108 Expr::UnaryOp {
109 op: UnaryOperator::Minus,
110 expr,
111 } => {
112 let inner = fold_expr(*expr);
113 if let Expr::Number(ref n) = inner {
114 if let Ok(v) = n.parse::<f64>() {
115 let neg = -v;
116 if neg.fract() == 0.0 && neg.abs() < i64::MAX as f64 {
117 return Expr::Number(format!("{}", neg as i64));
118 }
119 return Expr::Number(format!("{neg}"));
120 }
121 }
122 Expr::UnaryOp {
123 op: UnaryOperator::Minus,
124 expr: Box::new(inner),
125 }
126 }
127 Expr::Nested(inner) => {
128 let folded = fold_expr(*inner);
129 if folded.is_literal() {
130 folded
131 } else {
132 Expr::Nested(Box::new(folded))
133 }
134 }
135 other => other,
136 }
137}
138
139fn simplify_booleans(statement: Statement) -> Statement {
141 match statement {
142 Statement::Select(mut sel) => {
143 for item in &mut sel.columns {
145 if let SelectItem::Expr { expr, .. } = item {
146 *expr = simplify_bool_expr(expr.clone());
147 }
148 }
149 if let Some(wh) = sel.where_clause {
150 let simplified = simplify_bool_expr(wh);
151 if simplified == Expr::Boolean(true) {
153 sel.where_clause = None;
154 } else {
155 sel.where_clause = Some(simplified);
156 }
157 }
158 if let Some(having) = sel.having {
159 let simplified = simplify_bool_expr(having);
160 if simplified == Expr::Boolean(true) {
161 sel.having = None;
162 } else {
163 sel.having = Some(simplified);
164 }
165 }
166 Statement::Select(sel)
167 }
168 other => other,
169 }
170}
171
172fn simplify_bool_expr(expr: Expr) -> Expr {
173 match expr {
174 Expr::BinaryOp {
175 left,
176 op: BinaryOperator::And,
177 right,
178 } => {
179 let left = simplify_bool_expr(*left);
180 let right = simplify_bool_expr(*right);
181 match (&left, &right) {
182 (Expr::Boolean(true), _) => right,
183 (_, Expr::Boolean(true)) => left,
184 (Expr::Boolean(false), _) | (_, Expr::Boolean(false)) => Expr::Boolean(false),
185 _ => Expr::BinaryOp {
186 left: Box::new(left),
187 op: BinaryOperator::And,
188 right: Box::new(right),
189 },
190 }
191 }
192 Expr::BinaryOp {
193 left,
194 op: BinaryOperator::Or,
195 right,
196 } => {
197 let left = simplify_bool_expr(*left);
198 let right = simplify_bool_expr(*right);
199 match (&left, &right) {
200 (Expr::Boolean(true), _) | (_, Expr::Boolean(true)) => Expr::Boolean(true),
201 (Expr::Boolean(false), _) => right,
202 (_, Expr::Boolean(false)) => left,
203 _ => Expr::BinaryOp {
204 left: Box::new(left),
205 op: BinaryOperator::Or,
206 right: Box::new(right),
207 },
208 }
209 }
210 Expr::UnaryOp {
211 op: UnaryOperator::Not,
212 expr,
213 } => {
214 let inner = simplify_bool_expr(*expr);
215 match inner {
216 Expr::Boolean(b) => Expr::Boolean(!b),
217 Expr::UnaryOp {
218 op: UnaryOperator::Not,
219 expr: inner2,
220 } => *inner2,
221 other => Expr::UnaryOp {
222 op: UnaryOperator::Not,
223 expr: Box::new(other),
224 },
225 }
226 }
227 Expr::Nested(inner) => {
228 let simplified = simplify_bool_expr(*inner);
229 if simplified.is_literal() {
230 simplified
231 } else {
232 Expr::Nested(Box::new(simplified))
233 }
234 }
235 other => other,
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use crate::parser::Parser;
243
244 fn optimize_sql(sql: &str) -> Statement {
245 let stmt = Parser::new(sql).unwrap().parse_statement().unwrap();
246 optimize(stmt).unwrap()
247 }
248
249 #[test]
250 fn test_constant_folding() {
251 let stmt = optimize_sql("SELECT 1 + 2 FROM t");
252 if let Statement::Select(sel) = stmt {
253 if let SelectItem::Expr { expr, .. } = &sel.columns[0] {
254 assert_eq!(*expr, Expr::Number("3".to_string()));
255 }
256 }
257 }
258
259 #[test]
260 fn test_boolean_simplification_where_true() {
261 let stmt = optimize_sql("SELECT x FROM t WHERE TRUE");
262 if let Statement::Select(sel) = stmt {
263 assert!(sel.where_clause.is_none());
264 }
265 }
266
267 #[test]
268 fn test_boolean_simplification_and_true() {
269 let stmt = optimize_sql("SELECT x FROM t WHERE TRUE AND x > 1");
270 if let Statement::Select(sel) = stmt {
271 assert!(sel.where_clause.is_some());
273 assert!(!matches!(
274 &sel.where_clause,
275 Some(Expr::BinaryOp {
276 op: BinaryOperator::And,
277 ..
278 })
279 ));
280 }
281 }
282
283 #[test]
284 fn test_double_negation() {
285 let stmt = optimize_sql("SELECT x FROM t WHERE NOT NOT x > 1");
286 if let Statement::Select(sel) = stmt {
287 assert!(sel.where_clause.is_some());
289 }
290 }
291}