sqrust_rules/structure/
natural_join.rs1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Join, JoinConstraint, JoinOperator, Query, SetExpr, Statement, TableFactor};
3
4pub struct NaturalJoin;
5
6impl Rule for NaturalJoin {
7 fn name(&self) -> &'static str {
8 "NaturalJoin"
9 }
10
11 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12 if !ctx.parse_errors.is_empty() {
13 return Vec::new();
14 }
15
16 let mut diags = Vec::new();
17
18 for stmt in &ctx.statements {
19 if let Statement::Query(query) = stmt {
20 check_query(query, &ctx.source, self.name(), &mut diags);
21 }
22 }
23
24 diags
25 }
26}
27
28fn check_query(query: &Query, source: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
31 if let Some(with) = &query.with {
33 for cte in &with.cte_tables {
34 check_query(&cte.query, source, rule, diags);
35 }
36 }
37 check_set_expr(&query.body, source, rule, diags);
38}
39
40fn check_set_expr(expr: &SetExpr, source: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
41 match expr {
42 SetExpr::Select(sel) => {
43 for twj in &sel.from {
44 for join in &twj.joins {
45 check_join(join, source, rule, diags);
46 }
47 recurse_table_factor(&twj.relation, source, rule, diags);
48 for join in &twj.joins {
49 recurse_table_factor(&join.relation, source, rule, diags);
50 }
51 }
52 }
53 SetExpr::SetOperation { left, right, .. } => {
54 check_set_expr(left, source, rule, diags);
55 check_set_expr(right, source, rule, diags);
56 }
57 SetExpr::Query(inner) => {
58 check_query(inner, source, rule, diags);
59 }
60 _ => {}
61 }
62}
63
64fn is_natural(op: &JoinOperator) -> bool {
72 match op {
73 JoinOperator::Inner(c) => matches!(c, JoinConstraint::Natural),
74 JoinOperator::LeftOuter(c) => matches!(c, JoinConstraint::Natural),
75 JoinOperator::RightOuter(c) => matches!(c, JoinConstraint::Natural),
76 JoinOperator::FullOuter(c) => matches!(c, JoinConstraint::Natural),
77 _ => false,
78 }
79}
80
81fn check_join(join: &Join, source: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
82 if is_natural(&join.join_operator) {
83 let (line, col) = find_keyword_pos(source, "NATURAL");
84 diags.push(Diagnostic {
85 rule,
86 message: "NATURAL JOIN depends on column naming conventions; use explicit JOIN ON instead".to_string(),
87 line,
88 col,
89 });
90 }
91}
92
93fn recurse_table_factor(
94 tf: &TableFactor,
95 source: &str,
96 rule: &'static str,
97 diags: &mut Vec<Diagnostic>,
98) {
99 if let TableFactor::Derived { subquery, .. } = tf {
100 check_query(subquery, source, rule, diags);
101 }
102}
103
104fn find_keyword_pos(source: &str, keyword: &str) -> (usize, usize) {
110 let upper = source.to_uppercase();
111 let kw_upper = keyword.to_uppercase();
112 let kw_len = kw_upper.len();
113 let bytes = upper.as_bytes();
114 let len = bytes.len();
115
116 let mut pos = 0;
117 while pos + kw_len <= len {
118 if let Some(rel) = upper[pos..].find(kw_upper.as_str()) {
119 let abs = pos + rel;
120
121 let before_ok = abs == 0 || {
123 let b = bytes[abs - 1];
124 !b.is_ascii_alphanumeric() && b != b'_'
125 };
126 let after = abs + kw_len;
127 let after_ok = after >= len || {
128 let b = bytes[after];
129 !b.is_ascii_alphanumeric() && b != b'_'
130 };
131
132 if before_ok && after_ok {
133 return line_col(source, abs);
134 }
135
136 pos = abs + 1;
137 } else {
138 break;
139 }
140 }
141
142 (1, 1)
143}
144
145fn line_col(source: &str, offset: usize) -> (usize, usize) {
147 let before = &source[..offset];
148 let line = before.chars().filter(|&c| c == '\n').count() + 1;
149 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
150 (line, col)
151}