1use crate::query_plan::{QueryPlan, WorkUnit, WorkUnitExpression, WorkUnitType};
2use crate::sql::parser::ast::{
3 CTEType, ColumnRef, SelectItem, SelectStatement, SqlExpression, WhereClause, CTE,
4};
5use std::collections::HashSet;
6
7pub struct ExpressionLifter {
9 cte_counter: usize,
11
12 liftable_functions: HashSet<String>,
14}
15
16impl ExpressionLifter {
17 pub fn new() -> Self {
19 let mut liftable_functions = HashSet::new();
20
21 liftable_functions.insert("ROW_NUMBER".to_string());
23 liftable_functions.insert("RANK".to_string());
24 liftable_functions.insert("DENSE_RANK".to_string());
25 liftable_functions.insert("LAG".to_string());
26 liftable_functions.insert("LEAD".to_string());
27 liftable_functions.insert("FIRST_VALUE".to_string());
28 liftable_functions.insert("LAST_VALUE".to_string());
29 liftable_functions.insert("NTH_VALUE".to_string());
30
31 liftable_functions.insert("PERCENTILE_CONT".to_string());
33 liftable_functions.insert("PERCENTILE_DISC".to_string());
34
35 ExpressionLifter {
36 cte_counter: 0,
37 liftable_functions,
38 }
39 }
40
41 fn next_cte_name(&mut self) -> String {
43 self.cte_counter += 1;
44 format!("__lifted_{}", self.cte_counter)
45 }
46
47 pub fn needs_lifting(&self, expr: &SqlExpression) -> bool {
49 match expr {
50 SqlExpression::WindowFunction { .. } => true,
51
52 SqlExpression::FunctionCall { name, .. } => {
53 self.liftable_functions.contains(&name.to_uppercase())
54 }
55
56 SqlExpression::BinaryOp { left, right, .. } => {
57 self.needs_lifting(left) || self.needs_lifting(right)
58 }
59
60 SqlExpression::Not { expr } => self.needs_lifting(expr),
61
62 SqlExpression::InList { expr, values } => {
63 self.needs_lifting(expr) || values.iter().any(|v| self.needs_lifting(v))
64 }
65
66 SqlExpression::NotInList { expr, values } => {
67 self.needs_lifting(expr) || values.iter().any(|v| self.needs_lifting(v))
68 }
69
70 SqlExpression::Between { expr, lower, upper } => {
71 self.needs_lifting(expr) || self.needs_lifting(lower) || self.needs_lifting(upper)
72 }
73
74 SqlExpression::CaseExpression {
75 when_branches,
76 else_branch,
77 } => {
78 when_branches.iter().any(|branch| {
79 self.needs_lifting(&branch.condition) || self.needs_lifting(&branch.result)
80 }) || else_branch
81 .as_ref()
82 .map_or(false, |e| self.needs_lifting(e))
83 }
84
85 SqlExpression::SimpleCaseExpression {
86 expr,
87 when_branches,
88 else_branch,
89 } => {
90 self.needs_lifting(expr)
91 || when_branches.iter().any(|branch| {
92 self.needs_lifting(&branch.value) || self.needs_lifting(&branch.result)
93 })
94 || else_branch
95 .as_ref()
96 .map_or(false, |e| self.needs_lifting(e))
97 }
98
99 _ => false,
100 }
101 }
102
103 pub fn analyze_where_clause(&mut self, where_clause: &WhereClause) -> Vec<LiftableExpression> {
105 let mut liftable = Vec::new();
106
107 for condition in &where_clause.conditions {
109 if self.needs_lifting(&condition.expr) {
110 liftable.push(LiftableExpression {
111 expression: condition.expr.clone(),
112 suggested_name: self.next_cte_name(),
113 dependencies: Vec::new(), });
115 }
116 }
117
118 liftable
119 }
120
121 pub fn lift_expressions(&mut self, stmt: &mut SelectStatement) -> Vec<CTE> {
123 let mut lifted_ctes = Vec::new();
124
125 let alias_deps = self.analyze_column_alias_dependencies(stmt);
127 if !alias_deps.is_empty() {
128 let cte = self.lift_column_aliases(stmt, &alias_deps);
129 lifted_ctes.push(cte);
130 }
131
132 if let Some(ref where_clause) = stmt.where_clause {
134 let liftable = self.analyze_where_clause(where_clause);
135
136 for lift_expr in liftable {
137 let cte_select = SelectStatement {
139 distinct: false,
140 columns: vec!["*".to_string()],
141 select_items: vec![
142 SelectItem::Star {
143 table_prefix: None,
144 leading_comments: vec![],
145 trailing_comment: None,
146 },
147 SelectItem::Expression {
148 expr: lift_expr.expression.clone(),
149 alias: "lifted_value".to_string(),
150 leading_comments: vec![],
151 trailing_comment: None,
152 },
153 ],
154 from_table: stmt.from_table.clone(),
155 from_subquery: stmt.from_subquery.clone(),
156 from_function: stmt.from_function.clone(),
157 from_alias: stmt.from_alias.clone(),
158 joins: stmt.joins.clone(),
159 where_clause: None, order_by: None,
161 group_by: None,
162 having: None,
163 limit: None,
164 offset: None,
165 ctes: Vec::new(),
166 into_table: None,
167 set_operations: Vec::new(),
168 leading_comments: vec![],
169 trailing_comment: None,
170 };
171
172 let cte = CTE {
173 name: lift_expr.suggested_name.clone(),
174 column_list: None,
175 cte_type: CTEType::Standard(cte_select),
176 };
177
178 lifted_ctes.push(cte);
179
180 stmt.from_table = Some(lift_expr.suggested_name);
182
183 use crate::sql::parser::ast::Condition;
185 stmt.where_clause = Some(WhereClause {
186 conditions: vec![Condition {
187 expr: SqlExpression::Column(ColumnRef::unquoted(
188 "lifted_value".to_string(),
189 )),
190 connector: None,
191 }],
192 });
193 }
194 }
195
196 stmt.ctes.extend(lifted_ctes.clone());
198
199 lifted_ctes
200 }
201
202 fn analyze_column_alias_dependencies(
204 &self,
205 stmt: &SelectStatement,
206 ) -> Vec<(String, SqlExpression)> {
207 let mut dependencies = Vec::new();
208
209 let mut aliases = std::collections::HashMap::new();
211 for item in &stmt.select_items {
212 if let SelectItem::Expression { expr, alias, .. } = item {
213 aliases.insert(alias.clone(), expr.clone());
214 tracing::debug!("Found alias: {} -> {:?}", alias, expr);
215 }
216 }
217
218 for item in &stmt.select_items {
220 if let SelectItem::Expression { expr, .. } = item {
221 if let SqlExpression::WindowFunction { window_spec, .. } = expr {
222 for col in &window_spec.partition_by {
224 tracing::debug!("Checking PARTITION BY column: {}", col);
225 if aliases.contains_key(col) {
226 tracing::debug!(
227 "Found dependency: {} depends on {:?}",
228 col,
229 aliases[col]
230 );
231 dependencies.push((col.clone(), aliases[col].clone()));
232 }
233 }
234
235 for order_col in &window_spec.order_by {
237 let col = &order_col.column;
238 if aliases.contains_key(col) {
239 dependencies.push((col.clone(), aliases[col].clone()));
240 }
241 }
242 }
243 }
244 }
245
246 dependencies.sort_by(|a, b| a.0.cmp(&b.0));
248 dependencies.dedup_by(|a, b| a.0 == b.0);
249
250 dependencies
251 }
252
253 fn lift_column_aliases(
255 &mut self,
256 stmt: &mut SelectStatement,
257 deps: &[(String, SqlExpression)],
258 ) -> CTE {
259 let cte_name = self.next_cte_name();
260
261 let mut cte_select_items = vec![SelectItem::Star {
263 table_prefix: None,
264 leading_comments: vec![],
265 trailing_comment: None,
266 }];
267 for (alias, expr) in deps {
268 cte_select_items.push(SelectItem::Expression {
269 expr: expr.clone(),
270 alias: alias.clone(),
271 leading_comments: vec![],
272 trailing_comment: None,
273 });
274 }
275
276 let cte_select = SelectStatement {
277 distinct: false,
278 columns: vec!["*".to_string()],
279 select_items: cte_select_items,
280 from_table: stmt.from_table.clone(),
281 from_subquery: stmt.from_subquery.clone(),
282 from_function: stmt.from_function.clone(),
283 from_alias: stmt.from_alias.clone(),
284 joins: stmt.joins.clone(),
285 where_clause: stmt.where_clause.clone(),
286 order_by: None,
287 group_by: None,
288 having: None,
289 limit: None,
290 offset: None,
291 ctes: Vec::new(),
292 into_table: None,
293 set_operations: Vec::new(),
294 leading_comments: vec![],
295 trailing_comment: None,
296 };
297
298 let mut new_select_items = Vec::new();
300 for item in &stmt.select_items {
301 match item {
302 SelectItem::Expression { expr: _, alias, .. }
303 if deps.iter().any(|(a, _)| a == alias) =>
304 {
305 new_select_items.push(SelectItem::Column {
307 column: ColumnRef::unquoted(alias.clone()),
308 leading_comments: vec![],
309 trailing_comment: None,
310 });
311 }
312 _ => {
313 new_select_items.push(item.clone());
314 }
315 }
316 }
317
318 stmt.select_items = new_select_items;
319 stmt.from_table = Some(cte_name.clone());
320 stmt.from_subquery = None;
321 stmt.where_clause = None; CTE {
324 name: cte_name,
325 column_list: None,
326 cte_type: CTEType::Standard(cte_select),
327 }
328 }
329
330 pub fn create_work_units_for_lifted(
332 &mut self,
333 lifted_ctes: &[CTE],
334 plan: &mut QueryPlan,
335 ) -> Vec<String> {
336 let mut cte_ids = Vec::new();
337
338 for cte in lifted_ctes {
339 let unit_id = format!("cte_{}", cte.name);
340
341 let work_unit = WorkUnit {
342 id: unit_id.clone(),
343 work_type: WorkUnitType::CTE,
344 expression: match &cte.cte_type {
345 CTEType::Standard(select) => WorkUnitExpression::Select(select.clone()),
346 CTEType::Web(_) => WorkUnitExpression::Custom("WEB CTE".to_string()),
347 },
348 dependencies: Vec::new(), parallelizable: true, cost_estimate: None,
351 };
352
353 plan.add_unit(work_unit);
354 cte_ids.push(unit_id);
355 }
356
357 cte_ids
358 }
359}
360
361#[derive(Debug)]
363pub struct LiftableExpression {
364 pub expression: SqlExpression,
366
367 pub suggested_name: String,
369
370 pub dependencies: Vec<String>,
372}
373
374pub fn analyze_dependencies(expr: &SqlExpression) -> HashSet<String> {
376 let mut deps = HashSet::new();
377
378 match expr {
379 SqlExpression::Column(col) => {
380 deps.insert(col.name.clone());
381 }
382
383 SqlExpression::FunctionCall { args, .. } => {
384 for arg in args {
385 deps.extend(analyze_dependencies(arg));
386 }
387 }
388
389 SqlExpression::WindowFunction {
390 args, window_spec, ..
391 } => {
392 for arg in args {
393 deps.extend(analyze_dependencies(arg));
394 }
395
396 for col in &window_spec.partition_by {
398 deps.insert(col.clone());
399 }
400
401 for order_col in &window_spec.order_by {
402 deps.insert(order_col.column.clone());
403 }
404 }
405
406 SqlExpression::BinaryOp { left, right, .. } => {
407 deps.extend(analyze_dependencies(left));
408 deps.extend(analyze_dependencies(right));
409 }
410
411 SqlExpression::CaseExpression {
412 when_branches,
413 else_branch,
414 } => {
415 for branch in when_branches {
416 deps.extend(analyze_dependencies(&branch.condition));
417 deps.extend(analyze_dependencies(&branch.result));
418 }
419
420 if let Some(else_expr) = else_branch {
421 deps.extend(analyze_dependencies(else_expr));
422 }
423 }
424
425 SqlExpression::SimpleCaseExpression {
426 expr,
427 when_branches,
428 else_branch,
429 } => {
430 deps.extend(analyze_dependencies(expr));
431
432 for branch in when_branches {
433 deps.extend(analyze_dependencies(&branch.value));
434 deps.extend(analyze_dependencies(&branch.result));
435 }
436
437 if let Some(else_expr) = else_branch {
438 deps.extend(analyze_dependencies(else_expr));
439 }
440 }
441
442 _ => {}
443 }
444
445 deps
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451
452 #[test]
453 fn test_needs_lifting_window_function() {
454 let lifter = ExpressionLifter::new();
455
456 let window_expr = SqlExpression::WindowFunction {
457 name: "ROW_NUMBER".to_string(),
458 args: vec![],
459 window_spec: crate::sql::parser::ast::WindowSpec {
460 partition_by: vec![],
461 order_by: vec![],
462 frame: None,
463 },
464 };
465
466 assert!(lifter.needs_lifting(&window_expr));
467 }
468
469 #[test]
470 fn test_needs_lifting_simple_expression() {
471 let lifter = ExpressionLifter::new();
472
473 let simple_expr = SqlExpression::BinaryOp {
474 left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
475 "col1".to_string(),
476 ))),
477 op: "=".to_string(),
478 right: Box::new(SqlExpression::NumberLiteral("42".to_string())),
479 };
480
481 assert!(!lifter.needs_lifting(&simple_expr));
482 }
483
484 #[test]
485 fn test_analyze_dependencies() {
486 let expr = SqlExpression::BinaryOp {
487 left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
488 "col1".to_string(),
489 ))),
490 op: "+".to_string(),
491 right: Box::new(SqlExpression::Column(ColumnRef::unquoted(
492 "col2".to_string(),
493 ))),
494 };
495
496 let deps = analyze_dependencies(&expr);
497 assert!(deps.contains("col1"));
498 assert!(deps.contains("col2"));
499 assert_eq!(deps.len(), 2);
500 }
501}