Skip to main content

sqlglot_rust/optimizer/
pushdown_predicates.rs

1/// Push WHERE predicates into subqueries, derived tables, and JOIN conditions.
2///
3/// This is a standard query optimization that reduces the data processed by
4/// inner queries. The pass splits AND-conjunctions and pushes each conjunct
5/// as far down the query tree as safety rules allow.
6///
7/// ## Supported rewrites
8///
9/// | Pattern | Rewrite |
10/// |---|---|
11/// | `SELECT … FROM (SELECT … FROM t) AS s WHERE s.x > 5` | Push `x > 5` into derived table WHERE |
12/// | `SELECT … FROM a JOIN b ON … WHERE a.x > 5` | Move `a.x > 5` into JOIN ON (inner joins) |
13/// | `WITH c AS (SELECT …) SELECT … FROM c WHERE c.x > 5` | Push into CTE body |
14///
15/// ## Safety
16///
17/// The pass does **not** push predicates:
18/// - Through LIMIT / OFFSET / FETCH FIRST
19/// - Through DISTINCT
20/// - Through GROUP BY (unless predicate only references grouped columns)
21/// - Through window functions in the SELECT list
22/// - When a predicate contains non-deterministic functions (RAND, RANDOM, etc.)
23/// - When a predicate contains aggregate functions
24/// - When a predicate contains subqueries
25use std::collections::HashSet;
26
27use crate::ast::*;
28
29// ═══════════════════════════════════════════════════════════════════════
30// Public API
31// ═══════════════════════════════════════════════════════════════════════
32
33/// Apply predicate pushdown to a statement.
34///
35/// Returns the statement unchanged if no predicates can be pushed down.
36pub fn pushdown_predicates(statement: Statement) -> Statement {
37    match statement {
38        Statement::Select(sel) => Statement::Select(pushdown_select(sel)),
39        other => other,
40    }
41}
42
43// ═══════════════════════════════════════════════════════════════════════
44// Core logic
45// ═══════════════════════════════════════════════════════════════════════
46
47fn pushdown_select(mut sel: SelectStatement) -> SelectStatement {
48    // First, recursively pushdown in any nested derived tables and CTEs,
49    // regardless of whether *this* level has a WHERE clause.
50    if let Some(from) = &mut sel.from {
51        recurse_into_source(&mut from.source);
52    }
53    for join in &mut sel.joins {
54        recurse_into_source(&mut join.table);
55    }
56    for cte in &mut sel.ctes {
57        *cte.query = pushdown_predicates(*cte.query.clone());
58    }
59
60    // Now try to push this level's WHERE predicates down.
61    let where_clause = match sel.where_clause.take() {
62        Some(w) => w,
63        None => return sel,
64    };
65
66    let predicates = split_conjunction(where_clause);
67    let mut remaining: Vec<Expr> = Vec::new();
68
69    for pred in predicates {
70        if !is_pushable(&pred) {
71            remaining.push(pred);
72            continue;
73        }
74
75        let tables = referenced_tables(&pred);
76
77        // Try pushing into FROM derived table
78        let mut pushed = false;
79        if let Some(from) = &mut sel.from {
80            pushed = try_push_into_source(&mut from.source, &pred, &tables);
81        }
82
83        // Try pushing into JOIN ON conditions (inner joins only)
84        if !pushed {
85            for join in &mut sel.joins {
86                if try_push_into_join(join, &pred, &tables) {
87                    pushed = true;
88                    break;
89                }
90            }
91        }
92
93        if !pushed {
94            remaining.push(pred);
95        }
96    }
97
98    sel.where_clause = conjoin(remaining);
99    sel
100}
101
102// ═══════════════════════════════════════════════════════════════════════
103// Push into derived tables (subqueries in FROM)
104// ═══════════════════════════════════════════════════════════════════════
105
106/// Try to push a predicate into a table source. Returns true if pushed.
107fn try_push_into_source(source: &mut TableSource, pred: &Expr, tables: &HashSet<String>) -> bool {
108    match source {
109        TableSource::Subquery { query, alias } => {
110            let alias_name = match alias {
111                Some(a) => a.clone(),
112                None => return false,
113            };
114
115            // Predicate must reference only this derived table alias
116            if tables.is_empty() || !tables.iter().all(|t| t == &alias_name) {
117                return false;
118            }
119
120            // Check the inner query is a simple SELECT we can push into
121            let inner_sel = match query.as_mut() {
122                Statement::Select(sel) => sel,
123                _ => return false,
124            };
125
126            if !is_pushdown_safe_target(inner_sel) {
127                return false;
128            }
129
130            // Rewrite column references: strip the outer alias qualifier
131            // and map to inner column names using the SELECT list.
132            let rewritten = rewrite_predicate_for_derived_table(pred, &alias_name, inner_sel);
133            let rewritten = match rewritten {
134                Some(r) => r,
135                None => return false,
136            };
137
138            // Push into the inner WHERE
139            inner_sel.where_clause = match inner_sel.where_clause.take() {
140                Some(existing) => Some(Expr::BinaryOp {
141                    left: Box::new(existing),
142                    op: BinaryOperator::And,
143                    right: Box::new(rewritten),
144                }),
145                None => Some(rewritten),
146            };
147
148            true
149        }
150        _ => false,
151    }
152}
153
154/// Try to push a predicate into a JOIN's ON condition.
155/// Only safe for INNER and CROSS joins.
156fn try_push_into_join(join: &mut JoinClause, pred: &Expr, tables: &HashSet<String>) -> bool {
157    // Only push into inner joins — pushing into LEFT/RIGHT/FULL
158    // changes semantics.
159    if !matches!(join.join_type, JoinType::Inner | JoinType::Cross) {
160        return false;
161    }
162
163    // Get the table name/alias of this join's source
164    let join_table = source_alias(&join.table);
165    let join_table = match join_table {
166        Some(t) => t,
167        None => return false,
168    };
169
170    // Predicate must reference only the join's table
171    if tables.is_empty() || tables.len() != 1 || !tables.contains(&join_table) {
172        return false;
173    }
174
175    // Also try pushing into a derived-table join source
176    if matches!(join.table, TableSource::Subquery { .. })
177        && try_push_into_source(&mut join.table, pred, tables)
178    {
179        return true;
180    }
181
182    // Push into the ON condition
183    join.on = match join.on.take() {
184        Some(existing) => Some(Expr::BinaryOp {
185            left: Box::new(existing),
186            op: BinaryOperator::And,
187            right: Box::new(pred.clone()),
188        }),
189        None => Some(pred.clone()),
190    };
191
192    true
193}
194
195// ═══════════════════════════════════════════════════════════════════════
196// Recursion into nested structures
197// ═══════════════════════════════════════════════════════════════════════
198
199/// Recurse into a table source to pushdown predicates in nested derived tables.
200fn recurse_into_source(source: &mut TableSource) {
201    match source {
202        TableSource::Subquery { query, .. } => {
203            *query = Box::new(pushdown_predicates(*query.clone()));
204        }
205        TableSource::Lateral { source } => {
206            recurse_into_source(source);
207        }
208        TableSource::Pivot { source, .. } | TableSource::Unpivot { source, .. } => {
209            recurse_into_source(source);
210        }
211        _ => {}
212    }
213}
214
215// ═══════════════════════════════════════════════════════════════════════
216// Predicate analysis helpers
217// ═══════════════════════════════════════════════════════════════════════
218
219/// Split an expression on AND into a flat list of conjuncts.
220fn split_conjunction(expr: Expr) -> Vec<Expr> {
221    match expr {
222        Expr::BinaryOp {
223            left,
224            op: BinaryOperator::And,
225            right,
226        } => {
227            let mut result = split_conjunction(*left);
228            result.extend(split_conjunction(*right));
229            result
230        }
231        Expr::Nested(inner) => {
232            // Only unwrap if the nested expression is itself an AND
233            if matches!(
234                inner.as_ref(),
235                Expr::BinaryOp {
236                    op: BinaryOperator::And,
237                    ..
238                }
239            ) {
240                split_conjunction(*inner)
241            } else {
242                vec![Expr::Nested(inner)]
243            }
244        }
245        other => vec![other],
246    }
247}
248
249/// Rejoin a list of predicates with AND. Returns None if empty.
250fn conjoin(predicates: Vec<Expr>) -> Option<Expr> {
251    predicates.into_iter().reduce(|a, b| Expr::BinaryOp {
252        left: Box::new(a),
253        op: BinaryOperator::And,
254        right: Box::new(b),
255    })
256}
257
258/// Collect all table qualifiers referenced by column expressions.
259fn referenced_tables(expr: &Expr) -> HashSet<String> {
260    let mut tables = HashSet::new();
261    expr.walk(&mut |e| {
262        if let Expr::Column { table: Some(t), .. } = e {
263            tables.insert(t.clone());
264        }
265        true
266    });
267    tables
268}
269
270/// Check whether a predicate is safe to push down.
271///
272/// Returns false for predicates containing:
273/// - Aggregate functions
274/// - Window functions
275/// - Non-deterministic functions
276/// - Subqueries
277fn is_pushable(expr: &Expr) -> bool {
278    let mut safe = true;
279    expr.walk(&mut |e| {
280        if !safe {
281            return false;
282        }
283        match e {
284            // Subqueries should not be pushed
285            Expr::Subquery(_) | Expr::Exists { .. } | Expr::InSubquery { .. } => {
286                safe = false;
287                false
288            }
289            // Aggregate functions
290            Expr::Function { name, .. } if is_aggregate_function(name) => {
291                safe = false;
292                false
293            }
294            // Window functions (have OVER clause)
295            Expr::Function { over: Some(_), .. } | Expr::TypedFunction { over: Some(_), .. } => {
296                safe = false;
297                false
298            }
299            // Non-deterministic functions
300            Expr::Function { name, .. } if is_nondeterministic(name) => {
301                safe = false;
302                false
303            }
304            Expr::TypedFunction {
305                func: TypedFunction::CurrentTimestamp,
306                ..
307            } => {
308                safe = false;
309                false
310            }
311            _ => true,
312        }
313    });
314    safe
315}
316
317/// Check whether the target SELECT is safe for predicate pushdown.
318///
319/// We cannot push through LIMIT, OFFSET, DISTINCT, or window functions.
320fn is_pushdown_safe_target(sel: &SelectStatement) -> bool {
321    if sel.limit.is_some() || sel.offset.is_some() || sel.fetch_first.is_some() || sel.distinct {
322        return false;
323    }
324    // Check for window functions in SELECT list
325    for item in &sel.columns {
326        if let SelectItem::Expr { expr, .. } = item {
327            if contains_window_function(expr) {
328                return false;
329            }
330        }
331    }
332    true
333}
334
335/// Check whether an expression contains a window function.
336fn contains_window_function(expr: &Expr) -> bool {
337    let mut has_window = false;
338    expr.walk(&mut |e| {
339        if has_window {
340            return false;
341        }
342        match e {
343            Expr::Function { over: Some(_), .. } | Expr::TypedFunction { over: Some(_), .. } => {
344                has_window = true;
345                false
346            }
347            _ => true,
348        }
349    });
350    has_window
351}
352
353fn is_aggregate_function(name: &str) -> bool {
354    matches!(
355        name.to_uppercase().as_str(),
356        "COUNT"
357            | "SUM"
358            | "AVG"
359            | "MIN"
360            | "MAX"
361            | "GROUP_CONCAT"
362            | "STRING_AGG"
363            | "ARRAY_AGG"
364            | "LISTAGG"
365            | "STDDEV"
366            | "STDDEV_POP"
367            | "STDDEV_SAMP"
368            | "VARIANCE"
369            | "VAR_POP"
370            | "VAR_SAMP"
371            | "EVERY"
372            | "ANY_VALUE"
373            | "SOME"
374            | "BIT_AND"
375            | "BIT_OR"
376            | "BIT_XOR"
377            | "BOOL_AND"
378            | "BOOL_OR"
379            | "CORR"
380            | "COVAR_POP"
381            | "COVAR_SAMP"
382            | "REGR_SLOPE"
383            | "REGR_INTERCEPT"
384            | "PERCENTILE_CONT"
385            | "PERCENTILE_DISC"
386            | "APPROX_COUNT_DISTINCT"
387            | "HLL_COUNT"
388            | "APPROX_DISTINCT"
389    )
390}
391
392fn is_nondeterministic(name: &str) -> bool {
393    matches!(
394        name.to_uppercase().as_str(),
395        "RAND" | "RANDOM" | "UUID" | "NEWID" | "GEN_RANDOM_UUID" | "SYSDATE" | "SYSTIMESTAMP"
396    )
397}
398
399// ═══════════════════════════════════════════════════════════════════════
400// Column remapping for derived-table pushdown
401// ═══════════════════════════════════════════════════════════════════════
402
403/// Rewrite a predicate so its column references match the inner derived
404/// table's namespace. For example, given:
405///
406/// ```sql
407/// SELECT * FROM (SELECT id, name FROM t) AS sub WHERE sub.name = 'foo'
408/// ```
409///
410/// The predicate `sub.name = 'foo'` becomes `name = 'foo'` (or `t.name = 'foo'`
411/// if the inner query qualifies columns).
412///
413/// Returns `None` if the rewrite cannot be performed (e.g., the column
414/// isn't exposed by the derived table's SELECT list).
415fn rewrite_predicate_for_derived_table(
416    pred: &Expr,
417    outer_alias: &str,
418    inner_sel: &SelectStatement,
419) -> Option<Expr> {
420    // Build a mapping: outer column name → inner expression
421    let column_map = build_column_map(inner_sel);
422
423    // Check that all columns referenced in the predicate can be mapped
424    let mut can_rewrite = true;
425    pred.walk(&mut |e| {
426        if !can_rewrite {
427            return false;
428        }
429        if let Expr::Column {
430            table: Some(t),
431            name,
432            ..
433        } = e
434        {
435            if t == outer_alias && !column_map.contains_key(name.as_str()) {
436                can_rewrite = false;
437            }
438        }
439        can_rewrite
440    });
441
442    if !can_rewrite {
443        return None;
444    }
445
446    // If inner SELECT has GROUP BY, only allow pushing predicates that
447    // reference grouped columns (pre-aggregation filters).
448    if !inner_sel.group_by.is_empty() {
449        let grouped_names: HashSet<String> = inner_sel
450            .group_by
451            .iter()
452            .filter_map(|e| match e {
453                Expr::Column { name, .. } => Some(name.clone()),
454                _ => None,
455            })
456            .collect();
457
458        let mut all_grouped = true;
459        pred.walk(&mut |e| {
460            if !all_grouped {
461                return false;
462            }
463            if let Expr::Column {
464                table: Some(t),
465                name,
466                ..
467            } = e
468            {
469                if t == outer_alias {
470                    if let Some(inner_expr) = column_map.get(name.as_str()) {
471                        let inner_name = match inner_expr {
472                            Expr::Column { name: n, .. } => n.clone(),
473                            _ => name.clone(),
474                        };
475                        if !grouped_names.contains(&inner_name) {
476                            all_grouped = false;
477                        }
478                    }
479                }
480            }
481            all_grouped
482        });
483
484        if !all_grouped {
485            return None;
486        }
487    }
488
489    // Perform the rewrite
490    let rewritten = pred.clone().transform(&|e| match e {
491        Expr::Column {
492            table: Some(ref t),
493            ref name,
494            ..
495        } if t == outer_alias => {
496            if let Some(inner_expr) = column_map.get(name.as_str()) {
497                inner_expr.clone()
498            } else {
499                e
500            }
501        }
502        other => other,
503    });
504
505    Some(rewritten)
506}
507
508/// Build a mapping from output column name → inner expression for a SELECT.
509///
510/// For `SELECT id, name AS n, x + 1 AS calc FROM t`:
511/// - "id" → Column { name: "id", ... }
512/// - "n" → Column { name: "name", ... }
513/// - "calc" → BinaryOp(x + 1)
514fn build_column_map(sel: &SelectStatement) -> std::collections::HashMap<&str, Expr> {
515    let mut map = std::collections::HashMap::new();
516
517    for item in &sel.columns {
518        match item {
519            SelectItem::Expr {
520                expr:
521                    Expr::Column {
522                        name,
523                        table,
524                        quote_style,
525                        table_quote_style,
526                    },
527                alias,
528            } => {
529                let output_name = alias.as_deref().unwrap_or(name.as_str());
530                map.insert(
531                    output_name,
532                    Expr::Column {
533                        table: table.clone(),
534                        name: name.clone(),
535                        quote_style: *quote_style,
536                        table_quote_style: *table_quote_style,
537                    },
538                );
539            }
540            SelectItem::Expr { expr, alias } => {
541                if let Some(alias) = alias {
542                    map.insert(alias.as_str(), expr.clone());
543                }
544            }
545            SelectItem::Wildcard | SelectItem::QualifiedWildcard { .. } => {
546                // With *, we can't easily build a column map without schema
547                // info. Bail  — the predicate columns just need to not
548                // have a table qualifier to match.
549            }
550        }
551    }
552
553    map
554}
555
556/// Get the alias or table name for a table source.
557fn source_alias(source: &TableSource) -> Option<String> {
558    match source {
559        TableSource::Table(t) => Some(t.alias.clone().unwrap_or_else(|| t.name.clone())),
560        TableSource::Subquery { alias, .. } => alias.clone(),
561        TableSource::TableFunction { alias, .. } => alias.clone(),
562        TableSource::Unnest { alias, .. } => alias.clone(),
563        TableSource::Lateral { source } => source_alias(source),
564        TableSource::Pivot { alias, .. } | TableSource::Unpivot { alias, .. } => alias.clone(),
565    }
566}
567
568// ═══════════════════════════════════════════════════════════════════════
569// Tests
570// ═══════════════════════════════════════════════════════════════════════
571
572#[cfg(test)]
573mod tests {
574    use super::*;
575
576    #[test]
577    fn test_split_conjunction_single() {
578        let expr = Expr::Boolean(true);
579        let parts = split_conjunction(expr);
580        assert_eq!(parts.len(), 1);
581    }
582
583    #[test]
584    fn test_split_conjunction_and() {
585        let expr = Expr::BinaryOp {
586            left: Box::new(Expr::Boolean(true)),
587            op: BinaryOperator::And,
588            right: Box::new(Expr::Boolean(false)),
589        };
590        let parts = split_conjunction(expr);
591        assert_eq!(parts.len(), 2);
592    }
593
594    #[test]
595    fn test_split_conjunction_nested_and() {
596        // (a AND b) AND c
597        let expr = Expr::BinaryOp {
598            left: Box::new(Expr::BinaryOp {
599                left: Box::new(Expr::Column {
600                    table: None,
601                    name: "a".into(),
602                    quote_style: QuoteStyle::None,
603                    table_quote_style: QuoteStyle::None,
604                }),
605                op: BinaryOperator::And,
606                right: Box::new(Expr::Column {
607                    table: None,
608                    name: "b".into(),
609                    quote_style: QuoteStyle::None,
610                    table_quote_style: QuoteStyle::None,
611                }),
612            }),
613            op: BinaryOperator::And,
614            right: Box::new(Expr::Column {
615                table: None,
616                name: "c".into(),
617                quote_style: QuoteStyle::None,
618                table_quote_style: QuoteStyle::None,
619            }),
620        };
621        let parts = split_conjunction(expr);
622        assert_eq!(parts.len(), 3);
623    }
624
625    #[test]
626    fn test_conjoin_empty() {
627        assert!(conjoin(vec![]).is_none());
628    }
629
630    #[test]
631    fn test_conjoin_single() {
632        let r = conjoin(vec![Expr::Boolean(true)]);
633        assert_eq!(r, Some(Expr::Boolean(true)));
634    }
635
636    #[test]
637    fn test_is_pushable_simple_comparison() {
638        let expr = Expr::BinaryOp {
639            left: Box::new(Expr::Column {
640                table: Some("t".into()),
641                name: "x".into(),
642                quote_style: QuoteStyle::None,
643                table_quote_style: QuoteStyle::None,
644            }),
645            op: BinaryOperator::Gt,
646            right: Box::new(Expr::Number("5".into())),
647        };
648        assert!(is_pushable(&expr));
649    }
650
651    #[test]
652    fn test_is_pushable_rejects_aggregate() {
653        let expr = Expr::Function {
654            name: "COUNT".into(),
655            args: vec![Expr::Star],
656            distinct: false,
657            filter: None,
658            over: None,
659        };
660        assert!(!is_pushable(&expr));
661    }
662
663    #[test]
664    fn test_is_pushable_rejects_window() {
665        let expr = Expr::Function {
666            name: "ROW_NUMBER".into(),
667            args: vec![],
668            distinct: false,
669            filter: None,
670            over: Some(WindowSpec {
671                window_ref: None,
672                partition_by: vec![],
673                order_by: vec![],
674                frame: None,
675            }),
676        };
677        assert!(!is_pushable(&expr));
678    }
679
680    #[test]
681    fn test_is_pushable_rejects_subquery() {
682        let expr = Expr::Exists {
683            subquery: Box::new(Statement::Select(SelectStatement {
684                comments: vec![],
685                ctes: vec![],
686                distinct: false,
687                top: None,
688                columns: vec![],
689                from: None,
690                joins: vec![],
691                where_clause: None,
692                group_by: vec![],
693                having: None,
694                order_by: vec![],
695                limit: None,
696                offset: None,
697                fetch_first: None,
698                qualify: None,
699                window_definitions: vec![],
700            })),
701            negated: false,
702        };
703        assert!(!is_pushable(&expr));
704    }
705
706    #[test]
707    fn test_referenced_tables() {
708        let expr = Expr::BinaryOp {
709            left: Box::new(Expr::Column {
710                table: Some("a".into()),
711                name: "x".into(),
712                quote_style: QuoteStyle::None,
713                table_quote_style: QuoteStyle::None,
714            }),
715            op: BinaryOperator::Eq,
716            right: Box::new(Expr::Column {
717                table: Some("b".into()),
718                name: "y".into(),
719                quote_style: QuoteStyle::None,
720                table_quote_style: QuoteStyle::None,
721            }),
722        };
723        let tables = referenced_tables(&expr);
724        assert_eq!(tables.len(), 2);
725        assert!(tables.contains("a"));
726        assert!(tables.contains("b"));
727    }
728}