Skip to main content

sqrust_rules/structure/
having_conditions_count.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{BinaryOperator, Expr, Query, Select, SetExpr, Statement, TableFactor};
3
4use crate::capitalisation::{is_word_char, SkipMap};
5
6pub struct HavingConditionsCount {
7    /// Maximum number of conditions allowed in a HAVING clause.
8    /// A clause with N conditions is connected by N-1 AND/OR operators.
9    /// When the condition count exceeds this maximum the clause is flagged.
10    pub max_conditions: usize,
11}
12
13impl Default for HavingConditionsCount {
14    fn default() -> Self {
15        HavingConditionsCount { max_conditions: 5 }
16    }
17}
18
19impl Rule for HavingConditionsCount {
20    fn name(&self) -> &'static str {
21        "Structure/HavingConditionsCount"
22    }
23
24    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
25        if !ctx.parse_errors.is_empty() {
26            return Vec::new();
27        }
28
29        let mut diags = Vec::new();
30
31        for stmt in &ctx.statements {
32            if let Statement::Query(query) = stmt {
33                check_query(query, self.max_conditions, ctx, &mut diags);
34            }
35        }
36
37        diags
38    }
39}
40
41// ── AST walking ───────────────────────────────────────────────────────────────
42
43fn check_query(
44    query: &Query,
45    max: usize,
46    ctx: &FileContext,
47    diags: &mut Vec<Diagnostic>,
48) {
49    // Visit CTEs.
50    if let Some(with) = &query.with {
51        for cte in &with.cte_tables {
52            check_query(&cte.query, max, ctx, diags);
53        }
54    }
55
56    check_set_expr(&query.body, max, ctx, diags);
57}
58
59fn check_set_expr(
60    expr: &SetExpr,
61    max: usize,
62    ctx: &FileContext,
63    diags: &mut Vec<Diagnostic>,
64) {
65    match expr {
66        SetExpr::Select(sel) => {
67            check_select(sel, max, ctx, diags);
68        }
69        SetExpr::Query(inner) => {
70            check_query(inner, max, ctx, diags);
71        }
72        SetExpr::SetOperation { left, right, .. } => {
73            check_set_expr(left, max, ctx, diags);
74            check_set_expr(right, max, ctx, diags);
75        }
76        _ => {}
77    }
78}
79
80fn check_select(
81    sel: &Select,
82    max: usize,
83    ctx: &FileContext,
84    diags: &mut Vec<Diagnostic>,
85) {
86    if let Some(having) = &sel.having {
87        // AND/OR operator count; number of distinct conditions = ops + 1
88        let ops = count_and_or_ops(having);
89        let conditions = ops + 1;
90        if conditions > max {
91            let (line, col) = find_keyword_pos(&ctx.source, "HAVING");
92            diags.push(Diagnostic {
93                rule: "Structure/HavingConditionsCount",
94                message: format!(
95                    "HAVING clause has {conditions} conditions, exceeding the maximum of {max}",
96                ),
97                line,
98                col,
99            });
100        }
101    }
102
103    // Recurse into subqueries in the FROM clause.
104    for twj in &sel.from {
105        check_table_factor(&twj.relation, max, ctx, diags);
106        for join in &twj.joins {
107            check_table_factor(&join.relation, max, ctx, diags);
108        }
109    }
110}
111
112fn check_table_factor(
113    tf: &TableFactor,
114    max: usize,
115    ctx: &FileContext,
116    diags: &mut Vec<Diagnostic>,
117) {
118    if let TableFactor::Derived { subquery, .. } = tf {
119        check_query(subquery, max, ctx, diags);
120    }
121}
122
123// ── condition counting ────────────────────────────────────────────────────────
124
125/// Count the number of top-level AND/OR binary operations in an expression.
126/// Each AND or OR operator adds 1 to the count.
127/// A HAVING clause with N conditions is connected by N-1 such operators.
128fn count_and_or_ops(expr: &Expr) -> usize {
129    match expr {
130        Expr::BinaryOp {
131            left,
132            op: BinaryOperator::And | BinaryOperator::Or,
133            right,
134        } => 1 + count_and_or_ops(left) + count_and_or_ops(right),
135        Expr::BinaryOp { left, right, .. } => {
136            count_and_or_ops(left) + count_and_or_ops(right)
137        }
138        Expr::UnaryOp { expr: inner, .. } => count_and_or_ops(inner),
139        Expr::Nested(inner) => count_and_or_ops(inner),
140        _ => 0,
141    }
142}
143
144// ── keyword position helper ───────────────────────────────────────────────────
145
146/// Find the first occurrence of a keyword (case-insensitive, word-boundary,
147/// outside strings/comments) in `source`. Returns a 1-indexed (line, col)
148/// pair. Falls back to (1, 1) if not found.
149fn find_keyword_pos(source: &str, keyword: &str) -> (usize, usize) {
150    let bytes = source.as_bytes();
151    let len = bytes.len();
152    let skip_map = SkipMap::build(source);
153    let kw_upper: Vec<u8> = keyword.bytes().map(|b| b.to_ascii_uppercase()).collect();
154    let kw_len = kw_upper.len();
155
156    let mut i = 0;
157    while i + kw_len <= len {
158        if !skip_map.is_code(i) {
159            i += 1;
160            continue;
161        }
162
163        // Word boundary before.
164        let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
165        if !before_ok {
166            i += 1;
167            continue;
168        }
169
170        // Case-insensitive match.
171        let matches = bytes[i..i + kw_len]
172            .iter()
173            .zip(kw_upper.iter())
174            .all(|(a, b)| a.eq_ignore_ascii_case(b));
175
176        if matches {
177            // Word boundary after.
178            let after = i + kw_len;
179            let after_ok = after >= len || !is_word_char(bytes[after]);
180            let all_code = (i..i + kw_len).all(|k| skip_map.is_code(k));
181
182            if after_ok && all_code {
183                return line_col(source, i);
184            }
185        }
186
187        i += 1;
188    }
189
190    (1, 1)
191}
192
193/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
194fn line_col(source: &str, offset: usize) -> (usize, usize) {
195    let before = &source[..offset];
196    let line = before.chars().filter(|&c| c == '\n').count() + 1;
197    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
198    (line, col)
199}