sql_cli/query_plan/
order_by_alias_transformer.rs

1//! ORDER BY clause alias transformer
2//!
3//! This transformer rewrites ORDER BY clauses that reference aggregate functions
4//! to use the aliases from the SELECT clause instead.
5//!
6//! # Problem
7//!
8//! Users often write queries like:
9//! ```sql
10//! SELECT region, SUM(sales_amount) AS total
11//! FROM sales
12//! GROUP BY region
13//! ORDER BY SUM(sales_amount) DESC
14//! ```
15//!
16//! This fails because the parser treats `SUM(sales_amount)` as a column name "SUM"
17//! which doesn't exist.
18//!
19//! # Solution
20//!
21//! The transformer rewrites to:
22//! ```sql
23//! SELECT region, SUM(sales_amount) AS total
24//! FROM sales
25//! GROUP BY region
26//! ORDER BY total DESC
27//! ```
28//!
29//! # Algorithm
30//!
31//! 1. Find all aggregate functions in SELECT clause and their aliases
32//! 2. Scan ORDER BY clause for column names that match aggregate patterns
33//! 3. Replace with the corresponding alias from SELECT
34//!
35//! # Note
36//!
37//! This transformer works at the string level since ORDER BY currently only
38//! supports column names, not full expressions in the AST.
39
40use crate::query_plan::pipeline::ASTTransformer;
41use crate::sql::parser::ast::{ColumnRef, SelectItem, SelectStatement, SqlExpression};
42use anyhow::Result;
43use std::collections::HashMap;
44use tracing::debug;
45
46/// Transformer that rewrites ORDER BY to use aggregate aliases
47pub struct OrderByAliasTransformer {
48    /// Counter for generating unique alias names if needed
49    alias_counter: usize,
50}
51
52impl OrderByAliasTransformer {
53    pub fn new() -> Self {
54        Self { alias_counter: 0 }
55    }
56
57    /// Check if an expression is an aggregate function
58    fn is_aggregate_function(expr: &SqlExpression) -> bool {
59        matches!(
60            expr,
61            SqlExpression::FunctionCall { name, .. }
62                if matches!(
63                    name.to_uppercase().as_str(),
64                    "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "COUNT_DISTINCT"
65                )
66        )
67    }
68
69    /// Generate a unique alias name
70    fn generate_alias(&mut self) -> String {
71        self.alias_counter += 1;
72        format!("__orderby_agg_{}", self.alias_counter)
73    }
74
75    /// Normalize an aggregate expression to match against ORDER BY column names
76    ///
77    /// ORDER BY might have strings like "SUM(sales_amount)" which the parser
78    /// treats as a column name. We need to match these against actual aggregates.
79    /// Returns uppercase version for case-insensitive matching.
80    fn normalize_aggregate_expr(expr: &SqlExpression) -> String {
81        match expr {
82            SqlExpression::FunctionCall { name, args, .. } => {
83                let args_str = args
84                    .iter()
85                    .map(|arg| match arg {
86                        SqlExpression::Column(col_ref) => col_ref.name.to_uppercase(),
87                        // Special case: COUNT('*') should match COUNT(*)
88                        SqlExpression::StringLiteral(s) if s == "*" => "*".to_string(),
89                        SqlExpression::StringLiteral(s) => format!("'{}'", s).to_uppercase(),
90                        SqlExpression::NumberLiteral(n) => n.to_uppercase(),
91                        _ => format!("{:?}", arg).to_uppercase(), // Fallback for complex args
92                    })
93                    .collect::<Vec<_>>()
94                    .join(", ");
95                format!("{}({})", name.to_uppercase(), args_str)
96            }
97            _ => String::new(),
98        }
99    }
100
101    /// Extract aggregate functions from SELECT clause and build mapping
102    /// Returns: (normalized_expr -> alias, normalized_expr -> needs_alias_flag)
103    fn build_aggregate_map(
104        &mut self,
105        select_items: &mut Vec<SelectItem>,
106    ) -> HashMap<String, String> {
107        let mut aggregate_map = HashMap::new();
108
109        for item in select_items.iter_mut() {
110            if let SelectItem::Expression { expr, alias, .. } = item {
111                if Self::is_aggregate_function(expr) {
112                    let normalized = Self::normalize_aggregate_expr(expr);
113
114                    // If no alias exists, generate one
115                    if alias.is_empty() {
116                        *alias = self.generate_alias();
117                        debug!(
118                            "Generated alias '{}' for aggregate in ORDER BY: {}",
119                            alias, normalized
120                        );
121                    }
122
123                    debug!("Mapped aggregate '{}' to alias '{}'", normalized, alias);
124                    aggregate_map.insert(normalized, alias.clone());
125                }
126            }
127        }
128
129        aggregate_map
130    }
131
132    /// Convert an expression to a string representation for pattern matching
133    /// This is a simplified version that handles common cases
134    fn expression_to_string(expr: &SqlExpression) -> String {
135        match expr {
136            SqlExpression::Column(col_ref) => col_ref.name.to_uppercase(),
137            // Special case: StringLiteral("*") should render as * (for COUNT(*))
138            SqlExpression::StringLiteral(s) if s == "*" => "*".to_string(),
139            SqlExpression::StringLiteral(s) => format!("'{}'", s),
140            SqlExpression::FunctionCall { name, args, .. } => {
141                let args_str = args
142                    .iter()
143                    .map(|arg| Self::expression_to_string(arg))
144                    .collect::<Vec<_>>()
145                    .join(", ");
146                format!("{}({})", name.to_uppercase(), args_str)
147            }
148            _ => "expr".to_string(), // Fallback for complex expressions
149        }
150    }
151
152    /// Check if an ORDER BY column matches an aggregate pattern
153    /// Returns the normalized aggregate string if it matches
154    fn extract_aggregate_from_order_column(column_name: &str) -> Option<String> {
155        // Check if column name looks like an aggregate function call
156        // e.g., "SUM(sales_amount)" or "COUNT(*)"
157        let upper = column_name.to_uppercase();
158
159        if (upper.starts_with("COUNT(") && upper.ends_with(')'))
160            || (upper.starts_with("SUM(") && upper.ends_with(')'))
161            || (upper.starts_with("AVG(") && upper.ends_with(')'))
162            || (upper.starts_with("MIN(") && upper.ends_with(')'))
163            || (upper.starts_with("MAX(") && upper.ends_with(')'))
164            || (upper.starts_with("COUNT_DISTINCT(") && upper.ends_with(')'))
165        {
166            // Normalize to uppercase for matching
167            Some(upper)
168        } else {
169            None
170        }
171    }
172}
173
174impl Default for OrderByAliasTransformer {
175    fn default() -> Self {
176        Self::new()
177    }
178}
179
180impl ASTTransformer for OrderByAliasTransformer {
181    fn name(&self) -> &str {
182        "OrderByAliasTransformer"
183    }
184
185    fn description(&self) -> &str {
186        "Rewrites ORDER BY aggregate expressions to use SELECT aliases"
187    }
188
189    fn transform(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
190        // Only process if there's an ORDER BY clause
191        if stmt.order_by.is_none() {
192            return Ok(stmt);
193        }
194
195        // Step 1: Build mapping of aggregates to aliases
196        let aggregate_map = self.build_aggregate_map(&mut stmt.select_items);
197
198        if aggregate_map.is_empty() {
199            // No aggregates in SELECT, nothing to do
200            return Ok(stmt);
201        }
202
203        // Step 2: Rewrite ORDER BY expressions
204        if let Some(order_by) = stmt.order_by.as_mut() {
205            let mut modified = false;
206
207            for order_col in order_by.iter_mut() {
208                // Convert expression to string representation for pattern matching
209                let expr_str = Self::expression_to_string(&order_col.expr);
210
211                // Check if this looks like an aggregate function call
212                if let Some(normalized) = Self::extract_aggregate_from_order_column(&expr_str) {
213                    // Try to find matching aggregate in map
214                    if let Some(alias) = aggregate_map.get(&normalized) {
215                        debug!("Rewriting ORDER BY '{}' to use alias '{}'", expr_str, alias);
216                        // Replace expression with simple column reference to the alias
217                        order_col.expr = SqlExpression::Column(ColumnRef::unquoted(alias.clone()));
218                        modified = true;
219                    }
220                }
221            }
222
223            if modified {
224                debug!(
225                    "Rewrote ORDER BY to use {} aggregate alias(es)",
226                    aggregate_map.len()
227                );
228            }
229        }
230
231        Ok(stmt)
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use crate::sql::parser::ast::{ColumnRef, QuoteStyle, SortDirection};
239
240    #[test]
241    fn test_extract_aggregate_from_order_column() {
242        assert_eq!(
243            OrderByAliasTransformer::extract_aggregate_from_order_column("SUM(sales_amount)"),
244            Some("SUM(SALES_AMOUNT)".to_string())
245        );
246
247        assert_eq!(
248            OrderByAliasTransformer::extract_aggregate_from_order_column("COUNT(*)"),
249            Some("COUNT(*)".to_string())
250        );
251
252        assert_eq!(
253            OrderByAliasTransformer::extract_aggregate_from_order_column("region"),
254            None
255        );
256
257        assert_eq!(
258            OrderByAliasTransformer::extract_aggregate_from_order_column("total"),
259            None
260        );
261    }
262
263    #[test]
264    fn test_normalize_aggregate_expr() {
265        let expr = SqlExpression::FunctionCall {
266            name: "SUM".to_string(),
267            args: vec![SqlExpression::Column(ColumnRef {
268                name: "sales_amount".to_string(),
269                quote_style: QuoteStyle::None,
270                table_prefix: None,
271            })],
272            distinct: false,
273        };
274
275        assert_eq!(
276            OrderByAliasTransformer::normalize_aggregate_expr(&expr),
277            "SUM(SALES_AMOUNT)"
278        );
279    }
280
281    #[test]
282    fn test_is_aggregate_function() {
283        let sum_expr = SqlExpression::FunctionCall {
284            name: "SUM".to_string(),
285            args: vec![],
286            distinct: false,
287        };
288        assert!(OrderByAliasTransformer::is_aggregate_function(&sum_expr));
289
290        let upper_expr = SqlExpression::FunctionCall {
291            name: "UPPER".to_string(),
292            args: vec![],
293            distinct: false,
294        };
295        assert!(!OrderByAliasTransformer::is_aggregate_function(&upper_expr));
296    }
297}