1use crate::query_plan::{QueryPlan, WorkUnit, WorkUnitExpression, WorkUnitType};
2use crate::sql::parser::ast::{
3 CTEType, SelectItem, SelectStatement, SqlExpression, WhereClause, CTE,
4};
5use std::collections::HashSet;
6
7pub struct ExpressionLifter {
9 cte_counter: usize,
11
12 liftable_functions: HashSet<String>,
14}
15
16impl ExpressionLifter {
17 pub fn new() -> Self {
19 let mut liftable_functions = HashSet::new();
20
21 liftable_functions.insert("ROW_NUMBER".to_string());
23 liftable_functions.insert("RANK".to_string());
24 liftable_functions.insert("DENSE_RANK".to_string());
25 liftable_functions.insert("LAG".to_string());
26 liftable_functions.insert("LEAD".to_string());
27 liftable_functions.insert("FIRST_VALUE".to_string());
28 liftable_functions.insert("LAST_VALUE".to_string());
29 liftable_functions.insert("NTH_VALUE".to_string());
30
31 liftable_functions.insert("PERCENTILE_CONT".to_string());
33 liftable_functions.insert("PERCENTILE_DISC".to_string());
34
35 ExpressionLifter {
36 cte_counter: 0,
37 liftable_functions,
38 }
39 }
40
41 fn next_cte_name(&mut self) -> String {
43 self.cte_counter += 1;
44 format!("__lifted_{}", self.cte_counter)
45 }
46
47 pub fn needs_lifting(&self, expr: &SqlExpression) -> bool {
49 match expr {
50 SqlExpression::WindowFunction { .. } => true,
51
52 SqlExpression::FunctionCall { name, .. } => {
53 self.liftable_functions.contains(&name.to_uppercase())
54 }
55
56 SqlExpression::BinaryOp { left, right, .. } => {
57 self.needs_lifting(left) || self.needs_lifting(right)
58 }
59
60 SqlExpression::Not { expr } => self.needs_lifting(expr),
61
62 SqlExpression::InList { expr, values } => {
63 self.needs_lifting(expr) || values.iter().any(|v| self.needs_lifting(v))
64 }
65
66 SqlExpression::NotInList { expr, values } => {
67 self.needs_lifting(expr) || values.iter().any(|v| self.needs_lifting(v))
68 }
69
70 SqlExpression::Between { expr, lower, upper } => {
71 self.needs_lifting(expr) || self.needs_lifting(lower) || self.needs_lifting(upper)
72 }
73
74 SqlExpression::CaseExpression {
75 when_branches,
76 else_branch,
77 } => {
78 when_branches.iter().any(|branch| {
79 self.needs_lifting(&branch.condition) || self.needs_lifting(&branch.result)
80 }) || else_branch
81 .as_ref()
82 .map_or(false, |e| self.needs_lifting(e))
83 }
84
85 SqlExpression::SimpleCaseExpression {
86 expr,
87 when_branches,
88 else_branch,
89 } => {
90 self.needs_lifting(expr)
91 || when_branches.iter().any(|branch| {
92 self.needs_lifting(&branch.value) || self.needs_lifting(&branch.result)
93 })
94 || else_branch
95 .as_ref()
96 .map_or(false, |e| self.needs_lifting(e))
97 }
98
99 _ => false,
100 }
101 }
102
103 pub fn analyze_where_clause(&mut self, where_clause: &WhereClause) -> Vec<LiftableExpression> {
105 let mut liftable = Vec::new();
106
107 for condition in &where_clause.conditions {
109 if self.needs_lifting(&condition.expr) {
110 liftable.push(LiftableExpression {
111 expression: condition.expr.clone(),
112 suggested_name: self.next_cte_name(),
113 dependencies: Vec::new(), });
115 }
116 }
117
118 liftable
119 }
120
121 pub fn lift_expressions(&mut self, stmt: &mut SelectStatement) -> Vec<CTE> {
123 let mut lifted_ctes = Vec::new();
124
125 let alias_deps = self.analyze_column_alias_dependencies(stmt);
127 if !alias_deps.is_empty() {
128 let cte = self.lift_column_aliases(stmt, &alias_deps);
129 lifted_ctes.push(cte);
130 }
131
132 if let Some(ref where_clause) = stmt.where_clause {
134 let liftable = self.analyze_where_clause(where_clause);
135
136 for lift_expr in liftable {
137 let cte_select = SelectStatement {
139 distinct: false,
140 columns: vec!["*".to_string()],
141 select_items: vec![
142 SelectItem::Star,
143 SelectItem::Expression {
144 expr: lift_expr.expression.clone(),
145 alias: "lifted_value".to_string(),
146 },
147 ],
148 from_table: stmt.from_table.clone(),
149 from_subquery: stmt.from_subquery.clone(),
150 from_function: stmt.from_function.clone(),
151 from_alias: stmt.from_alias.clone(),
152 joins: stmt.joins.clone(),
153 where_clause: None, order_by: None,
155 group_by: None,
156 having: None,
157 limit: None,
158 offset: None,
159 ctes: Vec::new(),
160 };
161
162 let cte = CTE {
163 name: lift_expr.suggested_name.clone(),
164 column_list: None,
165 cte_type: CTEType::Standard(cte_select),
166 };
167
168 lifted_ctes.push(cte);
169
170 stmt.from_table = Some(lift_expr.suggested_name);
172
173 use crate::sql::parser::ast::Condition;
175 stmt.where_clause = Some(WhereClause {
176 conditions: vec![Condition {
177 expr: SqlExpression::Column("lifted_value".to_string()),
178 connector: None,
179 }],
180 });
181 }
182 }
183
184 stmt.ctes.extend(lifted_ctes.clone());
186
187 lifted_ctes
188 }
189
190 fn analyze_column_alias_dependencies(
192 &self,
193 stmt: &SelectStatement,
194 ) -> Vec<(String, SqlExpression)> {
195 let mut dependencies = Vec::new();
196
197 let mut aliases = std::collections::HashMap::new();
199 for item in &stmt.select_items {
200 if let SelectItem::Expression { expr, alias } = item {
201 aliases.insert(alias.clone(), expr.clone());
202 tracing::debug!("Found alias: {} -> {:?}", alias, expr);
203 }
204 }
205
206 for item in &stmt.select_items {
208 if let SelectItem::Expression { expr, .. } = item {
209 if let SqlExpression::WindowFunction { window_spec, .. } = expr {
210 for col in &window_spec.partition_by {
212 tracing::debug!("Checking PARTITION BY column: {}", col);
213 if aliases.contains_key(col) {
214 tracing::debug!(
215 "Found dependency: {} depends on {:?}",
216 col,
217 aliases[col]
218 );
219 dependencies.push((col.clone(), aliases[col].clone()));
220 }
221 }
222
223 for order_col in &window_spec.order_by {
225 let col = &order_col.column;
226 if aliases.contains_key(col) {
227 dependencies.push((col.clone(), aliases[col].clone()));
228 }
229 }
230 }
231 }
232 }
233
234 dependencies.sort_by(|a, b| a.0.cmp(&b.0));
236 dependencies.dedup_by(|a, b| a.0 == b.0);
237
238 dependencies
239 }
240
241 fn lift_column_aliases(
243 &mut self,
244 stmt: &mut SelectStatement,
245 deps: &[(String, SqlExpression)],
246 ) -> CTE {
247 let cte_name = self.next_cte_name();
248
249 let mut cte_select_items = vec![SelectItem::Star];
251 for (alias, expr) in deps {
252 cte_select_items.push(SelectItem::Expression {
253 expr: expr.clone(),
254 alias: alias.clone(),
255 });
256 }
257
258 let cte_select = SelectStatement {
259 distinct: false,
260 columns: vec!["*".to_string()],
261 select_items: cte_select_items,
262 from_table: stmt.from_table.clone(),
263 from_subquery: stmt.from_subquery.clone(),
264 from_function: stmt.from_function.clone(),
265 from_alias: stmt.from_alias.clone(),
266 joins: stmt.joins.clone(),
267 where_clause: stmt.where_clause.clone(),
268 order_by: None,
269 group_by: None,
270 having: None,
271 limit: None,
272 offset: None,
273 ctes: Vec::new(),
274 };
275
276 let mut new_select_items = Vec::new();
278 for item in &stmt.select_items {
279 match item {
280 SelectItem::Expression { expr: _, alias }
281 if deps.iter().any(|(a, _)| a == alias) =>
282 {
283 new_select_items.push(SelectItem::Column(alias.clone()));
285 }
286 _ => {
287 new_select_items.push(item.clone());
288 }
289 }
290 }
291
292 stmt.select_items = new_select_items;
293 stmt.from_table = Some(cte_name.clone());
294 stmt.from_subquery = None;
295 stmt.where_clause = None; CTE {
298 name: cte_name,
299 column_list: None,
300 cte_type: CTEType::Standard(cte_select),
301 }
302 }
303
304 pub fn create_work_units_for_lifted(
306 &mut self,
307 lifted_ctes: &[CTE],
308 plan: &mut QueryPlan,
309 ) -> Vec<String> {
310 let mut cte_ids = Vec::new();
311
312 for cte in lifted_ctes {
313 let unit_id = format!("cte_{}", cte.name);
314
315 let work_unit = WorkUnit {
316 id: unit_id.clone(),
317 work_type: WorkUnitType::CTE,
318 expression: match &cte.cte_type {
319 CTEType::Standard(select) => WorkUnitExpression::Select(select.clone()),
320 CTEType::Web(_) => WorkUnitExpression::Custom("WEB CTE".to_string()),
321 },
322 dependencies: Vec::new(), parallelizable: true, cost_estimate: None,
325 };
326
327 plan.add_unit(work_unit);
328 cte_ids.push(unit_id);
329 }
330
331 cte_ids
332 }
333}
334
335#[derive(Debug)]
337pub struct LiftableExpression {
338 pub expression: SqlExpression,
340
341 pub suggested_name: String,
343
344 pub dependencies: Vec<String>,
346}
347
348pub fn analyze_dependencies(expr: &SqlExpression) -> HashSet<String> {
350 let mut deps = HashSet::new();
351
352 match expr {
353 SqlExpression::Column(col) => {
354 deps.insert(col.clone());
355 }
356
357 SqlExpression::FunctionCall { args, .. } => {
358 for arg in args {
359 deps.extend(analyze_dependencies(arg));
360 }
361 }
362
363 SqlExpression::WindowFunction {
364 args, window_spec, ..
365 } => {
366 for arg in args {
367 deps.extend(analyze_dependencies(arg));
368 }
369
370 for col in &window_spec.partition_by {
372 deps.insert(col.clone());
373 }
374
375 for order_col in &window_spec.order_by {
376 deps.insert(order_col.column.clone());
377 }
378 }
379
380 SqlExpression::BinaryOp { left, right, .. } => {
381 deps.extend(analyze_dependencies(left));
382 deps.extend(analyze_dependencies(right));
383 }
384
385 SqlExpression::CaseExpression {
386 when_branches,
387 else_branch,
388 } => {
389 for branch in when_branches {
390 deps.extend(analyze_dependencies(&branch.condition));
391 deps.extend(analyze_dependencies(&branch.result));
392 }
393
394 if let Some(else_expr) = else_branch {
395 deps.extend(analyze_dependencies(else_expr));
396 }
397 }
398
399 SqlExpression::SimpleCaseExpression {
400 expr,
401 when_branches,
402 else_branch,
403 } => {
404 deps.extend(analyze_dependencies(expr));
405
406 for branch in when_branches {
407 deps.extend(analyze_dependencies(&branch.value));
408 deps.extend(analyze_dependencies(&branch.result));
409 }
410
411 if let Some(else_expr) = else_branch {
412 deps.extend(analyze_dependencies(else_expr));
413 }
414 }
415
416 _ => {}
417 }
418
419 deps
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 #[test]
427 fn test_needs_lifting_window_function() {
428 let lifter = ExpressionLifter::new();
429
430 let window_expr = SqlExpression::WindowFunction {
431 name: "ROW_NUMBER".to_string(),
432 args: vec![],
433 window_spec: crate::sql::parser::ast::WindowSpec {
434 partition_by: vec![],
435 order_by: vec![],
436 frame: None,
437 },
438 };
439
440 assert!(lifter.needs_lifting(&window_expr));
441 }
442
443 #[test]
444 fn test_needs_lifting_simple_expression() {
445 let lifter = ExpressionLifter::new();
446
447 let simple_expr = SqlExpression::BinaryOp {
448 left: Box::new(SqlExpression::Column("col1".to_string())),
449 op: "=".to_string(),
450 right: Box::new(SqlExpression::NumberLiteral("42".to_string())),
451 };
452
453 assert!(!lifter.needs_lifting(&simple_expr));
454 }
455
456 #[test]
457 fn test_analyze_dependencies() {
458 let expr = SqlExpression::BinaryOp {
459 left: Box::new(SqlExpression::Column("col1".to_string())),
460 op: "+".to_string(),
461 right: Box::new(SqlExpression::Column("col2".to_string())),
462 };
463
464 let deps = analyze_dependencies(&expr);
465 assert!(deps.contains("col1"));
466 assert!(deps.contains("col2"));
467 assert_eq!(deps.len(), 2);
468 }
469}