sqlglot_rust/optimizer/
mod.rs1pub mod scope_analysis;
15pub mod unnest_subqueries;
16
17use crate::ast::*;
18use crate::errors::Result;
19
20pub fn optimize(statement: Statement) -> Result<Statement> {
22 let mut stmt = statement;
23 stmt = fold_constants(stmt);
24 stmt = simplify_booleans(stmt);
25 stmt = unnest_subqueries::unnest_subqueries(stmt);
26 Ok(stmt)
27}
28
29fn fold_constants(statement: Statement) -> Statement {
31 match statement {
32 Statement::Select(mut sel) => {
33 if let Some(wh) = sel.where_clause {
34 sel.where_clause = Some(fold_expr(wh));
35 }
36 if let Some(having) = sel.having {
37 sel.having = Some(fold_expr(having));
38 }
39 for item in &mut sel.columns {
40 if let SelectItem::Expr { expr, .. } = item {
41 *expr = fold_expr(expr.clone());
42 }
43 }
44 Statement::Select(sel)
45 }
46 other => other,
47 }
48}
49
50fn fold_expr(expr: Expr) -> Expr {
51 match expr {
52 Expr::BinaryOp { left, op, right } => {
53 let left = fold_expr(*left);
54 let right = fold_expr(*right);
55
56 if let (Expr::Number(l), Expr::Number(r)) = (&left, &right) {
58 if let (Ok(lv), Ok(rv)) = (l.parse::<f64>(), r.parse::<f64>()) {
59 let result = match op {
60 BinaryOperator::Plus => Some(lv + rv),
61 BinaryOperator::Minus => Some(lv - rv),
62 BinaryOperator::Multiply => Some(lv * rv),
63 BinaryOperator::Divide if rv != 0.0 => Some(lv / rv),
64 BinaryOperator::Modulo if rv != 0.0 => Some(lv % rv),
65 _ => None,
66 };
67 if let Some(val) = result {
68 if val.fract() == 0.0 && val.abs() < i64::MAX as f64 {
70 return Expr::Number(format!("{}", val as i64));
71 }
72 return Expr::Number(format!("{val}"));
73 }
74
75 let cmp = match op {
77 BinaryOperator::Eq => Some(lv == rv),
78 BinaryOperator::Neq => Some(lv != rv),
79 BinaryOperator::Lt => Some(lv < rv),
80 BinaryOperator::Gt => Some(lv > rv),
81 BinaryOperator::LtEq => Some(lv <= rv),
82 BinaryOperator::GtEq => Some(lv >= rv),
83 _ => None,
84 };
85 if let Some(val) = cmp {
86 return Expr::Boolean(val);
87 }
88 }
89 }
90
91 if matches!(op, BinaryOperator::Concat) {
93 if let (Expr::StringLiteral(l), Expr::StringLiteral(r)) = (&left, &right) {
94 return Expr::StringLiteral(format!("{l}{r}"));
95 }
96 }
97
98 Expr::BinaryOp {
99 left: Box::new(left),
100 op,
101 right: Box::new(right),
102 }
103 }
104 Expr::UnaryOp {
105 op: UnaryOperator::Minus,
106 expr,
107 } => {
108 let inner = fold_expr(*expr);
109 if let Expr::Number(ref n) = inner {
110 if let Ok(v) = n.parse::<f64>() {
111 let neg = -v;
112 if neg.fract() == 0.0 && neg.abs() < i64::MAX as f64 {
113 return Expr::Number(format!("{}", neg as i64));
114 }
115 return Expr::Number(format!("{neg}"));
116 }
117 }
118 Expr::UnaryOp {
119 op: UnaryOperator::Minus,
120 expr: Box::new(inner),
121 }
122 }
123 Expr::Nested(inner) => {
124 let folded = fold_expr(*inner);
125 if folded.is_literal() {
126 folded
127 } else {
128 Expr::Nested(Box::new(folded))
129 }
130 }
131 other => other,
132 }
133}
134
135fn simplify_booleans(statement: Statement) -> Statement {
137 match statement {
138 Statement::Select(mut sel) => {
139 for item in &mut sel.columns {
141 if let SelectItem::Expr { expr, .. } = item {
142 *expr = simplify_bool_expr(expr.clone());
143 }
144 }
145 if let Some(wh) = sel.where_clause {
146 let simplified = simplify_bool_expr(wh);
147 if simplified == Expr::Boolean(true) {
149 sel.where_clause = None;
150 } else {
151 sel.where_clause = Some(simplified);
152 }
153 }
154 if let Some(having) = sel.having {
155 let simplified = simplify_bool_expr(having);
156 if simplified == Expr::Boolean(true) {
157 sel.having = None;
158 } else {
159 sel.having = Some(simplified);
160 }
161 }
162 Statement::Select(sel)
163 }
164 other => other,
165 }
166}
167
168fn simplify_bool_expr(expr: Expr) -> Expr {
169 match expr {
170 Expr::BinaryOp {
171 left,
172 op: BinaryOperator::And,
173 right,
174 } => {
175 let left = simplify_bool_expr(*left);
176 let right = simplify_bool_expr(*right);
177 match (&left, &right) {
178 (Expr::Boolean(true), _) => right,
179 (_, Expr::Boolean(true)) => left,
180 (Expr::Boolean(false), _) | (_, Expr::Boolean(false)) => Expr::Boolean(false),
181 _ => Expr::BinaryOp {
182 left: Box::new(left),
183 op: BinaryOperator::And,
184 right: Box::new(right),
185 },
186 }
187 }
188 Expr::BinaryOp {
189 left,
190 op: BinaryOperator::Or,
191 right,
192 } => {
193 let left = simplify_bool_expr(*left);
194 let right = simplify_bool_expr(*right);
195 match (&left, &right) {
196 (Expr::Boolean(true), _) | (_, Expr::Boolean(true)) => Expr::Boolean(true),
197 (Expr::Boolean(false), _) => right,
198 (_, Expr::Boolean(false)) => left,
199 _ => Expr::BinaryOp {
200 left: Box::new(left),
201 op: BinaryOperator::Or,
202 right: Box::new(right),
203 },
204 }
205 }
206 Expr::UnaryOp {
207 op: UnaryOperator::Not,
208 expr,
209 } => {
210 let inner = simplify_bool_expr(*expr);
211 match inner {
212 Expr::Boolean(b) => Expr::Boolean(!b),
213 Expr::UnaryOp {
214 op: UnaryOperator::Not,
215 expr: inner2,
216 } => *inner2,
217 other => Expr::UnaryOp {
218 op: UnaryOperator::Not,
219 expr: Box::new(other),
220 },
221 }
222 }
223 Expr::Nested(inner) => {
224 let simplified = simplify_bool_expr(*inner);
225 if simplified.is_literal() {
226 simplified
227 } else {
228 Expr::Nested(Box::new(simplified))
229 }
230 }
231 other => other,
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use crate::parser::Parser;
239
240 fn optimize_sql(sql: &str) -> Statement {
241 let stmt = Parser::new(sql).unwrap().parse_statement().unwrap();
242 optimize(stmt).unwrap()
243 }
244
245 #[test]
246 fn test_constant_folding() {
247 let stmt = optimize_sql("SELECT 1 + 2 FROM t");
248 if let Statement::Select(sel) = stmt {
249 if let SelectItem::Expr { expr, .. } = &sel.columns[0] {
250 assert_eq!(*expr, Expr::Number("3".to_string()));
251 }
252 }
253 }
254
255 #[test]
256 fn test_boolean_simplification_where_true() {
257 let stmt = optimize_sql("SELECT x FROM t WHERE TRUE");
258 if let Statement::Select(sel) = stmt {
259 assert!(sel.where_clause.is_none());
260 }
261 }
262
263 #[test]
264 fn test_boolean_simplification_and_true() {
265 let stmt = optimize_sql("SELECT x FROM t WHERE TRUE AND x > 1");
266 if let Statement::Select(sel) = stmt {
267 assert!(sel.where_clause.is_some());
269 assert!(!matches!(
270 &sel.where_clause,
271 Some(Expr::BinaryOp {
272 op: BinaryOperator::And,
273 ..
274 })
275 ));
276 }
277 }
278
279 #[test]
280 fn test_double_negation() {
281 let stmt = optimize_sql("SELECT x FROM t WHERE NOT NOT x > 1");
282 if let Statement::Select(sel) = stmt {
283 assert!(sel.where_clause.is_some());
285 }
286 }
287}