Skip to main content

sqrust_rules/structure/
max_join_on_conditions.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3    BinaryOperator, Expr, JoinConstraint, JoinOperator, Query, Select, SetExpr, Statement,
4    TableFactor,
5};
6
7use crate::capitalisation::{is_word_char, SkipMap};
8
9pub struct MaxJoinOnConditions {
10    /// Maximum number of conditions allowed in a single JOIN ON clause.
11    /// A clause with N conditions is connected by N-1 AND/OR operators.
12    /// When the condition count exceeds this maximum the clause is flagged.
13    pub max_conditions: usize,
14}
15
16impl Default for MaxJoinOnConditions {
17    fn default() -> Self {
18        MaxJoinOnConditions { max_conditions: 3 }
19    }
20}
21
22impl Rule for MaxJoinOnConditions {
23    fn name(&self) -> &'static str {
24        "Structure/MaxJoinOnConditions"
25    }
26
27    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
28        if !ctx.parse_errors.is_empty() {
29            return Vec::new();
30        }
31
32        let mut diags = Vec::new();
33
34        for stmt in &ctx.statements {
35            if let Statement::Query(query) = stmt {
36                check_query(query, self.max_conditions, ctx, &mut diags);
37            }
38        }
39
40        diags
41    }
42}
43
44// ── AST walking ───────────────────────────────────────────────────────────────
45
46fn check_query(
47    query: &Query,
48    max: usize,
49    ctx: &FileContext,
50    diags: &mut Vec<Diagnostic>,
51) {
52    // Visit CTEs.
53    if let Some(with) = &query.with {
54        for cte in &with.cte_tables {
55            check_query(&cte.query, max, ctx, diags);
56        }
57    }
58
59    check_set_expr(&query.body, max, ctx, diags);
60}
61
62fn check_set_expr(
63    expr: &SetExpr,
64    max: usize,
65    ctx: &FileContext,
66    diags: &mut Vec<Diagnostic>,
67) {
68    match expr {
69        SetExpr::Select(sel) => {
70            check_select(sel, max, ctx, diags);
71        }
72        SetExpr::Query(inner) => {
73            check_query(inner, max, ctx, diags);
74        }
75        SetExpr::SetOperation { left, right, .. } => {
76            check_set_expr(left, max, ctx, diags);
77            check_set_expr(right, max, ctx, diags);
78        }
79        _ => {}
80    }
81}
82
83fn check_select(
84    sel: &Select,
85    max: usize,
86    ctx: &FileContext,
87    diags: &mut Vec<Diagnostic>,
88) {
89    // Track which ON keyword occurrence we are on (0-indexed) so each JOIN ON
90    // gets its own position reported.
91    let mut on_occurrence: usize = 0;
92
93    for twj in &sel.from {
94        // Recurse into subqueries in the main table factor.
95        check_table_factor(&twj.relation, max, ctx, diags);
96
97        for join in &twj.joins {
98            // Recurse into subqueries inside joined tables.
99            check_table_factor(&join.relation, max, ctx, diags);
100
101            // Extract the ON expression based on join type.
102            let on_expr = match &join.join_operator {
103                JoinOperator::Inner(JoinConstraint::On(expr))
104                | JoinOperator::LeftOuter(JoinConstraint::On(expr))
105                | JoinOperator::RightOuter(JoinConstraint::On(expr))
106                | JoinOperator::FullOuter(JoinConstraint::On(expr)) => Some(expr),
107                _ => None,
108            };
109
110            if let Some(on_expr) = on_expr {
111                let ops = count_and_or_ops(on_expr);
112                let total = ops + 1;
113                if total > max {
114                    let (line, col) = find_keyword_pos(&ctx.source, "ON", on_occurrence);
115                    diags.push(Diagnostic {
116                        rule: "Structure/MaxJoinOnConditions",
117                        message: format!(
118                            "JOIN ON clause has {total} conditions, exceeding the maximum of {max}"
119                        ),
120                        line,
121                        col,
122                    });
123                }
124                on_occurrence += 1;
125            }
126        }
127    }
128}
129
130fn check_table_factor(
131    tf: &TableFactor,
132    max: usize,
133    ctx: &FileContext,
134    diags: &mut Vec<Diagnostic>,
135) {
136    if let TableFactor::Derived { subquery, .. } = tf {
137        check_query(subquery, max, ctx, diags);
138    }
139}
140
141// ── condition counting ────────────────────────────────────────────────────────
142
143/// Count the number of AND/OR binary operations in an expression recursively.
144/// Each AND or OR operator adds 1 to the count.
145fn count_and_or_ops(expr: &Expr) -> usize {
146    match expr {
147        Expr::BinaryOp {
148            left,
149            op: BinaryOperator::And | BinaryOperator::Or,
150            right,
151        } => 1 + count_and_or_ops(left) + count_and_or_ops(right),
152        Expr::BinaryOp { left, right, .. } => {
153            count_and_or_ops(left) + count_and_or_ops(right)
154        }
155        Expr::UnaryOp { expr: inner, .. } => count_and_or_ops(inner),
156        Expr::Nested(inner) => count_and_or_ops(inner),
157        _ => 0,
158    }
159}
160
161// ── keyword position helper ───────────────────────────────────────────────────
162
163/// Find the `nth` (0-indexed) occurrence of a keyword (case-insensitive,
164/// word-boundary, outside strings/comments) in `source`. Returns a
165/// 1-indexed (line, col) pair. Falls back to (1, 1) if not found.
166fn find_keyword_pos(source: &str, keyword: &str, nth: usize) -> (usize, usize) {
167    let bytes = source.as_bytes();
168    let len = bytes.len();
169    let skip_map = SkipMap::build(source);
170    let kw_upper: Vec<u8> = keyword.bytes().map(|b| b.to_ascii_uppercase()).collect();
171    let kw_len = kw_upper.len();
172
173    let mut count = 0usize;
174    let mut i = 0;
175    while i + kw_len <= len {
176        if !skip_map.is_code(i) {
177            i += 1;
178            continue;
179        }
180
181        // Word boundary before.
182        let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
183        if !before_ok {
184            i += 1;
185            continue;
186        }
187
188        // Case-insensitive match.
189        let matches = bytes[i..i + kw_len]
190            .iter()
191            .zip(kw_upper.iter())
192            .all(|(a, b)| a.eq_ignore_ascii_case(b));
193
194        if matches {
195            // Word boundary after.
196            let after = i + kw_len;
197            let after_ok = after >= len || !is_word_char(bytes[after]);
198            let all_code = (i..i + kw_len).all(|k| skip_map.is_code(k));
199
200            if after_ok && all_code {
201                if count == nth {
202                    return line_col(source, i);
203                }
204                count += 1;
205            }
206        }
207
208        i += 1;
209    }
210
211    (1, 1)
212}
213
214/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
215fn line_col(source: &str, offset: usize) -> (usize, usize) {
216    let before = &source[..offset];
217    let line = before.chars().filter(|&c| c == '\n').count() + 1;
218    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
219    (line, col)
220}