Skip to main content

sqlglot_rust/optimizer/
pushdown_predicates.rs

1/// Push WHERE predicates into subqueries, derived tables, and JOIN conditions.
2///
3/// This is a standard query optimization that reduces the data processed by
4/// inner queries. The pass splits AND-conjunctions and pushes each conjunct
5/// as far down the query tree as safety rules allow.
6///
7/// ## Supported rewrites
8///
9/// | Pattern | Rewrite |
10/// |---|---|
11/// | `SELECT … FROM (SELECT … FROM t) AS s WHERE s.x > 5` | Push `x > 5` into derived table WHERE |
12/// | `SELECT … FROM a JOIN b ON … WHERE a.x > 5` | Move `a.x > 5` into JOIN ON (inner joins) |
13/// | `WITH c AS (SELECT …) SELECT … FROM c WHERE c.x > 5` | Push into CTE body |
14///
15/// ## Safety
16///
17/// The pass does **not** push predicates:
18/// - Through LIMIT / OFFSET / FETCH FIRST
19/// - Through DISTINCT
20/// - Through GROUP BY (unless predicate only references grouped columns)
21/// - Through window functions in the SELECT list
22/// - When a predicate contains non-deterministic functions (RAND, RANDOM, etc.)
23/// - When a predicate contains aggregate functions
24/// - When a predicate contains subqueries
25use std::collections::HashSet;
26
27use crate::ast::*;
28
29// ═══════════════════════════════════════════════════════════════════════
30// Public API
31// ═══════════════════════════════════════════════════════════════════════
32
33/// Apply predicate pushdown to a statement.
34///
35/// Returns the statement unchanged if no predicates can be pushed down.
36pub fn pushdown_predicates(statement: Statement) -> Statement {
37    match statement {
38        Statement::Select(sel) => Statement::Select(pushdown_select(sel)),
39        other => other,
40    }
41}
42
43// ═══════════════════════════════════════════════════════════════════════
44// Core logic
45// ═══════════════════════════════════════════════════════════════════════
46
47fn pushdown_select(mut sel: SelectStatement) -> SelectStatement {
48    // First, recursively pushdown in any nested derived tables and CTEs,
49    // regardless of whether *this* level has a WHERE clause.
50    if let Some(from) = &mut sel.from {
51        recurse_into_source(&mut from.source);
52    }
53    for join in &mut sel.joins {
54        recurse_into_source(&mut join.table);
55    }
56    for cte in &mut sel.ctes {
57        *cte.query = pushdown_predicates(*cte.query.clone());
58    }
59
60    // Now try to push this level's WHERE predicates down.
61    let where_clause = match sel.where_clause.take() {
62        Some(w) => w,
63        None => return sel,
64    };
65
66    let predicates = split_conjunction(where_clause);
67    let mut remaining: Vec<Expr> = Vec::new();
68
69    for pred in predicates {
70        if !is_pushable(&pred) {
71            remaining.push(pred);
72            continue;
73        }
74
75        let tables = referenced_tables(&pred);
76
77        // Try pushing into FROM derived table
78        let mut pushed = false;
79        if let Some(from) = &mut sel.from {
80            pushed = try_push_into_source(&mut from.source, &pred, &tables);
81        }
82
83        // Try pushing into JOIN ON conditions (inner joins only)
84        if !pushed {
85            for join in &mut sel.joins {
86                if try_push_into_join(join, &pred, &tables) {
87                    pushed = true;
88                    break;
89                }
90            }
91        }
92
93        if !pushed {
94            remaining.push(pred);
95        }
96    }
97
98    sel.where_clause = conjoin(remaining);
99    sel
100}
101
102// ═══════════════════════════════════════════════════════════════════════
103// Push into derived tables (subqueries in FROM)
104// ═══════════════════════════════════════════════════════════════════════
105
106/// Try to push a predicate into a table source. Returns true if pushed.
107fn try_push_into_source(
108    source: &mut TableSource,
109    pred: &Expr,
110    tables: &HashSet<String>,
111) -> bool {
112    match source {
113        TableSource::Subquery { query, alias } => {
114            let alias_name = match alias {
115                Some(a) => a.clone(),
116                None => return false,
117            };
118
119            // Predicate must reference only this derived table alias
120            if tables.is_empty() || !tables.iter().all(|t| t == &alias_name) {
121                return false;
122            }
123
124            // Check the inner query is a simple SELECT we can push into
125            let inner_sel = match query.as_mut() {
126                Statement::Select(sel) => sel,
127                _ => return false,
128            };
129
130            if !is_pushdown_safe_target(inner_sel) {
131                return false;
132            }
133
134            // Rewrite column references: strip the outer alias qualifier
135            // and map to inner column names using the SELECT list.
136            let rewritten = rewrite_predicate_for_derived_table(pred, &alias_name, inner_sel);
137            let rewritten = match rewritten {
138                Some(r) => r,
139                None => return false,
140            };
141
142            // Push into the inner WHERE
143            inner_sel.where_clause = match inner_sel.where_clause.take() {
144                Some(existing) => Some(Expr::BinaryOp {
145                    left: Box::new(existing),
146                    op: BinaryOperator::And,
147                    right: Box::new(rewritten),
148                }),
149                None => Some(rewritten),
150            };
151
152            true
153        }
154        _ => false,
155    }
156}
157
158/// Try to push a predicate into a JOIN's ON condition.
159/// Only safe for INNER and CROSS joins.
160fn try_push_into_join(
161    join: &mut JoinClause,
162    pred: &Expr,
163    tables: &HashSet<String>,
164) -> bool {
165    // Only push into inner joins — pushing into LEFT/RIGHT/FULL
166    // changes semantics.
167    if !matches!(join.join_type, JoinType::Inner | JoinType::Cross) {
168        return false;
169    }
170
171    // Get the table name/alias of this join's source
172    let join_table = source_alias(&join.table);
173    let join_table = match join_table {
174        Some(t) => t,
175        None => return false,
176    };
177
178    // Predicate must reference only the join's table
179    if tables.is_empty() || tables.len() != 1 || !tables.contains(&join_table) {
180        return false;
181    }
182
183    // Also try pushing into a derived-table join source
184    if matches!(join.table, TableSource::Subquery { .. })
185        && try_push_into_source(&mut join.table, pred, tables)
186    {
187        return true;
188    }
189
190    // Push into the ON condition
191    join.on = match join.on.take() {
192        Some(existing) => Some(Expr::BinaryOp {
193            left: Box::new(existing),
194            op: BinaryOperator::And,
195            right: Box::new(pred.clone()),
196        }),
197        None => Some(pred.clone()),
198    };
199
200    true
201}
202
203// ═══════════════════════════════════════════════════════════════════════
204// Recursion into nested structures
205// ═══════════════════════════════════════════════════════════════════════
206
207/// Recurse into a table source to pushdown predicates in nested derived tables.
208fn recurse_into_source(source: &mut TableSource) {
209    match source {
210        TableSource::Subquery { query, .. } => {
211            *query = Box::new(pushdown_predicates(*query.clone()));
212        }
213        TableSource::Lateral { source } => {
214            recurse_into_source(source);
215        }
216        TableSource::Pivot { source, .. } | TableSource::Unpivot { source, .. } => {
217            recurse_into_source(source);
218        }
219        _ => {}
220    }
221}
222
223// ═══════════════════════════════════════════════════════════════════════
224// Predicate analysis helpers
225// ═══════════════════════════════════════════════════════════════════════
226
227/// Split an expression on AND into a flat list of conjuncts.
228fn split_conjunction(expr: Expr) -> Vec<Expr> {
229    match expr {
230        Expr::BinaryOp {
231            left,
232            op: BinaryOperator::And,
233            right,
234        } => {
235            let mut result = split_conjunction(*left);
236            result.extend(split_conjunction(*right));
237            result
238        }
239        Expr::Nested(inner) => {
240            // Only unwrap if the nested expression is itself an AND
241            if matches!(
242                inner.as_ref(),
243                Expr::BinaryOp {
244                    op: BinaryOperator::And,
245                    ..
246                }
247            ) {
248                split_conjunction(*inner)
249            } else {
250                vec![Expr::Nested(inner)]
251            }
252        }
253        other => vec![other],
254    }
255}
256
257/// Rejoin a list of predicates with AND. Returns None if empty.
258fn conjoin(predicates: Vec<Expr>) -> Option<Expr> {
259    predicates.into_iter().reduce(|a, b| Expr::BinaryOp {
260        left: Box::new(a),
261        op: BinaryOperator::And,
262        right: Box::new(b),
263    })
264}
265
266/// Collect all table qualifiers referenced by column expressions.
267fn referenced_tables(expr: &Expr) -> HashSet<String> {
268    let mut tables = HashSet::new();
269    expr.walk(&mut |e| {
270        if let Expr::Column {
271            table: Some(t), ..
272        } = e
273        {
274            tables.insert(t.clone());
275        }
276        true
277    });
278    tables
279}
280
281/// Check whether a predicate is safe to push down.
282///
283/// Returns false for predicates containing:
284/// - Aggregate functions
285/// - Window functions
286/// - Non-deterministic functions
287/// - Subqueries
288fn is_pushable(expr: &Expr) -> bool {
289    let mut safe = true;
290    expr.walk(&mut |e| {
291        if !safe {
292            return false;
293        }
294        match e {
295            // Subqueries should not be pushed
296            Expr::Subquery(_) | Expr::Exists { .. } | Expr::InSubquery { .. } => {
297                safe = false;
298                false
299            }
300            // Aggregate functions
301            Expr::Function { name, .. } if is_aggregate_function(name) => {
302                safe = false;
303                false
304            }
305            // Window functions (have OVER clause)
306            Expr::Function {
307                over: Some(_), ..
308            }
309            | Expr::TypedFunction {
310                over: Some(_), ..
311            } => {
312                safe = false;
313                false
314            }
315            // Non-deterministic functions
316            Expr::Function { name, .. } if is_nondeterministic(name) => {
317                safe = false;
318                false
319            }
320            Expr::TypedFunction {
321                func: TypedFunction::CurrentTimestamp,
322                ..
323            } => {
324                safe = false;
325                false
326            }
327            _ => true,
328        }
329    });
330    safe
331}
332
333/// Check whether the target SELECT is safe for predicate pushdown.
334///
335/// We cannot push through LIMIT, OFFSET, DISTINCT, or window functions.
336fn is_pushdown_safe_target(sel: &SelectStatement) -> bool {
337    if sel.limit.is_some() || sel.offset.is_some() || sel.fetch_first.is_some() || sel.distinct {
338        return false;
339    }
340    // Check for window functions in SELECT list
341    for item in &sel.columns {
342        if let SelectItem::Expr { expr, .. } = item {
343            if contains_window_function(expr) {
344                return false;
345            }
346        }
347    }
348    true
349}
350
351/// Check whether an expression contains a window function.
352fn contains_window_function(expr: &Expr) -> bool {
353    let mut has_window = false;
354    expr.walk(&mut |e| {
355        if has_window {
356            return false;
357        }
358        match e {
359            Expr::Function {
360                over: Some(_), ..
361            }
362            | Expr::TypedFunction {
363                over: Some(_), ..
364            } => {
365                has_window = true;
366                false
367            }
368            _ => true,
369        }
370    });
371    has_window
372}
373
374fn is_aggregate_function(name: &str) -> bool {
375    matches!(
376        name.to_uppercase().as_str(),
377        "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "GROUP_CONCAT"
378            | "STRING_AGG" | "ARRAY_AGG" | "LISTAGG"
379            | "STDDEV" | "STDDEV_POP" | "STDDEV_SAMP"
380            | "VARIANCE" | "VAR_POP" | "VAR_SAMP"
381            | "EVERY" | "ANY_VALUE" | "SOME"
382            | "BIT_AND" | "BIT_OR" | "BIT_XOR"
383            | "BOOL_AND" | "BOOL_OR"
384            | "CORR" | "COVAR_POP" | "COVAR_SAMP"
385            | "REGR_SLOPE" | "REGR_INTERCEPT"
386            | "PERCENTILE_CONT" | "PERCENTILE_DISC"
387            | "APPROX_COUNT_DISTINCT" | "HLL_COUNT" | "APPROX_DISTINCT"
388    )
389}
390
391fn is_nondeterministic(name: &str) -> bool {
392    matches!(
393        name.to_uppercase().as_str(),
394        "RAND" | "RANDOM" | "UUID" | "NEWID" | "GEN_RANDOM_UUID"
395            | "SYSDATE" | "SYSTIMESTAMP"
396    )
397}
398
399// ═══════════════════════════════════════════════════════════════════════
400// Column remapping for derived-table pushdown
401// ═══════════════════════════════════════════════════════════════════════
402
403/// Rewrite a predicate so its column references match the inner derived
404/// table's namespace. For example, given:
405///
406/// ```sql
407/// SELECT * FROM (SELECT id, name FROM t) AS sub WHERE sub.name = 'foo'
408/// ```
409///
410/// The predicate `sub.name = 'foo'` becomes `name = 'foo'` (or `t.name = 'foo'`
411/// if the inner query qualifies columns).
412///
413/// Returns `None` if the rewrite cannot be performed (e.g., the column
414/// isn't exposed by the derived table's SELECT list).
415fn rewrite_predicate_for_derived_table(
416    pred: &Expr,
417    outer_alias: &str,
418    inner_sel: &SelectStatement,
419) -> Option<Expr> {
420    // Build a mapping: outer column name → inner expression
421    let column_map = build_column_map(inner_sel);
422
423    // Check that all columns referenced in the predicate can be mapped
424    let mut can_rewrite = true;
425    pred.walk(&mut |e| {
426        if !can_rewrite {
427            return false;
428        }
429        if let Expr::Column {
430            table: Some(t),
431            name,
432            ..
433        } = e
434        {
435            if t == outer_alias && !column_map.contains_key(name.as_str()) {
436                can_rewrite = false;
437            }
438        }
439        can_rewrite
440    });
441
442    if !can_rewrite {
443        return None;
444    }
445
446    // If inner SELECT has GROUP BY, only allow pushing predicates that
447    // reference grouped columns (pre-aggregation filters).
448    if !inner_sel.group_by.is_empty() {
449        let grouped_names: HashSet<String> = inner_sel
450            .group_by
451            .iter()
452            .filter_map(|e| match e {
453                Expr::Column { name, .. } => Some(name.clone()),
454                _ => None,
455            })
456            .collect();
457
458        let mut all_grouped = true;
459        pred.walk(&mut |e| {
460            if !all_grouped {
461                return false;
462            }
463            if let Expr::Column {
464                table: Some(t),
465                name,
466                ..
467            } = e
468            {
469                if t == outer_alias {
470                    if let Some(inner_expr) = column_map.get(name.as_str()) {
471                        let inner_name = match inner_expr {
472                            Expr::Column { name: n, .. } => n.clone(),
473                            _ => name.clone(),
474                        };
475                        if !grouped_names.contains(&inner_name) {
476                            all_grouped = false;
477                        }
478                    }
479                }
480            }
481            all_grouped
482        });
483
484        if !all_grouped {
485            return None;
486        }
487    }
488
489    // Perform the rewrite
490    let rewritten = pred.clone().transform(&|e| match e {
491        Expr::Column {
492            table: Some(ref t),
493            ref name,
494            ..
495        } if t == outer_alias => {
496            if let Some(inner_expr) = column_map.get(name.as_str()) {
497                inner_expr.clone()
498            } else {
499                e
500            }
501        }
502        other => other,
503    });
504
505    Some(rewritten)
506}
507
508/// Build a mapping from output column name → inner expression for a SELECT.
509///
510/// For `SELECT id, name AS n, x + 1 AS calc FROM t`:
511/// - "id" → Column { name: "id", ... }
512/// - "n" → Column { name: "name", ... }
513/// - "calc" → BinaryOp(x + 1)
514fn build_column_map(sel: &SelectStatement) -> std::collections::HashMap<&str, Expr> {
515    let mut map = std::collections::HashMap::new();
516
517    for item in &sel.columns {
518        match item {
519            SelectItem::Expr {
520                expr: Expr::Column { name, table, quote_style, table_quote_style },
521                alias,
522            } => {
523                let output_name = alias.as_deref().unwrap_or(name.as_str());
524                map.insert(
525                    output_name,
526                    Expr::Column {
527                        table: table.clone(),
528                        name: name.clone(),
529                        quote_style: *quote_style,
530                        table_quote_style: *table_quote_style,
531                    },
532                );
533            }
534            SelectItem::Expr { expr, alias } => {
535                if let Some(alias) = alias {
536                    map.insert(alias.as_str(), expr.clone());
537                }
538            }
539            SelectItem::Wildcard | SelectItem::QualifiedWildcard { .. } => {
540                // With *, we can't easily build a column map without schema
541                // info. Bail  — the predicate columns just need to not
542                // have a table qualifier to match.
543            }
544        }
545    }
546
547    map
548}
549
550/// Get the alias or table name for a table source.
551fn source_alias(source: &TableSource) -> Option<String> {
552    match source {
553        TableSource::Table(t) => Some(t.alias.clone().unwrap_or_else(|| t.name.clone())),
554        TableSource::Subquery { alias, .. } => alias.clone(),
555        TableSource::TableFunction { alias, .. } => alias.clone(),
556        TableSource::Unnest { alias, .. } => alias.clone(),
557        TableSource::Lateral { source } => source_alias(source),
558        TableSource::Pivot { alias, .. } | TableSource::Unpivot { alias, .. } => alias.clone(),
559    }
560}
561
562// ═══════════════════════════════════════════════════════════════════════
563// Tests
564// ═══════════════════════════════════════════════════════════════════════
565
566#[cfg(test)]
567mod tests {
568    use super::*;
569
570    #[test]
571    fn test_split_conjunction_single() {
572        let expr = Expr::Boolean(true);
573        let parts = split_conjunction(expr);
574        assert_eq!(parts.len(), 1);
575    }
576
577    #[test]
578    fn test_split_conjunction_and() {
579        let expr = Expr::BinaryOp {
580            left: Box::new(Expr::Boolean(true)),
581            op: BinaryOperator::And,
582            right: Box::new(Expr::Boolean(false)),
583        };
584        let parts = split_conjunction(expr);
585        assert_eq!(parts.len(), 2);
586    }
587
588    #[test]
589    fn test_split_conjunction_nested_and() {
590        // (a AND b) AND c
591        let expr = Expr::BinaryOp {
592            left: Box::new(Expr::BinaryOp {
593                left: Box::new(Expr::Column {
594                    table: None,
595                    name: "a".into(),
596                    quote_style: QuoteStyle::None,
597                    table_quote_style: QuoteStyle::None,
598                }),
599                op: BinaryOperator::And,
600                right: Box::new(Expr::Column {
601                    table: None,
602                    name: "b".into(),
603                    quote_style: QuoteStyle::None,
604                    table_quote_style: QuoteStyle::None,
605                }),
606            }),
607            op: BinaryOperator::And,
608            right: Box::new(Expr::Column {
609                table: None,
610                name: "c".into(),
611                quote_style: QuoteStyle::None,
612                table_quote_style: QuoteStyle::None,
613            }),
614        };
615        let parts = split_conjunction(expr);
616        assert_eq!(parts.len(), 3);
617    }
618
619    #[test]
620    fn test_conjoin_empty() {
621        assert!(conjoin(vec![]).is_none());
622    }
623
624    #[test]
625    fn test_conjoin_single() {
626        let r = conjoin(vec![Expr::Boolean(true)]);
627        assert_eq!(r, Some(Expr::Boolean(true)));
628    }
629
630    #[test]
631    fn test_is_pushable_simple_comparison() {
632        let expr = Expr::BinaryOp {
633            left: Box::new(Expr::Column {
634                table: Some("t".into()),
635                name: "x".into(),
636                quote_style: QuoteStyle::None,
637                table_quote_style: QuoteStyle::None,
638            }),
639            op: BinaryOperator::Gt,
640            right: Box::new(Expr::Number("5".into())),
641        };
642        assert!(is_pushable(&expr));
643    }
644
645    #[test]
646    fn test_is_pushable_rejects_aggregate() {
647        let expr = Expr::Function {
648            name: "COUNT".into(),
649            args: vec![Expr::Star],
650            distinct: false,
651            filter: None,
652            over: None,
653        };
654        assert!(!is_pushable(&expr));
655    }
656
657    #[test]
658    fn test_is_pushable_rejects_window() {
659        let expr = Expr::Function {
660            name: "ROW_NUMBER".into(),
661            args: vec![],
662            distinct: false,
663            filter: None,
664            over: Some(WindowSpec {
665                window_ref: None,
666                partition_by: vec![],
667                order_by: vec![],
668                frame: None,
669            }),
670        };
671        assert!(!is_pushable(&expr));
672    }
673
674    #[test]
675    fn test_is_pushable_rejects_subquery() {
676        let expr = Expr::Exists {
677            subquery: Box::new(Statement::Select(SelectStatement {
678                ctes: vec![],
679                distinct: false,
680                top: None,
681                columns: vec![],
682                from: None,
683                joins: vec![],
684                where_clause: None,
685                group_by: vec![],
686                having: None,
687                order_by: vec![],
688                limit: None,
689                offset: None,
690                fetch_first: None,
691                qualify: None,
692                window_definitions: vec![],
693            })),
694            negated: false,
695        };
696        assert!(!is_pushable(&expr));
697    }
698
699    #[test]
700    fn test_referenced_tables() {
701        let expr = Expr::BinaryOp {
702            left: Box::new(Expr::Column {
703                table: Some("a".into()),
704                name: "x".into(),
705                quote_style: QuoteStyle::None,
706                table_quote_style: QuoteStyle::None,
707            }),
708            op: BinaryOperator::Eq,
709            right: Box::new(Expr::Column {
710                table: Some("b".into()),
711                name: "y".into(),
712                quote_style: QuoteStyle::None,
713                table_quote_style: QuoteStyle::None,
714            }),
715        };
716        let tables = referenced_tables(&expr);
717        assert_eq!(tables.len(), 2);
718        assert!(tables.contains("a"));
719        assert!(tables.contains("b"));
720    }
721}