1use std::collections::HashSet;
26
27use crate::ast::*;
28
29pub fn pushdown_predicates(statement: Statement) -> Statement {
37 match statement {
38 Statement::Select(sel) => Statement::Select(pushdown_select(sel)),
39 other => other,
40 }
41}
42
43fn pushdown_select(mut sel: SelectStatement) -> SelectStatement {
48 if let Some(from) = &mut sel.from {
51 recurse_into_source(&mut from.source);
52 }
53 for join in &mut sel.joins {
54 recurse_into_source(&mut join.table);
55 }
56 for cte in &mut sel.ctes {
57 *cte.query = pushdown_predicates(*cte.query.clone());
58 }
59
60 let where_clause = match sel.where_clause.take() {
62 Some(w) => w,
63 None => return sel,
64 };
65
66 let predicates = split_conjunction(where_clause);
67 let mut remaining: Vec<Expr> = Vec::new();
68
69 for pred in predicates {
70 if !is_pushable(&pred) {
71 remaining.push(pred);
72 continue;
73 }
74
75 let tables = referenced_tables(&pred);
76
77 let mut pushed = false;
79 if let Some(from) = &mut sel.from {
80 pushed = try_push_into_source(&mut from.source, &pred, &tables);
81 }
82
83 if !pushed {
85 for join in &mut sel.joins {
86 if try_push_into_join(join, &pred, &tables) {
87 pushed = true;
88 break;
89 }
90 }
91 }
92
93 if !pushed {
94 remaining.push(pred);
95 }
96 }
97
98 sel.where_clause = conjoin(remaining);
99 sel
100}
101
102fn try_push_into_source(source: &mut TableSource, pred: &Expr, tables: &HashSet<String>) -> bool {
108 match source {
109 TableSource::Subquery { query, alias, .. } => {
110 let alias_name = match alias {
111 Some(a) => a.clone(),
112 None => return false,
113 };
114
115 if tables.is_empty() || !tables.iter().all(|t| t == &alias_name) {
117 return false;
118 }
119
120 let inner_sel = match query.as_mut() {
122 Statement::Select(sel) => sel,
123 _ => return false,
124 };
125
126 if !is_pushdown_safe_target(inner_sel) {
127 return false;
128 }
129
130 let rewritten = rewrite_predicate_for_derived_table(pred, &alias_name, inner_sel);
133 let rewritten = match rewritten {
134 Some(r) => r,
135 None => return false,
136 };
137
138 inner_sel.where_clause = match inner_sel.where_clause.take() {
140 Some(existing) => Some(Expr::BinaryOp {
141 left: Box::new(existing),
142 op: BinaryOperator::And,
143 right: Box::new(rewritten),
144 }),
145 None => Some(rewritten),
146 };
147
148 true
149 }
150 _ => false,
151 }
152}
153
154fn try_push_into_join(join: &mut JoinClause, pred: &Expr, tables: &HashSet<String>) -> bool {
157 if !matches!(join.join_type, JoinType::Inner | JoinType::Cross) {
160 return false;
161 }
162
163 let join_table = source_alias(&join.table);
165 let join_table = match join_table {
166 Some(t) => t,
167 None => return false,
168 };
169
170 if tables.is_empty() || tables.len() != 1 || !tables.contains(&join_table) {
172 return false;
173 }
174
175 if matches!(join.table, TableSource::Subquery { .. })
177 && try_push_into_source(&mut join.table, pred, tables)
178 {
179 return true;
180 }
181
182 join.on = match join.on.take() {
184 Some(existing) => Some(Expr::BinaryOp {
185 left: Box::new(existing),
186 op: BinaryOperator::And,
187 right: Box::new(pred.clone()),
188 }),
189 None => Some(pred.clone()),
190 };
191
192 true
193}
194
195fn recurse_into_source(source: &mut TableSource) {
201 match source {
202 TableSource::Subquery { query, .. } => {
203 *query = Box::new(pushdown_predicates(*query.clone()));
204 }
205 TableSource::Lateral { source } => {
206 recurse_into_source(source);
207 }
208 TableSource::Pivot { source, .. } | TableSource::Unpivot { source, .. } => {
209 recurse_into_source(source);
210 }
211 _ => {}
212 }
213}
214
215fn split_conjunction(expr: Expr) -> Vec<Expr> {
221 match expr {
222 Expr::BinaryOp {
223 left,
224 op: BinaryOperator::And,
225 right,
226 } => {
227 let mut result = split_conjunction(*left);
228 result.extend(split_conjunction(*right));
229 result
230 }
231 Expr::Nested(inner) => {
232 if matches!(
234 inner.as_ref(),
235 Expr::BinaryOp {
236 op: BinaryOperator::And,
237 ..
238 }
239 ) {
240 split_conjunction(*inner)
241 } else {
242 vec![Expr::Nested(inner)]
243 }
244 }
245 other => vec![other],
246 }
247}
248
249fn conjoin(predicates: Vec<Expr>) -> Option<Expr> {
251 predicates.into_iter().reduce(|a, b| Expr::BinaryOp {
252 left: Box::new(a),
253 op: BinaryOperator::And,
254 right: Box::new(b),
255 })
256}
257
258fn referenced_tables(expr: &Expr) -> HashSet<String> {
260 let mut tables = HashSet::new();
261 expr.walk(&mut |e| {
262 if let Expr::Column { table: Some(t), .. } = e {
263 tables.insert(t.clone());
264 }
265 true
266 });
267 tables
268}
269
270fn is_pushable(expr: &Expr) -> bool {
278 let mut safe = true;
279 expr.walk(&mut |e| {
280 if !safe {
281 return false;
282 }
283 match e {
284 Expr::Subquery(_) | Expr::Exists { .. } | Expr::InSubquery { .. } => {
286 safe = false;
287 false
288 }
289 Expr::Function { name, .. } if is_aggregate_function(name) => {
291 safe = false;
292 false
293 }
294 Expr::Function { over: Some(_), .. } | Expr::TypedFunction { over: Some(_), .. } => {
296 safe = false;
297 false
298 }
299 Expr::Function { name, .. } if is_nondeterministic(name) => {
301 safe = false;
302 false
303 }
304 Expr::TypedFunction {
305 func: TypedFunction::CurrentTimestamp,
306 ..
307 } => {
308 safe = false;
309 false
310 }
311 _ => true,
312 }
313 });
314 safe
315}
316
317fn is_pushdown_safe_target(sel: &SelectStatement) -> bool {
321 if sel.limit.is_some() || sel.offset.is_some() || sel.fetch_first.is_some() || sel.distinct {
322 return false;
323 }
324 for item in &sel.columns {
326 if let SelectItem::Expr { expr, .. } = item {
327 if contains_window_function(expr) {
328 return false;
329 }
330 }
331 }
332 true
333}
334
335fn contains_window_function(expr: &Expr) -> bool {
337 let mut has_window = false;
338 expr.walk(&mut |e| {
339 if has_window {
340 return false;
341 }
342 match e {
343 Expr::Function { over: Some(_), .. } | Expr::TypedFunction { over: Some(_), .. } => {
344 has_window = true;
345 false
346 }
347 _ => true,
348 }
349 });
350 has_window
351}
352
353fn is_aggregate_function(name: &str) -> bool {
354 matches!(
355 name.to_uppercase().as_str(),
356 "COUNT"
357 | "SUM"
358 | "AVG"
359 | "MIN"
360 | "MAX"
361 | "GROUP_CONCAT"
362 | "STRING_AGG"
363 | "ARRAY_AGG"
364 | "LISTAGG"
365 | "STDDEV"
366 | "STDDEV_POP"
367 | "STDDEV_SAMP"
368 | "VARIANCE"
369 | "VAR_POP"
370 | "VAR_SAMP"
371 | "EVERY"
372 | "ANY_VALUE"
373 | "SOME"
374 | "BIT_AND"
375 | "BIT_OR"
376 | "BIT_XOR"
377 | "BOOL_AND"
378 | "BOOL_OR"
379 | "CORR"
380 | "COVAR_POP"
381 | "COVAR_SAMP"
382 | "REGR_SLOPE"
383 | "REGR_INTERCEPT"
384 | "PERCENTILE_CONT"
385 | "PERCENTILE_DISC"
386 | "APPROX_COUNT_DISTINCT"
387 | "HLL_COUNT"
388 | "APPROX_DISTINCT"
389 )
390}
391
392fn is_nondeterministic(name: &str) -> bool {
393 matches!(
394 name.to_uppercase().as_str(),
395 "RAND" | "RANDOM" | "UUID" | "NEWID" | "GEN_RANDOM_UUID" | "SYSDATE" | "SYSTIMESTAMP"
396 )
397}
398
399fn rewrite_predicate_for_derived_table(
416 pred: &Expr,
417 outer_alias: &str,
418 inner_sel: &SelectStatement,
419) -> Option<Expr> {
420 let column_map = build_column_map(inner_sel);
422
423 let mut can_rewrite = true;
425 pred.walk(&mut |e| {
426 if !can_rewrite {
427 return false;
428 }
429 if let Expr::Column {
430 table: Some(t),
431 name,
432 ..
433 } = e
434 {
435 if t == outer_alias && !column_map.contains_key(name.as_str()) {
436 can_rewrite = false;
437 }
438 }
439 can_rewrite
440 });
441
442 if !can_rewrite {
443 return None;
444 }
445
446 if !inner_sel.group_by.is_empty() {
449 let grouped_names: HashSet<String> = inner_sel
450 .group_by
451 .iter()
452 .filter_map(|e| match e {
453 Expr::Column { name, .. } => Some(name.clone()),
454 _ => None,
455 })
456 .collect();
457
458 let mut all_grouped = true;
459 pred.walk(&mut |e| {
460 if !all_grouped {
461 return false;
462 }
463 if let Expr::Column {
464 table: Some(t),
465 name,
466 ..
467 } = e
468 {
469 if t == outer_alias {
470 if let Some(inner_expr) = column_map.get(name.as_str()) {
471 let inner_name = match inner_expr {
472 Expr::Column { name: n, .. } => n.clone(),
473 _ => name.clone(),
474 };
475 if !grouped_names.contains(&inner_name) {
476 all_grouped = false;
477 }
478 }
479 }
480 }
481 all_grouped
482 });
483
484 if !all_grouped {
485 return None;
486 }
487 }
488
489 let rewritten = pred.clone().transform(&|e| match e {
491 Expr::Column {
492 table: Some(ref t),
493 ref name,
494 ..
495 } if t == outer_alias => {
496 if let Some(inner_expr) = column_map.get(name.as_str()) {
497 inner_expr.clone()
498 } else {
499 e
500 }
501 }
502 other => other,
503 });
504
505 Some(rewritten)
506}
507
508fn build_column_map(sel: &SelectStatement) -> std::collections::HashMap<&str, Expr> {
515 let mut map = std::collections::HashMap::new();
516
517 for item in &sel.columns {
518 match item {
519 SelectItem::Expr {
520 expr:
521 Expr::Column {
522 name,
523 table,
524 quote_style,
525 table_quote_style,
526 },
527 alias,
528 ..
529 } => {
530 let output_name = alias.as_deref().unwrap_or(name.as_str());
531 map.insert(
532 output_name,
533 Expr::Column {
534 table: table.clone(),
535 name: name.clone(),
536 quote_style: *quote_style,
537 table_quote_style: *table_quote_style,
538 },
539 );
540 }
541 SelectItem::Expr { expr, alias, .. } => {
542 if let Some(alias) = alias {
543 map.insert(alias.as_str(), expr.clone());
544 }
545 }
546 SelectItem::Wildcard | SelectItem::QualifiedWildcard { .. } => {
547 }
551 }
552 }
553
554 map
555}
556
557fn source_alias(source: &TableSource) -> Option<String> {
559 match source {
560 TableSource::Table(t) => Some(t.alias.clone().unwrap_or_else(|| t.name.clone())),
561 TableSource::Subquery { alias, .. } => alias.clone(),
562 TableSource::TableFunction { alias, .. } => alias.clone(),
563 TableSource::Unnest { alias, .. } => alias.clone(),
564 TableSource::Lateral { source } => source_alias(source),
565 TableSource::Pivot { alias, .. } | TableSource::Unpivot { alias, .. } => alias.clone(),
566 }
567}
568
569#[cfg(test)]
574mod tests {
575 use super::*;
576
577 #[test]
578 fn test_split_conjunction_single() {
579 let expr = Expr::Boolean(true);
580 let parts = split_conjunction(expr);
581 assert_eq!(parts.len(), 1);
582 }
583
584 #[test]
585 fn test_split_conjunction_and() {
586 let expr = Expr::BinaryOp {
587 left: Box::new(Expr::Boolean(true)),
588 op: BinaryOperator::And,
589 right: Box::new(Expr::Boolean(false)),
590 };
591 let parts = split_conjunction(expr);
592 assert_eq!(parts.len(), 2);
593 }
594
595 #[test]
596 fn test_split_conjunction_nested_and() {
597 let expr = Expr::BinaryOp {
599 left: Box::new(Expr::BinaryOp {
600 left: Box::new(Expr::Column {
601 table: None,
602 name: "a".into(),
603 quote_style: QuoteStyle::None,
604 table_quote_style: QuoteStyle::None,
605 }),
606 op: BinaryOperator::And,
607 right: Box::new(Expr::Column {
608 table: None,
609 name: "b".into(),
610 quote_style: QuoteStyle::None,
611 table_quote_style: QuoteStyle::None,
612 }),
613 }),
614 op: BinaryOperator::And,
615 right: Box::new(Expr::Column {
616 table: None,
617 name: "c".into(),
618 quote_style: QuoteStyle::None,
619 table_quote_style: QuoteStyle::None,
620 }),
621 };
622 let parts = split_conjunction(expr);
623 assert_eq!(parts.len(), 3);
624 }
625
626 #[test]
627 fn test_conjoin_empty() {
628 assert!(conjoin(vec![]).is_none());
629 }
630
631 #[test]
632 fn test_conjoin_single() {
633 let r = conjoin(vec![Expr::Boolean(true)]);
634 assert_eq!(r, Some(Expr::Boolean(true)));
635 }
636
637 #[test]
638 fn test_is_pushable_simple_comparison() {
639 let expr = Expr::BinaryOp {
640 left: Box::new(Expr::Column {
641 table: Some("t".into()),
642 name: "x".into(),
643 quote_style: QuoteStyle::None,
644 table_quote_style: QuoteStyle::None,
645 }),
646 op: BinaryOperator::Gt,
647 right: Box::new(Expr::Number("5".into())),
648 };
649 assert!(is_pushable(&expr));
650 }
651
652 #[test]
653 fn test_is_pushable_rejects_aggregate() {
654 let expr = Expr::Function {
655 name: "COUNT".into(),
656 args: vec![Expr::Star],
657 distinct: false,
658 filter: None,
659 over: None,
660 };
661 assert!(!is_pushable(&expr));
662 }
663
664 #[test]
665 fn test_is_pushable_rejects_window() {
666 let expr = Expr::Function {
667 name: "ROW_NUMBER".into(),
668 args: vec![],
669 distinct: false,
670 filter: None,
671 over: Some(WindowSpec {
672 window_ref: None,
673 partition_by: vec![],
674 order_by: vec![],
675 frame: None,
676 }),
677 };
678 assert!(!is_pushable(&expr));
679 }
680
681 #[test]
682 fn test_is_pushable_rejects_subquery() {
683 let expr = Expr::Exists {
684 subquery: Box::new(Statement::Select(SelectStatement {
685 comments: vec![],
686 ctes: vec![],
687 distinct: false,
688 top: None,
689 columns: vec![],
690 from: None,
691 joins: vec![],
692 where_clause: None,
693 group_by: vec![],
694 having: None,
695 order_by: vec![],
696 limit: None,
697 offset: None,
698 fetch_first: None,
699 qualify: None,
700 window_definitions: vec![],
701 })),
702 negated: false,
703 };
704 assert!(!is_pushable(&expr));
705 }
706
707 #[test]
708 fn test_referenced_tables() {
709 let expr = Expr::BinaryOp {
710 left: Box::new(Expr::Column {
711 table: Some("a".into()),
712 name: "x".into(),
713 quote_style: QuoteStyle::None,
714 table_quote_style: QuoteStyle::None,
715 }),
716 op: BinaryOperator::Eq,
717 right: Box::new(Expr::Column {
718 table: Some("b".into()),
719 name: "y".into(),
720 quote_style: QuoteStyle::None,
721 table_quote_style: QuoteStyle::None,
722 }),
723 };
724 let tables = referenced_tables(&expr);
725 assert_eq!(tables.len(), 2);
726 assert!(tables.contains("a"));
727 assert!(tables.contains("b"));
728 }
729}