1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3 BinaryOperator, Expr, JoinConstraint, JoinOperator, Query, Select, SetExpr, Statement,
4 TableFactor,
5};
6
7use crate::capitalisation::{is_word_char, SkipMap};
8
9pub struct MaxJoinOnConditions {
10 pub max_conditions: usize,
14}
15
16impl Default for MaxJoinOnConditions {
17 fn default() -> Self {
18 MaxJoinOnConditions { max_conditions: 3 }
19 }
20}
21
22impl Rule for MaxJoinOnConditions {
23 fn name(&self) -> &'static str {
24 "Structure/MaxJoinOnConditions"
25 }
26
27 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
28 if !ctx.parse_errors.is_empty() {
29 return Vec::new();
30 }
31
32 let mut diags = Vec::new();
33
34 for stmt in &ctx.statements {
35 if let Statement::Query(query) = stmt {
36 check_query(query, self.max_conditions, ctx, &mut diags);
37 }
38 }
39
40 diags
41 }
42}
43
44fn check_query(
47 query: &Query,
48 max: usize,
49 ctx: &FileContext,
50 diags: &mut Vec<Diagnostic>,
51) {
52 if let Some(with) = &query.with {
54 for cte in &with.cte_tables {
55 check_query(&cte.query, max, ctx, diags);
56 }
57 }
58
59 check_set_expr(&query.body, max, ctx, diags);
60}
61
62fn check_set_expr(
63 expr: &SetExpr,
64 max: usize,
65 ctx: &FileContext,
66 diags: &mut Vec<Diagnostic>,
67) {
68 match expr {
69 SetExpr::Select(sel) => {
70 check_select(sel, max, ctx, diags);
71 }
72 SetExpr::Query(inner) => {
73 check_query(inner, max, ctx, diags);
74 }
75 SetExpr::SetOperation { left, right, .. } => {
76 check_set_expr(left, max, ctx, diags);
77 check_set_expr(right, max, ctx, diags);
78 }
79 _ => {}
80 }
81}
82
83fn check_select(
84 sel: &Select,
85 max: usize,
86 ctx: &FileContext,
87 diags: &mut Vec<Diagnostic>,
88) {
89 let mut on_occurrence: usize = 0;
92
93 for twj in &sel.from {
94 check_table_factor(&twj.relation, max, ctx, diags);
96
97 for join in &twj.joins {
98 check_table_factor(&join.relation, max, ctx, diags);
100
101 let on_expr = match &join.join_operator {
103 JoinOperator::Inner(JoinConstraint::On(expr))
104 | JoinOperator::LeftOuter(JoinConstraint::On(expr))
105 | JoinOperator::RightOuter(JoinConstraint::On(expr))
106 | JoinOperator::FullOuter(JoinConstraint::On(expr)) => Some(expr),
107 _ => None,
108 };
109
110 if let Some(on_expr) = on_expr {
111 let ops = count_and_or_ops(on_expr);
112 let total = ops + 1;
113 if total > max {
114 let (line, col) = find_keyword_pos(&ctx.source, "ON", on_occurrence);
115 diags.push(Diagnostic {
116 rule: "Structure/MaxJoinOnConditions",
117 message: format!(
118 "JOIN ON clause has {total} conditions, exceeding the maximum of {max}"
119 ),
120 line,
121 col,
122 });
123 }
124 on_occurrence += 1;
125 }
126 }
127 }
128}
129
130fn check_table_factor(
131 tf: &TableFactor,
132 max: usize,
133 ctx: &FileContext,
134 diags: &mut Vec<Diagnostic>,
135) {
136 if let TableFactor::Derived { subquery, .. } = tf {
137 check_query(subquery, max, ctx, diags);
138 }
139}
140
141fn count_and_or_ops(expr: &Expr) -> usize {
146 match expr {
147 Expr::BinaryOp {
148 left,
149 op: BinaryOperator::And | BinaryOperator::Or,
150 right,
151 } => 1 + count_and_or_ops(left) + count_and_or_ops(right),
152 Expr::BinaryOp { left, right, .. } => {
153 count_and_or_ops(left) + count_and_or_ops(right)
154 }
155 Expr::UnaryOp { expr: inner, .. } => count_and_or_ops(inner),
156 Expr::Nested(inner) => count_and_or_ops(inner),
157 _ => 0,
158 }
159}
160
161fn find_keyword_pos(source: &str, keyword: &str, nth: usize) -> (usize, usize) {
167 let bytes = source.as_bytes();
168 let len = bytes.len();
169 let skip_map = SkipMap::build(source);
170 let kw_upper: Vec<u8> = keyword.bytes().map(|b| b.to_ascii_uppercase()).collect();
171 let kw_len = kw_upper.len();
172
173 let mut count = 0usize;
174 let mut i = 0;
175 while i + kw_len <= len {
176 if !skip_map.is_code(i) {
177 i += 1;
178 continue;
179 }
180
181 let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
183 if !before_ok {
184 i += 1;
185 continue;
186 }
187
188 let matches = bytes[i..i + kw_len]
190 .iter()
191 .zip(kw_upper.iter())
192 .all(|(a, b)| a.eq_ignore_ascii_case(b));
193
194 if matches {
195 let after = i + kw_len;
197 let after_ok = after >= len || !is_word_char(bytes[after]);
198 let all_code = (i..i + kw_len).all(|k| skip_map.is_code(k));
199
200 if after_ok && all_code {
201 if count == nth {
202 return line_col(source, i);
203 }
204 count += 1;
205 }
206 }
207
208 i += 1;
209 }
210
211 (1, 1)
212}
213
214fn line_col(source: &str, offset: usize) -> (usize, usize) {
216 let before = &source[..offset];
217 let line = before.chars().filter(|&c| c == '\n').count() + 1;
218 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
219 (line, col)
220}