sqrust_rules/convention/
left_join.rs1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{JoinOperator, Query, Select, SetExpr, Statement, TableFactor};
3use crate::capitalisation::{is_word_char, SkipMap};
4
5pub struct LeftJoin;
6
7impl Rule for LeftJoin {
8 fn name(&self) -> &'static str {
9 "Convention/LeftJoin"
10 }
11
12 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
13 if !ctx.parse_errors.is_empty() {
14 return Vec::new();
15 }
16 let mut diags = Vec::new();
17 let mut count = 0usize;
18 for stmt in &ctx.statements {
19 if let Statement::Query(q) = stmt {
20 check_query(q, &ctx.source, &mut count, &mut diags);
21 }
22 }
23 diags
24 }
25}
26
27fn check_query(q: &Query, src: &str, count: &mut usize, diags: &mut Vec<Diagnostic>) {
28 if let Some(with) = &q.with {
29 for cte in &with.cte_tables {
30 check_query(&cte.query, src, count, diags);
31 }
32 }
33 check_set_expr(&q.body, src, count, diags);
34}
35
36fn check_set_expr(expr: &SetExpr, src: &str, count: &mut usize, diags: &mut Vec<Diagnostic>) {
37 match expr {
38 SetExpr::Select(sel) => check_select(sel, src, count, diags),
39 SetExpr::Query(q) => check_query(q, src, count, diags),
40 SetExpr::SetOperation { left, right, .. } => {
41 check_set_expr(left, src, count, diags);
42 check_set_expr(right, src, count, diags);
43 }
44 _ => {}
45 }
46}
47
48fn check_select(sel: &Select, src: &str, count: &mut usize, diags: &mut Vec<Diagnostic>) {
49 for twj in &sel.from {
50 recurse_factor(&twj.relation, src, count, diags);
51 for join in &twj.joins {
52 recurse_factor(&join.relation, src, count, diags);
53 if is_right_join(&join.join_operator) {
54 let occ = *count;
55 *count += 1;
56 if let Some(offset) = find_nth_keyword(src, b"RIGHT", occ) {
57 let (line, col) = offset_to_line_col(src, offset);
58 diags.push(Diagnostic {
59 rule: "Convention/LeftJoin",
60 message: "Prefer LEFT JOIN over RIGHT JOIN; rewrite from the other table's perspective".to_string(),
61 line,
62 col,
63 });
64 }
65 }
66 }
67 }
68}
69
70fn recurse_factor(tf: &TableFactor, src: &str, count: &mut usize, diags: &mut Vec<Diagnostic>) {
71 if let TableFactor::Derived { subquery, .. } = tf {
72 check_query(subquery, src, count, diags);
73 }
74}
75
76fn is_right_join(op: &JoinOperator) -> bool {
77 matches!(
78 op,
79 JoinOperator::RightOuter(_) | JoinOperator::RightSemi(_) | JoinOperator::RightAnti(_)
80 )
81}
82
83fn find_nth_keyword(source: &str, keyword: &[u8], nth: usize) -> Option<usize> {
84 let bytes = source.as_bytes();
85 let kw_len = keyword.len();
86 let len = bytes.len();
87 let skip = SkipMap::build(source);
88 let mut count = 0;
89 let mut i = 0;
90 while i + kw_len <= len {
91 if !skip.is_code(i) {
92 i += 1;
93 continue;
94 }
95 let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
96 if !before_ok {
97 i += 1;
98 continue;
99 }
100 let matches = bytes[i..i + kw_len]
101 .iter()
102 .zip(keyword.iter())
103 .all(|(&a, &b)| a.to_ascii_uppercase() == b.to_ascii_uppercase());
104 if matches {
105 let end = i + kw_len;
106 let after_ok = end >= len || !is_word_char(bytes[end]);
107 let all_code = (i..end).all(|k| skip.is_code(k));
108 if after_ok && all_code {
109 if count == nth {
110 return Some(i);
111 }
112 count += 1;
113 i += kw_len;
114 continue;
115 }
116 }
117 i += 1;
118 }
119 None
120}
121
122fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
123 let before = &source[..offset.min(source.len())];
124 let line = before.chars().filter(|&c| c == '\n').count() + 1;
125 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
126 (line, col)
127}