Skip to main content

sql_cli/query_plan/
expression_lifter.rs

1use crate::query_plan::{QueryPlan, WorkUnit, WorkUnitExpression, WorkUnitType};
2use crate::sql::parser::ast::{
3    CTEType, ColumnRef, SelectItem, SelectStatement, SqlExpression, WhereClause, CTE,
4};
5use std::collections::HashSet;
6
7/// Expression lifter that identifies and lifts non-supported WHERE expressions to CTEs
8pub struct ExpressionLifter {
9    /// Counter for generating unique CTE names
10    cte_counter: usize,
11
12    /// Set of function names that need to be lifted
13    liftable_functions: HashSet<String>,
14}
15
16impl ExpressionLifter {
17    /// Create a new expression lifter
18    pub fn new() -> Self {
19        let mut liftable_functions = HashSet::new();
20
21        // Add window functions that need lifting
22        liftable_functions.insert("ROW_NUMBER".to_string());
23        liftable_functions.insert("RANK".to_string());
24        liftable_functions.insert("DENSE_RANK".to_string());
25        liftable_functions.insert("LAG".to_string());
26        liftable_functions.insert("LEAD".to_string());
27        liftable_functions.insert("FIRST_VALUE".to_string());
28        liftable_functions.insert("LAST_VALUE".to_string());
29        liftable_functions.insert("NTH_VALUE".to_string());
30
31        // Add aggregate functions that might need lifting in certain contexts
32        liftable_functions.insert("PERCENTILE_CONT".to_string());
33        liftable_functions.insert("PERCENTILE_DISC".to_string());
34
35        ExpressionLifter {
36            cte_counter: 0,
37            liftable_functions,
38        }
39    }
40
41    /// Generate a unique CTE name
42    fn next_cte_name(&mut self) -> String {
43        self.cte_counter += 1;
44        format!("__lifted_{}", self.cte_counter)
45    }
46
47    /// Check if an expression needs to be lifted
48    pub fn needs_lifting(&self, expr: &SqlExpression) -> bool {
49        match expr {
50            SqlExpression::WindowFunction { .. } => true,
51
52            SqlExpression::FunctionCall { name, .. } => {
53                self.liftable_functions.contains(&name.to_uppercase())
54            }
55
56            SqlExpression::BinaryOp { left, right, .. } => {
57                self.needs_lifting(left) || self.needs_lifting(right)
58            }
59
60            SqlExpression::Not { expr } => self.needs_lifting(expr),
61
62            SqlExpression::InList { expr, values } => {
63                self.needs_lifting(expr) || values.iter().any(|v| self.needs_lifting(v))
64            }
65
66            SqlExpression::NotInList { expr, values } => {
67                self.needs_lifting(expr) || values.iter().any(|v| self.needs_lifting(v))
68            }
69
70            SqlExpression::Between { expr, lower, upper } => {
71                self.needs_lifting(expr) || self.needs_lifting(lower) || self.needs_lifting(upper)
72            }
73
74            SqlExpression::CaseExpression {
75                when_branches,
76                else_branch,
77            } => {
78                when_branches.iter().any(|branch| {
79                    self.needs_lifting(&branch.condition) || self.needs_lifting(&branch.result)
80                }) || else_branch
81                    .as_ref()
82                    .map_or(false, |e| self.needs_lifting(e))
83            }
84
85            SqlExpression::SimpleCaseExpression {
86                expr,
87                when_branches,
88                else_branch,
89            } => {
90                self.needs_lifting(expr)
91                    || when_branches.iter().any(|branch| {
92                        self.needs_lifting(&branch.value) || self.needs_lifting(&branch.result)
93                    })
94                    || else_branch
95                        .as_ref()
96                        .map_or(false, |e| self.needs_lifting(e))
97            }
98
99            _ => false,
100        }
101    }
102
103    /// Analyze WHERE clause and identify expressions to lift
104    pub fn analyze_where_clause(&mut self, where_clause: &WhereClause) -> Vec<LiftableExpression> {
105        let mut liftable = Vec::new();
106
107        // Analyze each condition in the WHERE clause
108        for condition in &where_clause.conditions {
109            if self.needs_lifting(&condition.expr) {
110                liftable.push(LiftableExpression {
111                    expression: condition.expr.clone(),
112                    suggested_name: self.next_cte_name(),
113                    dependencies: Vec::new(), // TODO: Analyze dependencies
114                });
115            }
116        }
117
118        liftable
119    }
120
121    /// Lift expressions from a SELECT statement
122    pub fn lift_expressions(&mut self, stmt: &mut SelectStatement) -> Vec<CTE> {
123        let mut lifted_ctes = Vec::new();
124
125        // First, check for column alias dependencies (e.g., using alias in PARTITION BY)
126        let alias_deps = self.analyze_column_alias_dependencies(stmt);
127        if !alias_deps.is_empty() {
128            let cte = self.lift_column_aliases(stmt, &alias_deps);
129            lifted_ctes.push(cte);
130        }
131
132        // Check WHERE clause for liftable expressions
133        if let Some(ref where_clause) = stmt.where_clause {
134            let liftable = self.analyze_where_clause(where_clause);
135
136            for lift_expr in liftable {
137                // Create a CTE that includes the lifted expression as a computed column
138                let cte_select = SelectStatement {
139                    distinct: false,
140                    columns: vec!["*".to_string()],
141                    select_items: vec![
142                        SelectItem::Star {
143                            table_prefix: None,
144                            leading_comments: vec![],
145                            trailing_comment: None,
146                        },
147                        SelectItem::Expression {
148                            expr: lift_expr.expression.clone(),
149                            alias: "lifted_value".to_string(),
150                            leading_comments: vec![],
151                            trailing_comment: None,
152                        },
153                    ],
154                    from_source: stmt.from_source.clone(),
155                    #[allow(deprecated)]
156                    from_table: stmt.from_table.clone(),
157                    #[allow(deprecated)]
158                    from_subquery: stmt.from_subquery.clone(),
159                    #[allow(deprecated)]
160                    from_function: stmt.from_function.clone(),
161                    #[allow(deprecated)]
162                    from_alias: stmt.from_alias.clone(),
163                    joins: stmt.joins.clone(),
164                    where_clause: None, // Move simpler parts of WHERE here if possible
165                    qualify: None,
166                    order_by: None,
167                    group_by: None,
168                    having: None,
169                    limit: None,
170                    offset: None,
171                    ctes: Vec::new(),
172                    into_table: None,
173                    set_operations: Vec::new(),
174                    leading_comments: vec![],
175                    trailing_comment: None,
176                };
177
178                let cte = CTE {
179                    name: lift_expr.suggested_name.clone(),
180                    column_list: None,
181                    cte_type: CTEType::Standard(cte_select),
182                };
183
184                lifted_ctes.push(cte);
185
186                // Update the main query to reference the CTE
187                stmt.from_table = Some(lift_expr.suggested_name);
188
189                // Replace the complex WHERE expression with a simple column reference
190                use crate::sql::parser::ast::Condition;
191                stmt.where_clause = Some(WhereClause {
192                    conditions: vec![Condition {
193                        expr: SqlExpression::Column(ColumnRef::unquoted(
194                            "lifted_value".to_string(),
195                        )),
196                        connector: None,
197                    }],
198                });
199            }
200        }
201
202        // Add lifted CTEs to the statement
203        stmt.ctes.extend(lifted_ctes.clone());
204
205        lifted_ctes
206    }
207
208    /// Analyze column alias dependencies (e.g., alias used in PARTITION BY)
209    fn analyze_column_alias_dependencies(
210        &self,
211        stmt: &SelectStatement,
212    ) -> Vec<(String, SqlExpression)> {
213        let mut dependencies = Vec::new();
214
215        // Extract all aliases defined in SELECT
216        let mut aliases = std::collections::HashMap::new();
217        for item in &stmt.select_items {
218            if let SelectItem::Expression { expr, alias, .. } = item {
219                aliases.insert(alias.clone(), expr.clone());
220                tracing::debug!("Found alias: {} -> {:?}", alias, expr);
221            }
222        }
223
224        // Check if any aliases are used in window functions
225        for item in &stmt.select_items {
226            if let SelectItem::Expression { expr, .. } = item {
227                if let SqlExpression::WindowFunction { window_spec, .. } = expr {
228                    // Check PARTITION BY
229                    for col in &window_spec.partition_by {
230                        tracing::debug!("Checking PARTITION BY column: {}", col);
231                        if aliases.contains_key(col) {
232                            tracing::debug!(
233                                "Found dependency: {} depends on {:?}",
234                                col,
235                                aliases[col]
236                            );
237                            dependencies.push((col.clone(), aliases[col].clone()));
238                        }
239                    }
240
241                    // Check ORDER BY
242                    for order_col in &window_spec.order_by {
243                        // Extract column name from expression
244                        if let SqlExpression::Column(col_ref) = &order_col.expr {
245                            let col = &col_ref.name;
246                            if aliases.contains_key(col) {
247                                dependencies.push((col.clone(), aliases[col].clone()));
248                            }
249                        }
250                    }
251                }
252            }
253        }
254
255        // Check if QUALIFY clause references any window function aliases
256        // QUALIFY is designed to filter on window function results, so if it references
257        // an alias that's a window function, we need to lift that window function to a CTE
258        if let Some(ref qualify_expr) = stmt.qualify {
259            tracing::debug!("Checking QUALIFY clause for window function aliases");
260            let qualify_column_refs = extract_column_references(qualify_expr);
261
262            for col_name in qualify_column_refs {
263                tracing::debug!("QUALIFY references column: {}", col_name);
264                if let Some(expr) = aliases.get(&col_name) {
265                    // Check if this alias is a window function
266                    if matches!(expr, SqlExpression::WindowFunction { .. }) {
267                        tracing::debug!(
268                            "QUALIFY references window function alias: {} -> {:?}",
269                            col_name,
270                            expr
271                        );
272                        dependencies.push((col_name.clone(), expr.clone()));
273                    }
274                }
275            }
276        }
277
278        // Remove duplicates
279        dependencies.sort_by(|a, b| a.0.cmp(&b.0));
280        dependencies.dedup_by(|a, b| a.0 == b.0);
281
282        dependencies
283    }
284
285    /// Lift column aliases to a CTE when they're used in the same SELECT
286    fn lift_column_aliases(
287        &mut self,
288        stmt: &mut SelectStatement,
289        deps: &[(String, SqlExpression)],
290    ) -> CTE {
291        let cte_name = self.next_cte_name();
292
293        // Build CTE that computes the aliased columns
294        let mut cte_select_items = vec![SelectItem::Star {
295            table_prefix: None,
296            leading_comments: vec![],
297            trailing_comment: None,
298        }];
299        for (alias, expr) in deps {
300            cte_select_items.push(SelectItem::Expression {
301                expr: expr.clone(),
302                alias: alias.clone(),
303                leading_comments: vec![],
304                trailing_comment: None,
305            });
306        }
307
308        let cte_select = SelectStatement {
309            distinct: false,
310            columns: vec!["*".to_string()],
311            select_items: cte_select_items,
312            from_source: stmt.from_source.clone(),
313            #[allow(deprecated)]
314            from_table: stmt.from_table.clone(),
315            #[allow(deprecated)]
316            from_subquery: stmt.from_subquery.clone(),
317            #[allow(deprecated)]
318            from_function: stmt.from_function.clone(),
319            #[allow(deprecated)]
320            from_alias: stmt.from_alias.clone(),
321            joins: stmt.joins.clone(),
322            where_clause: stmt.where_clause.clone(),
323            order_by: None,
324            group_by: None,
325            having: None,
326            limit: None,
327            offset: None,
328            ctes: Vec::new(),
329            into_table: None,
330            set_operations: Vec::new(),
331            leading_comments: vec![],
332            trailing_comment: None,
333            qualify: None,
334        };
335
336        // Update the main query to use simple column references
337        let mut new_select_items = Vec::new();
338        for item in &stmt.select_items {
339            match item {
340                SelectItem::Expression { expr: _, alias, .. }
341                    if deps.iter().any(|(a, _)| a == alias) =>
342                {
343                    // Replace with simple column reference
344                    new_select_items.push(SelectItem::Column {
345                        column: ColumnRef::unquoted(alias.clone()),
346                        leading_comments: vec![],
347                        trailing_comment: None,
348                    });
349                }
350                _ => {
351                    new_select_items.push(item.clone());
352                }
353            }
354        }
355
356        stmt.select_items = new_select_items;
357        // Set from_source to reference the CTE (preferred)
358        stmt.from_source = Some(crate::sql::parser::ast::TableSource::Table(
359            cte_name.clone(),
360        ));
361        // Also set deprecated field for backward compatibility
362        #[allow(deprecated)]
363        {
364            stmt.from_table = Some(cte_name.clone());
365            stmt.from_subquery = None;
366        }
367        stmt.where_clause = None; // Already in the CTE
368
369        CTE {
370            name: cte_name,
371            column_list: None,
372            cte_type: CTEType::Standard(cte_select),
373        }
374    }
375
376    /// Create work units for lifted expressions
377    pub fn create_work_units_for_lifted(
378        &mut self,
379        lifted_ctes: &[CTE],
380        plan: &mut QueryPlan,
381    ) -> Vec<String> {
382        let mut cte_ids = Vec::new();
383
384        for cte in lifted_ctes {
385            let unit_id = format!("cte_{}", cte.name);
386
387            let work_unit = WorkUnit {
388                id: unit_id.clone(),
389                work_type: WorkUnitType::CTE,
390                expression: match &cte.cte_type {
391                    CTEType::Standard(select) => WorkUnitExpression::Select(select.clone()),
392                    CTEType::Web(_) => WorkUnitExpression::Custom("WEB CTE".to_string()),
393                    CTEType::File(_) => WorkUnitExpression::Custom("FILE CTE".to_string()),
394                },
395                dependencies: Vec::new(), // CTEs typically don't depend on each other initially
396                parallelizable: true,     // CTEs can often be computed in parallel
397                cost_estimate: None,
398            };
399
400            plan.add_unit(work_unit);
401            cte_ids.push(unit_id);
402        }
403
404        cte_ids
405    }
406}
407
408/// Represents an expression that can be lifted to a CTE
409#[derive(Debug)]
410pub struct LiftableExpression {
411    /// The expression to lift
412    pub expression: SqlExpression,
413
414    /// Suggested name for the CTE
415    pub suggested_name: String,
416
417    /// Dependencies on other CTEs or tables
418    pub dependencies: Vec<String>,
419}
420
421/// Analyze dependencies between expressions
422/// Extract all column references from an expression (used for QUALIFY analysis)
423fn extract_column_references(expr: &SqlExpression) -> HashSet<String> {
424    let mut refs = HashSet::new();
425
426    match expr {
427        SqlExpression::Column(col_ref) => {
428            refs.insert(col_ref.name.clone());
429        }
430
431        SqlExpression::BinaryOp { left, right, .. } => {
432            refs.extend(extract_column_references(left));
433            refs.extend(extract_column_references(right));
434        }
435
436        SqlExpression::Not { expr } => {
437            refs.extend(extract_column_references(expr));
438        }
439
440        SqlExpression::Between { expr, lower, upper } => {
441            refs.extend(extract_column_references(expr));
442            refs.extend(extract_column_references(lower));
443            refs.extend(extract_column_references(upper));
444        }
445
446        SqlExpression::InList { expr, values } | SqlExpression::NotInList { expr, values } => {
447            refs.extend(extract_column_references(expr));
448            for val in values {
449                refs.extend(extract_column_references(val));
450            }
451        }
452
453        SqlExpression::FunctionCall { args, .. } | SqlExpression::WindowFunction { args, .. } => {
454            for arg in args {
455                refs.extend(extract_column_references(arg));
456            }
457        }
458
459        SqlExpression::CaseExpression {
460            when_branches,
461            else_branch,
462        } => {
463            for branch in when_branches {
464                refs.extend(extract_column_references(&branch.condition));
465                refs.extend(extract_column_references(&branch.result));
466            }
467            if let Some(else_expr) = else_branch {
468                refs.extend(extract_column_references(else_expr));
469            }
470        }
471
472        SqlExpression::SimpleCaseExpression {
473            expr,
474            when_branches,
475            else_branch,
476        } => {
477            refs.extend(extract_column_references(expr));
478            for branch in when_branches {
479                refs.extend(extract_column_references(&branch.value));
480                refs.extend(extract_column_references(&branch.result));
481            }
482            if let Some(else_expr) = else_branch {
483                refs.extend(extract_column_references(else_expr));
484            }
485        }
486
487        // Literals and other expressions don't contain column references
488        _ => {}
489    }
490
491    refs
492}
493
494pub fn analyze_dependencies(expr: &SqlExpression) -> HashSet<String> {
495    let mut deps = HashSet::new();
496
497    match expr {
498        SqlExpression::Column(col) => {
499            deps.insert(col.name.clone());
500        }
501
502        SqlExpression::FunctionCall { args, .. } => {
503            for arg in args {
504                deps.extend(analyze_dependencies(arg));
505            }
506        }
507
508        SqlExpression::WindowFunction {
509            args, window_spec, ..
510        } => {
511            for arg in args {
512                deps.extend(analyze_dependencies(arg));
513            }
514
515            // Add partition and order columns as dependencies
516            for col in &window_spec.partition_by {
517                deps.insert(col.clone());
518            }
519
520            for order_col in &window_spec.order_by {
521                // Extract column name from expression
522                if let SqlExpression::Column(col_ref) = &order_col.expr {
523                    deps.insert(col_ref.name.clone());
524                }
525            }
526        }
527
528        SqlExpression::BinaryOp { left, right, .. } => {
529            deps.extend(analyze_dependencies(left));
530            deps.extend(analyze_dependencies(right));
531        }
532
533        SqlExpression::CaseExpression {
534            when_branches,
535            else_branch,
536        } => {
537            for branch in when_branches {
538                deps.extend(analyze_dependencies(&branch.condition));
539                deps.extend(analyze_dependencies(&branch.result));
540            }
541
542            if let Some(else_expr) = else_branch {
543                deps.extend(analyze_dependencies(else_expr));
544            }
545        }
546
547        SqlExpression::SimpleCaseExpression {
548            expr,
549            when_branches,
550            else_branch,
551        } => {
552            deps.extend(analyze_dependencies(expr));
553
554            for branch in when_branches {
555                deps.extend(analyze_dependencies(&branch.value));
556                deps.extend(analyze_dependencies(&branch.result));
557            }
558
559            if let Some(else_expr) = else_branch {
560                deps.extend(analyze_dependencies(else_expr));
561            }
562        }
563
564        _ => {}
565    }
566
567    deps
568}
569
570#[cfg(test)]
571mod tests {
572    use super::*;
573
574    #[test]
575    fn test_needs_lifting_window_function() {
576        let lifter = ExpressionLifter::new();
577
578        let window_expr = SqlExpression::WindowFunction {
579            name: "ROW_NUMBER".to_string(),
580            args: vec![],
581            window_spec: crate::sql::parser::ast::WindowSpec {
582                partition_by: vec![],
583                order_by: vec![],
584                frame: None,
585            },
586        };
587
588        assert!(lifter.needs_lifting(&window_expr));
589    }
590
591    #[test]
592    fn test_needs_lifting_simple_expression() {
593        let lifter = ExpressionLifter::new();
594
595        let simple_expr = SqlExpression::BinaryOp {
596            left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
597                "col1".to_string(),
598            ))),
599            op: "=".to_string(),
600            right: Box::new(SqlExpression::NumberLiteral("42".to_string())),
601        };
602
603        assert!(!lifter.needs_lifting(&simple_expr));
604    }
605
606    #[test]
607    fn test_analyze_dependencies() {
608        let expr = SqlExpression::BinaryOp {
609            left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
610                "col1".to_string(),
611            ))),
612            op: "+".to_string(),
613            right: Box::new(SqlExpression::Column(ColumnRef::unquoted(
614                "col2".to_string(),
615            ))),
616        };
617
618        let deps = analyze_dependencies(&expr);
619        assert!(deps.contains("col1"));
620        assert!(deps.contains("col2"));
621        assert_eq!(deps.len(), 2);
622    }
623}