Skip to main content

sqrust_rules/ambiguous/
implicit_cross_join.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Query, SetExpr, Statement, TableFactor};
3
4pub struct ImplicitCrossJoin;
5
6impl Rule for ImplicitCrossJoin {
7    fn name(&self) -> &'static str {
8        "Ambiguous/ImplicitCrossJoin"
9    }
10
11    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12        if !ctx.parse_errors.is_empty() {
13            return Vec::new();
14        }
15
16        let mut diags = Vec::new();
17        for stmt in &ctx.statements {
18            if let Statement::Query(query) = stmt {
19                check_query(query, &ctx.source, self.name(), &mut diags);
20            }
21        }
22        diags
23    }
24}
25
26fn check_query(query: &Query, source: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
27    // Recurse into CTEs.
28    if let Some(with) = &query.with {
29        for cte in &with.cte_tables {
30            check_query(&cte.query, source, rule, diags);
31        }
32    }
33    check_set_expr(&query.body, source, rule, diags);
34}
35
36fn check_set_expr(expr: &SetExpr, source: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
37    match expr {
38        SetExpr::Select(sel) => {
39            // When `from` has more than one element, the tables were listed
40            // comma-separated in the FROM clause — an implicit cross join.
41            if sel.from.len() > 1 {
42                let (line, col) = find_keyword_position(source, "FROM");
43                diags.push(Diagnostic {
44                    rule,
45                    message: "Implicit cross join from comma-separated tables; use explicit JOIN syntax".to_string(),
46                    line,
47                    col,
48                });
49            }
50
51            // Recurse into subqueries inside FROM / JOIN clauses.
52            for twj in &sel.from {
53                recurse_table_factor(&twj.relation, source, rule, diags);
54                for join in &twj.joins {
55                    recurse_table_factor(&join.relation, source, rule, diags);
56                }
57            }
58        }
59        SetExpr::SetOperation { left, right, .. } => {
60            check_set_expr(left, source, rule, diags);
61            check_set_expr(right, source, rule, diags);
62        }
63        SetExpr::Query(inner) => {
64            check_query(inner, source, rule, diags);
65        }
66        _ => {}
67    }
68}
69
70fn recurse_table_factor(
71    tf: &TableFactor,
72    source: &str,
73    rule: &'static str,
74    diags: &mut Vec<Diagnostic>,
75) {
76    if let TableFactor::Derived { subquery, .. } = tf {
77        check_query(subquery, source, rule, diags);
78    }
79}
80
81/// Finds the first occurrence of `keyword` (case-insensitive, word-boundary-checked)
82/// in `source` and returns a 1-indexed (line, col). Falls back to (1, 1) if not found.
83fn find_keyword_position(source: &str, keyword: &str) -> (usize, usize) {
84    let upper = source.to_uppercase();
85    let kw_upper = keyword.to_uppercase();
86    let bytes = upper.as_bytes();
87    let kw_bytes = kw_upper.as_bytes();
88    let kw_len = kw_bytes.len();
89
90    let mut i = 0;
91    while i + kw_len <= bytes.len() {
92        if bytes[i..i + kw_len] == *kw_bytes {
93            let before_ok =
94                i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
95            let after = i + kw_len;
96            let after_ok = after >= bytes.len()
97                || (!bytes[after].is_ascii_alphanumeric() && bytes[after] != b'_');
98            if before_ok && after_ok {
99                return offset_to_line_col(source, i);
100            }
101        }
102        i += 1;
103    }
104    (1, 1)
105}
106
107/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
108fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
109    let mut line = 1usize;
110    let mut col = 1usize;
111    for (i, ch) in source.char_indices() {
112        if i == offset {
113            break;
114        }
115        if ch == '\n' {
116            line += 1;
117            col = 1;
118        } else {
119            col += 1;
120        }
121    }
122    (line, col)
123}