1use std::collections::HashSet;
26
27use crate::ast::*;
28
29pub fn pushdown_predicates(statement: Statement) -> Statement {
37 match statement {
38 Statement::Select(sel) => Statement::Select(pushdown_select(sel)),
39 other => other,
40 }
41}
42
43fn pushdown_select(mut sel: SelectStatement) -> SelectStatement {
48 if let Some(from) = &mut sel.from {
51 recurse_into_source(&mut from.source);
52 }
53 for join in &mut sel.joins {
54 recurse_into_source(&mut join.table);
55 }
56 for cte in &mut sel.ctes {
57 *cte.query = pushdown_predicates(*cte.query.clone());
58 }
59
60 let where_clause = match sel.where_clause.take() {
62 Some(w) => w,
63 None => return sel,
64 };
65
66 let predicates = split_conjunction(where_clause);
67 let mut remaining: Vec<Expr> = Vec::new();
68
69 for pred in predicates {
70 if !is_pushable(&pred) {
71 remaining.push(pred);
72 continue;
73 }
74
75 let tables = referenced_tables(&pred);
76
77 let mut pushed = false;
79 if let Some(from) = &mut sel.from {
80 pushed = try_push_into_source(&mut from.source, &pred, &tables);
81 }
82
83 if !pushed {
85 for join in &mut sel.joins {
86 if try_push_into_join(join, &pred, &tables) {
87 pushed = true;
88 break;
89 }
90 }
91 }
92
93 if !pushed {
94 remaining.push(pred);
95 }
96 }
97
98 sel.where_clause = conjoin(remaining);
99 sel
100}
101
102fn try_push_into_source(source: &mut TableSource, pred: &Expr, tables: &HashSet<String>) -> bool {
108 match source {
109 TableSource::Subquery { query, alias } => {
110 let alias_name = match alias {
111 Some(a) => a.clone(),
112 None => return false,
113 };
114
115 if tables.is_empty() || !tables.iter().all(|t| t == &alias_name) {
117 return false;
118 }
119
120 let inner_sel = match query.as_mut() {
122 Statement::Select(sel) => sel,
123 _ => return false,
124 };
125
126 if !is_pushdown_safe_target(inner_sel) {
127 return false;
128 }
129
130 let rewritten = rewrite_predicate_for_derived_table(pred, &alias_name, inner_sel);
133 let rewritten = match rewritten {
134 Some(r) => r,
135 None => return false,
136 };
137
138 inner_sel.where_clause = match inner_sel.where_clause.take() {
140 Some(existing) => Some(Expr::BinaryOp {
141 left: Box::new(existing),
142 op: BinaryOperator::And,
143 right: Box::new(rewritten),
144 }),
145 None => Some(rewritten),
146 };
147
148 true
149 }
150 _ => false,
151 }
152}
153
154fn try_push_into_join(join: &mut JoinClause, pred: &Expr, tables: &HashSet<String>) -> bool {
157 if !matches!(join.join_type, JoinType::Inner | JoinType::Cross) {
160 return false;
161 }
162
163 let join_table = source_alias(&join.table);
165 let join_table = match join_table {
166 Some(t) => t,
167 None => return false,
168 };
169
170 if tables.is_empty() || tables.len() != 1 || !tables.contains(&join_table) {
172 return false;
173 }
174
175 if matches!(join.table, TableSource::Subquery { .. })
177 && try_push_into_source(&mut join.table, pred, tables)
178 {
179 return true;
180 }
181
182 join.on = match join.on.take() {
184 Some(existing) => Some(Expr::BinaryOp {
185 left: Box::new(existing),
186 op: BinaryOperator::And,
187 right: Box::new(pred.clone()),
188 }),
189 None => Some(pred.clone()),
190 };
191
192 true
193}
194
195fn recurse_into_source(source: &mut TableSource) {
201 match source {
202 TableSource::Subquery { query, .. } => {
203 *query = Box::new(pushdown_predicates(*query.clone()));
204 }
205 TableSource::Lateral { source } => {
206 recurse_into_source(source);
207 }
208 TableSource::Pivot { source, .. } | TableSource::Unpivot { source, .. } => {
209 recurse_into_source(source);
210 }
211 _ => {}
212 }
213}
214
215fn split_conjunction(expr: Expr) -> Vec<Expr> {
221 match expr {
222 Expr::BinaryOp {
223 left,
224 op: BinaryOperator::And,
225 right,
226 } => {
227 let mut result = split_conjunction(*left);
228 result.extend(split_conjunction(*right));
229 result
230 }
231 Expr::Nested(inner) => {
232 if matches!(
234 inner.as_ref(),
235 Expr::BinaryOp {
236 op: BinaryOperator::And,
237 ..
238 }
239 ) {
240 split_conjunction(*inner)
241 } else {
242 vec![Expr::Nested(inner)]
243 }
244 }
245 other => vec![other],
246 }
247}
248
249fn conjoin(predicates: Vec<Expr>) -> Option<Expr> {
251 predicates.into_iter().reduce(|a, b| Expr::BinaryOp {
252 left: Box::new(a),
253 op: BinaryOperator::And,
254 right: Box::new(b),
255 })
256}
257
258fn referenced_tables(expr: &Expr) -> HashSet<String> {
260 let mut tables = HashSet::new();
261 expr.walk(&mut |e| {
262 if let Expr::Column { table: Some(t), .. } = e {
263 tables.insert(t.clone());
264 }
265 true
266 });
267 tables
268}
269
270fn is_pushable(expr: &Expr) -> bool {
278 let mut safe = true;
279 expr.walk(&mut |e| {
280 if !safe {
281 return false;
282 }
283 match e {
284 Expr::Subquery(_) | Expr::Exists { .. } | Expr::InSubquery { .. } => {
286 safe = false;
287 false
288 }
289 Expr::Function { name, .. } if is_aggregate_function(name) => {
291 safe = false;
292 false
293 }
294 Expr::Function { over: Some(_), .. } | Expr::TypedFunction { over: Some(_), .. } => {
296 safe = false;
297 false
298 }
299 Expr::Function { name, .. } if is_nondeterministic(name) => {
301 safe = false;
302 false
303 }
304 Expr::TypedFunction {
305 func: TypedFunction::CurrentTimestamp,
306 ..
307 } => {
308 safe = false;
309 false
310 }
311 _ => true,
312 }
313 });
314 safe
315}
316
317fn is_pushdown_safe_target(sel: &SelectStatement) -> bool {
321 if sel.limit.is_some() || sel.offset.is_some() || sel.fetch_first.is_some() || sel.distinct {
322 return false;
323 }
324 for item in &sel.columns {
326 if let SelectItem::Expr { expr, .. } = item {
327 if contains_window_function(expr) {
328 return false;
329 }
330 }
331 }
332 true
333}
334
335fn contains_window_function(expr: &Expr) -> bool {
337 let mut has_window = false;
338 expr.walk(&mut |e| {
339 if has_window {
340 return false;
341 }
342 match e {
343 Expr::Function { over: Some(_), .. } | Expr::TypedFunction { over: Some(_), .. } => {
344 has_window = true;
345 false
346 }
347 _ => true,
348 }
349 });
350 has_window
351}
352
353fn is_aggregate_function(name: &str) -> bool {
354 matches!(
355 name.to_uppercase().as_str(),
356 "COUNT"
357 | "SUM"
358 | "AVG"
359 | "MIN"
360 | "MAX"
361 | "GROUP_CONCAT"
362 | "STRING_AGG"
363 | "ARRAY_AGG"
364 | "LISTAGG"
365 | "STDDEV"
366 | "STDDEV_POP"
367 | "STDDEV_SAMP"
368 | "VARIANCE"
369 | "VAR_POP"
370 | "VAR_SAMP"
371 | "EVERY"
372 | "ANY_VALUE"
373 | "SOME"
374 | "BIT_AND"
375 | "BIT_OR"
376 | "BIT_XOR"
377 | "BOOL_AND"
378 | "BOOL_OR"
379 | "CORR"
380 | "COVAR_POP"
381 | "COVAR_SAMP"
382 | "REGR_SLOPE"
383 | "REGR_INTERCEPT"
384 | "PERCENTILE_CONT"
385 | "PERCENTILE_DISC"
386 | "APPROX_COUNT_DISTINCT"
387 | "HLL_COUNT"
388 | "APPROX_DISTINCT"
389 )
390}
391
392fn is_nondeterministic(name: &str) -> bool {
393 matches!(
394 name.to_uppercase().as_str(),
395 "RAND" | "RANDOM" | "UUID" | "NEWID" | "GEN_RANDOM_UUID" | "SYSDATE" | "SYSTIMESTAMP"
396 )
397}
398
399fn rewrite_predicate_for_derived_table(
416 pred: &Expr,
417 outer_alias: &str,
418 inner_sel: &SelectStatement,
419) -> Option<Expr> {
420 let column_map = build_column_map(inner_sel);
422
423 let mut can_rewrite = true;
425 pred.walk(&mut |e| {
426 if !can_rewrite {
427 return false;
428 }
429 if let Expr::Column {
430 table: Some(t),
431 name,
432 ..
433 } = e
434 {
435 if t == outer_alias && !column_map.contains_key(name.as_str()) {
436 can_rewrite = false;
437 }
438 }
439 can_rewrite
440 });
441
442 if !can_rewrite {
443 return None;
444 }
445
446 if !inner_sel.group_by.is_empty() {
449 let grouped_names: HashSet<String> = inner_sel
450 .group_by
451 .iter()
452 .filter_map(|e| match e {
453 Expr::Column { name, .. } => Some(name.clone()),
454 _ => None,
455 })
456 .collect();
457
458 let mut all_grouped = true;
459 pred.walk(&mut |e| {
460 if !all_grouped {
461 return false;
462 }
463 if let Expr::Column {
464 table: Some(t),
465 name,
466 ..
467 } = e
468 {
469 if t == outer_alias {
470 if let Some(inner_expr) = column_map.get(name.as_str()) {
471 let inner_name = match inner_expr {
472 Expr::Column { name: n, .. } => n.clone(),
473 _ => name.clone(),
474 };
475 if !grouped_names.contains(&inner_name) {
476 all_grouped = false;
477 }
478 }
479 }
480 }
481 all_grouped
482 });
483
484 if !all_grouped {
485 return None;
486 }
487 }
488
489 let rewritten = pred.clone().transform(&|e| match e {
491 Expr::Column {
492 table: Some(ref t),
493 ref name,
494 ..
495 } if t == outer_alias => {
496 if let Some(inner_expr) = column_map.get(name.as_str()) {
497 inner_expr.clone()
498 } else {
499 e
500 }
501 }
502 other => other,
503 });
504
505 Some(rewritten)
506}
507
508fn build_column_map(sel: &SelectStatement) -> std::collections::HashMap<&str, Expr> {
515 let mut map = std::collections::HashMap::new();
516
517 for item in &sel.columns {
518 match item {
519 SelectItem::Expr {
520 expr:
521 Expr::Column {
522 name,
523 table,
524 quote_style,
525 table_quote_style,
526 },
527 alias,
528 } => {
529 let output_name = alias.as_deref().unwrap_or(name.as_str());
530 map.insert(
531 output_name,
532 Expr::Column {
533 table: table.clone(),
534 name: name.clone(),
535 quote_style: *quote_style,
536 table_quote_style: *table_quote_style,
537 },
538 );
539 }
540 SelectItem::Expr { expr, alias } => {
541 if let Some(alias) = alias {
542 map.insert(alias.as_str(), expr.clone());
543 }
544 }
545 SelectItem::Wildcard | SelectItem::QualifiedWildcard { .. } => {
546 }
550 }
551 }
552
553 map
554}
555
556fn source_alias(source: &TableSource) -> Option<String> {
558 match source {
559 TableSource::Table(t) => Some(t.alias.clone().unwrap_or_else(|| t.name.clone())),
560 TableSource::Subquery { alias, .. } => alias.clone(),
561 TableSource::TableFunction { alias, .. } => alias.clone(),
562 TableSource::Unnest { alias, .. } => alias.clone(),
563 TableSource::Lateral { source } => source_alias(source),
564 TableSource::Pivot { alias, .. } | TableSource::Unpivot { alias, .. } => alias.clone(),
565 }
566}
567
568#[cfg(test)]
573mod tests {
574 use super::*;
575
576 #[test]
577 fn test_split_conjunction_single() {
578 let expr = Expr::Boolean(true);
579 let parts = split_conjunction(expr);
580 assert_eq!(parts.len(), 1);
581 }
582
583 #[test]
584 fn test_split_conjunction_and() {
585 let expr = Expr::BinaryOp {
586 left: Box::new(Expr::Boolean(true)),
587 op: BinaryOperator::And,
588 right: Box::new(Expr::Boolean(false)),
589 };
590 let parts = split_conjunction(expr);
591 assert_eq!(parts.len(), 2);
592 }
593
594 #[test]
595 fn test_split_conjunction_nested_and() {
596 let expr = Expr::BinaryOp {
598 left: Box::new(Expr::BinaryOp {
599 left: Box::new(Expr::Column {
600 table: None,
601 name: "a".into(),
602 quote_style: QuoteStyle::None,
603 table_quote_style: QuoteStyle::None,
604 }),
605 op: BinaryOperator::And,
606 right: Box::new(Expr::Column {
607 table: None,
608 name: "b".into(),
609 quote_style: QuoteStyle::None,
610 table_quote_style: QuoteStyle::None,
611 }),
612 }),
613 op: BinaryOperator::And,
614 right: Box::new(Expr::Column {
615 table: None,
616 name: "c".into(),
617 quote_style: QuoteStyle::None,
618 table_quote_style: QuoteStyle::None,
619 }),
620 };
621 let parts = split_conjunction(expr);
622 assert_eq!(parts.len(), 3);
623 }
624
625 #[test]
626 fn test_conjoin_empty() {
627 assert!(conjoin(vec![]).is_none());
628 }
629
630 #[test]
631 fn test_conjoin_single() {
632 let r = conjoin(vec![Expr::Boolean(true)]);
633 assert_eq!(r, Some(Expr::Boolean(true)));
634 }
635
636 #[test]
637 fn test_is_pushable_simple_comparison() {
638 let expr = Expr::BinaryOp {
639 left: Box::new(Expr::Column {
640 table: Some("t".into()),
641 name: "x".into(),
642 quote_style: QuoteStyle::None,
643 table_quote_style: QuoteStyle::None,
644 }),
645 op: BinaryOperator::Gt,
646 right: Box::new(Expr::Number("5".into())),
647 };
648 assert!(is_pushable(&expr));
649 }
650
651 #[test]
652 fn test_is_pushable_rejects_aggregate() {
653 let expr = Expr::Function {
654 name: "COUNT".into(),
655 args: vec![Expr::Star],
656 distinct: false,
657 filter: None,
658 over: None,
659 };
660 assert!(!is_pushable(&expr));
661 }
662
663 #[test]
664 fn test_is_pushable_rejects_window() {
665 let expr = Expr::Function {
666 name: "ROW_NUMBER".into(),
667 args: vec![],
668 distinct: false,
669 filter: None,
670 over: Some(WindowSpec {
671 window_ref: None,
672 partition_by: vec![],
673 order_by: vec![],
674 frame: None,
675 }),
676 };
677 assert!(!is_pushable(&expr));
678 }
679
680 #[test]
681 fn test_is_pushable_rejects_subquery() {
682 let expr = Expr::Exists {
683 subquery: Box::new(Statement::Select(SelectStatement {
684 ctes: vec![],
685 distinct: false,
686 top: None,
687 columns: vec![],
688 from: None,
689 joins: vec![],
690 where_clause: None,
691 group_by: vec![],
692 having: None,
693 order_by: vec![],
694 limit: None,
695 offset: None,
696 fetch_first: None,
697 qualify: None,
698 window_definitions: vec![],
699 })),
700 negated: false,
701 };
702 assert!(!is_pushable(&expr));
703 }
704
705 #[test]
706 fn test_referenced_tables() {
707 let expr = Expr::BinaryOp {
708 left: Box::new(Expr::Column {
709 table: Some("a".into()),
710 name: "x".into(),
711 quote_style: QuoteStyle::None,
712 table_quote_style: QuoteStyle::None,
713 }),
714 op: BinaryOperator::Eq,
715 right: Box::new(Expr::Column {
716 table: Some("b".into()),
717 name: "y".into(),
718 quote_style: QuoteStyle::None,
719 table_quote_style: QuoteStyle::None,
720 }),
721 };
722 let tables = referenced_tables(&expr);
723 assert_eq!(tables.len(), 2);
724 assert!(tables.contains("a"));
725 assert!(tables.contains("b"));
726 }
727}