Skip to main content

sqrust_rules/convention/
join_condition_style.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{BinaryOperator, Expr, Query, Select, SetExpr, Statement, TableFactor};
3use crate::capitalisation::{is_word_char, SkipMap};
4
5pub struct JoinConditionStyle;
6
7impl Rule for JoinConditionStyle {
8    fn name(&self) -> &'static str {
9        "Convention/JoinConditionStyle"
10    }
11
12    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
13        if !ctx.parse_errors.is_empty() {
14            return Vec::new();
15        }
16        let mut diags = Vec::new();
17        let mut count = 0usize;
18        for stmt in &ctx.statements {
19            if let Statement::Query(q) = stmt {
20                check_query(q, ctx, &mut count, &mut diags);
21            }
22        }
23        diags
24    }
25}
26
27fn check_query(q: &Query, ctx: &FileContext, count: &mut usize, diags: &mut Vec<Diagnostic>) {
28    if let Some(with) = &q.with {
29        for cte in &with.cte_tables {
30            check_query(&cte.query, ctx, count, diags);
31        }
32    }
33    check_set_expr(&q.body, ctx, count, diags);
34}
35
36fn check_set_expr(expr: &SetExpr, ctx: &FileContext, count: &mut usize, diags: &mut Vec<Diagnostic>) {
37    match expr {
38        SetExpr::Select(sel) => check_select(sel, ctx, count, diags),
39        SetExpr::Query(q) => check_query(q, ctx, count, diags),
40        SetExpr::SetOperation { left, right, .. } => {
41            check_set_expr(left, ctx, count, diags);
42            check_set_expr(right, ctx, count, diags);
43        }
44        _ => {}
45    }
46}
47
48fn check_select(sel: &Select, ctx: &FileContext, count: &mut usize, diags: &mut Vec<Diagnostic>) {
49    for twj in &sel.from {
50        recurse_factor(&twj.relation, ctx, count, diags);
51        for join in &twj.joins {
52            recurse_factor(&join.relation, ctx, count, diags);
53        }
54    }
55    if let Some(where_expr) = &sel.selection {
56        collect_cross_table_eq(where_expr, ctx, count, diags);
57    }
58}
59
60fn recurse_factor(tf: &TableFactor, ctx: &FileContext, count: &mut usize, diags: &mut Vec<Diagnostic>) {
61    if let TableFactor::Derived { subquery, .. } = tf {
62        check_query(subquery, ctx, count, diags);
63    }
64}
65
66fn collect_cross_table_eq(expr: &Expr, ctx: &FileContext, count: &mut usize, diags: &mut Vec<Diagnostic>) {
67    match expr {
68        Expr::BinaryOp { left, op, right } => {
69            if matches!(op, BinaryOperator::Eq) {
70                if let (Expr::CompoundIdentifier(l_parts), Expr::CompoundIdentifier(r_parts)) =
71                    (left.as_ref(), right.as_ref())
72                {
73                    if l_parts.len() >= 2 && r_parts.len() >= 2 {
74                        let l_table = l_parts[0].value.to_lowercase();
75                        let r_table = r_parts[0].value.to_lowercase();
76                        if l_table != r_table {
77                            let occ = *count;
78                            *count += 1;
79                            if let Some(offset) = find_nth_word(&ctx.source, &l_parts[0].value, occ) {
80                                let (line, col) = offset_to_line_col(&ctx.source, offset);
81                                diags.push(Diagnostic {
82                                    rule: "Convention/JoinConditionStyle",
83                                    message: "Join condition found in WHERE clause; move it to the ON clause".to_string(),
84                                    line,
85                                    col,
86                                });
87                            }
88                            return;
89                        }
90                    }
91                }
92            }
93            collect_cross_table_eq(left, ctx, count, diags);
94            collect_cross_table_eq(right, ctx, count, diags);
95        }
96        Expr::Nested(inner) => collect_cross_table_eq(inner, ctx, count, diags),
97        _ => {}
98    }
99}
100
101fn find_nth_word(source: &str, word: &str, nth: usize) -> Option<usize> {
102    let bytes = source.as_bytes();
103    let word_upper: Vec<u8> = word.bytes().map(|b| b.to_ascii_uppercase()).collect();
104    let wlen = word_upper.len();
105    let len = bytes.len();
106    let skip = SkipMap::build(source);
107    let mut count = 0;
108    let mut i = 0;
109    while i + wlen <= len {
110        if !skip.is_code(i) {
111            i += 1;
112            continue;
113        }
114        let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
115        if !before_ok {
116            i += 1;
117            continue;
118        }
119        let matches = bytes[i..i + wlen]
120            .iter()
121            .zip(word_upper.iter())
122            .all(|(&a, &b)| a.to_ascii_uppercase() == b);
123        if matches {
124            let end = i + wlen;
125            let after_ok = end >= len || !is_word_char(bytes[end]);
126            if after_ok && (i..end).all(|k| skip.is_code(k)) {
127                if count == nth {
128                    return Some(i);
129                }
130                count += 1;
131                i += wlen;
132                continue;
133            }
134        }
135        i += 1;
136    }
137    None
138}
139
140fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
141    let before = &source[..offset.min(source.len())];
142    let line = before.chars().filter(|&c| c == '\n').count() + 1;
143    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
144    (line, col)
145}