Skip to main content

sql_cli/query_plan/
order_by_alias_transformer.rs

1//! ORDER BY clause alias transformer
2//!
3//! This transformer rewrites ORDER BY clauses that reference aggregate functions
4//! to use the aliases from the SELECT clause instead.
5//!
6//! # Problem
7//!
8//! Users often write queries like:
9//! ```sql
10//! SELECT region, SUM(sales_amount) AS total
11//! FROM sales
12//! GROUP BY region
13//! ORDER BY SUM(sales_amount) DESC
14//! ```
15//!
16//! This fails because the parser treats `SUM(sales_amount)` as a column name "SUM"
17//! which doesn't exist.
18//!
19//! # Solution
20//!
21//! The transformer rewrites to:
22//! ```sql
23//! SELECT region, SUM(sales_amount) AS total
24//! FROM sales
25//! GROUP BY region
26//! ORDER BY total DESC
27//! ```
28//!
29//! # Algorithm
30//!
31//! 1. Find all aggregate functions in SELECT clause and their aliases
32//! 2. Scan ORDER BY clause for column names that match aggregate patterns
33//! 3. Replace with the corresponding alias from SELECT
34//!
35//! # Note
36//!
37//! This transformer works at the string level since ORDER BY currently only
38//! supports column names, not full expressions in the AST.
39
40use crate::query_plan::pipeline::ASTTransformer;
41use crate::sql::parser::ast::{
42    CTEType, ColumnRef, SelectItem, SelectStatement, SqlExpression, TableSource,
43};
44use anyhow::Result;
45use std::collections::HashMap;
46use tracing::debug;
47
48/// Transformer that rewrites ORDER BY to use aggregate aliases
49pub struct OrderByAliasTransformer {
50    /// Counter for generating unique alias names if needed
51    alias_counter: usize,
52}
53
54impl OrderByAliasTransformer {
55    pub fn new() -> Self {
56        Self { alias_counter: 0 }
57    }
58
59    /// Check if an expression is an aggregate function
60    fn is_aggregate_function(expr: &SqlExpression) -> bool {
61        matches!(
62            expr,
63            SqlExpression::FunctionCall { name, .. }
64                if matches!(
65                    name.to_uppercase().as_str(),
66                    "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "COUNT_DISTINCT"
67                )
68        )
69    }
70
71    /// Generate a unique alias name
72    fn generate_alias(&mut self) -> String {
73        self.alias_counter += 1;
74        format!("__orderby_agg_{}", self.alias_counter)
75    }
76
77    /// Normalize an aggregate expression to match against ORDER BY column names
78    ///
79    /// ORDER BY might have strings like "SUM(sales_amount)" which the parser
80    /// treats as a column name. We need to match these against actual aggregates.
81    /// Returns uppercase version for case-insensitive matching.
82    fn normalize_aggregate_expr(expr: &SqlExpression) -> String {
83        match expr {
84            SqlExpression::FunctionCall { name, args, .. } => {
85                let args_str = args
86                    .iter()
87                    .map(|arg| match arg {
88                        SqlExpression::Column(col_ref) => col_ref.name.to_uppercase(),
89                        // Special case: COUNT('*') should match COUNT(*)
90                        SqlExpression::StringLiteral(s) if s == "*" => "*".to_string(),
91                        SqlExpression::StringLiteral(s) => format!("'{}'", s).to_uppercase(),
92                        SqlExpression::NumberLiteral(n) => n.to_uppercase(),
93                        _ => format!("{:?}", arg).to_uppercase(), // Fallback for complex args
94                    })
95                    .collect::<Vec<_>>()
96                    .join(", ");
97                format!("{}({})", name.to_uppercase(), args_str)
98            }
99            _ => String::new(),
100        }
101    }
102
103    /// Extract aggregate functions from SELECT clause and build mapping
104    /// Returns: (normalized_expr -> alias, normalized_expr -> needs_alias_flag)
105    fn build_aggregate_map(
106        &mut self,
107        select_items: &mut Vec<SelectItem>,
108    ) -> HashMap<String, String> {
109        let mut aggregate_map = HashMap::new();
110
111        for item in select_items.iter_mut() {
112            if let SelectItem::Expression { expr, alias, .. } = item {
113                if Self::is_aggregate_function(expr) {
114                    let normalized = Self::normalize_aggregate_expr(expr);
115
116                    // If no alias exists, generate one
117                    if alias.is_empty() {
118                        *alias = self.generate_alias();
119                        debug!(
120                            "Generated alias '{}' for aggregate in ORDER BY: {}",
121                            alias, normalized
122                        );
123                    }
124
125                    debug!("Mapped aggregate '{}' to alias '{}'", normalized, alias);
126                    aggregate_map.insert(normalized, alias.clone());
127                }
128            }
129        }
130
131        aggregate_map
132    }
133
134    /// Convert an expression to a string representation for pattern matching
135    /// This is a simplified version that handles common cases
136    fn expression_to_string(expr: &SqlExpression) -> String {
137        match expr {
138            SqlExpression::Column(col_ref) => col_ref.name.to_uppercase(),
139            // Special case: StringLiteral("*") should render as * (for COUNT(*))
140            SqlExpression::StringLiteral(s) if s == "*" => "*".to_string(),
141            SqlExpression::StringLiteral(s) => format!("'{}'", s),
142            SqlExpression::FunctionCall { name, args, .. } => {
143                let args_str = args
144                    .iter()
145                    .map(|arg| Self::expression_to_string(arg))
146                    .collect::<Vec<_>>()
147                    .join(", ");
148                format!("{}({})", name.to_uppercase(), args_str)
149            }
150            _ => "expr".to_string(), // Fallback for complex expressions
151        }
152    }
153
154    /// Check if an ORDER BY column matches an aggregate pattern
155    /// Returns the normalized aggregate string if it matches
156    fn extract_aggregate_from_order_column(column_name: &str) -> Option<String> {
157        // Check if column name looks like an aggregate function call
158        // e.g., "SUM(sales_amount)" or "COUNT(*)"
159        let upper = column_name.to_uppercase();
160
161        if (upper.starts_with("COUNT(") && upper.ends_with(')'))
162            || (upper.starts_with("SUM(") && upper.ends_with(')'))
163            || (upper.starts_with("AVG(") && upper.ends_with(')'))
164            || (upper.starts_with("MIN(") && upper.ends_with(')'))
165            || (upper.starts_with("MAX(") && upper.ends_with(')'))
166            || (upper.starts_with("COUNT_DISTINCT(") && upper.ends_with(')'))
167        {
168            // Normalize to uppercase for matching
169            Some(upper)
170        } else {
171            None
172        }
173    }
174}
175
176impl Default for OrderByAliasTransformer {
177    fn default() -> Self {
178        Self::new()
179    }
180}
181
182impl ASTTransformer for OrderByAliasTransformer {
183    fn name(&self) -> &str {
184        "OrderByAliasTransformer"
185    }
186
187    fn description(&self) -> &str {
188        "Rewrites ORDER BY aggregate expressions to use SELECT aliases"
189    }
190
191    fn transform(&mut self, stmt: SelectStatement) -> Result<SelectStatement> {
192        self.transform_statement(stmt)
193    }
194}
195
196impl OrderByAliasTransformer {
197    /// Transform a SelectStatement and recurse into nested SELECT statements
198    /// (CTEs, FROM subqueries, set operations). Mirrors the recursion pattern
199    /// used by HavingAliasTransformer and GroupByAliasExpander.
200    #[allow(deprecated)]
201    fn transform_statement(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
202        // Recurse into CTEs
203        for cte in stmt.ctes.iter_mut() {
204            if let CTEType::Standard(ref mut inner) = cte.cte_type {
205                let taken = std::mem::take(inner);
206                *inner = self.transform_statement(taken)?;
207            }
208        }
209
210        // Recurse into FROM DerivedTable subqueries
211        if let Some(TableSource::DerivedTable { query, .. }) = stmt.from_source.as_mut() {
212            let taken = std::mem::take(query.as_mut());
213            **query = self.transform_statement(taken)?;
214        }
215
216        // Recurse into legacy from_subquery
217        if let Some(subq) = stmt.from_subquery.as_mut() {
218            let taken = std::mem::take(subq.as_mut());
219            **subq = self.transform_statement(taken)?;
220        }
221
222        // Recurse into set operation right-hand sides
223        for (_op, rhs) in stmt.set_operations.iter_mut() {
224            let taken = std::mem::take(rhs.as_mut());
225            **rhs = self.transform_statement(taken)?;
226        }
227
228        // Apply ORDER BY alias rewrite at this level
229        self.apply_rewrite(&mut stmt);
230
231        Ok(stmt)
232    }
233
234    /// Apply ORDER BY alias rewriting to a single statement (no recursion).
235    fn apply_rewrite(&mut self, stmt: &mut SelectStatement) {
236        if stmt.order_by.is_none() {
237            return;
238        }
239
240        // Step 1: Build mapping of aggregates to aliases
241        let aggregate_map = self.build_aggregate_map(&mut stmt.select_items);
242
243        if aggregate_map.is_empty() {
244            return;
245        }
246
247        // Step 2: Rewrite ORDER BY expressions
248        if let Some(order_by) = stmt.order_by.as_mut() {
249            let mut modified = false;
250
251            for order_col in order_by.iter_mut() {
252                let expr_str = Self::expression_to_string(&order_col.expr);
253
254                if let Some(normalized) = Self::extract_aggregate_from_order_column(&expr_str) {
255                    if let Some(alias) = aggregate_map.get(&normalized) {
256                        debug!("Rewriting ORDER BY '{}' to use alias '{}'", expr_str, alias);
257                        order_col.expr = SqlExpression::Column(ColumnRef::unquoted(alias.clone()));
258                        modified = true;
259                    }
260                }
261            }
262
263            if modified {
264                debug!(
265                    "Rewrote ORDER BY to use {} aggregate alias(es)",
266                    aggregate_map.len()
267                );
268            }
269        }
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use crate::sql::parser::ast::{ColumnRef, QuoteStyle, SortDirection};
277
278    #[test]
279    fn test_extract_aggregate_from_order_column() {
280        assert_eq!(
281            OrderByAliasTransformer::extract_aggregate_from_order_column("SUM(sales_amount)"),
282            Some("SUM(SALES_AMOUNT)".to_string())
283        );
284
285        assert_eq!(
286            OrderByAliasTransformer::extract_aggregate_from_order_column("COUNT(*)"),
287            Some("COUNT(*)".to_string())
288        );
289
290        assert_eq!(
291            OrderByAliasTransformer::extract_aggregate_from_order_column("region"),
292            None
293        );
294
295        assert_eq!(
296            OrderByAliasTransformer::extract_aggregate_from_order_column("total"),
297            None
298        );
299    }
300
301    #[test]
302    fn test_normalize_aggregate_expr() {
303        let expr = SqlExpression::FunctionCall {
304            name: "SUM".to_string(),
305            args: vec![SqlExpression::Column(ColumnRef {
306                name: "sales_amount".to_string(),
307                quote_style: QuoteStyle::None,
308                table_prefix: None,
309            })],
310            distinct: false,
311        };
312
313        assert_eq!(
314            OrderByAliasTransformer::normalize_aggregate_expr(&expr),
315            "SUM(SALES_AMOUNT)"
316        );
317    }
318
319    #[test]
320    fn test_is_aggregate_function() {
321        let sum_expr = SqlExpression::FunctionCall {
322            name: "SUM".to_string(),
323            args: vec![],
324            distinct: false,
325        };
326        assert!(OrderByAliasTransformer::is_aggregate_function(&sum_expr));
327
328        let upper_expr = SqlExpression::FunctionCall {
329            name: "UPPER".to_string(),
330            args: vec![],
331            distinct: false,
332        };
333        assert!(!OrderByAliasTransformer::is_aggregate_function(&upper_expr));
334    }
335}