Skip to main content

sqrust_rules/ambiguous/
full_outer_join.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{JoinOperator, Query, SetExpr, Statement, TableFactor};
3
4pub struct FullOuterJoin;
5
6impl Rule for FullOuterJoin {
7    fn name(&self) -> &'static str {
8        "Ambiguous/FullOuterJoin"
9    }
10
11    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12        if !ctx.parse_errors.is_empty() {
13            return Vec::new();
14        }
15
16        let mut diags = Vec::new();
17        let mut occurrence = 0usize;
18        for stmt in &ctx.statements {
19            if let Statement::Query(query) = stmt {
20                check_query(query, &ctx.source, &mut occurrence, &mut diags);
21            }
22        }
23        diags
24    }
25}
26
27fn check_query(
28    query: &Query,
29    source: &str,
30    occurrence: &mut usize,
31    diags: &mut Vec<Diagnostic>,
32) {
33    // Recurse into CTEs.
34    if let Some(with) = &query.with {
35        for cte in &with.cte_tables {
36            check_query(&cte.query, source, occurrence, diags);
37        }
38    }
39    check_set_expr(&query.body, source, occurrence, diags);
40}
41
42fn check_set_expr(
43    expr: &SetExpr,
44    source: &str,
45    occurrence: &mut usize,
46    diags: &mut Vec<Diagnostic>,
47) {
48    match expr {
49        SetExpr::Select(sel) => {
50            for twj in &sel.from {
51                recurse_table_factor(&twj.relation, source, occurrence, diags);
52                for join in &twj.joins {
53                    if matches!(join.join_operator, JoinOperator::FullOuter(_)) {
54                        let (line, col) =
55                            find_nth_keyword_position(source, "FULL", *occurrence);
56                        *occurrence += 1;
57                        diags.push(Diagnostic {
58                            rule: "Ambiguous/FullOuterJoin",
59                            message: "FULL OUTER JOIN may produce unintentionally large results; verify this is intentional".to_string(),
60                            line,
61                            col,
62                        });
63                    }
64                    recurse_table_factor(&join.relation, source, occurrence, diags);
65                }
66            }
67        }
68        SetExpr::SetOperation { left, right, .. } => {
69            check_set_expr(left, source, occurrence, diags);
70            check_set_expr(right, source, occurrence, diags);
71        }
72        SetExpr::Query(inner) => {
73            check_query(inner, source, occurrence, diags);
74        }
75        _ => {}
76    }
77}
78
79fn recurse_table_factor(
80    tf: &TableFactor,
81    source: &str,
82    occurrence: &mut usize,
83    diags: &mut Vec<Diagnostic>,
84) {
85    if let TableFactor::Derived { subquery, .. } = tf {
86        check_query(subquery, source, occurrence, diags);
87    }
88}
89
90/// Finds the `n`-th (0-indexed) occurrence of `keyword` (case-insensitive,
91/// word-boundary-checked) in `source` and returns a 1-indexed (line, col).
92/// Falls back to (1, 1) if not found.
93fn find_nth_keyword_position(source: &str, keyword: &str, n: usize) -> (usize, usize) {
94    let upper = source.to_uppercase();
95    let kw_upper = keyword.to_uppercase();
96    let bytes = upper.as_bytes();
97    let kw_bytes = kw_upper.as_bytes();
98    let kw_len = kw_bytes.len();
99
100    let mut found = 0usize;
101    let mut i = 0;
102    while i + kw_len <= bytes.len() {
103        if bytes[i..i + kw_len] == *kw_bytes {
104            let before_ok =
105                i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
106            let after = i + kw_len;
107            let after_ok = after >= bytes.len()
108                || (!bytes[after].is_ascii_alphanumeric() && bytes[after] != b'_');
109            if before_ok && after_ok {
110                if found == n {
111                    return offset_to_line_col(source, i);
112                }
113                found += 1;
114            }
115        }
116        i += 1;
117    }
118    (1, 1)
119}
120
121/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
122fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
123    let before = &source[..offset];
124    let line = before.chars().filter(|&c| c == '\n').count() + 1;
125    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
126    (line, col)
127}