1use crate::query_plan::{QueryPlan, WorkUnit, WorkUnitExpression, WorkUnitType};
2use crate::sql::parser::ast::{
3 CTEType, ColumnRef, SelectItem, SelectStatement, SqlExpression, WhereClause, CTE,
4};
5use std::collections::HashSet;
6
7pub struct ExpressionLifter {
9 cte_counter: usize,
11
12 liftable_functions: HashSet<String>,
14}
15
16impl ExpressionLifter {
17 pub fn new() -> Self {
19 let mut liftable_functions = HashSet::new();
20
21 liftable_functions.insert("ROW_NUMBER".to_string());
23 liftable_functions.insert("RANK".to_string());
24 liftable_functions.insert("DENSE_RANK".to_string());
25 liftable_functions.insert("LAG".to_string());
26 liftable_functions.insert("LEAD".to_string());
27 liftable_functions.insert("FIRST_VALUE".to_string());
28 liftable_functions.insert("LAST_VALUE".to_string());
29 liftable_functions.insert("NTH_VALUE".to_string());
30
31 liftable_functions.insert("PERCENTILE_CONT".to_string());
33 liftable_functions.insert("PERCENTILE_DISC".to_string());
34
35 ExpressionLifter {
36 cte_counter: 0,
37 liftable_functions,
38 }
39 }
40
41 fn next_cte_name(&mut self) -> String {
43 self.cte_counter += 1;
44 format!("__lifted_{}", self.cte_counter)
45 }
46
47 pub fn needs_lifting(&self, expr: &SqlExpression) -> bool {
49 match expr {
50 SqlExpression::WindowFunction { .. } => true,
51
52 SqlExpression::FunctionCall { name, .. } => {
53 self.liftable_functions.contains(&name.to_uppercase())
54 }
55
56 SqlExpression::BinaryOp { left, right, .. } => {
57 self.needs_lifting(left) || self.needs_lifting(right)
58 }
59
60 SqlExpression::Not { expr } => self.needs_lifting(expr),
61
62 SqlExpression::InList { expr, values } => {
63 self.needs_lifting(expr) || values.iter().any(|v| self.needs_lifting(v))
64 }
65
66 SqlExpression::NotInList { expr, values } => {
67 self.needs_lifting(expr) || values.iter().any(|v| self.needs_lifting(v))
68 }
69
70 SqlExpression::Between { expr, lower, upper } => {
71 self.needs_lifting(expr) || self.needs_lifting(lower) || self.needs_lifting(upper)
72 }
73
74 SqlExpression::CaseExpression {
75 when_branches,
76 else_branch,
77 } => {
78 when_branches.iter().any(|branch| {
79 self.needs_lifting(&branch.condition) || self.needs_lifting(&branch.result)
80 }) || else_branch
81 .as_ref()
82 .map_or(false, |e| self.needs_lifting(e))
83 }
84
85 SqlExpression::SimpleCaseExpression {
86 expr,
87 when_branches,
88 else_branch,
89 } => {
90 self.needs_lifting(expr)
91 || when_branches.iter().any(|branch| {
92 self.needs_lifting(&branch.value) || self.needs_lifting(&branch.result)
93 })
94 || else_branch
95 .as_ref()
96 .map_or(false, |e| self.needs_lifting(e))
97 }
98
99 _ => false,
100 }
101 }
102
103 pub fn analyze_where_clause(&mut self, where_clause: &WhereClause) -> Vec<LiftableExpression> {
105 let mut liftable = Vec::new();
106
107 for condition in &where_clause.conditions {
109 if self.needs_lifting(&condition.expr) {
110 liftable.push(LiftableExpression {
111 expression: condition.expr.clone(),
112 suggested_name: self.next_cte_name(),
113 dependencies: Vec::new(), });
115 }
116 }
117
118 liftable
119 }
120
121 pub fn lift_expressions(&mut self, stmt: &mut SelectStatement) -> Vec<CTE> {
123 let mut lifted_ctes = Vec::new();
124
125 let alias_deps = self.analyze_column_alias_dependencies(stmt);
127 if !alias_deps.is_empty() {
128 let cte = self.lift_column_aliases(stmt, &alias_deps);
129 lifted_ctes.push(cte);
130 }
131
132 if let Some(ref where_clause) = stmt.where_clause {
134 let liftable = self.analyze_where_clause(where_clause);
135
136 for lift_expr in liftable {
137 let cte_select = SelectStatement {
139 distinct: false,
140 columns: vec!["*".to_string()],
141 select_items: vec![
142 SelectItem::Star {
143 table_prefix: None,
144 leading_comments: vec![],
145 trailing_comment: None,
146 },
147 SelectItem::Expression {
148 expr: lift_expr.expression.clone(),
149 alias: "lifted_value".to_string(),
150 leading_comments: vec![],
151 trailing_comment: None,
152 },
153 ],
154 from_source: stmt.from_source.clone(),
155 #[allow(deprecated)]
156 from_table: stmt.from_table.clone(),
157 #[allow(deprecated)]
158 from_subquery: stmt.from_subquery.clone(),
159 #[allow(deprecated)]
160 from_function: stmt.from_function.clone(),
161 #[allow(deprecated)]
162 from_alias: stmt.from_alias.clone(),
163 joins: stmt.joins.clone(),
164 where_clause: None, qualify: None,
166 order_by: None,
167 group_by: None,
168 having: None,
169 limit: None,
170 offset: None,
171 ctes: Vec::new(),
172 into_table: None,
173 set_operations: Vec::new(),
174 leading_comments: vec![],
175 trailing_comment: None,
176 };
177
178 let cte = CTE {
179 name: lift_expr.suggested_name.clone(),
180 column_list: None,
181 cte_type: CTEType::Standard(cte_select),
182 };
183
184 lifted_ctes.push(cte);
185
186 stmt.from_table = Some(lift_expr.suggested_name);
188
189 use crate::sql::parser::ast::Condition;
191 stmt.where_clause = Some(WhereClause {
192 conditions: vec![Condition {
193 expr: SqlExpression::Column(ColumnRef::unquoted(
194 "lifted_value".to_string(),
195 )),
196 connector: None,
197 }],
198 });
199 }
200 }
201
202 stmt.ctes.extend(lifted_ctes.clone());
204
205 lifted_ctes
206 }
207
208 fn analyze_column_alias_dependencies(
210 &self,
211 stmt: &SelectStatement,
212 ) -> Vec<(String, SqlExpression)> {
213 let mut dependencies = Vec::new();
214
215 let mut aliases = std::collections::HashMap::new();
217 for item in &stmt.select_items {
218 if let SelectItem::Expression { expr, alias, .. } = item {
219 aliases.insert(alias.clone(), expr.clone());
220 tracing::debug!("Found alias: {} -> {:?}", alias, expr);
221 }
222 }
223
224 for item in &stmt.select_items {
226 if let SelectItem::Expression { expr, .. } = item {
227 if let SqlExpression::WindowFunction { window_spec, .. } = expr {
228 for col in &window_spec.partition_by {
230 tracing::debug!("Checking PARTITION BY column: {}", col);
231 if aliases.contains_key(col) {
232 tracing::debug!(
233 "Found dependency: {} depends on {:?}",
234 col,
235 aliases[col]
236 );
237 dependencies.push((col.clone(), aliases[col].clone()));
238 }
239 }
240
241 for order_col in &window_spec.order_by {
243 if let SqlExpression::Column(col_ref) = &order_col.expr {
245 let col = &col_ref.name;
246 if aliases.contains_key(col) {
247 dependencies.push((col.clone(), aliases[col].clone()));
248 }
249 }
250 }
251 }
252 }
253 }
254
255 if let Some(ref qualify_expr) = stmt.qualify {
259 tracing::debug!("Checking QUALIFY clause for window function aliases");
260 let qualify_column_refs = extract_column_references(qualify_expr);
261
262 for col_name in qualify_column_refs {
263 tracing::debug!("QUALIFY references column: {}", col_name);
264 if let Some(expr) = aliases.get(&col_name) {
265 if matches!(expr, SqlExpression::WindowFunction { .. }) {
267 tracing::debug!(
268 "QUALIFY references window function alias: {} -> {:?}",
269 col_name,
270 expr
271 );
272 dependencies.push((col_name.clone(), expr.clone()));
273 }
274 }
275 }
276 }
277
278 dependencies.sort_by(|a, b| a.0.cmp(&b.0));
280 dependencies.dedup_by(|a, b| a.0 == b.0);
281
282 dependencies
283 }
284
285 fn lift_column_aliases(
287 &mut self,
288 stmt: &mut SelectStatement,
289 deps: &[(String, SqlExpression)],
290 ) -> CTE {
291 let cte_name = self.next_cte_name();
292
293 let mut cte_select_items = vec![SelectItem::Star {
295 table_prefix: None,
296 leading_comments: vec![],
297 trailing_comment: None,
298 }];
299 for (alias, expr) in deps {
300 cte_select_items.push(SelectItem::Expression {
301 expr: expr.clone(),
302 alias: alias.clone(),
303 leading_comments: vec![],
304 trailing_comment: None,
305 });
306 }
307
308 let cte_select = SelectStatement {
309 distinct: false,
310 columns: vec!["*".to_string()],
311 select_items: cte_select_items,
312 from_source: stmt.from_source.clone(),
313 #[allow(deprecated)]
314 from_table: stmt.from_table.clone(),
315 #[allow(deprecated)]
316 from_subquery: stmt.from_subquery.clone(),
317 #[allow(deprecated)]
318 from_function: stmt.from_function.clone(),
319 #[allow(deprecated)]
320 from_alias: stmt.from_alias.clone(),
321 joins: stmt.joins.clone(),
322 where_clause: stmt.where_clause.clone(),
323 order_by: None,
324 group_by: None,
325 having: None,
326 limit: None,
327 offset: None,
328 ctes: Vec::new(),
329 into_table: None,
330 set_operations: Vec::new(),
331 leading_comments: vec![],
332 trailing_comment: None,
333 qualify: None,
334 };
335
336 let mut new_select_items = Vec::new();
338 for item in &stmt.select_items {
339 match item {
340 SelectItem::Expression { expr: _, alias, .. }
341 if deps.iter().any(|(a, _)| a == alias) =>
342 {
343 new_select_items.push(SelectItem::Column {
345 column: ColumnRef::unquoted(alias.clone()),
346 leading_comments: vec![],
347 trailing_comment: None,
348 });
349 }
350 _ => {
351 new_select_items.push(item.clone());
352 }
353 }
354 }
355
356 stmt.select_items = new_select_items;
357 stmt.from_source = Some(crate::sql::parser::ast::TableSource::Table(
359 cte_name.clone(),
360 ));
361 #[allow(deprecated)]
363 {
364 stmt.from_table = Some(cte_name.clone());
365 stmt.from_subquery = None;
366 }
367 stmt.where_clause = None; CTE {
370 name: cte_name,
371 column_list: None,
372 cte_type: CTEType::Standard(cte_select),
373 }
374 }
375
376 pub fn create_work_units_for_lifted(
378 &mut self,
379 lifted_ctes: &[CTE],
380 plan: &mut QueryPlan,
381 ) -> Vec<String> {
382 let mut cte_ids = Vec::new();
383
384 for cte in lifted_ctes {
385 let unit_id = format!("cte_{}", cte.name);
386
387 let work_unit = WorkUnit {
388 id: unit_id.clone(),
389 work_type: WorkUnitType::CTE,
390 expression: match &cte.cte_type {
391 CTEType::Standard(select) => WorkUnitExpression::Select(select.clone()),
392 CTEType::Web(_) => WorkUnitExpression::Custom("WEB CTE".to_string()),
393 CTEType::File(_) => WorkUnitExpression::Custom("FILE CTE".to_string()),
394 },
395 dependencies: Vec::new(), parallelizable: true, cost_estimate: None,
398 };
399
400 plan.add_unit(work_unit);
401 cte_ids.push(unit_id);
402 }
403
404 cte_ids
405 }
406}
407
408#[derive(Debug)]
410pub struct LiftableExpression {
411 pub expression: SqlExpression,
413
414 pub suggested_name: String,
416
417 pub dependencies: Vec<String>,
419}
420
421fn extract_column_references(expr: &SqlExpression) -> HashSet<String> {
424 let mut refs = HashSet::new();
425
426 match expr {
427 SqlExpression::Column(col_ref) => {
428 refs.insert(col_ref.name.clone());
429 }
430
431 SqlExpression::BinaryOp { left, right, .. } => {
432 refs.extend(extract_column_references(left));
433 refs.extend(extract_column_references(right));
434 }
435
436 SqlExpression::Not { expr } => {
437 refs.extend(extract_column_references(expr));
438 }
439
440 SqlExpression::Between { expr, lower, upper } => {
441 refs.extend(extract_column_references(expr));
442 refs.extend(extract_column_references(lower));
443 refs.extend(extract_column_references(upper));
444 }
445
446 SqlExpression::InList { expr, values } | SqlExpression::NotInList { expr, values } => {
447 refs.extend(extract_column_references(expr));
448 for val in values {
449 refs.extend(extract_column_references(val));
450 }
451 }
452
453 SqlExpression::FunctionCall { args, .. } | SqlExpression::WindowFunction { args, .. } => {
454 for arg in args {
455 refs.extend(extract_column_references(arg));
456 }
457 }
458
459 SqlExpression::CaseExpression {
460 when_branches,
461 else_branch,
462 } => {
463 for branch in when_branches {
464 refs.extend(extract_column_references(&branch.condition));
465 refs.extend(extract_column_references(&branch.result));
466 }
467 if let Some(else_expr) = else_branch {
468 refs.extend(extract_column_references(else_expr));
469 }
470 }
471
472 SqlExpression::SimpleCaseExpression {
473 expr,
474 when_branches,
475 else_branch,
476 } => {
477 refs.extend(extract_column_references(expr));
478 for branch in when_branches {
479 refs.extend(extract_column_references(&branch.value));
480 refs.extend(extract_column_references(&branch.result));
481 }
482 if let Some(else_expr) = else_branch {
483 refs.extend(extract_column_references(else_expr));
484 }
485 }
486
487 _ => {}
489 }
490
491 refs
492}
493
494pub fn analyze_dependencies(expr: &SqlExpression) -> HashSet<String> {
495 let mut deps = HashSet::new();
496
497 match expr {
498 SqlExpression::Column(col) => {
499 deps.insert(col.name.clone());
500 }
501
502 SqlExpression::FunctionCall { args, .. } => {
503 for arg in args {
504 deps.extend(analyze_dependencies(arg));
505 }
506 }
507
508 SqlExpression::WindowFunction {
509 args, window_spec, ..
510 } => {
511 for arg in args {
512 deps.extend(analyze_dependencies(arg));
513 }
514
515 for col in &window_spec.partition_by {
517 deps.insert(col.clone());
518 }
519
520 for order_col in &window_spec.order_by {
521 if let SqlExpression::Column(col_ref) = &order_col.expr {
523 deps.insert(col_ref.name.clone());
524 }
525 }
526 }
527
528 SqlExpression::BinaryOp { left, right, .. } => {
529 deps.extend(analyze_dependencies(left));
530 deps.extend(analyze_dependencies(right));
531 }
532
533 SqlExpression::CaseExpression {
534 when_branches,
535 else_branch,
536 } => {
537 for branch in when_branches {
538 deps.extend(analyze_dependencies(&branch.condition));
539 deps.extend(analyze_dependencies(&branch.result));
540 }
541
542 if let Some(else_expr) = else_branch {
543 deps.extend(analyze_dependencies(else_expr));
544 }
545 }
546
547 SqlExpression::SimpleCaseExpression {
548 expr,
549 when_branches,
550 else_branch,
551 } => {
552 deps.extend(analyze_dependencies(expr));
553
554 for branch in when_branches {
555 deps.extend(analyze_dependencies(&branch.value));
556 deps.extend(analyze_dependencies(&branch.result));
557 }
558
559 if let Some(else_expr) = else_branch {
560 deps.extend(analyze_dependencies(else_expr));
561 }
562 }
563
564 _ => {}
565 }
566
567 deps
568}
569
570#[cfg(test)]
571mod tests {
572 use super::*;
573
574 #[test]
575 fn test_needs_lifting_window_function() {
576 let lifter = ExpressionLifter::new();
577
578 let window_expr = SqlExpression::WindowFunction {
579 name: "ROW_NUMBER".to_string(),
580 args: vec![],
581 window_spec: crate::sql::parser::ast::WindowSpec {
582 partition_by: vec![],
583 order_by: vec![],
584 frame: None,
585 },
586 };
587
588 assert!(lifter.needs_lifting(&window_expr));
589 }
590
591 #[test]
592 fn test_needs_lifting_simple_expression() {
593 let lifter = ExpressionLifter::new();
594
595 let simple_expr = SqlExpression::BinaryOp {
596 left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
597 "col1".to_string(),
598 ))),
599 op: "=".to_string(),
600 right: Box::new(SqlExpression::NumberLiteral("42".to_string())),
601 };
602
603 assert!(!lifter.needs_lifting(&simple_expr));
604 }
605
606 #[test]
607 fn test_analyze_dependencies() {
608 let expr = SqlExpression::BinaryOp {
609 left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
610 "col1".to_string(),
611 ))),
612 op: "+".to_string(),
613 right: Box::new(SqlExpression::Column(ColumnRef::unquoted(
614 "col2".to_string(),
615 ))),
616 };
617
618 let deps = analyze_dependencies(&expr);
619 assert!(deps.contains("col1"));
620 assert!(deps.contains("col2"));
621 assert_eq!(deps.len(), 2);
622 }
623}