1use crate::query_plan::pipeline::ASTTransformer;
2use crate::sql::parser::ast::{
3 ColumnRef, Condition, LogicalOp, OrderByItem, PivotAggregate, QuoteStyle, SelectItem,
4 SelectStatement, SqlExpression, TableSource, WhenBranch,
5};
6use anyhow::{anyhow, Result};
7
8pub struct PivotExpander;
36
37impl PivotExpander {
38 pub fn expand(mut statement: SelectStatement) -> Result<SelectStatement> {
40 if let Some(ref from_source) = statement.from_source {
42 match from_source {
43 TableSource::Pivot {
44 source,
45 aggregate,
46 pivot_column,
47 pivot_values,
48 alias,
49 } => {
50 return Self::expand_pivot(
52 source,
53 aggregate,
54 pivot_column,
55 pivot_values,
56 alias,
57 );
58 }
59 TableSource::DerivedTable { query, .. } => {
60 let processed_subquery = Self::expand(*query.clone())?;
62 statement.from_source = Some(TableSource::DerivedTable {
63 query: Box::new(processed_subquery),
64 alias: match from_source {
65 TableSource::DerivedTable { alias, .. } => alias.clone(),
66 _ => String::new(),
67 },
68 });
69 }
70 TableSource::Table(_) => {
71 }
73 }
74 }
75
76 Ok(statement)
77 }
78
79 pub fn expand_pivot(
81 source: &TableSource,
82 aggregate: &PivotAggregate,
83 pivot_column: &str,
84 pivot_values: &[String],
85 alias: &Option<String>,
86 ) -> Result<SelectStatement> {
87 let (base_table, base_alias, base_subquery) = Self::extract_base_source(source)?;
89
90 let group_by_columns = Self::determine_group_by_columns(
93 &base_table,
94 &base_alias,
95 &base_subquery,
96 pivot_column,
97 &aggregate.column,
98 )?;
99
100 let mut select_items = Vec::new();
102
103 for col in &group_by_columns {
105 select_items.push(SelectItem::Column {
106 column: ColumnRef::unquoted(col.clone()),
107 leading_comments: Vec::new(),
108 trailing_comment: None,
109 });
110 }
111
112 for pivot_value in pivot_values {
114 let case_expr = Self::build_pivot_case_expression(
115 pivot_column,
116 pivot_value,
117 &aggregate.column,
118 &aggregate.function,
119 )?;
120
121 select_items.push(SelectItem::Expression {
122 expr: case_expr,
123 alias: pivot_value.clone(),
124 leading_comments: Vec::new(),
125 trailing_comment: None,
126 });
127 }
128
129 let from_source = if let Some(ref table) = base_table {
132 Some(TableSource::Table(table.clone()))
133 } else if let Some(ref subquery) = base_subquery {
134 Some(TableSource::DerivedTable {
135 query: subquery.clone(),
136 alias: base_alias.clone().unwrap_or_default(),
137 })
138 } else {
139 None
140 };
141
142 let mut result = SelectStatement {
143 distinct: false,
144 columns: Vec::new(), select_items,
146 from_source,
147 #[allow(deprecated)]
148 from_table: base_table,
149 #[allow(deprecated)]
150 from_subquery: base_subquery,
151 #[allow(deprecated)]
152 from_function: None,
153 #[allow(deprecated)]
154 from_alias: base_alias.or_else(|| alias.clone()),
155 joins: Vec::new(),
156 where_clause: None,
157 order_by: None,
158 group_by: Some(
159 group_by_columns
160 .iter()
161 .map(|col| SqlExpression::Column(ColumnRef::unquoted(col.clone())))
162 .collect(),
163 ),
164 having: None,
165 qualify: None,
166 limit: None,
167 offset: None,
168 ctes: Vec::new(),
169 into_table: None,
170 set_operations: Vec::new(),
171 leading_comments: Vec::new(),
172 trailing_comment: None,
173 };
174
175 Ok(result)
176 }
177
178 fn extract_base_source(
180 source: &TableSource,
181 ) -> Result<(Option<String>, Option<String>, Option<Box<SelectStatement>>)> {
182 match source {
183 TableSource::Table(name) => Ok((Some(name.clone()), None, None)),
184 TableSource::DerivedTable { query, alias } => {
185 Ok((None, Some(alias.clone()), Some(query.clone())))
186 }
187 TableSource::Pivot { .. } => Err(anyhow!("Nested PIVOT operations are not supported")),
188 }
189 }
190
191 fn determine_group_by_columns(
194 base_table: &Option<String>,
195 base_alias: &Option<String>,
196 base_subquery: &Option<Box<SelectStatement>>,
197 pivot_column: &str,
198 aggregate_column: &str,
199 ) -> Result<Vec<String>> {
200 if let Some(subquery) = base_subquery {
207 let mut columns = Vec::new();
209 for item in &subquery.select_items {
210 match item {
211 SelectItem::Column { column, .. } => {
212 let col_name = column.name.clone();
213 if col_name != pivot_column && col_name != aggregate_column {
214 columns.push(col_name);
215 }
216 }
217 SelectItem::Expression { alias, .. } => {
218 if alias != pivot_column && alias != aggregate_column {
219 columns.push(alias.clone());
220 }
221 }
222 SelectItem::Star { .. } => {
223 return Err(anyhow!(
225 "PIVOT with SELECT * is not supported. Please specify columns explicitly."
226 ));
227 }
228 SelectItem::StarExclude { .. } => {
229 return Err(anyhow!(
230 "PIVOT with SELECT * EXCLUDE is not supported. Please specify columns explicitly."
231 ));
232 }
233 }
234 }
235 Ok(columns)
236 } else {
237 Err(anyhow!(
240 "PIVOT on table sources requires explicit column specification. \
241 Use a subquery: SELECT col1, col2, pivot_col, agg_col FROM table"
242 ))
243 }
244 }
245
246 fn build_pivot_case_expression(
249 pivot_column: &str,
250 pivot_value: &str,
251 aggregate_column: &str,
252 aggregate_function: &str,
253 ) -> Result<SqlExpression> {
254 let case_expr = SqlExpression::CaseExpression {
256 when_branches: vec![WhenBranch {
257 condition: Box::new(SqlExpression::BinaryOp {
258 left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
259 pivot_column.to_string(),
260 ))),
261 op: "=".to_string(),
262 right: Box::new(SqlExpression::StringLiteral(pivot_value.to_string())),
263 }),
264 result: Box::new(SqlExpression::Column(ColumnRef::unquoted(
265 aggregate_column.to_string(),
266 ))),
267 }],
268 else_branch: Some(Box::new(SqlExpression::Null)),
269 };
270
271 let aggregated = SqlExpression::FunctionCall {
273 name: aggregate_function.to_uppercase(),
274 args: vec![case_expr],
275 distinct: false,
276 };
277
278 Ok(aggregated)
279 }
280}
281
282impl ASTTransformer for PivotExpander {
283 fn name(&self) -> &str {
284 "PivotExpander"
285 }
286
287 fn description(&self) -> &str {
288 "Expands PIVOT operations into CASE expressions with GROUP BY"
289 }
290
291 fn transform(&mut self, stmt: SelectStatement) -> Result<SelectStatement> {
292 Self::expand(stmt)
293 }
294}
295
296impl Default for PivotExpander {
297 fn default() -> Self {
298 Self
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn test_build_pivot_case_expression() {
308 let expr =
309 PivotExpander::build_pivot_case_expression("FoodName", "Sammich", "AmountEaten", "MAX")
310 .unwrap();
311
312 match expr {
314 SqlExpression::FunctionCall { name, args, .. } => {
315 assert_eq!(name, "MAX");
316 assert_eq!(args.len(), 1);
317
318 match &args[0] {
320 SqlExpression::CaseExpression {
321 when_branches,
322 else_branch,
323 } => {
324 assert_eq!(when_branches.len(), 1);
325 assert!(else_branch.is_some());
326 }
327 _ => panic!("Expected CaseExpression inside function call"),
328 }
329 }
330 _ => panic!("Expected FunctionCall"),
331 }
332 }
333
334 #[test]
335 fn test_determine_group_by_columns_with_subquery() {
336 let subquery = SelectStatement {
338 distinct: false,
339 columns: Vec::new(),
340 select_items: vec![
341 SelectItem::Column {
342 column: ColumnRef::unquoted("Date".to_string()),
343 leading_comments: Vec::new(),
344 trailing_comment: None,
345 },
346 SelectItem::Column {
347 column: ColumnRef::unquoted("FoodName".to_string()),
348 leading_comments: Vec::new(),
349 trailing_comment: None,
350 },
351 SelectItem::Column {
352 column: ColumnRef::unquoted("AmountEaten".to_string()),
353 leading_comments: Vec::new(),
354 trailing_comment: None,
355 },
356 ],
357 from_source: None,
358 #[allow(deprecated)]
359 from_table: Some("food_eaten".to_string()),
360 #[allow(deprecated)]
361 from_subquery: None,
362 #[allow(deprecated)]
363 from_function: None,
364 #[allow(deprecated)]
365 from_alias: None,
366 joins: Vec::new(),
367 where_clause: None,
368 order_by: None,
369 group_by: None,
370 having: None,
371 qualify: None,
372 limit: None,
373 offset: None,
374 ctes: Vec::new(),
375 into_table: None,
376 set_operations: Vec::new(),
377 leading_comments: Vec::new(),
378 trailing_comment: None,
379 };
380
381 let columns = PivotExpander::determine_group_by_columns(
382 &None,
383 &Some("src".to_string()),
384 &Some(Box::new(subquery)),
385 "FoodName",
386 "AmountEaten",
387 )
388 .unwrap();
389
390 assert_eq!(columns.len(), 1);
392 assert_eq!(columns[0], "Date");
393 }
394}