1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{BinaryOperator, Expr, Query, Select, SetExpr, Statement, TableFactor};
3use crate::capitalisation::{is_word_char, SkipMap};
4
5pub struct JoinConditionStyle;
6
7impl Rule for JoinConditionStyle {
8 fn name(&self) -> &'static str {
9 "Convention/JoinConditionStyle"
10 }
11
12 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
13 if !ctx.parse_errors.is_empty() {
14 return Vec::new();
15 }
16 let mut diags = Vec::new();
17 let mut count = 0usize;
18 for stmt in &ctx.statements {
19 if let Statement::Query(q) = stmt {
20 check_query(q, ctx, &mut count, &mut diags);
21 }
22 }
23 diags
24 }
25}
26
27fn check_query(q: &Query, ctx: &FileContext, count: &mut usize, diags: &mut Vec<Diagnostic>) {
28 if let Some(with) = &q.with {
29 for cte in &with.cte_tables {
30 check_query(&cte.query, ctx, count, diags);
31 }
32 }
33 check_set_expr(&q.body, ctx, count, diags);
34}
35
36fn check_set_expr(expr: &SetExpr, ctx: &FileContext, count: &mut usize, diags: &mut Vec<Diagnostic>) {
37 match expr {
38 SetExpr::Select(sel) => check_select(sel, ctx, count, diags),
39 SetExpr::Query(q) => check_query(q, ctx, count, diags),
40 SetExpr::SetOperation { left, right, .. } => {
41 check_set_expr(left, ctx, count, diags);
42 check_set_expr(right, ctx, count, diags);
43 }
44 _ => {}
45 }
46}
47
48fn check_select(sel: &Select, ctx: &FileContext, count: &mut usize, diags: &mut Vec<Diagnostic>) {
49 for twj in &sel.from {
50 recurse_factor(&twj.relation, ctx, count, diags);
51 for join in &twj.joins {
52 recurse_factor(&join.relation, ctx, count, diags);
53 }
54 }
55 if let Some(where_expr) = &sel.selection {
56 collect_cross_table_eq(where_expr, ctx, count, diags);
57 }
58}
59
60fn recurse_factor(tf: &TableFactor, ctx: &FileContext, count: &mut usize, diags: &mut Vec<Diagnostic>) {
61 if let TableFactor::Derived { subquery, .. } = tf {
62 check_query(subquery, ctx, count, diags);
63 }
64}
65
66fn collect_cross_table_eq(expr: &Expr, ctx: &FileContext, count: &mut usize, diags: &mut Vec<Diagnostic>) {
67 match expr {
68 Expr::BinaryOp { left, op, right } => {
69 if matches!(op, BinaryOperator::Eq) {
70 if let (Expr::CompoundIdentifier(l_parts), Expr::CompoundIdentifier(r_parts)) =
71 (left.as_ref(), right.as_ref())
72 {
73 if l_parts.len() >= 2 && r_parts.len() >= 2 {
74 let l_table = l_parts[0].value.to_lowercase();
75 let r_table = r_parts[0].value.to_lowercase();
76 if l_table != r_table {
77 let occ = *count;
78 *count += 1;
79 if let Some(offset) = find_nth_word(&ctx.source, &l_parts[0].value, occ) {
80 let (line, col) = offset_to_line_col(&ctx.source, offset);
81 diags.push(Diagnostic {
82 rule: "Convention/JoinConditionStyle",
83 message: "Join condition found in WHERE clause; move it to the ON clause".to_string(),
84 line,
85 col,
86 });
87 }
88 return;
89 }
90 }
91 }
92 }
93 collect_cross_table_eq(left, ctx, count, diags);
94 collect_cross_table_eq(right, ctx, count, diags);
95 }
96 Expr::Nested(inner) => collect_cross_table_eq(inner, ctx, count, diags),
97 _ => {}
98 }
99}
100
101fn find_nth_word(source: &str, word: &str, nth: usize) -> Option<usize> {
102 let bytes = source.as_bytes();
103 let word_upper: Vec<u8> = word.bytes().map(|b| b.to_ascii_uppercase()).collect();
104 let wlen = word_upper.len();
105 let len = bytes.len();
106 let skip = SkipMap::build(source);
107 let mut count = 0;
108 let mut i = 0;
109 while i + wlen <= len {
110 if !skip.is_code(i) {
111 i += 1;
112 continue;
113 }
114 let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
115 if !before_ok {
116 i += 1;
117 continue;
118 }
119 let matches = bytes[i..i + wlen]
120 .iter()
121 .zip(word_upper.iter())
122 .all(|(&a, &b)| a.to_ascii_uppercase() == b);
123 if matches {
124 let end = i + wlen;
125 let after_ok = end >= len || !is_word_char(bytes[end]);
126 if after_ok && (i..end).all(|k| skip.is_code(k)) {
127 if count == nth {
128 return Some(i);
129 }
130 count += 1;
131 i += wlen;
132 continue;
133 }
134 }
135 i += 1;
136 }
137 None
138}
139
140fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
141 let before = &source[..offset.min(source.len())];
142 let line = before.chars().filter(|&c| c == '\n').count() + 1;
143 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
144 (line, col)
145}