Skip to main content

sqrust_rules/ambiguous/
or_in_join_condition.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{BinaryOperator, Expr, Join, JoinConstraint, JoinOperator, Query, SetExpr,
3    Statement, TableFactor};
4
5pub struct OrInJoinCondition;
6
7impl Rule for OrInJoinCondition {
8    fn name(&self) -> &'static str {
9        "Ambiguous/OrInJoinCondition"
10    }
11
12    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
13        if !ctx.parse_errors.is_empty() {
14            return Vec::new();
15        }
16
17        let mut diags = Vec::new();
18        for stmt in &ctx.statements {
19            if let Statement::Query(query) = stmt {
20                check_query(query, &ctx.source, &mut diags);
21            }
22        }
23        diags
24    }
25}
26
27fn check_query(query: &Query, source: &str, diags: &mut Vec<Diagnostic>) {
28    // Recurse into CTEs.
29    if let Some(with) = &query.with {
30        for cte in &with.cte_tables {
31            check_query(&cte.query, source, diags);
32        }
33    }
34    check_set_expr(&query.body, source, diags);
35}
36
37fn check_set_expr(expr: &SetExpr, source: &str, diags: &mut Vec<Diagnostic>) {
38    match expr {
39        SetExpr::Select(select) => {
40            for twj in &select.from {
41                // Check each join in this FROM item.
42                for join in &twj.joins {
43                    check_join(join, source, diags);
44                }
45                // Recurse into subqueries inside table factors.
46                recurse_table_factor(&twj.relation, source, diags);
47                for join in &twj.joins {
48                    recurse_table_factor(&join.relation, source, diags);
49                }
50            }
51        }
52        SetExpr::SetOperation { left, right, .. } => {
53            check_set_expr(left, source, diags);
54            check_set_expr(right, source, diags);
55        }
56        SetExpr::Query(inner) => {
57            check_query(inner, source, diags);
58        }
59        _ => {}
60    }
61}
62
63/// Extracts the `ON` expression from a join operator, if it has one.
64fn on_expr(join: &Join) -> Option<&Expr> {
65    match &join.join_operator {
66        JoinOperator::Inner(JoinConstraint::On(e))
67        | JoinOperator::LeftOuter(JoinConstraint::On(e))
68        | JoinOperator::RightOuter(JoinConstraint::On(e))
69        | JoinOperator::FullOuter(JoinConstraint::On(e)) => Some(e),
70        _ => None,
71    }
72}
73
74/// Checks a single join for an OR condition in its ON clause.
75fn check_join(join: &Join, source: &str, diags: &mut Vec<Diagnostic>) {
76    if let Some(expr) = on_expr(join) {
77        if has_or(expr) {
78            let (line, col) = find_or_position(source);
79            diags.push(Diagnostic {
80                rule: "Ambiguous/OrInJoinCondition",
81                message: "OR condition in JOIN ON clause; this may produce unintended cross-join-like results"
82                    .to_string(),
83                line,
84                col,
85            });
86        }
87    }
88}
89
90/// Returns `true` if `expr` contains an OR operator at any nesting level.
91fn has_or(expr: &Expr) -> bool {
92    match expr {
93        Expr::BinaryOp {
94            op: BinaryOperator::Or,
95            ..
96        } => true,
97        Expr::BinaryOp { left, right, .. } => has_or(left) || has_or(right),
98        Expr::Nested(e) => has_or(e),
99        Expr::UnaryOp { expr: e, .. } => has_or(e),
100        _ => false,
101    }
102}
103
104/// Recurses into a `TableFactor::Derived` (subquery) to check joins inside it.
105fn recurse_table_factor(tf: &TableFactor, source: &str, diags: &mut Vec<Diagnostic>) {
106    if let TableFactor::Derived { subquery, .. } = tf {
107        check_query(subquery, source, diags);
108    }
109}
110
111/// Finds the first word-boundary occurrence of `OR` (case-insensitive) in
112/// `source` and returns a 1-indexed (line, col). Falls back to (1, 1).
113fn find_or_position(source: &str) -> (usize, usize) {
114    let bytes = source.as_bytes();
115    let len = bytes.len();
116    let kw = b"OR";
117    let kw_len = kw.len();
118
119    let mut i = 0;
120    while i + kw_len <= len {
121        let matches = bytes[i].eq_ignore_ascii_case(&kw[0])
122            && bytes[i + 1].eq_ignore_ascii_case(&kw[1]);
123        if matches {
124            let before_ok =
125                i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
126            let after = i + kw_len;
127            let after_ok = after >= len
128                || (!bytes[after].is_ascii_alphanumeric() && bytes[after] != b'_');
129            if before_ok && after_ok {
130                return offset_to_line_col(source, i);
131            }
132        }
133        i += 1;
134    }
135
136    (1, 1)
137}
138
139/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
140fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
141    let before = &source[..offset];
142    let line = before.chars().filter(|&c| c == '\n').count() + 1;
143    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
144    (line, col)
145}