Skip to main content

sqrust_rules/convention/
left_join.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{JoinOperator, Query, Select, SetExpr, Statement, TableFactor};
3use crate::capitalisation::{is_word_char, SkipMap};
4
5pub struct LeftJoin;
6
7impl Rule for LeftJoin {
8    fn name(&self) -> &'static str {
9        "Convention/LeftJoin"
10    }
11
12    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
13        if !ctx.parse_errors.is_empty() {
14            return Vec::new();
15        }
16        let mut diags = Vec::new();
17        let mut count = 0usize;
18        for stmt in &ctx.statements {
19            if let Statement::Query(q) = stmt {
20                check_query(q, &ctx.source, &mut count, &mut diags);
21            }
22        }
23        diags
24    }
25}
26
27fn check_query(q: &Query, src: &str, count: &mut usize, diags: &mut Vec<Diagnostic>) {
28    if let Some(with) = &q.with {
29        for cte in &with.cte_tables {
30            check_query(&cte.query, src, count, diags);
31        }
32    }
33    check_set_expr(&q.body, src, count, diags);
34}
35
36fn check_set_expr(expr: &SetExpr, src: &str, count: &mut usize, diags: &mut Vec<Diagnostic>) {
37    match expr {
38        SetExpr::Select(sel) => check_select(sel, src, count, diags),
39        SetExpr::Query(q) => check_query(q, src, count, diags),
40        SetExpr::SetOperation { left, right, .. } => {
41            check_set_expr(left, src, count, diags);
42            check_set_expr(right, src, count, diags);
43        }
44        _ => {}
45    }
46}
47
48fn check_select(sel: &Select, src: &str, count: &mut usize, diags: &mut Vec<Diagnostic>) {
49    for twj in &sel.from {
50        recurse_factor(&twj.relation, src, count, diags);
51        for join in &twj.joins {
52            recurse_factor(&join.relation, src, count, diags);
53            if is_right_join(&join.join_operator) {
54                let occ = *count;
55                *count += 1;
56                if let Some(offset) = find_nth_keyword(src, b"RIGHT", occ) {
57                    let (line, col) = offset_to_line_col(src, offset);
58                    diags.push(Diagnostic {
59                        rule: "Convention/LeftJoin",
60                        message: "Prefer LEFT JOIN over RIGHT JOIN; rewrite from the other table's perspective".to_string(),
61                        line,
62                        col,
63                    });
64                }
65            }
66        }
67    }
68}
69
70fn recurse_factor(tf: &TableFactor, src: &str, count: &mut usize, diags: &mut Vec<Diagnostic>) {
71    if let TableFactor::Derived { subquery, .. } = tf {
72        check_query(subquery, src, count, diags);
73    }
74}
75
76fn is_right_join(op: &JoinOperator) -> bool {
77    matches!(
78        op,
79        JoinOperator::RightOuter(_) | JoinOperator::RightSemi(_) | JoinOperator::RightAnti(_)
80    )
81}
82
83fn find_nth_keyword(source: &str, keyword: &[u8], nth: usize) -> Option<usize> {
84    let bytes = source.as_bytes();
85    let kw_len = keyword.len();
86    let len = bytes.len();
87    let skip = SkipMap::build(source);
88    let mut count = 0;
89    let mut i = 0;
90    while i + kw_len <= len {
91        if !skip.is_code(i) {
92            i += 1;
93            continue;
94        }
95        let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
96        if !before_ok {
97            i += 1;
98            continue;
99        }
100        let matches = bytes[i..i + kw_len]
101            .iter()
102            .zip(keyword.iter())
103            .all(|(&a, &b)| a.to_ascii_uppercase() == b.to_ascii_uppercase());
104        if matches {
105            let end = i + kw_len;
106            let after_ok = end >= len || !is_word_char(bytes[end]);
107            let all_code = (i..end).all(|k| skip.is_code(k));
108            if after_ok && all_code {
109                if count == nth {
110                    return Some(i);
111                }
112                count += 1;
113                i += kw_len;
114                continue;
115            }
116        }
117        i += 1;
118    }
119    None
120}
121
122fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
123    let before = &source[..offset.min(source.len())];
124    let line = before.chars().filter(|&c| c == '\n').count() + 1;
125    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
126    (line, col)
127}