Skip to main content

sqrust_rules/structure/
large_in_list.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor};
3
4use crate::capitalisation::{is_word_char, SkipMap};
5
6pub struct LargeInList {
7    /// Maximum number of values allowed in an IN list.
8    /// IN lists with more values than this are flagged.
9    pub max_values: usize,
10}
11
12impl Default for LargeInList {
13    fn default() -> Self {
14        LargeInList { max_values: 10 }
15    }
16}
17
18impl Rule for LargeInList {
19    fn name(&self) -> &'static str {
20        "LargeInList"
21    }
22
23    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
24        if !ctx.parse_errors.is_empty() {
25            return Vec::new();
26        }
27
28        let mut diags = Vec::new();
29        // Track how many IN keywords we've consumed so we can map each
30        // violation to the correct source position.
31        let mut in_occurrence: usize = 0;
32
33        for stmt in &ctx.statements {
34            if let Statement::Query(query) = stmt {
35                check_query(query, self.max_values, ctx, &mut in_occurrence, &mut diags);
36            }
37        }
38
39        diags
40    }
41}
42
43// ── AST walking ───────────────────────────────────────────────────────────────
44
45fn check_query(
46    query: &Query,
47    max: usize,
48    ctx: &FileContext,
49    occurrence: &mut usize,
50    diags: &mut Vec<Diagnostic>,
51) {
52    if let Some(with) = &query.with {
53        for cte in &with.cte_tables {
54            check_query(&cte.query, max, ctx, occurrence, diags);
55        }
56    }
57
58    check_set_expr(&query.body, max, ctx, occurrence, diags);
59}
60
61fn check_set_expr(
62    expr: &SetExpr,
63    max: usize,
64    ctx: &FileContext,
65    occurrence: &mut usize,
66    diags: &mut Vec<Diagnostic>,
67) {
68    match expr {
69        SetExpr::Select(sel) => check_select(sel, max, ctx, occurrence, diags),
70        SetExpr::Query(inner) => check_query(inner, max, ctx, occurrence, diags),
71        SetExpr::SetOperation { left, right, .. } => {
72            check_set_expr(left, max, ctx, occurrence, diags);
73            check_set_expr(right, max, ctx, occurrence, diags);
74        }
75        _ => {}
76    }
77}
78
79fn check_select(
80    sel: &Select,
81    max: usize,
82    ctx: &FileContext,
83    occurrence: &mut usize,
84    diags: &mut Vec<Diagnostic>,
85) {
86    // SELECT projection.
87    for item in &sel.projection {
88        if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
89            check_expr(e, max, ctx, occurrence, diags);
90        }
91    }
92
93    // WHERE clause.
94    if let Some(selection) = &sel.selection {
95        check_expr(selection, max, ctx, occurrence, diags);
96    }
97
98    // FROM / JOIN — recurse into derived tables.
99    for twj in &sel.from {
100        check_table_factor(&twj.relation, max, ctx, occurrence, diags);
101        for join in &twj.joins {
102            check_table_factor(&join.relation, max, ctx, occurrence, diags);
103        }
104    }
105}
106
107fn check_table_factor(
108    tf: &TableFactor,
109    max: usize,
110    ctx: &FileContext,
111    occurrence: &mut usize,
112    diags: &mut Vec<Diagnostic>,
113) {
114    if let TableFactor::Derived { subquery, .. } = tf {
115        check_query(subquery, max, ctx, occurrence, diags);
116    }
117}
118
119fn check_expr(
120    expr: &Expr,
121    max: usize,
122    ctx: &FileContext,
123    occurrence: &mut usize,
124    diags: &mut Vec<Diagnostic>,
125) {
126    match expr {
127        Expr::InList { expr: inner, list, .. } => {
128            // Consume one IN occurrence.
129            let occ = *occurrence;
130            *occurrence += 1;
131
132            let n = list.len();
133            if n > max {
134                let (line, col) = find_nth_keyword_pos(&ctx.source, "IN", occ);
135                diags.push(Diagnostic {
136                    rule: "LargeInList",
137                    message: format!(
138                        "IN list has {n} values, exceeding the maximum of {max}"
139                    ),
140                    line,
141                    col,
142                });
143            }
144
145            // Recurse into the expression being tested and the list values.
146            check_expr(inner, max, ctx, occurrence, diags);
147            for val in list {
148                check_expr(val, max, ctx, occurrence, diags);
149            }
150        }
151
152        Expr::BinaryOp { left, right, .. } => {
153            check_expr(left, max, ctx, occurrence, diags);
154            check_expr(right, max, ctx, occurrence, diags);
155        }
156        Expr::UnaryOp { expr: inner, .. } => {
157            check_expr(inner, max, ctx, occurrence, diags);
158        }
159        Expr::Subquery(q) => check_query(q, max, ctx, occurrence, diags),
160        Expr::InSubquery { subquery, expr: e, .. } => {
161            check_expr(e, max, ctx, occurrence, diags);
162            check_query(subquery, max, ctx, occurrence, diags);
163        }
164        Expr::Exists { subquery, .. } => check_query(subquery, max, ctx, occurrence, diags),
165        Expr::Nested(inner) => check_expr(inner, max, ctx, occurrence, diags),
166        Expr::Case {
167            operand,
168            conditions,
169            results,
170            else_result,
171        } => {
172            if let Some(op) = operand {
173                check_expr(op, max, ctx, occurrence, diags);
174            }
175            for cond in conditions {
176                check_expr(cond, max, ctx, occurrence, diags);
177            }
178            for res in results {
179                check_expr(res, max, ctx, occurrence, diags);
180            }
181            if let Some(els) = else_result {
182                check_expr(els, max, ctx, occurrence, diags);
183            }
184        }
185        Expr::Function(f) => {
186            use sqlparser::ast::{FunctionArg, FunctionArgExpr, FunctionArguments};
187            if let FunctionArguments::List(list) = &f.args {
188                for arg in &list.args {
189                    let arg_expr = match arg {
190                        FunctionArg::Unnamed(e) => Some(e),
191                        FunctionArg::Named { arg: e, .. } => Some(e),
192                        FunctionArg::ExprNamed { arg: e, .. } => Some(e),
193                    };
194                    if let Some(FunctionArgExpr::Expr(inner)) = arg_expr {
195                        check_expr(inner, max, ctx, occurrence, diags);
196                    }
197                }
198            }
199        }
200        _ => {}
201    }
202}
203
204// ── keyword position helper ───────────────────────────────────────────────────
205
206/// Find the `nth` occurrence (0-indexed) of a keyword (case-insensitive,
207/// word-boundary, outside strings/comments) in `source`.
208/// Returns a 1-indexed (line, col) pair. Falls back to (1, 1) if not found.
209fn find_nth_keyword_pos(source: &str, keyword: &str, nth: usize) -> (usize, usize) {
210    let bytes = source.as_bytes();
211    let len = bytes.len();
212    let skip_map = SkipMap::build(source);
213    let kw_upper: Vec<u8> = keyword.bytes().map(|b| b.to_ascii_uppercase()).collect();
214    let kw_len = kw_upper.len();
215
216    let mut count = 0usize;
217    let mut i = 0;
218    while i + kw_len <= len {
219        if !skip_map.is_code(i) {
220            i += 1;
221            continue;
222        }
223
224        let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
225        if !before_ok {
226            i += 1;
227            continue;
228        }
229
230        let matches = bytes[i..i + kw_len]
231            .iter()
232            .zip(kw_upper.iter())
233            .all(|(a, b)| a.eq_ignore_ascii_case(b));
234
235        if matches {
236            let after = i + kw_len;
237            let after_ok = after >= len || !is_word_char(bytes[after]);
238            let all_code = (i..i + kw_len).all(|k| skip_map.is_code(k));
239
240            if after_ok && all_code {
241                if count == nth {
242                    return line_col(source, i);
243                }
244                count += 1;
245            }
246        }
247
248        i += 1;
249    }
250
251    (1, 1)
252}
253
254/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
255fn line_col(source: &str, offset: usize) -> (usize, usize) {
256    let before = &source[..offset];
257    let line = before.chars().filter(|&c| c == '\n').count() + 1;
258    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
259    (line, col)
260}