1use crate::sql::parser::ast::{CTEType, SelectStatement, SqlExpression, CTE};
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct RewriteSuggestion {
17 pub rewrite_type: RewriteType,
19 pub location: Option<String>,
21 pub issue: String,
23 pub suggestion: String,
25 pub rewritten_sql: Option<String>,
27 pub suggested_cte: Option<String>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub enum RewriteType {
33 AggregateExpressionHoist,
35 WindowExpressionHoist,
37 WhereExpressionHoist,
39 LagLeadExpressionHoist,
41 JoinConditionHoist,
43 NestedAggregateHoist,
45}
46
47pub struct QueryRewriter {
49 suggestions: Vec<RewriteSuggestion>,
50}
51
52impl QueryRewriter {
53 pub fn new() -> Self {
54 Self {
55 suggestions: Vec::new(),
56 }
57 }
58
59 pub fn analyze(&mut self, stmt: &SelectStatement) -> Vec<RewriteSuggestion> {
61 self.suggestions.clear();
62
63 self.analyze_select_items(stmt);
65
66 if let Some(where_clause) = &stmt.where_clause {
68 self.analyze_where_clause(where_clause);
69 }
70
71 if let Some(group_by) = &stmt.group_by {
73 self.analyze_group_by(group_by);
74 }
75
76 for cte in &stmt.ctes {
78 self.analyze_cte(cte);
79 }
80
81 self.suggestions.clone()
82 }
83
84 fn analyze_select_items(&mut self, stmt: &SelectStatement) {
86 for item in &stmt.select_items {
87 if let crate::sql::parser::ast::SelectItem::Expression { expr, alias, .. } = item {
88 self.check_expression_for_hoisting(expr, Some(alias));
89 }
90 }
91 }
92
93 fn check_expression_for_hoisting(&mut self, expr: &SqlExpression, context: Option<&str>) {
95 match expr {
96 SqlExpression::WindowFunction { name, args, .. } => {
97 for arg in args {
99 if self.is_complex_expression(arg) {
100 self.suggestions.push(RewriteSuggestion {
101 rewrite_type: RewriteType::WindowExpressionHoist,
102 location: context.map(|s| s.to_string()),
103 issue: format!("Window function {} contains complex expression", name),
104 suggestion: "Hoist the expression to a CTE and reference the column"
105 .to_string(),
106 rewritten_sql: None,
107 suggested_cte: Some(self.generate_hoist_cte(arg, "expr_cte")),
108 });
109 }
110 }
111 }
112 SqlExpression::FunctionCall { name, args, .. } => {
113 if self.is_aggregate_function(name) {
115 for arg in args {
116 if self.is_complex_expression(arg) {
117 self.suggestions.push(RewriteSuggestion {
118 rewrite_type: RewriteType::AggregateExpressionHoist,
119 location: context.map(|s| s.to_string()),
120 issue: format!("Aggregate function {} contains expression: {:?}", name, arg),
121 suggestion: "Create a CTE with the calculated expression, then aggregate the result column".to_string(),
122 rewritten_sql: None,
123 suggested_cte: Some(self.generate_hoist_cte(arg, "calc_cte")),
124 });
125 }
126 }
127 }
128
129 if name == "LAG" || name == "LEAD" {
131 if let Some(first_arg) = args.first() {
132 if self.is_complex_expression(first_arg) {
133 self.suggestions.push(RewriteSuggestion {
134 rewrite_type: RewriteType::LagLeadExpressionHoist,
135 location: context.map(|s| s.to_string()),
136 issue: format!("{} function contains expression instead of column reference", name),
137 suggestion: format!("Calculate expression in a CTE, then apply {} to the result column", name),
138 rewritten_sql: None,
139 suggested_cte: Some(self.generate_hoist_cte(first_arg, "lag_lead_cte")),
140 });
141 }
142 }
143 }
144 }
145 SqlExpression::BinaryOp { left, right, .. } => {
146 self.check_expression_for_hoisting(left, context);
148 self.check_expression_for_hoisting(right, context);
149 }
150 _ => {}
151 }
152 }
153
154 fn is_complex_expression(&self, expr: &SqlExpression) -> bool {
156 !matches!(
157 expr,
158 SqlExpression::Column(_)
159 | SqlExpression::NumberLiteral(_)
160 | SqlExpression::StringLiteral(_)
161 )
162 }
163
164 fn is_aggregate_function(&self, name: &str) -> bool {
166 matches!(
167 name.to_uppercase().as_str(),
168 "SUM" | "AVG" | "COUNT" | "MIN" | "MAX" | "STDDEV" | "VARIANCE" | "MEDIAN"
169 )
170 }
171
172 fn generate_hoist_cte(&self, expr: &SqlExpression, cte_name: &str) -> String {
174 let expr_str = self.expression_to_sql(expr);
175 format!(
176 "{} AS (\n SELECT \n *,\n {} AS calculated_value\n FROM previous_table\n)",
177 cte_name, expr_str
178 )
179 }
180
181 fn expression_to_sql(&self, expr: &SqlExpression) -> String {
183 match expr {
184 SqlExpression::Column(col_ref) => col_ref.to_sql(),
185 SqlExpression::BinaryOp { left, right, op } => {
186 format!(
187 "{} {} {}",
188 self.expression_to_sql(left),
189 op,
190 self.expression_to_sql(right)
191 )
192 }
193 SqlExpression::NumberLiteral(n) => n.clone(),
194 SqlExpression::StringLiteral(s) => format!("'{}'", s),
195 SqlExpression::FunctionCall { name, args, .. } => {
196 let arg_strs: Vec<String> =
197 args.iter().map(|a| self.expression_to_sql(a)).collect();
198 format!("{}({})", name, arg_strs.join(", "))
199 }
200 _ => format!("{:?}", expr), }
202 }
203
204 fn analyze_where_clause(&mut self, _where_clause: &crate::sql::parser::ast::WhereClause) {
205 }
207
208 fn analyze_group_by(&mut self, _group_by: &[SqlExpression]) {
209 }
211
212 fn analyze_cte(&mut self, cte: &CTE) {
213 if let CTEType::Standard(query) = &cte.cte_type {
215 let mut sub_rewriter = QueryRewriter::new();
216 sub_rewriter.analyze(query);
217 for mut suggestion in sub_rewriter.suggestions {
218 suggestion.location = Some(format!(
220 "CTE '{}': {}",
221 cte.name,
222 suggestion.location.unwrap_or_default()
223 ));
224 self.suggestions.push(suggestion);
225 }
226 }
227 }
228
229 pub fn rewrite(&self, _stmt: &SelectStatement) -> Option<SelectStatement> {
231 None
234 }
235}
236
237#[derive(Debug, Serialize, Deserialize)]
239pub struct RewriteAnalysis {
240 pub success: bool,
241 pub suggestions: Vec<RewriteSuggestion>,
242 pub can_auto_rewrite: bool,
243 pub rewritten_query: Option<String>,
244}
245
246impl RewriteAnalysis {
247 pub fn from_suggestions(suggestions: Vec<RewriteSuggestion>) -> Self {
248 let can_auto_rewrite = suggestions.iter().any(|s| s.rewritten_sql.is_some());
249 Self {
250 success: true,
251 suggestions,
252 can_auto_rewrite,
253 rewritten_query: None,
254 }
255 }
256}