Skip to main content

sqrust_rules/ambiguous/
join_without_condition.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Join, JoinConstraint, JoinOperator, Query, SetExpr, Statement, TableFactor};
3
4pub struct JoinWithoutCondition;
5
6impl Rule for JoinWithoutCondition {
7    fn name(&self) -> &'static str {
8        "Ambiguous/JoinWithoutCondition"
9    }
10
11    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12        if !ctx.parse_errors.is_empty() {
13            return Vec::new();
14        }
15
16        let mut diags = Vec::new();
17        for stmt in &ctx.statements {
18            if let Statement::Query(query) = stmt {
19                check_query(query, &ctx.source, self.name(), &mut diags);
20            }
21        }
22        diags
23    }
24}
25
26fn check_query(query: &Query, source: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
27    // Recurse into CTEs.
28    if let Some(with) = &query.with {
29        for cte in &with.cte_tables {
30            check_query(&cte.query, source, rule, diags);
31        }
32    }
33    check_set_expr(&query.body, source, rule, diags);
34}
35
36fn check_set_expr(expr: &SetExpr, source: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
37    match expr {
38        SetExpr::Select(sel) => {
39            for twj in &sel.from {
40                // Check all joins in this FROM item.
41                for join in &twj.joins {
42                    check_join(join, source, rule, diags);
43                }
44                // Recurse into subqueries in the relation and joins.
45                recurse_table_factor(&twj.relation, source, rule, diags);
46                for join in &twj.joins {
47                    recurse_table_factor(&join.relation, source, rule, diags);
48                }
49            }
50        }
51        SetExpr::SetOperation { left, right, .. } => {
52            check_set_expr(left, source, rule, diags);
53            check_set_expr(right, source, rule, diags);
54        }
55        SetExpr::Query(inner) => {
56            check_query(inner, source, rule, diags);
57        }
58        _ => {}
59    }
60}
61
62/// Returns true if the `JoinConstraint` represents a missing condition (i.e. `None` variant).
63/// `Natural` is excluded — it has an implicit condition.
64fn is_missing_condition(constraint: &JoinConstraint) -> bool {
65    matches!(constraint, JoinConstraint::None)
66}
67
68/// Checks a single `Join` node for a missing ON/USING condition.
69/// Flags `Inner`, `LeftOuter`, `RightOuter`, and `FullOuter` with `JoinConstraint::None`.
70/// Skips `CrossJoin`, `CrossApply`, `OuterApply`, and `Natural` joins.
71fn check_join(join: &Join, source: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
72    let has_violation = match &join.join_operator {
73        JoinOperator::Inner(c) => is_missing_condition(c),
74        JoinOperator::LeftOuter(c) => is_missing_condition(c),
75        JoinOperator::RightOuter(c) => is_missing_condition(c),
76        JoinOperator::FullOuter(c) => is_missing_condition(c),
77        // CrossJoin, CrossApply, OuterApply, Semi*, Anti*, AsOf — no condition expected.
78        _ => false,
79    };
80
81    if has_violation {
82        let (line, col) = find_keyword_position(source, "JOIN");
83        diags.push(Diagnostic {
84            rule,
85            message: "JOIN without ON or USING condition; this will produce a cross join"
86                .to_string(),
87            line,
88            col,
89        });
90    }
91}
92
93/// Recurses into a `TableFactor::Derived` (subquery) to check its joins.
94fn recurse_table_factor(
95    tf: &TableFactor,
96    source: &str,
97    rule: &'static str,
98    diags: &mut Vec<Diagnostic>,
99) {
100    if let TableFactor::Derived { subquery, .. } = tf {
101        check_query(subquery, source, rule, diags);
102    }
103}
104
105/// Finds the first occurrence of `keyword` (case-insensitive, word-boundary-checked)
106/// in `source` and returns a 1-indexed (line, col). Falls back to (1, 1) if not found.
107fn find_keyword_position(source: &str, keyword: &str) -> (usize, usize) {
108    let upper = source.to_uppercase();
109    let kw_upper = keyword.to_uppercase();
110    let bytes = upper.as_bytes();
111    let kw_bytes = kw_upper.as_bytes();
112    let kw_len = kw_bytes.len();
113
114    let mut i = 0;
115    while i + kw_len <= bytes.len() {
116        if bytes[i..i + kw_len] == *kw_bytes {
117            let before_ok =
118                i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
119            let after = i + kw_len;
120            let after_ok = after >= bytes.len()
121                || (!bytes[after].is_ascii_alphanumeric() && bytes[after] != b'_');
122            if before_ok && after_ok {
123                return offset_to_line_col(source, i);
124            }
125        }
126        i += 1;
127    }
128    (1, 1)
129}
130
131/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
132fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
133    let mut line = 1usize;
134    let mut col = 1usize;
135    for (i, ch) in source.char_indices() {
136        if i == offset {
137            break;
138        }
139        if ch == '\n' {
140            line += 1;
141            col = 1;
142        } else {
143            col += 1;
144        }
145    }
146    (line, col)
147}