1use crate::query_plan::{QueryPlan, WorkUnit, WorkUnitExpression, WorkUnitType};
2use crate::sql::parser::ast::{
3 CTEType, ColumnRef, SelectItem, SelectStatement, SqlExpression, WhereClause, CTE,
4};
5use std::collections::HashSet;
6
7pub struct ExpressionLifter {
9 cte_counter: usize,
11
12 liftable_functions: HashSet<String>,
14}
15
16impl ExpressionLifter {
17 pub fn new() -> Self {
19 let mut liftable_functions = HashSet::new();
20
21 liftable_functions.insert("ROW_NUMBER".to_string());
23 liftable_functions.insert("RANK".to_string());
24 liftable_functions.insert("DENSE_RANK".to_string());
25 liftable_functions.insert("LAG".to_string());
26 liftable_functions.insert("LEAD".to_string());
27 liftable_functions.insert("FIRST_VALUE".to_string());
28 liftable_functions.insert("LAST_VALUE".to_string());
29 liftable_functions.insert("NTH_VALUE".to_string());
30
31 liftable_functions.insert("PERCENTILE_CONT".to_string());
33 liftable_functions.insert("PERCENTILE_DISC".to_string());
34
35 ExpressionLifter {
36 cte_counter: 0,
37 liftable_functions,
38 }
39 }
40
41 fn next_cte_name(&mut self) -> String {
43 self.cte_counter += 1;
44 format!("__lifted_{}", self.cte_counter)
45 }
46
47 pub fn needs_lifting(&self, expr: &SqlExpression) -> bool {
49 match expr {
50 SqlExpression::WindowFunction { .. } => true,
51
52 SqlExpression::FunctionCall { name, .. } => {
53 self.liftable_functions.contains(&name.to_uppercase())
54 }
55
56 SqlExpression::BinaryOp { left, right, .. } => {
57 self.needs_lifting(left) || self.needs_lifting(right)
58 }
59
60 SqlExpression::Not { expr } => self.needs_lifting(expr),
61
62 SqlExpression::InList { expr, values } => {
63 self.needs_lifting(expr) || values.iter().any(|v| self.needs_lifting(v))
64 }
65
66 SqlExpression::NotInList { expr, values } => {
67 self.needs_lifting(expr) || values.iter().any(|v| self.needs_lifting(v))
68 }
69
70 SqlExpression::Between { expr, lower, upper } => {
71 self.needs_lifting(expr) || self.needs_lifting(lower) || self.needs_lifting(upper)
72 }
73
74 SqlExpression::CaseExpression {
75 when_branches,
76 else_branch,
77 } => {
78 when_branches.iter().any(|branch| {
79 self.needs_lifting(&branch.condition) || self.needs_lifting(&branch.result)
80 }) || else_branch
81 .as_ref()
82 .map_or(false, |e| self.needs_lifting(e))
83 }
84
85 SqlExpression::SimpleCaseExpression {
86 expr,
87 when_branches,
88 else_branch,
89 } => {
90 self.needs_lifting(expr)
91 || when_branches.iter().any(|branch| {
92 self.needs_lifting(&branch.value) || self.needs_lifting(&branch.result)
93 })
94 || else_branch
95 .as_ref()
96 .map_or(false, |e| self.needs_lifting(e))
97 }
98
99 _ => false,
100 }
101 }
102
103 pub fn analyze_where_clause(&mut self, where_clause: &WhereClause) -> Vec<LiftableExpression> {
105 let mut liftable = Vec::new();
106
107 for condition in &where_clause.conditions {
109 if self.needs_lifting(&condition.expr) {
110 liftable.push(LiftableExpression {
111 expression: condition.expr.clone(),
112 suggested_name: self.next_cte_name(),
113 dependencies: Vec::new(), });
115 }
116 }
117
118 liftable
119 }
120
121 pub fn lift_expressions(&mut self, stmt: &mut SelectStatement) -> Vec<CTE> {
123 let mut lifted_ctes = Vec::new();
124
125 let alias_deps = self.analyze_column_alias_dependencies(stmt);
127 if !alias_deps.is_empty() {
128 let cte = self.lift_column_aliases(stmt, &alias_deps);
129 lifted_ctes.push(cte);
130 }
131
132 if let Some(ref where_clause) = stmt.where_clause {
134 let liftable = self.analyze_where_clause(where_clause);
135
136 for lift_expr in liftable {
137 let cte_select = SelectStatement {
139 distinct: false,
140 columns: vec!["*".to_string()],
141 select_items: vec![
142 SelectItem::Star {
143 table_prefix: None,
144 leading_comments: vec![],
145 trailing_comment: None,
146 },
147 SelectItem::Expression {
148 expr: lift_expr.expression.clone(),
149 alias: "lifted_value".to_string(),
150 leading_comments: vec![],
151 trailing_comment: None,
152 },
153 ],
154 from_table: stmt.from_table.clone(),
155 from_subquery: stmt.from_subquery.clone(),
156 from_function: stmt.from_function.clone(),
157 from_alias: stmt.from_alias.clone(),
158 joins: stmt.joins.clone(),
159 where_clause: None, qualify: None,
161 order_by: None,
162 group_by: None,
163 having: None,
164 limit: None,
165 offset: None,
166 ctes: Vec::new(),
167 into_table: None,
168 set_operations: Vec::new(),
169 leading_comments: vec![],
170 trailing_comment: None,
171 };
172
173 let cte = CTE {
174 name: lift_expr.suggested_name.clone(),
175 column_list: None,
176 cte_type: CTEType::Standard(cte_select),
177 };
178
179 lifted_ctes.push(cte);
180
181 stmt.from_table = Some(lift_expr.suggested_name);
183
184 use crate::sql::parser::ast::Condition;
186 stmt.where_clause = Some(WhereClause {
187 conditions: vec![Condition {
188 expr: SqlExpression::Column(ColumnRef::unquoted(
189 "lifted_value".to_string(),
190 )),
191 connector: None,
192 }],
193 });
194 }
195 }
196
197 stmt.ctes.extend(lifted_ctes.clone());
199
200 lifted_ctes
201 }
202
203 fn analyze_column_alias_dependencies(
205 &self,
206 stmt: &SelectStatement,
207 ) -> Vec<(String, SqlExpression)> {
208 let mut dependencies = Vec::new();
209
210 let mut aliases = std::collections::HashMap::new();
212 for item in &stmt.select_items {
213 if let SelectItem::Expression { expr, alias, .. } = item {
214 aliases.insert(alias.clone(), expr.clone());
215 tracing::debug!("Found alias: {} -> {:?}", alias, expr);
216 }
217 }
218
219 for item in &stmt.select_items {
221 if let SelectItem::Expression { expr, .. } = item {
222 if let SqlExpression::WindowFunction { window_spec, .. } = expr {
223 for col in &window_spec.partition_by {
225 tracing::debug!("Checking PARTITION BY column: {}", col);
226 if aliases.contains_key(col) {
227 tracing::debug!(
228 "Found dependency: {} depends on {:?}",
229 col,
230 aliases[col]
231 );
232 dependencies.push((col.clone(), aliases[col].clone()));
233 }
234 }
235
236 for order_col in &window_spec.order_by {
238 if let SqlExpression::Column(col_ref) = &order_col.expr {
240 let col = &col_ref.name;
241 if aliases.contains_key(col) {
242 dependencies.push((col.clone(), aliases[col].clone()));
243 }
244 }
245 }
246 }
247 }
248 }
249
250 if let Some(ref qualify_expr) = stmt.qualify {
254 tracing::debug!("Checking QUALIFY clause for window function aliases");
255 let qualify_column_refs = extract_column_references(qualify_expr);
256
257 for col_name in qualify_column_refs {
258 tracing::debug!("QUALIFY references column: {}", col_name);
259 if let Some(expr) = aliases.get(&col_name) {
260 if matches!(expr, SqlExpression::WindowFunction { .. }) {
262 tracing::debug!(
263 "QUALIFY references window function alias: {} -> {:?}",
264 col_name,
265 expr
266 );
267 dependencies.push((col_name.clone(), expr.clone()));
268 }
269 }
270 }
271 }
272
273 dependencies.sort_by(|a, b| a.0.cmp(&b.0));
275 dependencies.dedup_by(|a, b| a.0 == b.0);
276
277 dependencies
278 }
279
280 fn lift_column_aliases(
282 &mut self,
283 stmt: &mut SelectStatement,
284 deps: &[(String, SqlExpression)],
285 ) -> CTE {
286 let cte_name = self.next_cte_name();
287
288 let mut cte_select_items = vec![SelectItem::Star {
290 table_prefix: None,
291 leading_comments: vec![],
292 trailing_comment: None,
293 }];
294 for (alias, expr) in deps {
295 cte_select_items.push(SelectItem::Expression {
296 expr: expr.clone(),
297 alias: alias.clone(),
298 leading_comments: vec![],
299 trailing_comment: None,
300 });
301 }
302
303 let cte_select = SelectStatement {
304 distinct: false,
305 columns: vec!["*".to_string()],
306 select_items: cte_select_items,
307 from_table: stmt.from_table.clone(),
308 from_subquery: stmt.from_subquery.clone(),
309 from_function: stmt.from_function.clone(),
310 from_alias: stmt.from_alias.clone(),
311 joins: stmt.joins.clone(),
312 where_clause: stmt.where_clause.clone(),
313 order_by: None,
314 group_by: None,
315 having: None,
316 limit: None,
317 offset: None,
318 ctes: Vec::new(),
319 into_table: None,
320 set_operations: Vec::new(),
321 leading_comments: vec![],
322 trailing_comment: None,
323 qualify: None,
324 };
325
326 let mut new_select_items = Vec::new();
328 for item in &stmt.select_items {
329 match item {
330 SelectItem::Expression { expr: _, alias, .. }
331 if deps.iter().any(|(a, _)| a == alias) =>
332 {
333 new_select_items.push(SelectItem::Column {
335 column: ColumnRef::unquoted(alias.clone()),
336 leading_comments: vec![],
337 trailing_comment: None,
338 });
339 }
340 _ => {
341 new_select_items.push(item.clone());
342 }
343 }
344 }
345
346 stmt.select_items = new_select_items;
347 stmt.from_table = Some(cte_name.clone());
348 stmt.from_subquery = None;
349 stmt.where_clause = None; CTE {
352 name: cte_name,
353 column_list: None,
354 cte_type: CTEType::Standard(cte_select),
355 }
356 }
357
358 pub fn create_work_units_for_lifted(
360 &mut self,
361 lifted_ctes: &[CTE],
362 plan: &mut QueryPlan,
363 ) -> Vec<String> {
364 let mut cte_ids = Vec::new();
365
366 for cte in lifted_ctes {
367 let unit_id = format!("cte_{}", cte.name);
368
369 let work_unit = WorkUnit {
370 id: unit_id.clone(),
371 work_type: WorkUnitType::CTE,
372 expression: match &cte.cte_type {
373 CTEType::Standard(select) => WorkUnitExpression::Select(select.clone()),
374 CTEType::Web(_) => WorkUnitExpression::Custom("WEB CTE".to_string()),
375 },
376 dependencies: Vec::new(), parallelizable: true, cost_estimate: None,
379 };
380
381 plan.add_unit(work_unit);
382 cte_ids.push(unit_id);
383 }
384
385 cte_ids
386 }
387}
388
389#[derive(Debug)]
391pub struct LiftableExpression {
392 pub expression: SqlExpression,
394
395 pub suggested_name: String,
397
398 pub dependencies: Vec<String>,
400}
401
402fn extract_column_references(expr: &SqlExpression) -> HashSet<String> {
405 let mut refs = HashSet::new();
406
407 match expr {
408 SqlExpression::Column(col_ref) => {
409 refs.insert(col_ref.name.clone());
410 }
411
412 SqlExpression::BinaryOp { left, right, .. } => {
413 refs.extend(extract_column_references(left));
414 refs.extend(extract_column_references(right));
415 }
416
417 SqlExpression::Not { expr } => {
418 refs.extend(extract_column_references(expr));
419 }
420
421 SqlExpression::Between { expr, lower, upper } => {
422 refs.extend(extract_column_references(expr));
423 refs.extend(extract_column_references(lower));
424 refs.extend(extract_column_references(upper));
425 }
426
427 SqlExpression::InList { expr, values } | SqlExpression::NotInList { expr, values } => {
428 refs.extend(extract_column_references(expr));
429 for val in values {
430 refs.extend(extract_column_references(val));
431 }
432 }
433
434 SqlExpression::FunctionCall { args, .. } | SqlExpression::WindowFunction { args, .. } => {
435 for arg in args {
436 refs.extend(extract_column_references(arg));
437 }
438 }
439
440 SqlExpression::CaseExpression {
441 when_branches,
442 else_branch,
443 } => {
444 for branch in when_branches {
445 refs.extend(extract_column_references(&branch.condition));
446 refs.extend(extract_column_references(&branch.result));
447 }
448 if let Some(else_expr) = else_branch {
449 refs.extend(extract_column_references(else_expr));
450 }
451 }
452
453 SqlExpression::SimpleCaseExpression {
454 expr,
455 when_branches,
456 else_branch,
457 } => {
458 refs.extend(extract_column_references(expr));
459 for branch in when_branches {
460 refs.extend(extract_column_references(&branch.value));
461 refs.extend(extract_column_references(&branch.result));
462 }
463 if let Some(else_expr) = else_branch {
464 refs.extend(extract_column_references(else_expr));
465 }
466 }
467
468 _ => {}
470 }
471
472 refs
473}
474
475pub fn analyze_dependencies(expr: &SqlExpression) -> HashSet<String> {
476 let mut deps = HashSet::new();
477
478 match expr {
479 SqlExpression::Column(col) => {
480 deps.insert(col.name.clone());
481 }
482
483 SqlExpression::FunctionCall { args, .. } => {
484 for arg in args {
485 deps.extend(analyze_dependencies(arg));
486 }
487 }
488
489 SqlExpression::WindowFunction {
490 args, window_spec, ..
491 } => {
492 for arg in args {
493 deps.extend(analyze_dependencies(arg));
494 }
495
496 for col in &window_spec.partition_by {
498 deps.insert(col.clone());
499 }
500
501 for order_col in &window_spec.order_by {
502 if let SqlExpression::Column(col_ref) = &order_col.expr {
504 deps.insert(col_ref.name.clone());
505 }
506 }
507 }
508
509 SqlExpression::BinaryOp { left, right, .. } => {
510 deps.extend(analyze_dependencies(left));
511 deps.extend(analyze_dependencies(right));
512 }
513
514 SqlExpression::CaseExpression {
515 when_branches,
516 else_branch,
517 } => {
518 for branch in when_branches {
519 deps.extend(analyze_dependencies(&branch.condition));
520 deps.extend(analyze_dependencies(&branch.result));
521 }
522
523 if let Some(else_expr) = else_branch {
524 deps.extend(analyze_dependencies(else_expr));
525 }
526 }
527
528 SqlExpression::SimpleCaseExpression {
529 expr,
530 when_branches,
531 else_branch,
532 } => {
533 deps.extend(analyze_dependencies(expr));
534
535 for branch in when_branches {
536 deps.extend(analyze_dependencies(&branch.value));
537 deps.extend(analyze_dependencies(&branch.result));
538 }
539
540 if let Some(else_expr) = else_branch {
541 deps.extend(analyze_dependencies(else_expr));
542 }
543 }
544
545 _ => {}
546 }
547
548 deps
549}
550
551#[cfg(test)]
552mod tests {
553 use super::*;
554
555 #[test]
556 fn test_needs_lifting_window_function() {
557 let lifter = ExpressionLifter::new();
558
559 let window_expr = SqlExpression::WindowFunction {
560 name: "ROW_NUMBER".to_string(),
561 args: vec![],
562 window_spec: crate::sql::parser::ast::WindowSpec {
563 partition_by: vec![],
564 order_by: vec![],
565 frame: None,
566 },
567 };
568
569 assert!(lifter.needs_lifting(&window_expr));
570 }
571
572 #[test]
573 fn test_needs_lifting_simple_expression() {
574 let lifter = ExpressionLifter::new();
575
576 let simple_expr = SqlExpression::BinaryOp {
577 left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
578 "col1".to_string(),
579 ))),
580 op: "=".to_string(),
581 right: Box::new(SqlExpression::NumberLiteral("42".to_string())),
582 };
583
584 assert!(!lifter.needs_lifting(&simple_expr));
585 }
586
587 #[test]
588 fn test_analyze_dependencies() {
589 let expr = SqlExpression::BinaryOp {
590 left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
591 "col1".to_string(),
592 ))),
593 op: "+".to_string(),
594 right: Box::new(SqlExpression::Column(ColumnRef::unquoted(
595 "col2".to_string(),
596 ))),
597 };
598
599 let deps = analyze_dependencies(&expr);
600 assert!(deps.contains("col1"));
601 assert!(deps.contains("col2"));
602 assert_eq!(deps.len(), 2);
603 }
604}