sqlglot_rust/optimizer/
mod.rs1pub mod unnest_subqueries;
15
16use crate::ast::*;
17use crate::errors::Result;
18
19pub fn optimize(statement: Statement) -> Result<Statement> {
21 let mut stmt = statement;
22 stmt = fold_constants(stmt);
23 stmt = simplify_booleans(stmt);
24 stmt = unnest_subqueries::unnest_subqueries(stmt);
25 Ok(stmt)
26}
27
28fn fold_constants(statement: Statement) -> Statement {
30 match statement {
31 Statement::Select(mut sel) => {
32 if let Some(wh) = sel.where_clause {
33 sel.where_clause = Some(fold_expr(wh));
34 }
35 if let Some(having) = sel.having {
36 sel.having = Some(fold_expr(having));
37 }
38 for item in &mut sel.columns {
39 if let SelectItem::Expr { expr, .. } = item {
40 *expr = fold_expr(expr.clone());
41 }
42 }
43 Statement::Select(sel)
44 }
45 other => other,
46 }
47}
48
49fn fold_expr(expr: Expr) -> Expr {
50 match expr {
51 Expr::BinaryOp { left, op, right } => {
52 let left = fold_expr(*left);
53 let right = fold_expr(*right);
54
55 if let (Expr::Number(l), Expr::Number(r)) = (&left, &right) {
57 if let (Ok(lv), Ok(rv)) = (l.parse::<f64>(), r.parse::<f64>()) {
58 let result = match op {
59 BinaryOperator::Plus => Some(lv + rv),
60 BinaryOperator::Minus => Some(lv - rv),
61 BinaryOperator::Multiply => Some(lv * rv),
62 BinaryOperator::Divide if rv != 0.0 => Some(lv / rv),
63 BinaryOperator::Modulo if rv != 0.0 => Some(lv % rv),
64 _ => None,
65 };
66 if let Some(val) = result {
67 if val.fract() == 0.0 && val.abs() < i64::MAX as f64 {
69 return Expr::Number(format!("{}", val as i64));
70 }
71 return Expr::Number(format!("{val}"));
72 }
73
74 let cmp = match op {
76 BinaryOperator::Eq => Some(lv == rv),
77 BinaryOperator::Neq => Some(lv != rv),
78 BinaryOperator::Lt => Some(lv < rv),
79 BinaryOperator::Gt => Some(lv > rv),
80 BinaryOperator::LtEq => Some(lv <= rv),
81 BinaryOperator::GtEq => Some(lv >= rv),
82 _ => None,
83 };
84 if let Some(val) = cmp {
85 return Expr::Boolean(val);
86 }
87 }
88 }
89
90 if matches!(op, BinaryOperator::Concat) {
92 if let (Expr::StringLiteral(l), Expr::StringLiteral(r)) = (&left, &right) {
93 return Expr::StringLiteral(format!("{l}{r}"));
94 }
95 }
96
97 Expr::BinaryOp {
98 left: Box::new(left),
99 op,
100 right: Box::new(right),
101 }
102 }
103 Expr::UnaryOp { op: UnaryOperator::Minus, expr } => {
104 let inner = fold_expr(*expr);
105 if let Expr::Number(ref n) = inner {
106 if let Ok(v) = n.parse::<f64>() {
107 let neg = -v;
108 if neg.fract() == 0.0 && neg.abs() < i64::MAX as f64 {
109 return Expr::Number(format!("{}", neg as i64));
110 }
111 return Expr::Number(format!("{neg}"));
112 }
113 }
114 Expr::UnaryOp {
115 op: UnaryOperator::Minus,
116 expr: Box::new(inner),
117 }
118 }
119 Expr::Nested(inner) => {
120 let folded = fold_expr(*inner);
121 if folded.is_literal() {
122 folded
123 } else {
124 Expr::Nested(Box::new(folded))
125 }
126 }
127 other => other,
128 }
129}
130
131fn simplify_booleans(statement: Statement) -> Statement {
133 match statement {
134 Statement::Select(mut sel) => {
135 for item in &mut sel.columns {
137 if let SelectItem::Expr { expr, .. } = item {
138 *expr = simplify_bool_expr(expr.clone());
139 }
140 }
141 if let Some(wh) = sel.where_clause {
142 let simplified = simplify_bool_expr(wh);
143 if simplified == Expr::Boolean(true) {
145 sel.where_clause = None;
146 } else {
147 sel.where_clause = Some(simplified);
148 }
149 }
150 if let Some(having) = sel.having {
151 let simplified = simplify_bool_expr(having);
152 if simplified == Expr::Boolean(true) {
153 sel.having = None;
154 } else {
155 sel.having = Some(simplified);
156 }
157 }
158 Statement::Select(sel)
159 }
160 other => other,
161 }
162}
163
164fn simplify_bool_expr(expr: Expr) -> Expr {
165 match expr {
166 Expr::BinaryOp { left, op: BinaryOperator::And, right } => {
167 let left = simplify_bool_expr(*left);
168 let right = simplify_bool_expr(*right);
169 match (&left, &right) {
170 (Expr::Boolean(true), _) => right,
171 (_, Expr::Boolean(true)) => left,
172 (Expr::Boolean(false), _) | (_, Expr::Boolean(false)) => Expr::Boolean(false),
173 _ => Expr::BinaryOp {
174 left: Box::new(left),
175 op: BinaryOperator::And,
176 right: Box::new(right),
177 },
178 }
179 }
180 Expr::BinaryOp { left, op: BinaryOperator::Or, right } => {
181 let left = simplify_bool_expr(*left);
182 let right = simplify_bool_expr(*right);
183 match (&left, &right) {
184 (Expr::Boolean(true), _) | (_, Expr::Boolean(true)) => Expr::Boolean(true),
185 (Expr::Boolean(false), _) => right,
186 (_, Expr::Boolean(false)) => left,
187 _ => Expr::BinaryOp {
188 left: Box::new(left),
189 op: BinaryOperator::Or,
190 right: Box::new(right),
191 },
192 }
193 }
194 Expr::UnaryOp { op: UnaryOperator::Not, expr } => {
195 let inner = simplify_bool_expr(*expr);
196 match inner {
197 Expr::Boolean(b) => Expr::Boolean(!b),
198 Expr::UnaryOp { op: UnaryOperator::Not, expr: inner2 } => *inner2,
199 other => Expr::UnaryOp {
200 op: UnaryOperator::Not,
201 expr: Box::new(other),
202 },
203 }
204 }
205 Expr::Nested(inner) => {
206 let simplified = simplify_bool_expr(*inner);
207 if simplified.is_literal() {
208 simplified
209 } else {
210 Expr::Nested(Box::new(simplified))
211 }
212 }
213 other => other,
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use crate::parser::Parser;
221
222 fn optimize_sql(sql: &str) -> Statement {
223 let stmt = Parser::new(sql).unwrap().parse_statement().unwrap();
224 optimize(stmt).unwrap()
225 }
226
227 #[test]
228 fn test_constant_folding() {
229 let stmt = optimize_sql("SELECT 1 + 2 FROM t");
230 if let Statement::Select(sel) = stmt {
231 if let SelectItem::Expr { expr, .. } = &sel.columns[0] {
232 assert_eq!(*expr, Expr::Number("3".to_string()));
233 }
234 }
235 }
236
237 #[test]
238 fn test_boolean_simplification_where_true() {
239 let stmt = optimize_sql("SELECT x FROM t WHERE TRUE");
240 if let Statement::Select(sel) = stmt {
241 assert!(sel.where_clause.is_none());
242 }
243 }
244
245 #[test]
246 fn test_boolean_simplification_and_true() {
247 let stmt = optimize_sql("SELECT x FROM t WHERE TRUE AND x > 1");
248 if let Statement::Select(sel) = stmt {
249 assert!(sel.where_clause.is_some());
251 assert!(!matches!(&sel.where_clause, Some(Expr::BinaryOp { op: BinaryOperator::And, .. })));
252 }
253 }
254
255 #[test]
256 fn test_double_negation() {
257 let stmt = optimize_sql("SELECT x FROM t WHERE NOT NOT x > 1");
258 if let Statement::Select(sel) = stmt {
259 assert!(sel.where_clause.is_some());
261 }
262 }
263}