sqlglot_rust/optimizer/
mod.rs1use crate::ast::*;
15use crate::errors::Result;
16
17pub fn optimize(statement: Statement) -> Result<Statement> {
19 let mut stmt = statement;
20 stmt = fold_constants(stmt);
21 stmt = simplify_booleans(stmt);
22 Ok(stmt)
23}
24
25fn fold_constants(statement: Statement) -> Statement {
27 match statement {
28 Statement::Select(mut sel) => {
29 if let Some(wh) = sel.where_clause {
30 sel.where_clause = Some(fold_expr(wh));
31 }
32 if let Some(having) = sel.having {
33 sel.having = Some(fold_expr(having));
34 }
35 for item in &mut sel.columns {
36 if let SelectItem::Expr { expr, .. } = item {
37 *expr = fold_expr(expr.clone());
38 }
39 }
40 Statement::Select(sel)
41 }
42 other => other,
43 }
44}
45
46fn fold_expr(expr: Expr) -> Expr {
47 match expr {
48 Expr::BinaryOp { left, op, right } => {
49 let left = fold_expr(*left);
50 let right = fold_expr(*right);
51
52 if let (Expr::Number(l), Expr::Number(r)) = (&left, &right) {
54 if let (Ok(lv), Ok(rv)) = (l.parse::<f64>(), r.parse::<f64>()) {
55 let result = match op {
56 BinaryOperator::Plus => Some(lv + rv),
57 BinaryOperator::Minus => Some(lv - rv),
58 BinaryOperator::Multiply => Some(lv * rv),
59 BinaryOperator::Divide if rv != 0.0 => Some(lv / rv),
60 BinaryOperator::Modulo if rv != 0.0 => Some(lv % rv),
61 _ => None,
62 };
63 if let Some(val) = result {
64 if val.fract() == 0.0 && val.abs() < i64::MAX as f64 {
66 return Expr::Number(format!("{}", val as i64));
67 }
68 return Expr::Number(format!("{val}"));
69 }
70
71 let cmp = match op {
73 BinaryOperator::Eq => Some(lv == rv),
74 BinaryOperator::Neq => Some(lv != rv),
75 BinaryOperator::Lt => Some(lv < rv),
76 BinaryOperator::Gt => Some(lv > rv),
77 BinaryOperator::LtEq => Some(lv <= rv),
78 BinaryOperator::GtEq => Some(lv >= rv),
79 _ => None,
80 };
81 if let Some(val) = cmp {
82 return Expr::Boolean(val);
83 }
84 }
85 }
86
87 if matches!(op, BinaryOperator::Concat) {
89 if let (Expr::StringLiteral(l), Expr::StringLiteral(r)) = (&left, &right) {
90 return Expr::StringLiteral(format!("{l}{r}"));
91 }
92 }
93
94 Expr::BinaryOp {
95 left: Box::new(left),
96 op,
97 right: Box::new(right),
98 }
99 }
100 Expr::UnaryOp { op: UnaryOperator::Minus, expr } => {
101 let inner = fold_expr(*expr);
102 if let Expr::Number(ref n) = inner {
103 if let Ok(v) = n.parse::<f64>() {
104 let neg = -v;
105 if neg.fract() == 0.0 && neg.abs() < i64::MAX as f64 {
106 return Expr::Number(format!("{}", neg as i64));
107 }
108 return Expr::Number(format!("{neg}"));
109 }
110 }
111 Expr::UnaryOp {
112 op: UnaryOperator::Minus,
113 expr: Box::new(inner),
114 }
115 }
116 Expr::Nested(inner) => {
117 let folded = fold_expr(*inner);
118 if folded.is_literal() {
119 folded
120 } else {
121 Expr::Nested(Box::new(folded))
122 }
123 }
124 other => other,
125 }
126}
127
128fn simplify_booleans(statement: Statement) -> Statement {
130 match statement {
131 Statement::Select(mut sel) => {
132 for item in &mut sel.columns {
134 if let SelectItem::Expr { expr, .. } = item {
135 *expr = simplify_bool_expr(expr.clone());
136 }
137 }
138 if let Some(wh) = sel.where_clause {
139 let simplified = simplify_bool_expr(wh);
140 if simplified == Expr::Boolean(true) {
142 sel.where_clause = None;
143 } else {
144 sel.where_clause = Some(simplified);
145 }
146 }
147 if let Some(having) = sel.having {
148 let simplified = simplify_bool_expr(having);
149 if simplified == Expr::Boolean(true) {
150 sel.having = None;
151 } else {
152 sel.having = Some(simplified);
153 }
154 }
155 Statement::Select(sel)
156 }
157 other => other,
158 }
159}
160
161fn simplify_bool_expr(expr: Expr) -> Expr {
162 match expr {
163 Expr::BinaryOp { left, op: BinaryOperator::And, right } => {
164 let left = simplify_bool_expr(*left);
165 let right = simplify_bool_expr(*right);
166 match (&left, &right) {
167 (Expr::Boolean(true), _) => right,
168 (_, Expr::Boolean(true)) => left,
169 (Expr::Boolean(false), _) | (_, Expr::Boolean(false)) => Expr::Boolean(false),
170 _ => Expr::BinaryOp {
171 left: Box::new(left),
172 op: BinaryOperator::And,
173 right: Box::new(right),
174 },
175 }
176 }
177 Expr::BinaryOp { left, op: BinaryOperator::Or, right } => {
178 let left = simplify_bool_expr(*left);
179 let right = simplify_bool_expr(*right);
180 match (&left, &right) {
181 (Expr::Boolean(true), _) | (_, Expr::Boolean(true)) => Expr::Boolean(true),
182 (Expr::Boolean(false), _) => right,
183 (_, Expr::Boolean(false)) => left,
184 _ => Expr::BinaryOp {
185 left: Box::new(left),
186 op: BinaryOperator::Or,
187 right: Box::new(right),
188 },
189 }
190 }
191 Expr::UnaryOp { op: UnaryOperator::Not, expr } => {
192 let inner = simplify_bool_expr(*expr);
193 match inner {
194 Expr::Boolean(b) => Expr::Boolean(!b),
195 Expr::UnaryOp { op: UnaryOperator::Not, expr: inner2 } => *inner2,
196 other => Expr::UnaryOp {
197 op: UnaryOperator::Not,
198 expr: Box::new(other),
199 },
200 }
201 }
202 Expr::Nested(inner) => {
203 let simplified = simplify_bool_expr(*inner);
204 if simplified.is_literal() {
205 simplified
206 } else {
207 Expr::Nested(Box::new(simplified))
208 }
209 }
210 other => other,
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use crate::parser::Parser;
218
219 fn optimize_sql(sql: &str) -> Statement {
220 let stmt = Parser::new(sql).unwrap().parse_statement().unwrap();
221 optimize(stmt).unwrap()
222 }
223
224 #[test]
225 fn test_constant_folding() {
226 let stmt = optimize_sql("SELECT 1 + 2 FROM t");
227 if let Statement::Select(sel) = stmt {
228 if let SelectItem::Expr { expr, .. } = &sel.columns[0] {
229 assert_eq!(*expr, Expr::Number("3".to_string()));
230 }
231 }
232 }
233
234 #[test]
235 fn test_boolean_simplification_where_true() {
236 let stmt = optimize_sql("SELECT x FROM t WHERE TRUE");
237 if let Statement::Select(sel) = stmt {
238 assert!(sel.where_clause.is_none());
239 }
240 }
241
242 #[test]
243 fn test_boolean_simplification_and_true() {
244 let stmt = optimize_sql("SELECT x FROM t WHERE TRUE AND x > 1");
245 if let Statement::Select(sel) = stmt {
246 assert!(sel.where_clause.is_some());
248 assert!(!matches!(&sel.where_clause, Some(Expr::BinaryOp { op: BinaryOperator::And, .. })));
249 }
250 }
251
252 #[test]
253 fn test_double_negation() {
254 let stmt = optimize_sql("SELECT x FROM t WHERE NOT NOT x > 1");
255 if let Statement::Select(sel) = stmt {
256 assert!(sel.where_clause.is_some());
258 }
259 }
260}