sql_cli/sql/
query_rewriter.rs

1//! Query Rewriter Module
2//!
3//! This module analyzes SQL queries and suggests/performs transformations
4//! to make them compatible with the SQL engine's capabilities.
5//!
6//! Main transformations:
7//! - Hoist expressions from aggregate/window functions into CTEs
8//! - Convert complex expressions to simpler forms
9//! - Identify patterns that need rewriting
10
11use crate::sql::parser::ast::{CTEType, SelectStatement, SqlExpression, CTE};
12use serde::{Deserialize, Serialize};
13
14/// Represents a suggested rewrite for a query
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct RewriteSuggestion {
17    /// Type of rewrite needed
18    pub rewrite_type: RewriteType,
19    /// Location in original query (if available)
20    pub location: Option<String>,
21    /// Description of the issue
22    pub issue: String,
23    /// Suggested fix
24    pub suggestion: String,
25    /// The rewritten SQL (if automatic rewrite is possible)
26    pub rewritten_sql: Option<String>,
27    /// CTE that could be added to fix the issue
28    pub suggested_cte: Option<String>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub enum RewriteType {
33    /// Expression in aggregate function needs hoisting
34    AggregateExpressionHoist,
35    /// Expression in window function needs hoisting
36    WindowExpressionHoist,
37    /// Complex WHERE clause expression needs simplification
38    WhereExpressionHoist,
39    /// LAG/LEAD with expression needs hoisting
40    LagLeadExpressionHoist,
41    /// Complex JOIN condition needs simplification
42    JoinConditionHoist,
43    /// Nested aggregate functions
44    NestedAggregateHoist,
45}
46
47/// Analyzes a query and returns rewrite suggestions
48pub struct QueryRewriter {
49    suggestions: Vec<RewriteSuggestion>,
50}
51
52impl QueryRewriter {
53    pub fn new() -> Self {
54        Self {
55            suggestions: Vec::new(),
56        }
57    }
58
59    /// Analyze a query and return suggestions
60    pub fn analyze(&mut self, stmt: &SelectStatement) -> Vec<RewriteSuggestion> {
61        self.suggestions.clear();
62
63        // Analyze SELECT items for complex expressions
64        self.analyze_select_items(stmt);
65
66        // Analyze WHERE clause
67        if let Some(where_clause) = &stmt.where_clause {
68            self.analyze_where_clause(where_clause);
69        }
70
71        // Analyze GROUP BY
72        if let Some(group_by) = &stmt.group_by {
73            self.analyze_group_by(group_by);
74        }
75
76        // Analyze existing CTEs for issues
77        for cte in &stmt.ctes {
78            self.analyze_cte(cte);
79        }
80
81        self.suggestions.clone()
82    }
83
84    /// Check SELECT items for expressions that need hoisting
85    fn analyze_select_items(&mut self, stmt: &SelectStatement) {
86        for item in &stmt.select_items {
87            if let crate::sql::parser::ast::SelectItem::Expression { expr, alias } = item {
88                self.check_expression_for_hoisting(expr, Some(alias));
89            }
90        }
91    }
92
93    /// Check if an expression needs hoisting
94    fn check_expression_for_hoisting(&mut self, expr: &SqlExpression, context: Option<&str>) {
95        match expr {
96            SqlExpression::WindowFunction { name, args, .. } => {
97                // Check if window function has complex expressions
98                for arg in args {
99                    if self.is_complex_expression(arg) {
100                        self.suggestions.push(RewriteSuggestion {
101                            rewrite_type: RewriteType::WindowExpressionHoist,
102                            location: context.map(|s| s.to_string()),
103                            issue: format!("Window function {} contains complex expression", name),
104                            suggestion: "Hoist the expression to a CTE and reference the column"
105                                .to_string(),
106                            rewritten_sql: None,
107                            suggested_cte: Some(self.generate_hoist_cte(arg, "expr_cte")),
108                        });
109                    }
110                }
111            }
112            SqlExpression::FunctionCall { name, args, .. } => {
113                // Check if it's an aggregate function with complex expression
114                if self.is_aggregate_function(name) {
115                    for arg in args {
116                        if self.is_complex_expression(arg) {
117                            self.suggestions.push(RewriteSuggestion {
118                                rewrite_type: RewriteType::AggregateExpressionHoist,
119                                location: context.map(|s| s.to_string()),
120                                issue: format!("Aggregate function {} contains expression: {:?}", name, arg),
121                                suggestion: "Create a CTE with the calculated expression, then aggregate the result column".to_string(),
122                                rewritten_sql: None,
123                                suggested_cte: Some(self.generate_hoist_cte(arg, "calc_cte")),
124                            });
125                        }
126                    }
127                }
128
129                // Check for LAG/LEAD with expressions
130                if name == "LAG" || name == "LEAD" {
131                    if let Some(first_arg) = args.first() {
132                        if self.is_complex_expression(first_arg) {
133                            self.suggestions.push(RewriteSuggestion {
134                                rewrite_type: RewriteType::LagLeadExpressionHoist,
135                                location: context.map(|s| s.to_string()),
136                                issue: format!("{} function contains expression instead of column reference", name),
137                                suggestion: format!("Calculate expression in a CTE, then apply {} to the result column", name),
138                                rewritten_sql: None,
139                                suggested_cte: Some(self.generate_hoist_cte(first_arg, "lag_lead_cte")),
140                            });
141                        }
142                    }
143                }
144            }
145            SqlExpression::BinaryOp { left, right, .. } => {
146                // Recursively check both sides
147                self.check_expression_for_hoisting(left, context);
148                self.check_expression_for_hoisting(right, context);
149            }
150            _ => {}
151        }
152    }
153
154    /// Check if an expression is complex (not just a column reference)
155    fn is_complex_expression(&self, expr: &SqlExpression) -> bool {
156        !matches!(
157            expr,
158            SqlExpression::Column(_)
159                | SqlExpression::NumberLiteral(_)
160                | SqlExpression::StringLiteral(_)
161        )
162    }
163
164    /// Check if a function is an aggregate function
165    fn is_aggregate_function(&self, name: &str) -> bool {
166        matches!(
167            name.to_uppercase().as_str(),
168            "SUM" | "AVG" | "COUNT" | "MIN" | "MAX" | "STDDEV" | "VARIANCE" | "MEDIAN"
169        )
170    }
171
172    /// Generate a suggested CTE for hoisting an expression
173    fn generate_hoist_cte(&self, expr: &SqlExpression, cte_name: &str) -> String {
174        let expr_str = self.expression_to_sql(expr);
175        format!(
176            "{} AS (\n    SELECT \n        *,\n        {} AS calculated_value\n    FROM previous_table\n)",
177            cte_name, expr_str
178        )
179    }
180
181    /// Convert an expression to SQL string
182    fn expression_to_sql(&self, expr: &SqlExpression) -> String {
183        match expr {
184            SqlExpression::Column(col_ref) => col_ref.to_sql(),
185            SqlExpression::BinaryOp { left, right, op } => {
186                format!(
187                    "{} {} {}",
188                    self.expression_to_sql(left),
189                    op,
190                    self.expression_to_sql(right)
191                )
192            }
193            SqlExpression::NumberLiteral(n) => n.clone(),
194            SqlExpression::StringLiteral(s) => format!("'{}'", s),
195            SqlExpression::FunctionCall { name, args, .. } => {
196                let arg_strs: Vec<String> =
197                    args.iter().map(|a| self.expression_to_sql(a)).collect();
198                format!("{}({})", name, arg_strs.join(", "))
199            }
200            _ => format!("{:?}", expr), // Fallback for complex types
201        }
202    }
203
204    fn analyze_where_clause(&mut self, _where_clause: &crate::sql::parser::ast::WhereClause) {
205        // TODO: Analyze WHERE clause for complex expressions
206    }
207
208    fn analyze_group_by(&mut self, _group_by: &[SqlExpression]) {
209        // TODO: Analyze GROUP BY for complex expressions
210    }
211
212    fn analyze_cte(&mut self, cte: &CTE) {
213        // Recursively analyze CTEs
214        if let CTEType::Standard(query) = &cte.cte_type {
215            let mut sub_rewriter = QueryRewriter::new();
216            sub_rewriter.analyze(query);
217            for mut suggestion in sub_rewriter.suggestions {
218                // Prepend CTE name to location
219                suggestion.location = Some(format!(
220                    "CTE '{}': {}",
221                    cte.name,
222                    suggestion.location.unwrap_or_default()
223                ));
224                self.suggestions.push(suggestion);
225            }
226        }
227    }
228
229    /// Attempt to automatically rewrite a query
230    pub fn rewrite(&self, _stmt: &SelectStatement) -> Option<SelectStatement> {
231        // This would implement actual rewriting logic
232        // For now, we just analyze and suggest
233        None
234    }
235}
236
237/// JSON output for CLI integration
238#[derive(Debug, Serialize, Deserialize)]
239pub struct RewriteAnalysis {
240    pub success: bool,
241    pub suggestions: Vec<RewriteSuggestion>,
242    pub can_auto_rewrite: bool,
243    pub rewritten_query: Option<String>,
244}
245
246impl RewriteAnalysis {
247    pub fn from_suggestions(suggestions: Vec<RewriteSuggestion>) -> Self {
248        let can_auto_rewrite = suggestions.iter().any(|s| s.rewritten_sql.is_some());
249        Self {
250            success: true,
251            suggestions,
252            can_auto_rewrite,
253            rewritten_query: None,
254        }
255    }
256}