Skip to main content

sqrust_rules/ambiguous/
cross_join_keyword.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{JoinOperator, Query, SetExpr, Statement, TableFactor};
3
4pub struct CrossJoinKeyword;
5
6impl Rule for CrossJoinKeyword {
7    fn name(&self) -> &'static str {
8        "Ambiguous/CrossJoinKeyword"
9    }
10
11    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12        if !ctx.parse_errors.is_empty() {
13            return Vec::new();
14        }
15
16        let mut diags = Vec::new();
17        let mut occurrence = 0usize;
18        for stmt in &ctx.statements {
19            if let Statement::Query(query) = stmt {
20                check_query(query, &ctx.source, &mut occurrence, &mut diags);
21            }
22        }
23        diags
24    }
25}
26
27fn check_query(
28    query: &Query,
29    source: &str,
30    occurrence: &mut usize,
31    diags: &mut Vec<Diagnostic>,
32) {
33    // Recurse into CTEs.
34    if let Some(with) = &query.with {
35        for cte in &with.cte_tables {
36            check_query(&cte.query, source, occurrence, diags);
37        }
38    }
39    check_set_expr(&query.body, source, occurrence, diags);
40}
41
42fn check_set_expr(
43    expr: &SetExpr,
44    source: &str,
45    occurrence: &mut usize,
46    diags: &mut Vec<Diagnostic>,
47) {
48    match expr {
49        SetExpr::Select(sel) => {
50            for twj in &sel.from {
51                recurse_table_factor(&twj.relation, source, occurrence, diags);
52                for join in &twj.joins {
53                    if matches!(join.join_operator, JoinOperator::CrossJoin) {
54                        let (line, col) =
55                            find_nth_keyword_position(source, "CROSS", *occurrence);
56                        *occurrence += 1;
57                        diags.push(Diagnostic {
58                            rule: "Ambiguous/CrossJoinKeyword",
59                            message:
60                                "CROSS JOIN produces a Cartesian product; verify this is intentional"
61                                    .to_string(),
62                            line,
63                            col,
64                        });
65                    }
66                    recurse_table_factor(&join.relation, source, occurrence, diags);
67                }
68            }
69        }
70        SetExpr::SetOperation { left, right, .. } => {
71            check_set_expr(left, source, occurrence, diags);
72            check_set_expr(right, source, occurrence, diags);
73        }
74        SetExpr::Query(inner) => {
75            check_query(inner, source, occurrence, diags);
76        }
77        _ => {}
78    }
79}
80
81fn recurse_table_factor(
82    tf: &TableFactor,
83    source: &str,
84    occurrence: &mut usize,
85    diags: &mut Vec<Diagnostic>,
86) {
87    if let TableFactor::Derived { subquery, .. } = tf {
88        check_query(subquery, source, occurrence, diags);
89    }
90}
91
92/// Finds the `n`-th (0-indexed) occurrence of `keyword` (case-insensitive,
93/// word-boundary-checked) in `source` and returns a 1-indexed (line, col).
94/// Falls back to (1, 1) if not found.
95fn find_nth_keyword_position(source: &str, keyword: &str, n: usize) -> (usize, usize) {
96    let upper = source.to_uppercase();
97    let kw_upper = keyword.to_uppercase();
98    let bytes = upper.as_bytes();
99    let kw_bytes = kw_upper.as_bytes();
100    let kw_len = kw_bytes.len();
101
102    let mut found = 0usize;
103    let mut i = 0;
104    while i + kw_len <= bytes.len() {
105        if bytes[i..i + kw_len] == *kw_bytes {
106            let before_ok =
107                i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
108            let after = i + kw_len;
109            let after_ok = after >= bytes.len()
110                || (!bytes[after].is_ascii_alphanumeric() && bytes[after] != b'_');
111            if before_ok && after_ok {
112                if found == n {
113                    return offset_to_line_col(source, i);
114                }
115                found += 1;
116            }
117        }
118        i += 1;
119    }
120    (1, 1)
121}
122
123/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
124fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
125    let before = &source[..offset];
126    let line = before.chars().filter(|&c| c == '\n').count() + 1;
127    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
128    (line, col)
129}