Skip to main content

sqrust_rules/structure/
max_select_columns.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor};
3
4use crate::capitalisation::{is_word_char, SkipMap};
5
6pub struct MaxSelectColumns {
7    /// Maximum number of non-wildcard columns allowed in a single SELECT list.
8    pub max_columns: usize,
9}
10
11impl Default for MaxSelectColumns {
12    fn default() -> Self {
13        MaxSelectColumns { max_columns: 20 }
14    }
15}
16
17impl Rule for MaxSelectColumns {
18    fn name(&self) -> &'static str {
19        "Structure/MaxSelectColumns"
20    }
21
22    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
23        if !ctx.parse_errors.is_empty() {
24            return Vec::new();
25        }
26
27        let select_offsets = collect_select_offsets(&ctx.source);
28
29        let mut diags = Vec::new();
30        let mut select_index: usize = 0;
31
32        for stmt in &ctx.statements {
33            if let Statement::Query(query) = stmt {
34                check_query(
35                    query,
36                    self.max_columns,
37                    self.name(),
38                    &ctx.source,
39                    &select_offsets,
40                    &mut select_index,
41                    &mut diags,
42                );
43            }
44        }
45
46        diags
47    }
48}
49
50// ── AST walking ───────────────────────────────────────────────────────────────
51
52fn check_query(
53    query: &Query,
54    max: usize,
55    rule: &'static str,
56    source: &str,
57    offsets: &[usize],
58    idx: &mut usize,
59    diags: &mut Vec<Diagnostic>,
60) {
61    if let Some(with) = &query.with {
62        for cte in &with.cte_tables {
63            check_query(&cte.query, max, rule, source, offsets, idx, diags);
64        }
65    }
66    check_set_expr(&query.body, max, rule, source, offsets, idx, diags);
67}
68
69fn check_set_expr(
70    expr: &SetExpr,
71    max: usize,
72    rule: &'static str,
73    source: &str,
74    offsets: &[usize],
75    idx: &mut usize,
76    diags: &mut Vec<Diagnostic>,
77) {
78    match expr {
79        SetExpr::Select(sel) => {
80            check_select(sel, max, rule, source, offsets, idx, diags);
81        }
82        SetExpr::Query(inner) => {
83            check_query(inner, max, rule, source, offsets, idx, diags);
84        }
85        SetExpr::SetOperation { left, right, .. } => {
86            check_set_expr(left, max, rule, source, offsets, idx, diags);
87            check_set_expr(right, max, rule, source, offsets, idx, diags);
88        }
89        _ => {}
90    }
91}
92
93fn check_select(
94    sel: &Select,
95    max: usize,
96    rule: &'static str,
97    source: &str,
98    offsets: &[usize],
99    idx: &mut usize,
100    diags: &mut Vec<Diagnostic>,
101) {
102    let offset = offsets.get(*idx).copied().unwrap_or(0);
103    *idx += 1;
104
105    // Count only non-wildcard items.
106    let count = sel.projection.iter().filter(|item| {
107        !matches!(item, SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(..))
108    }).count();
109
110    if count > max {
111        let (line, col) = line_col(source, offset);
112        diags.push(Diagnostic {
113            rule,
114            message: format!(
115                "SELECT has {count} columns; maximum is {max}",
116                count = count,
117                max = max,
118            ),
119            line,
120            col,
121        });
122    }
123
124    // Recurse into subqueries inside FROM / JOIN clauses.
125    for table in &sel.from {
126        recurse_table_factor(&table.relation, max, rule, source, offsets, idx, diags);
127        for join in &table.joins {
128            recurse_table_factor(&join.relation, max, rule, source, offsets, idx, diags);
129        }
130    }
131
132    // Recurse into subqueries inside the WHERE clause.
133    if let Some(selection) = &sel.selection {
134        recurse_expr(selection, max, rule, source, offsets, idx, diags);
135    }
136
137    // Recurse into scalar subqueries in the projection list.
138    for item in &sel.projection {
139        if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
140            recurse_expr(e, max, rule, source, offsets, idx, diags);
141        }
142    }
143}
144
145fn recurse_table_factor(
146    tf: &TableFactor,
147    max: usize,
148    rule: &'static str,
149    source: &str,
150    offsets: &[usize],
151    idx: &mut usize,
152    diags: &mut Vec<Diagnostic>,
153) {
154    if let TableFactor::Derived { subquery, .. } = tf {
155        check_query(subquery, max, rule, source, offsets, idx, diags);
156    }
157}
158
159fn recurse_expr(
160    expr: &Expr,
161    max: usize,
162    rule: &'static str,
163    source: &str,
164    offsets: &[usize],
165    idx: &mut usize,
166    diags: &mut Vec<Diagnostic>,
167) {
168    match expr {
169        Expr::Subquery(q) => check_query(q, max, rule, source, offsets, idx, diags),
170        Expr::InSubquery { subquery, .. } => {
171            check_query(subquery, max, rule, source, offsets, idx, diags)
172        }
173        Expr::Exists { subquery, .. } => {
174            check_query(subquery, max, rule, source, offsets, idx, diags)
175        }
176        Expr::BinaryOp { left, right, .. } => {
177            recurse_expr(left, max, rule, source, offsets, idx, diags);
178            recurse_expr(right, max, rule, source, offsets, idx, diags);
179        }
180        _ => {}
181    }
182}
183
184// ── helpers ───────────────────────────────────────────────────────────────────
185
186/// Collect byte offsets of every `SELECT` keyword (case-insensitive,
187/// word-boundary, outside strings/comments) in source order.
188fn collect_select_offsets(source: &str) -> Vec<usize> {
189    let bytes = source.as_bytes();
190    let len = bytes.len();
191    let skip_map = SkipMap::build(source);
192    let kw = b"SELECT";
193    let kw_len = kw.len();
194    let mut offsets = Vec::new();
195
196    let mut i = 0;
197    while i + kw_len <= len {
198        if !skip_map.is_code(i) {
199            i += 1;
200            continue;
201        }
202
203        let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
204        if !before_ok {
205            i += 1;
206            continue;
207        }
208
209        let matches = bytes[i..i + kw_len]
210            .iter()
211            .zip(kw.iter())
212            .all(|(a, b)| a.eq_ignore_ascii_case(b));
213
214        if matches {
215            let after = i + kw_len;
216            let after_ok = after >= len || !is_word_char(bytes[after]);
217            let all_code = (i..i + kw_len).all(|k| skip_map.is_code(k));
218
219            if after_ok && all_code {
220                offsets.push(i);
221                i += kw_len;
222                continue;
223            }
224        }
225
226        i += 1;
227    }
228
229    offsets
230}
231
232/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
233fn line_col(source: &str, offset: usize) -> (usize, usize) {
234    let before = &source[..offset];
235    let line = before.chars().filter(|&c| c == '\n').count() + 1;
236    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
237    (line, col)
238}