Skip to main content

sql_cli/query_plan/
pivot_expander.rs

1use crate::query_plan::pipeline::ASTTransformer;
2use crate::sql::parser::ast::{
3    ColumnRef, Condition, LogicalOp, OrderByItem, PivotAggregate, QuoteStyle, SelectItem,
4    SelectStatement, SqlExpression, TableSource, WhenBranch,
5};
6use anyhow::{anyhow, Result};
7
8/// PIVOT Expander - Transforms PIVOT syntax into standard SQL with CASE expressions
9///
10/// This transformer converts SQL Server-style PIVOT operations into standard SQL
11/// using CASE expressions and GROUP BY clauses.
12///
13/// Transformation Example:
14/// ```sql
15/// -- Input (PIVOT syntax):
16/// SELECT * FROM food_eaten
17/// PIVOT (MAX(AmountEaten) FOR FoodName IN ('Sammich', 'Pickle', 'Apple'))
18///
19/// -- Output (Standard SQL):
20/// SELECT Date,
21///     MAX(CASE WHEN FoodName = 'Sammich' THEN AmountEaten ELSE NULL END) AS Sammich,
22///     MAX(CASE WHEN FoodName = 'Pickle' THEN AmountEaten ELSE NULL END) AS Pickle,
23///     MAX(CASE WHEN FoodName = 'Apple' THEN AmountEaten ELSE NULL END) AS Apple
24/// FROM food_eaten
25/// GROUP BY Date
26/// ```
27///
28/// The algorithm:
29/// 1. Detect PIVOT in FROM clause or JOINs
30/// 2. Extract PIVOT specification (aggregate function, pivot column, pivot values)
31/// 3. Generate CASE expression for each pivot value
32/// 4. Wrap each CASE in the aggregate function
33/// 5. Determine GROUP BY columns (all source columns except pivot_column and aggregate_column)
34/// 6. Build new SelectStatement with CASE expressions and GROUP BY
35pub struct PivotExpander;
36
37impl PivotExpander {
38    /// Transform a SELECT statement, expanding any PIVOT operations
39    pub fn expand(mut statement: SelectStatement) -> Result<SelectStatement> {
40        // Check if FROM contains a PIVOT
41        if let Some(ref from_source) = statement.from_source {
42            match from_source {
43                TableSource::Pivot {
44                    source,
45                    aggregate,
46                    pivot_column,
47                    pivot_values,
48                    alias,
49                } => {
50                    // This is a PIVOT! Expand it to CASE expressions + GROUP BY
51                    return Self::expand_pivot(
52                        source,
53                        aggregate,
54                        pivot_column,
55                        pivot_values,
56                        alias,
57                    );
58                }
59                TableSource::DerivedTable { query, .. } => {
60                    // Recursively process the derived table (subquery)
61                    let processed_subquery = Self::expand(*query.clone())?;
62                    statement.from_source = Some(TableSource::DerivedTable {
63                        query: Box::new(processed_subquery),
64                        alias: match from_source {
65                            TableSource::DerivedTable { alias, .. } => alias.clone(),
66                            _ => String::new(),
67                        },
68                    });
69                }
70                TableSource::Table(_) => {
71                    // Regular table, nothing to expand
72                }
73            }
74        }
75
76        Ok(statement)
77    }
78
79    /// Expand a PIVOT operation into CASE expressions + GROUP BY
80    pub fn expand_pivot(
81        source: &TableSource,
82        aggregate: &PivotAggregate,
83        pivot_column: &str,
84        pivot_values: &[String],
85        alias: &Option<String>,
86    ) -> Result<SelectStatement> {
87        // Extract the base source table/subquery
88        let (base_table, base_alias, base_subquery) = Self::extract_base_source(source)?;
89
90        // Determine columns for GROUP BY
91        // We need all columns from the source except pivot_column and aggregate.column
92        let group_by_columns = Self::determine_group_by_columns(
93            &base_table,
94            &base_alias,
95            &base_subquery,
96            pivot_column,
97            &aggregate.column,
98        )?;
99
100        // Build SELECT items
101        let mut select_items = Vec::new();
102
103        // Add GROUP BY columns to SELECT
104        for col in &group_by_columns {
105            select_items.push(SelectItem::Column {
106                column: ColumnRef::unquoted(col.clone()),
107                leading_comments: Vec::new(),
108                trailing_comment: None,
109            });
110        }
111
112        // Generate CASE expression for each pivot value
113        for pivot_value in pivot_values {
114            let case_expr = Self::build_pivot_case_expression(
115                pivot_column,
116                pivot_value,
117                &aggregate.column,
118                &aggregate.function,
119            )?;
120
121            select_items.push(SelectItem::Expression {
122                expr: case_expr,
123                alias: pivot_value.clone(),
124                leading_comments: Vec::new(),
125                trailing_comment: None,
126            });
127        }
128
129        // Build the transformed statement
130        // Build from_source from the extracted base
131        let from_source = if let Some(ref table) = base_table {
132            Some(TableSource::Table(table.clone()))
133        } else if let Some(ref subquery) = base_subquery {
134            Some(TableSource::DerivedTable {
135                query: subquery.clone(),
136                alias: base_alias.clone().unwrap_or_default(),
137            })
138        } else {
139            None
140        };
141
142        let mut result = SelectStatement {
143            distinct: false,
144            columns: Vec::new(), // Deprecated field
145            select_items,
146            from_source,
147            #[allow(deprecated)]
148            from_table: base_table,
149            #[allow(deprecated)]
150            from_subquery: base_subquery,
151            #[allow(deprecated)]
152            from_function: None,
153            #[allow(deprecated)]
154            from_alias: base_alias.or_else(|| alias.clone()),
155            joins: Vec::new(),
156            where_clause: None,
157            order_by: None,
158            group_by: Some(
159                group_by_columns
160                    .iter()
161                    .map(|col| SqlExpression::Column(ColumnRef::unquoted(col.clone())))
162                    .collect(),
163            ),
164            having: None,
165            qualify: None,
166            limit: None,
167            offset: None,
168            ctes: Vec::new(),
169            into_table: None,
170            set_operations: Vec::new(),
171            leading_comments: Vec::new(),
172            trailing_comment: None,
173        };
174
175        Ok(result)
176    }
177
178    /// Extract the base source from a TableSource (table name or subquery)
179    fn extract_base_source(
180        source: &TableSource,
181    ) -> Result<(Option<String>, Option<String>, Option<Box<SelectStatement>>)> {
182        match source {
183            TableSource::Table(name) => Ok((Some(name.clone()), None, None)),
184            TableSource::DerivedTable { query, alias } => {
185                Ok((None, Some(alias.clone()), Some(query.clone())))
186            }
187            TableSource::Pivot { .. } => Err(anyhow!("Nested PIVOT operations are not supported")),
188        }
189    }
190
191    /// Determine which columns should be in the GROUP BY clause
192    /// These are all columns except the pivot_column and the aggregate_column
193    fn determine_group_by_columns(
194        base_table: &Option<String>,
195        base_alias: &Option<String>,
196        base_subquery: &Option<Box<SelectStatement>>,
197        pivot_column: &str,
198        aggregate_column: &str,
199    ) -> Result<Vec<String>> {
200        // For now, we need to infer columns from the source
201        // This is a simplified implementation - in production, you'd want to:
202        // 1. Query the data source schema
203        // 2. Extract columns from subquery SELECT items
204        // 3. Handle qualified column names
205
206        if let Some(subquery) = base_subquery {
207            // Extract column names from subquery's SELECT items
208            let mut columns = Vec::new();
209            for item in &subquery.select_items {
210                match item {
211                    SelectItem::Column { column, .. } => {
212                        let col_name = column.name.clone();
213                        if col_name != pivot_column && col_name != aggregate_column {
214                            columns.push(col_name);
215                        }
216                    }
217                    SelectItem::Expression { alias, .. } => {
218                        if alias != pivot_column && alias != aggregate_column {
219                            columns.push(alias.clone());
220                        }
221                    }
222                    SelectItem::Star { .. } => {
223                        // Cannot determine columns from *, would need schema info
224                        return Err(anyhow!(
225                            "PIVOT with SELECT * is not supported. Please specify columns explicitly."
226                        ));
227                    }
228                    SelectItem::StarExclude { .. } => {
229                        return Err(anyhow!(
230                            "PIVOT with SELECT * EXCLUDE is not supported. Please specify columns explicitly."
231                        ));
232                    }
233                }
234            }
235            Ok(columns)
236        } else {
237            // For table sources, we'd need schema information
238            // This is a limitation - in production, integrate with schema discovery
239            Err(anyhow!(
240                "PIVOT on table sources requires explicit column specification. \
241                 Use a subquery: SELECT col1, col2, pivot_col, agg_col FROM table"
242            ))
243        }
244    }
245
246    /// Build a CASE expression for a single pivot value
247    /// Example: MAX(CASE WHEN FoodName = 'Sammich' THEN AmountEaten ELSE NULL END)
248    fn build_pivot_case_expression(
249        pivot_column: &str,
250        pivot_value: &str,
251        aggregate_column: &str,
252        aggregate_function: &str,
253    ) -> Result<SqlExpression> {
254        // Build: CASE WHEN pivot_column = 'pivot_value' THEN aggregate_column ELSE NULL END
255        let case_expr = SqlExpression::CaseExpression {
256            when_branches: vec![WhenBranch {
257                condition: Box::new(SqlExpression::BinaryOp {
258                    left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
259                        pivot_column.to_string(),
260                    ))),
261                    op: "=".to_string(),
262                    right: Box::new(SqlExpression::StringLiteral(pivot_value.to_string())),
263                }),
264                result: Box::new(SqlExpression::Column(ColumnRef::unquoted(
265                    aggregate_column.to_string(),
266                ))),
267            }],
268            else_branch: Some(Box::new(SqlExpression::Null)),
269        };
270
271        // Wrap in aggregate function: aggregate_function(case_expr)
272        let aggregated = SqlExpression::FunctionCall {
273            name: aggregate_function.to_uppercase(),
274            args: vec![case_expr],
275            distinct: false,
276        };
277
278        Ok(aggregated)
279    }
280}
281
282impl ASTTransformer for PivotExpander {
283    fn name(&self) -> &str {
284        "PivotExpander"
285    }
286
287    fn description(&self) -> &str {
288        "Expands PIVOT operations into CASE expressions with GROUP BY"
289    }
290
291    fn transform(&mut self, stmt: SelectStatement) -> Result<SelectStatement> {
292        Self::expand(stmt)
293    }
294}
295
296impl Default for PivotExpander {
297    fn default() -> Self {
298        Self
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn test_build_pivot_case_expression() {
308        let expr =
309            PivotExpander::build_pivot_case_expression("FoodName", "Sammich", "AmountEaten", "MAX")
310                .unwrap();
311
312        // Verify it's a function call
313        match expr {
314            SqlExpression::FunctionCall { name, args, .. } => {
315                assert_eq!(name, "MAX");
316                assert_eq!(args.len(), 1);
317
318                // Verify the CASE expression inside
319                match &args[0] {
320                    SqlExpression::CaseExpression {
321                        when_branches,
322                        else_branch,
323                    } => {
324                        assert_eq!(when_branches.len(), 1);
325                        assert!(else_branch.is_some());
326                    }
327                    _ => panic!("Expected CaseExpression inside function call"),
328                }
329            }
330            _ => panic!("Expected FunctionCall"),
331        }
332    }
333
334    #[test]
335    fn test_determine_group_by_columns_with_subquery() {
336        // Create a simple subquery: SELECT Date, FoodName, AmountEaten FROM food_eaten
337        let subquery = SelectStatement {
338            distinct: false,
339            columns: Vec::new(),
340            select_items: vec![
341                SelectItem::Column {
342                    column: ColumnRef::unquoted("Date".to_string()),
343                    leading_comments: Vec::new(),
344                    trailing_comment: None,
345                },
346                SelectItem::Column {
347                    column: ColumnRef::unquoted("FoodName".to_string()),
348                    leading_comments: Vec::new(),
349                    trailing_comment: None,
350                },
351                SelectItem::Column {
352                    column: ColumnRef::unquoted("AmountEaten".to_string()),
353                    leading_comments: Vec::new(),
354                    trailing_comment: None,
355                },
356            ],
357            from_source: None,
358            #[allow(deprecated)]
359            from_table: Some("food_eaten".to_string()),
360            #[allow(deprecated)]
361            from_subquery: None,
362            #[allow(deprecated)]
363            from_function: None,
364            #[allow(deprecated)]
365            from_alias: None,
366            joins: Vec::new(),
367            where_clause: None,
368            order_by: None,
369            group_by: None,
370            having: None,
371            qualify: None,
372            limit: None,
373            offset: None,
374            ctes: Vec::new(),
375            into_table: None,
376            set_operations: Vec::new(),
377            leading_comments: Vec::new(),
378            trailing_comment: None,
379        };
380
381        let columns = PivotExpander::determine_group_by_columns(
382            &None,
383            &Some("src".to_string()),
384            &Some(Box::new(subquery)),
385            "FoodName",
386            "AmountEaten",
387        )
388        .unwrap();
389
390        // Should only return "Date" (excluding FoodName and AmountEaten)
391        assert_eq!(columns.len(), 1);
392        assert_eq!(columns[0], "Date");
393    }
394}