sql_cli/query_plan/
expression_lifter.rs

1use crate::query_plan::{QueryPlan, WorkUnit, WorkUnitExpression, WorkUnitType};
2use crate::sql::parser::ast::{
3    CTEType, ColumnRef, SelectItem, SelectStatement, SqlExpression, WhereClause, CTE,
4};
5use std::collections::HashSet;
6
7/// Expression lifter that identifies and lifts non-supported WHERE expressions to CTEs
8pub struct ExpressionLifter {
9    /// Counter for generating unique CTE names
10    cte_counter: usize,
11
12    /// Set of function names that need to be lifted
13    liftable_functions: HashSet<String>,
14}
15
16impl ExpressionLifter {
17    /// Create a new expression lifter
18    pub fn new() -> Self {
19        let mut liftable_functions = HashSet::new();
20
21        // Add window functions that need lifting
22        liftable_functions.insert("ROW_NUMBER".to_string());
23        liftable_functions.insert("RANK".to_string());
24        liftable_functions.insert("DENSE_RANK".to_string());
25        liftable_functions.insert("LAG".to_string());
26        liftable_functions.insert("LEAD".to_string());
27        liftable_functions.insert("FIRST_VALUE".to_string());
28        liftable_functions.insert("LAST_VALUE".to_string());
29        liftable_functions.insert("NTH_VALUE".to_string());
30
31        // Add aggregate functions that might need lifting in certain contexts
32        liftable_functions.insert("PERCENTILE_CONT".to_string());
33        liftable_functions.insert("PERCENTILE_DISC".to_string());
34
35        ExpressionLifter {
36            cte_counter: 0,
37            liftable_functions,
38        }
39    }
40
41    /// Generate a unique CTE name
42    fn next_cte_name(&mut self) -> String {
43        self.cte_counter += 1;
44        format!("__lifted_{}", self.cte_counter)
45    }
46
47    /// Check if an expression needs to be lifted
48    pub fn needs_lifting(&self, expr: &SqlExpression) -> bool {
49        match expr {
50            SqlExpression::WindowFunction { .. } => true,
51
52            SqlExpression::FunctionCall { name, .. } => {
53                self.liftable_functions.contains(&name.to_uppercase())
54            }
55
56            SqlExpression::BinaryOp { left, right, .. } => {
57                self.needs_lifting(left) || self.needs_lifting(right)
58            }
59
60            SqlExpression::Not { expr } => self.needs_lifting(expr),
61
62            SqlExpression::InList { expr, values } => {
63                self.needs_lifting(expr) || values.iter().any(|v| self.needs_lifting(v))
64            }
65
66            SqlExpression::NotInList { expr, values } => {
67                self.needs_lifting(expr) || values.iter().any(|v| self.needs_lifting(v))
68            }
69
70            SqlExpression::Between { expr, lower, upper } => {
71                self.needs_lifting(expr) || self.needs_lifting(lower) || self.needs_lifting(upper)
72            }
73
74            SqlExpression::CaseExpression {
75                when_branches,
76                else_branch,
77            } => {
78                when_branches.iter().any(|branch| {
79                    self.needs_lifting(&branch.condition) || self.needs_lifting(&branch.result)
80                }) || else_branch
81                    .as_ref()
82                    .map_or(false, |e| self.needs_lifting(e))
83            }
84
85            SqlExpression::SimpleCaseExpression {
86                expr,
87                when_branches,
88                else_branch,
89            } => {
90                self.needs_lifting(expr)
91                    || when_branches.iter().any(|branch| {
92                        self.needs_lifting(&branch.value) || self.needs_lifting(&branch.result)
93                    })
94                    || else_branch
95                        .as_ref()
96                        .map_or(false, |e| self.needs_lifting(e))
97            }
98
99            _ => false,
100        }
101    }
102
103    /// Analyze WHERE clause and identify expressions to lift
104    pub fn analyze_where_clause(&mut self, where_clause: &WhereClause) -> Vec<LiftableExpression> {
105        let mut liftable = Vec::new();
106
107        // Analyze each condition in the WHERE clause
108        for condition in &where_clause.conditions {
109            if self.needs_lifting(&condition.expr) {
110                liftable.push(LiftableExpression {
111                    expression: condition.expr.clone(),
112                    suggested_name: self.next_cte_name(),
113                    dependencies: Vec::new(), // TODO: Analyze dependencies
114                });
115            }
116        }
117
118        liftable
119    }
120
121    /// Lift expressions from a SELECT statement
122    pub fn lift_expressions(&mut self, stmt: &mut SelectStatement) -> Vec<CTE> {
123        let mut lifted_ctes = Vec::new();
124
125        // First, check for column alias dependencies (e.g., using alias in PARTITION BY)
126        let alias_deps = self.analyze_column_alias_dependencies(stmt);
127        if !alias_deps.is_empty() {
128            let cte = self.lift_column_aliases(stmt, &alias_deps);
129            lifted_ctes.push(cte);
130        }
131
132        // Check WHERE clause for liftable expressions
133        if let Some(ref where_clause) = stmt.where_clause {
134            let liftable = self.analyze_where_clause(where_clause);
135
136            for lift_expr in liftable {
137                // Create a CTE that includes the lifted expression as a computed column
138                let cte_select = SelectStatement {
139                    distinct: false,
140                    columns: vec!["*".to_string()],
141                    select_items: vec![
142                        SelectItem::Star {
143                            leading_comments: vec![],
144                            trailing_comment: None,
145                        },
146                        SelectItem::Expression {
147                            expr: lift_expr.expression.clone(),
148                            alias: "lifted_value".to_string(),
149                            leading_comments: vec![],
150                            trailing_comment: None,
151                        },
152                    ],
153                    from_table: stmt.from_table.clone(),
154                    from_subquery: stmt.from_subquery.clone(),
155                    from_function: stmt.from_function.clone(),
156                    from_alias: stmt.from_alias.clone(),
157                    joins: stmt.joins.clone(),
158                    where_clause: None, // Move simpler parts of WHERE here if possible
159                    order_by: None,
160                    group_by: None,
161                    having: None,
162                    limit: None,
163                    offset: None,
164                    ctes: Vec::new(),
165                    into_table: None,
166                    set_operations: Vec::new(),
167                    leading_comments: vec![],
168                    trailing_comment: None,
169                };
170
171                let cte = CTE {
172                    name: lift_expr.suggested_name.clone(),
173                    column_list: None,
174                    cte_type: CTEType::Standard(cte_select),
175                };
176
177                lifted_ctes.push(cte);
178
179                // Update the main query to reference the CTE
180                stmt.from_table = Some(lift_expr.suggested_name);
181
182                // Replace the complex WHERE expression with a simple column reference
183                use crate::sql::parser::ast::Condition;
184                stmt.where_clause = Some(WhereClause {
185                    conditions: vec![Condition {
186                        expr: SqlExpression::Column(ColumnRef::unquoted(
187                            "lifted_value".to_string(),
188                        )),
189                        connector: None,
190                    }],
191                });
192            }
193        }
194
195        // Add lifted CTEs to the statement
196        stmt.ctes.extend(lifted_ctes.clone());
197
198        lifted_ctes
199    }
200
201    /// Analyze column alias dependencies (e.g., alias used in PARTITION BY)
202    fn analyze_column_alias_dependencies(
203        &self,
204        stmt: &SelectStatement,
205    ) -> Vec<(String, SqlExpression)> {
206        let mut dependencies = Vec::new();
207
208        // Extract all aliases defined in SELECT
209        let mut aliases = std::collections::HashMap::new();
210        for item in &stmt.select_items {
211            if let SelectItem::Expression { expr, alias, .. } = item {
212                aliases.insert(alias.clone(), expr.clone());
213                tracing::debug!("Found alias: {} -> {:?}", alias, expr);
214            }
215        }
216
217        // Check if any aliases are used in window functions
218        for item in &stmt.select_items {
219            if let SelectItem::Expression { expr, .. } = item {
220                if let SqlExpression::WindowFunction { window_spec, .. } = expr {
221                    // Check PARTITION BY
222                    for col in &window_spec.partition_by {
223                        tracing::debug!("Checking PARTITION BY column: {}", col);
224                        if aliases.contains_key(col) {
225                            tracing::debug!(
226                                "Found dependency: {} depends on {:?}",
227                                col,
228                                aliases[col]
229                            );
230                            dependencies.push((col.clone(), aliases[col].clone()));
231                        }
232                    }
233
234                    // Check ORDER BY
235                    for order_col in &window_spec.order_by {
236                        let col = &order_col.column;
237                        if aliases.contains_key(col) {
238                            dependencies.push((col.clone(), aliases[col].clone()));
239                        }
240                    }
241                }
242            }
243        }
244
245        // Remove duplicates
246        dependencies.sort_by(|a, b| a.0.cmp(&b.0));
247        dependencies.dedup_by(|a, b| a.0 == b.0);
248
249        dependencies
250    }
251
252    /// Lift column aliases to a CTE when they're used in the same SELECT
253    fn lift_column_aliases(
254        &mut self,
255        stmt: &mut SelectStatement,
256        deps: &[(String, SqlExpression)],
257    ) -> CTE {
258        let cte_name = self.next_cte_name();
259
260        // Build CTE that computes the aliased columns
261        let mut cte_select_items = vec![SelectItem::Star {
262            leading_comments: vec![],
263            trailing_comment: None,
264        }];
265        for (alias, expr) in deps {
266            cte_select_items.push(SelectItem::Expression {
267                expr: expr.clone(),
268                alias: alias.clone(),
269                leading_comments: vec![],
270                trailing_comment: None,
271            });
272        }
273
274        let cte_select = SelectStatement {
275            distinct: false,
276            columns: vec!["*".to_string()],
277            select_items: cte_select_items,
278            from_table: stmt.from_table.clone(),
279            from_subquery: stmt.from_subquery.clone(),
280            from_function: stmt.from_function.clone(),
281            from_alias: stmt.from_alias.clone(),
282            joins: stmt.joins.clone(),
283            where_clause: stmt.where_clause.clone(),
284            order_by: None,
285            group_by: None,
286            having: None,
287            limit: None,
288            offset: None,
289            ctes: Vec::new(),
290            into_table: None,
291            set_operations: Vec::new(),
292            leading_comments: vec![],
293            trailing_comment: None,
294        };
295
296        // Update the main query to use simple column references
297        let mut new_select_items = Vec::new();
298        for item in &stmt.select_items {
299            match item {
300                SelectItem::Expression { expr: _, alias, .. }
301                    if deps.iter().any(|(a, _)| a == alias) =>
302                {
303                    // Replace with simple column reference
304                    new_select_items.push(SelectItem::Column {
305                        column: ColumnRef::unquoted(alias.clone()),
306                        leading_comments: vec![],
307                        trailing_comment: None,
308                    });
309                }
310                _ => {
311                    new_select_items.push(item.clone());
312                }
313            }
314        }
315
316        stmt.select_items = new_select_items;
317        stmt.from_table = Some(cte_name.clone());
318        stmt.from_subquery = None;
319        stmt.where_clause = None; // Already in the CTE
320
321        CTE {
322            name: cte_name,
323            column_list: None,
324            cte_type: CTEType::Standard(cte_select),
325        }
326    }
327
328    /// Create work units for lifted expressions
329    pub fn create_work_units_for_lifted(
330        &mut self,
331        lifted_ctes: &[CTE],
332        plan: &mut QueryPlan,
333    ) -> Vec<String> {
334        let mut cte_ids = Vec::new();
335
336        for cte in lifted_ctes {
337            let unit_id = format!("cte_{}", cte.name);
338
339            let work_unit = WorkUnit {
340                id: unit_id.clone(),
341                work_type: WorkUnitType::CTE,
342                expression: match &cte.cte_type {
343                    CTEType::Standard(select) => WorkUnitExpression::Select(select.clone()),
344                    CTEType::Web(_) => WorkUnitExpression::Custom("WEB CTE".to_string()),
345                },
346                dependencies: Vec::new(), // CTEs typically don't depend on each other initially
347                parallelizable: true,     // CTEs can often be computed in parallel
348                cost_estimate: None,
349            };
350
351            plan.add_unit(work_unit);
352            cte_ids.push(unit_id);
353        }
354
355        cte_ids
356    }
357}
358
359/// Represents an expression that can be lifted to a CTE
360#[derive(Debug)]
361pub struct LiftableExpression {
362    /// The expression to lift
363    pub expression: SqlExpression,
364
365    /// Suggested name for the CTE
366    pub suggested_name: String,
367
368    /// Dependencies on other CTEs or tables
369    pub dependencies: Vec<String>,
370}
371
372/// Analyze dependencies between expressions
373pub fn analyze_dependencies(expr: &SqlExpression) -> HashSet<String> {
374    let mut deps = HashSet::new();
375
376    match expr {
377        SqlExpression::Column(col) => {
378            deps.insert(col.name.clone());
379        }
380
381        SqlExpression::FunctionCall { args, .. } => {
382            for arg in args {
383                deps.extend(analyze_dependencies(arg));
384            }
385        }
386
387        SqlExpression::WindowFunction {
388            args, window_spec, ..
389        } => {
390            for arg in args {
391                deps.extend(analyze_dependencies(arg));
392            }
393
394            // Add partition and order columns as dependencies
395            for col in &window_spec.partition_by {
396                deps.insert(col.clone());
397            }
398
399            for order_col in &window_spec.order_by {
400                deps.insert(order_col.column.clone());
401            }
402        }
403
404        SqlExpression::BinaryOp { left, right, .. } => {
405            deps.extend(analyze_dependencies(left));
406            deps.extend(analyze_dependencies(right));
407        }
408
409        SqlExpression::CaseExpression {
410            when_branches,
411            else_branch,
412        } => {
413            for branch in when_branches {
414                deps.extend(analyze_dependencies(&branch.condition));
415                deps.extend(analyze_dependencies(&branch.result));
416            }
417
418            if let Some(else_expr) = else_branch {
419                deps.extend(analyze_dependencies(else_expr));
420            }
421        }
422
423        SqlExpression::SimpleCaseExpression {
424            expr,
425            when_branches,
426            else_branch,
427        } => {
428            deps.extend(analyze_dependencies(expr));
429
430            for branch in when_branches {
431                deps.extend(analyze_dependencies(&branch.value));
432                deps.extend(analyze_dependencies(&branch.result));
433            }
434
435            if let Some(else_expr) = else_branch {
436                deps.extend(analyze_dependencies(else_expr));
437            }
438        }
439
440        _ => {}
441    }
442
443    deps
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449
450    #[test]
451    fn test_needs_lifting_window_function() {
452        let lifter = ExpressionLifter::new();
453
454        let window_expr = SqlExpression::WindowFunction {
455            name: "ROW_NUMBER".to_string(),
456            args: vec![],
457            window_spec: crate::sql::parser::ast::WindowSpec {
458                partition_by: vec![],
459                order_by: vec![],
460                frame: None,
461            },
462        };
463
464        assert!(lifter.needs_lifting(&window_expr));
465    }
466
467    #[test]
468    fn test_needs_lifting_simple_expression() {
469        let lifter = ExpressionLifter::new();
470
471        let simple_expr = SqlExpression::BinaryOp {
472            left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
473                "col1".to_string(),
474            ))),
475            op: "=".to_string(),
476            right: Box::new(SqlExpression::NumberLiteral("42".to_string())),
477        };
478
479        assert!(!lifter.needs_lifting(&simple_expr));
480    }
481
482    #[test]
483    fn test_analyze_dependencies() {
484        let expr = SqlExpression::BinaryOp {
485            left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
486                "col1".to_string(),
487            ))),
488            op: "+".to_string(),
489            right: Box::new(SqlExpression::Column(ColumnRef::unquoted(
490                "col2".to_string(),
491            ))),
492        };
493
494        let deps = analyze_dependencies(&expr);
495        assert!(deps.contains("col1"));
496        assert!(deps.contains("col2"));
497        assert_eq!(deps.len(), 2);
498    }
499}