Skip to main content

sql_cli/query_plan/
order_by_alias_transformer.rs

1//! ORDER BY clause alias transformer
2//!
3//! This transformer rewrites ORDER BY clauses that reference aggregate functions
4//! to use the aliases from the SELECT clause instead.
5//!
6//! # Problem
7//!
8//! Users often write queries like:
9//! ```sql
10//! SELECT region, SUM(sales_amount) AS total
11//! FROM sales
12//! GROUP BY region
13//! ORDER BY SUM(sales_amount) DESC
14//! ```
15//!
16//! This fails because the parser treats `SUM(sales_amount)` as a column name "SUM"
17//! which doesn't exist.
18//!
19//! # Solution
20//!
21//! The transformer rewrites to:
22//! ```sql
23//! SELECT region, SUM(sales_amount) AS total
24//! FROM sales
25//! GROUP BY region
26//! ORDER BY total DESC
27//! ```
28//!
29//! # Algorithm
30//!
31//! 1. Find all aggregate functions in SELECT clause and their aliases
32//! 2. Scan ORDER BY clause for column names that match aggregate patterns
33//! 3. Replace with the corresponding alias from SELECT
34//!
35//! # Note
36//!
37//! This transformer works at the string level since ORDER BY currently only
38//! supports column names, not full expressions in the AST.
39
40use crate::query_plan::pipeline::ASTTransformer;
41use crate::sql::parser::ast::{
42    CTEType, ColumnRef, SelectItem, SelectStatement, SqlExpression, TableSource,
43};
44use anyhow::Result;
45use std::collections::{HashMap, HashSet};
46use tracing::debug;
47
48/// Prefix used for ORDER BY columns promoted into SELECT so they survive
49/// projection. Columns with this prefix are stripped from the final output
50/// after ORDER BY runs.
51pub const HIDDEN_ORDERBY_PREFIX: &str = "__hidden_orderby_";
52
53/// Transformer that rewrites ORDER BY to use aggregate aliases
54pub struct OrderByAliasTransformer {
55    /// Counter for generating unique alias names if needed
56    alias_counter: usize,
57    /// Counter for ORDER BY columns promoted into SELECT
58    hidden_counter: usize,
59}
60
61impl OrderByAliasTransformer {
62    pub fn new() -> Self {
63        Self {
64            alias_counter: 0,
65            hidden_counter: 0,
66        }
67    }
68
69    /// Check if an expression is an aggregate function
70    fn is_aggregate_function(expr: &SqlExpression) -> bool {
71        matches!(
72            expr,
73            SqlExpression::FunctionCall { name, .. }
74                if matches!(
75                    name.to_uppercase().as_str(),
76                    "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "COUNT_DISTINCT"
77                )
78        )
79    }
80
81    /// Generate a unique alias name
82    fn generate_alias(&mut self) -> String {
83        self.alias_counter += 1;
84        format!("__orderby_agg_{}", self.alias_counter)
85    }
86
87    /// Normalize an aggregate expression to match against ORDER BY column names
88    ///
89    /// ORDER BY might have strings like "SUM(sales_amount)" which the parser
90    /// treats as a column name. We need to match these against actual aggregates.
91    /// Returns uppercase version for case-insensitive matching.
92    fn normalize_aggregate_expr(expr: &SqlExpression) -> String {
93        match expr {
94            SqlExpression::FunctionCall { name, args, .. } => {
95                let args_str = args
96                    .iter()
97                    .map(|arg| match arg {
98                        SqlExpression::Column(col_ref) => col_ref.name.to_uppercase(),
99                        // Special case: COUNT('*') should match COUNT(*)
100                        SqlExpression::StringLiteral(s) if s == "*" => "*".to_string(),
101                        SqlExpression::StringLiteral(s) => format!("'{}'", s).to_uppercase(),
102                        SqlExpression::NumberLiteral(n) => n.to_uppercase(),
103                        _ => format!("{:?}", arg).to_uppercase(), // Fallback for complex args
104                    })
105                    .collect::<Vec<_>>()
106                    .join(", ");
107                format!("{}({})", name.to_uppercase(), args_str)
108            }
109            _ => String::new(),
110        }
111    }
112
113    /// Extract aggregate functions from SELECT clause and build mapping
114    /// Returns: (normalized_expr -> alias, normalized_expr -> needs_alias_flag)
115    fn build_aggregate_map(
116        &mut self,
117        select_items: &mut Vec<SelectItem>,
118    ) -> HashMap<String, String> {
119        let mut aggregate_map = HashMap::new();
120
121        for item in select_items.iter_mut() {
122            if let SelectItem::Expression { expr, alias, .. } = item {
123                if Self::is_aggregate_function(expr) {
124                    let normalized = Self::normalize_aggregate_expr(expr);
125
126                    // If no alias exists, generate one
127                    if alias.is_empty() {
128                        *alias = self.generate_alias();
129                        debug!(
130                            "Generated alias '{}' for aggregate in ORDER BY: {}",
131                            alias, normalized
132                        );
133                    }
134
135                    debug!("Mapped aggregate '{}' to alias '{}'", normalized, alias);
136                    aggregate_map.insert(normalized, alias.clone());
137                }
138            }
139        }
140
141        aggregate_map
142    }
143
144    /// Convert an expression to a string representation for pattern matching
145    /// This is a simplified version that handles common cases
146    fn expression_to_string(expr: &SqlExpression) -> String {
147        match expr {
148            SqlExpression::Column(col_ref) => col_ref.name.to_uppercase(),
149            // Special case: StringLiteral("*") should render as * (for COUNT(*))
150            SqlExpression::StringLiteral(s) if s == "*" => "*".to_string(),
151            SqlExpression::StringLiteral(s) => format!("'{}'", s),
152            SqlExpression::FunctionCall { name, args, .. } => {
153                let args_str = args
154                    .iter()
155                    .map(|arg| Self::expression_to_string(arg))
156                    .collect::<Vec<_>>()
157                    .join(", ");
158                format!("{}({})", name.to_uppercase(), args_str)
159            }
160            _ => "expr".to_string(), // Fallback for complex expressions
161        }
162    }
163
164    /// Check if an ORDER BY column matches an aggregate pattern
165    /// Returns the normalized aggregate string if it matches
166    fn extract_aggregate_from_order_column(column_name: &str) -> Option<String> {
167        // Check if column name looks like an aggregate function call
168        // e.g., "SUM(sales_amount)" or "COUNT(*)"
169        let upper = column_name.to_uppercase();
170
171        if (upper.starts_with("COUNT(") && upper.ends_with(')'))
172            || (upper.starts_with("SUM(") && upper.ends_with(')'))
173            || (upper.starts_with("AVG(") && upper.ends_with(')'))
174            || (upper.starts_with("MIN(") && upper.ends_with(')'))
175            || (upper.starts_with("MAX(") && upper.ends_with(')'))
176            || (upper.starts_with("COUNT_DISTINCT(") && upper.ends_with(')'))
177        {
178            // Normalize to uppercase for matching
179            Some(upper)
180        } else {
181            None
182        }
183    }
184}
185
186impl Default for OrderByAliasTransformer {
187    fn default() -> Self {
188        Self::new()
189    }
190}
191
192impl ASTTransformer for OrderByAliasTransformer {
193    fn name(&self) -> &str {
194        "OrderByAliasTransformer"
195    }
196
197    fn description(&self) -> &str {
198        "Rewrites ORDER BY aggregate expressions to use SELECT aliases"
199    }
200
201    fn transform(&mut self, stmt: SelectStatement) -> Result<SelectStatement> {
202        self.transform_statement(stmt)
203    }
204}
205
206impl OrderByAliasTransformer {
207    /// Transform a SelectStatement and recurse into nested SELECT statements
208    /// (CTEs, FROM subqueries, set operations). Mirrors the recursion pattern
209    /// used by HavingAliasTransformer and GroupByAliasExpander.
210    #[allow(deprecated)]
211    fn transform_statement(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
212        // Recurse into CTEs
213        for cte in stmt.ctes.iter_mut() {
214            if let CTEType::Standard(ref mut inner) = cte.cte_type {
215                let taken = std::mem::take(inner);
216                *inner = self.transform_statement(taken)?;
217            }
218        }
219
220        // Recurse into FROM DerivedTable subqueries
221        if let Some(TableSource::DerivedTable { query, .. }) = stmt.from_source.as_mut() {
222            let taken = std::mem::take(query.as_mut());
223            **query = self.transform_statement(taken)?;
224        }
225
226        // Recurse into legacy from_subquery
227        if let Some(subq) = stmt.from_subquery.as_mut() {
228            let taken = std::mem::take(subq.as_mut());
229            **subq = self.transform_statement(taken)?;
230        }
231
232        // Recurse into set operation right-hand sides
233        for (_op, rhs) in stmt.set_operations.iter_mut() {
234            let taken = std::mem::take(rhs.as_mut());
235            **rhs = self.transform_statement(taken)?;
236        }
237
238        // Apply ORDER BY alias rewrite at this level
239        self.apply_rewrite(&mut stmt);
240
241        Ok(stmt)
242    }
243
244    /// Apply ORDER BY alias rewriting to a single statement (no recursion).
245    fn apply_rewrite(&mut self, stmt: &mut SelectStatement) {
246        if stmt.order_by.is_none() {
247            return;
248        }
249
250        // Step 1: Build mapping of aggregates to aliases
251        let aggregate_map = self.build_aggregate_map(&mut stmt.select_items);
252
253        // Step 2: Rewrite ORDER BY aggregate expressions to use SELECT aliases
254        if !aggregate_map.is_empty() {
255            if let Some(order_by) = stmt.order_by.as_mut() {
256                let mut modified = false;
257
258                for order_col in order_by.iter_mut() {
259                    let expr_str = Self::expression_to_string(&order_col.expr);
260
261                    if let Some(normalized) = Self::extract_aggregate_from_order_column(&expr_str) {
262                        if let Some(alias) = aggregate_map.get(&normalized) {
263                            debug!("Rewriting ORDER BY '{}' to use alias '{}'", expr_str, alias);
264                            order_col.expr =
265                                SqlExpression::Column(ColumnRef::unquoted(alias.clone()));
266                            modified = true;
267                        }
268                    }
269                }
270
271                if modified {
272                    debug!(
273                        "Rewrote ORDER BY to use {} aggregate alias(es)",
274                        aggregate_map.len()
275                    );
276                }
277            }
278        }
279
280        // Step 3: Promote ORDER BY columns that aren't visible after projection.
281        // Without this, `SELECT name AS results ... ORDER BY name` (or any
282        // ORDER BY column not in SELECT) fails because projection narrows the
283        // result columns before ORDER BY can resolve them.
284        self.promote_hidden_order_by_columns(stmt);
285    }
286
287    /// Promote ORDER BY column references that aren't already in the SELECT
288    /// output. Each missing column is appended to SELECT as a hidden
289    /// expression; the ORDER BY ref is rewritten to use the hidden alias.
290    /// Hidden columns get stripped from output by query_engine after sort.
291    fn promote_hidden_order_by_columns(&mut self, stmt: &mut SelectStatement) {
292        let order_by = match stmt.order_by.as_mut() {
293            Some(o) if !o.is_empty() => o,
294            _ => return,
295        };
296
297        // SELECT * (or similar) keeps all source columns visible — no promotion needed.
298        if stmt
299            .select_items
300            .iter()
301            .any(|i| matches!(i, SelectItem::Star { .. } | SelectItem::StarExclude { .. }))
302        {
303            return;
304        }
305
306        // Names exposed by SELECT — output names ORDER BY can already resolve.
307        // For Column items the output name is the column's own name; for
308        // Expression items it's the alias.
309        let mut visible: HashSet<String> = HashSet::new();
310        for item in stmt.select_items.iter() {
311            match item {
312                SelectItem::Column { column, .. } => {
313                    visible.insert(column.name.to_lowercase());
314                }
315                SelectItem::Expression { alias, .. } if !alias.is_empty() => {
316                    visible.insert(alias.to_lowercase());
317                }
318                _ => {}
319            }
320        }
321
322        // Dedup: column refs (by name) and computed expressions (by stringified
323        // form) share hidden aliases when they appear multiple times in ORDER BY.
324        let mut promoted_columns: HashMap<String, String> = HashMap::new();
325        let mut promoted_exprs: HashMap<String, String> = HashMap::new();
326
327        for order_col in order_by.iter_mut() {
328            // Skip if the ORDER BY item is already a Column ref to a visible
329            // SELECT output — no promotion needed.
330            if let SqlExpression::Column(c) = &order_col.expr {
331                if visible.contains(&c.name.to_lowercase()) {
332                    continue;
333                }
334            }
335
336            // Determine the dedup key and clone the expression we'll promote.
337            // For Column refs we use the column name (case-insensitive); for
338            // arbitrary expressions we use the debug-formatted string. Two
339            // semantically-identical exprs may not dedup but that's harmless —
340            // the worst case is one redundant computed column.
341            let expr_to_promote = order_col.expr.clone();
342            let (dedup_key, is_column) = match &expr_to_promote {
343                SqlExpression::Column(c) => (c.name.to_lowercase(), true),
344                other => (format!("{:?}", other), false),
345            };
346
347            let existing_alias = if is_column {
348                promoted_columns.get(&dedup_key).cloned()
349            } else {
350                promoted_exprs.get(&dedup_key).cloned()
351            };
352
353            let hidden_alias = if let Some(alias) = existing_alias {
354                alias
355            } else {
356                self.hidden_counter += 1;
357                let alias = format!("{}{}", HIDDEN_ORDERBY_PREFIX, self.hidden_counter);
358                debug!(
359                    "Promoting ORDER BY expression as hidden SELECT item '{}': {:?}",
360                    alias, expr_to_promote
361                );
362                stmt.select_items.push(SelectItem::Expression {
363                    expr: expr_to_promote,
364                    alias: alias.clone(),
365                    leading_comments: Vec::new(),
366                    trailing_comment: None,
367                });
368                if is_column {
369                    promoted_columns.insert(dedup_key, alias.clone());
370                } else {
371                    promoted_exprs.insert(dedup_key, alias.clone());
372                }
373                alias
374            };
375
376            order_col.expr = SqlExpression::Column(ColumnRef::unquoted(hidden_alias));
377        }
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use crate::sql::parser::ast::{ColumnRef, QuoteStyle, SortDirection};
385
386    #[test]
387    fn test_extract_aggregate_from_order_column() {
388        assert_eq!(
389            OrderByAliasTransformer::extract_aggregate_from_order_column("SUM(sales_amount)"),
390            Some("SUM(SALES_AMOUNT)".to_string())
391        );
392
393        assert_eq!(
394            OrderByAliasTransformer::extract_aggregate_from_order_column("COUNT(*)"),
395            Some("COUNT(*)".to_string())
396        );
397
398        assert_eq!(
399            OrderByAliasTransformer::extract_aggregate_from_order_column("region"),
400            None
401        );
402
403        assert_eq!(
404            OrderByAliasTransformer::extract_aggregate_from_order_column("total"),
405            None
406        );
407    }
408
409    #[test]
410    fn test_normalize_aggregate_expr() {
411        let expr = SqlExpression::FunctionCall {
412            name: "SUM".to_string(),
413            args: vec![SqlExpression::Column(ColumnRef {
414                name: "sales_amount".to_string(),
415                quote_style: QuoteStyle::None,
416                table_prefix: None,
417            })],
418            distinct: false,
419        };
420
421        assert_eq!(
422            OrderByAliasTransformer::normalize_aggregate_expr(&expr),
423            "SUM(SALES_AMOUNT)"
424        );
425    }
426
427    #[test]
428    fn test_is_aggregate_function() {
429        let sum_expr = SqlExpression::FunctionCall {
430            name: "SUM".to_string(),
431            args: vec![],
432            distinct: false,
433        };
434        assert!(OrderByAliasTransformer::is_aggregate_function(&sum_expr));
435
436        let upper_expr = SqlExpression::FunctionCall {
437            name: "UPPER".to_string(),
438            args: vec![],
439            distinct: false,
440        };
441        assert!(!OrderByAliasTransformer::is_aggregate_function(&upper_expr));
442    }
443}