1use std::collections::HashSet;
26
27use crate::ast::*;
28
29pub fn pushdown_predicates(statement: Statement) -> Statement {
37 match statement {
38 Statement::Select(sel) => Statement::Select(pushdown_select(sel)),
39 other => other,
40 }
41}
42
43fn pushdown_select(mut sel: SelectStatement) -> SelectStatement {
48 if let Some(from) = &mut sel.from {
51 recurse_into_source(&mut from.source);
52 }
53 for join in &mut sel.joins {
54 recurse_into_source(&mut join.table);
55 }
56 for cte in &mut sel.ctes {
57 *cte.query = pushdown_predicates(*cte.query.clone());
58 }
59
60 let where_clause = match sel.where_clause.take() {
62 Some(w) => w,
63 None => return sel,
64 };
65
66 let predicates = split_conjunction(where_clause);
67 let mut remaining: Vec<Expr> = Vec::new();
68
69 for pred in predicates {
70 if !is_pushable(&pred) {
71 remaining.push(pred);
72 continue;
73 }
74
75 let tables = referenced_tables(&pred);
76
77 let mut pushed = false;
79 if let Some(from) = &mut sel.from {
80 pushed = try_push_into_source(&mut from.source, &pred, &tables);
81 }
82
83 if !pushed {
85 for join in &mut sel.joins {
86 if try_push_into_join(join, &pred, &tables) {
87 pushed = true;
88 break;
89 }
90 }
91 }
92
93 if !pushed {
94 remaining.push(pred);
95 }
96 }
97
98 sel.where_clause = conjoin(remaining);
99 sel
100}
101
102fn try_push_into_source(
108 source: &mut TableSource,
109 pred: &Expr,
110 tables: &HashSet<String>,
111) -> bool {
112 match source {
113 TableSource::Subquery { query, alias } => {
114 let alias_name = match alias {
115 Some(a) => a.clone(),
116 None => return false,
117 };
118
119 if tables.is_empty() || !tables.iter().all(|t| t == &alias_name) {
121 return false;
122 }
123
124 let inner_sel = match query.as_mut() {
126 Statement::Select(sel) => sel,
127 _ => return false,
128 };
129
130 if !is_pushdown_safe_target(inner_sel) {
131 return false;
132 }
133
134 let rewritten = rewrite_predicate_for_derived_table(pred, &alias_name, inner_sel);
137 let rewritten = match rewritten {
138 Some(r) => r,
139 None => return false,
140 };
141
142 inner_sel.where_clause = match inner_sel.where_clause.take() {
144 Some(existing) => Some(Expr::BinaryOp {
145 left: Box::new(existing),
146 op: BinaryOperator::And,
147 right: Box::new(rewritten),
148 }),
149 None => Some(rewritten),
150 };
151
152 true
153 }
154 _ => false,
155 }
156}
157
158fn try_push_into_join(
161 join: &mut JoinClause,
162 pred: &Expr,
163 tables: &HashSet<String>,
164) -> bool {
165 if !matches!(join.join_type, JoinType::Inner | JoinType::Cross) {
168 return false;
169 }
170
171 let join_table = source_alias(&join.table);
173 let join_table = match join_table {
174 Some(t) => t,
175 None => return false,
176 };
177
178 if tables.is_empty() || tables.len() != 1 || !tables.contains(&join_table) {
180 return false;
181 }
182
183 if matches!(join.table, TableSource::Subquery { .. })
185 && try_push_into_source(&mut join.table, pred, tables)
186 {
187 return true;
188 }
189
190 join.on = match join.on.take() {
192 Some(existing) => Some(Expr::BinaryOp {
193 left: Box::new(existing),
194 op: BinaryOperator::And,
195 right: Box::new(pred.clone()),
196 }),
197 None => Some(pred.clone()),
198 };
199
200 true
201}
202
203fn recurse_into_source(source: &mut TableSource) {
209 match source {
210 TableSource::Subquery { query, .. } => {
211 *query = Box::new(pushdown_predicates(*query.clone()));
212 }
213 TableSource::Lateral { source } => {
214 recurse_into_source(source);
215 }
216 TableSource::Pivot { source, .. } | TableSource::Unpivot { source, .. } => {
217 recurse_into_source(source);
218 }
219 _ => {}
220 }
221}
222
223fn split_conjunction(expr: Expr) -> Vec<Expr> {
229 match expr {
230 Expr::BinaryOp {
231 left,
232 op: BinaryOperator::And,
233 right,
234 } => {
235 let mut result = split_conjunction(*left);
236 result.extend(split_conjunction(*right));
237 result
238 }
239 Expr::Nested(inner) => {
240 if matches!(
242 inner.as_ref(),
243 Expr::BinaryOp {
244 op: BinaryOperator::And,
245 ..
246 }
247 ) {
248 split_conjunction(*inner)
249 } else {
250 vec![Expr::Nested(inner)]
251 }
252 }
253 other => vec![other],
254 }
255}
256
257fn conjoin(predicates: Vec<Expr>) -> Option<Expr> {
259 predicates.into_iter().reduce(|a, b| Expr::BinaryOp {
260 left: Box::new(a),
261 op: BinaryOperator::And,
262 right: Box::new(b),
263 })
264}
265
266fn referenced_tables(expr: &Expr) -> HashSet<String> {
268 let mut tables = HashSet::new();
269 expr.walk(&mut |e| {
270 if let Expr::Column {
271 table: Some(t), ..
272 } = e
273 {
274 tables.insert(t.clone());
275 }
276 true
277 });
278 tables
279}
280
281fn is_pushable(expr: &Expr) -> bool {
289 let mut safe = true;
290 expr.walk(&mut |e| {
291 if !safe {
292 return false;
293 }
294 match e {
295 Expr::Subquery(_) | Expr::Exists { .. } | Expr::InSubquery { .. } => {
297 safe = false;
298 false
299 }
300 Expr::Function { name, .. } if is_aggregate_function(name) => {
302 safe = false;
303 false
304 }
305 Expr::Function {
307 over: Some(_), ..
308 }
309 | Expr::TypedFunction {
310 over: Some(_), ..
311 } => {
312 safe = false;
313 false
314 }
315 Expr::Function { name, .. } if is_nondeterministic(name) => {
317 safe = false;
318 false
319 }
320 Expr::TypedFunction {
321 func: TypedFunction::CurrentTimestamp,
322 ..
323 } => {
324 safe = false;
325 false
326 }
327 _ => true,
328 }
329 });
330 safe
331}
332
333fn is_pushdown_safe_target(sel: &SelectStatement) -> bool {
337 if sel.limit.is_some() || sel.offset.is_some() || sel.fetch_first.is_some() || sel.distinct {
338 return false;
339 }
340 for item in &sel.columns {
342 if let SelectItem::Expr { expr, .. } = item {
343 if contains_window_function(expr) {
344 return false;
345 }
346 }
347 }
348 true
349}
350
351fn contains_window_function(expr: &Expr) -> bool {
353 let mut has_window = false;
354 expr.walk(&mut |e| {
355 if has_window {
356 return false;
357 }
358 match e {
359 Expr::Function {
360 over: Some(_), ..
361 }
362 | Expr::TypedFunction {
363 over: Some(_), ..
364 } => {
365 has_window = true;
366 false
367 }
368 _ => true,
369 }
370 });
371 has_window
372}
373
374fn is_aggregate_function(name: &str) -> bool {
375 matches!(
376 name.to_uppercase().as_str(),
377 "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "GROUP_CONCAT"
378 | "STRING_AGG" | "ARRAY_AGG" | "LISTAGG"
379 | "STDDEV" | "STDDEV_POP" | "STDDEV_SAMP"
380 | "VARIANCE" | "VAR_POP" | "VAR_SAMP"
381 | "EVERY" | "ANY_VALUE" | "SOME"
382 | "BIT_AND" | "BIT_OR" | "BIT_XOR"
383 | "BOOL_AND" | "BOOL_OR"
384 | "CORR" | "COVAR_POP" | "COVAR_SAMP"
385 | "REGR_SLOPE" | "REGR_INTERCEPT"
386 | "PERCENTILE_CONT" | "PERCENTILE_DISC"
387 | "APPROX_COUNT_DISTINCT" | "HLL_COUNT" | "APPROX_DISTINCT"
388 )
389}
390
391fn is_nondeterministic(name: &str) -> bool {
392 matches!(
393 name.to_uppercase().as_str(),
394 "RAND" | "RANDOM" | "UUID" | "NEWID" | "GEN_RANDOM_UUID"
395 | "SYSDATE" | "SYSTIMESTAMP"
396 )
397}
398
399fn rewrite_predicate_for_derived_table(
416 pred: &Expr,
417 outer_alias: &str,
418 inner_sel: &SelectStatement,
419) -> Option<Expr> {
420 let column_map = build_column_map(inner_sel);
422
423 let mut can_rewrite = true;
425 pred.walk(&mut |e| {
426 if !can_rewrite {
427 return false;
428 }
429 if let Expr::Column {
430 table: Some(t),
431 name,
432 ..
433 } = e
434 {
435 if t == outer_alias && !column_map.contains_key(name.as_str()) {
436 can_rewrite = false;
437 }
438 }
439 can_rewrite
440 });
441
442 if !can_rewrite {
443 return None;
444 }
445
446 if !inner_sel.group_by.is_empty() {
449 let grouped_names: HashSet<String> = inner_sel
450 .group_by
451 .iter()
452 .filter_map(|e| match e {
453 Expr::Column { name, .. } => Some(name.clone()),
454 _ => None,
455 })
456 .collect();
457
458 let mut all_grouped = true;
459 pred.walk(&mut |e| {
460 if !all_grouped {
461 return false;
462 }
463 if let Expr::Column {
464 table: Some(t),
465 name,
466 ..
467 } = e
468 {
469 if t == outer_alias {
470 if let Some(inner_expr) = column_map.get(name.as_str()) {
471 let inner_name = match inner_expr {
472 Expr::Column { name: n, .. } => n.clone(),
473 _ => name.clone(),
474 };
475 if !grouped_names.contains(&inner_name) {
476 all_grouped = false;
477 }
478 }
479 }
480 }
481 all_grouped
482 });
483
484 if !all_grouped {
485 return None;
486 }
487 }
488
489 let rewritten = pred.clone().transform(&|e| match e {
491 Expr::Column {
492 table: Some(ref t),
493 ref name,
494 ..
495 } if t == outer_alias => {
496 if let Some(inner_expr) = column_map.get(name.as_str()) {
497 inner_expr.clone()
498 } else {
499 e
500 }
501 }
502 other => other,
503 });
504
505 Some(rewritten)
506}
507
508fn build_column_map(sel: &SelectStatement) -> std::collections::HashMap<&str, Expr> {
515 let mut map = std::collections::HashMap::new();
516
517 for item in &sel.columns {
518 match item {
519 SelectItem::Expr {
520 expr: Expr::Column { name, table, quote_style, table_quote_style },
521 alias,
522 } => {
523 let output_name = alias.as_deref().unwrap_or(name.as_str());
524 map.insert(
525 output_name,
526 Expr::Column {
527 table: table.clone(),
528 name: name.clone(),
529 quote_style: *quote_style,
530 table_quote_style: *table_quote_style,
531 },
532 );
533 }
534 SelectItem::Expr { expr, alias } => {
535 if let Some(alias) = alias {
536 map.insert(alias.as_str(), expr.clone());
537 }
538 }
539 SelectItem::Wildcard | SelectItem::QualifiedWildcard { .. } => {
540 }
544 }
545 }
546
547 map
548}
549
550fn source_alias(source: &TableSource) -> Option<String> {
552 match source {
553 TableSource::Table(t) => Some(t.alias.clone().unwrap_or_else(|| t.name.clone())),
554 TableSource::Subquery { alias, .. } => alias.clone(),
555 TableSource::TableFunction { alias, .. } => alias.clone(),
556 TableSource::Unnest { alias, .. } => alias.clone(),
557 TableSource::Lateral { source } => source_alias(source),
558 TableSource::Pivot { alias, .. } | TableSource::Unpivot { alias, .. } => alias.clone(),
559 }
560}
561
562#[cfg(test)]
567mod tests {
568 use super::*;
569
570 #[test]
571 fn test_split_conjunction_single() {
572 let expr = Expr::Boolean(true);
573 let parts = split_conjunction(expr);
574 assert_eq!(parts.len(), 1);
575 }
576
577 #[test]
578 fn test_split_conjunction_and() {
579 let expr = Expr::BinaryOp {
580 left: Box::new(Expr::Boolean(true)),
581 op: BinaryOperator::And,
582 right: Box::new(Expr::Boolean(false)),
583 };
584 let parts = split_conjunction(expr);
585 assert_eq!(parts.len(), 2);
586 }
587
588 #[test]
589 fn test_split_conjunction_nested_and() {
590 let expr = Expr::BinaryOp {
592 left: Box::new(Expr::BinaryOp {
593 left: Box::new(Expr::Column {
594 table: None,
595 name: "a".into(),
596 quote_style: QuoteStyle::None,
597 table_quote_style: QuoteStyle::None,
598 }),
599 op: BinaryOperator::And,
600 right: Box::new(Expr::Column {
601 table: None,
602 name: "b".into(),
603 quote_style: QuoteStyle::None,
604 table_quote_style: QuoteStyle::None,
605 }),
606 }),
607 op: BinaryOperator::And,
608 right: Box::new(Expr::Column {
609 table: None,
610 name: "c".into(),
611 quote_style: QuoteStyle::None,
612 table_quote_style: QuoteStyle::None,
613 }),
614 };
615 let parts = split_conjunction(expr);
616 assert_eq!(parts.len(), 3);
617 }
618
619 #[test]
620 fn test_conjoin_empty() {
621 assert!(conjoin(vec![]).is_none());
622 }
623
624 #[test]
625 fn test_conjoin_single() {
626 let r = conjoin(vec![Expr::Boolean(true)]);
627 assert_eq!(r, Some(Expr::Boolean(true)));
628 }
629
630 #[test]
631 fn test_is_pushable_simple_comparison() {
632 let expr = Expr::BinaryOp {
633 left: Box::new(Expr::Column {
634 table: Some("t".into()),
635 name: "x".into(),
636 quote_style: QuoteStyle::None,
637 table_quote_style: QuoteStyle::None,
638 }),
639 op: BinaryOperator::Gt,
640 right: Box::new(Expr::Number("5".into())),
641 };
642 assert!(is_pushable(&expr));
643 }
644
645 #[test]
646 fn test_is_pushable_rejects_aggregate() {
647 let expr = Expr::Function {
648 name: "COUNT".into(),
649 args: vec![Expr::Star],
650 distinct: false,
651 filter: None,
652 over: None,
653 };
654 assert!(!is_pushable(&expr));
655 }
656
657 #[test]
658 fn test_is_pushable_rejects_window() {
659 let expr = Expr::Function {
660 name: "ROW_NUMBER".into(),
661 args: vec![],
662 distinct: false,
663 filter: None,
664 over: Some(WindowSpec {
665 window_ref: None,
666 partition_by: vec![],
667 order_by: vec![],
668 frame: None,
669 }),
670 };
671 assert!(!is_pushable(&expr));
672 }
673
674 #[test]
675 fn test_is_pushable_rejects_subquery() {
676 let expr = Expr::Exists {
677 subquery: Box::new(Statement::Select(SelectStatement {
678 ctes: vec![],
679 distinct: false,
680 top: None,
681 columns: vec![],
682 from: None,
683 joins: vec![],
684 where_clause: None,
685 group_by: vec![],
686 having: None,
687 order_by: vec![],
688 limit: None,
689 offset: None,
690 fetch_first: None,
691 qualify: None,
692 window_definitions: vec![],
693 })),
694 negated: false,
695 };
696 assert!(!is_pushable(&expr));
697 }
698
699 #[test]
700 fn test_referenced_tables() {
701 let expr = Expr::BinaryOp {
702 left: Box::new(Expr::Column {
703 table: Some("a".into()),
704 name: "x".into(),
705 quote_style: QuoteStyle::None,
706 table_quote_style: QuoteStyle::None,
707 }),
708 op: BinaryOperator::Eq,
709 right: Box::new(Expr::Column {
710 table: Some("b".into()),
711 name: "y".into(),
712 quote_style: QuoteStyle::None,
713 table_quote_style: QuoteStyle::None,
714 }),
715 };
716 let tables = referenced_tables(&expr);
717 assert_eq!(tables.len(), 2);
718 assert!(tables.contains("a"));
719 assert!(tables.contains("b"));
720 }
721}