Skip to main content

sqrust_rules/structure/
column_count.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor};
3
4use crate::capitalisation::{is_word_char, SkipMap};
5
6pub struct ColumnCount {
7    /// Maximum number of columns allowed in a single SELECT list.
8    pub max_columns: usize,
9}
10
11impl Default for ColumnCount {
12    fn default() -> Self {
13        ColumnCount { max_columns: 20 }
14    }
15}
16
17impl Rule for ColumnCount {
18    fn name(&self) -> &'static str {
19        "Structure/ColumnCount"
20    }
21
22    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
23        if !ctx.parse_errors.is_empty() {
24            return Vec::new();
25        }
26
27        // Collect all SELECT keyword offsets in source order so we can assign
28        // accurate line/col to each violating SELECT.
29        let select_offsets = collect_select_offsets(&ctx.source);
30
31        let mut diags = Vec::new();
32        // `select_index` tracks how many SELECT keywords we have consumed so
33        // far as we walk the AST in source order.
34        let mut select_index: usize = 0;
35
36        for stmt in &ctx.statements {
37            if let Statement::Query(query) = stmt {
38                check_query(
39                    query,
40                    self.max_columns,
41                    self.name(),
42                    &ctx.source,
43                    &select_offsets,
44                    &mut select_index,
45                    &mut diags,
46                );
47            }
48        }
49
50        diags
51    }
52}
53
54// ── AST walking ───────────────────────────────────────────────────────────────
55
56fn check_query(
57    query: &Query,
58    max: usize,
59    rule: &'static str,
60    source: &str,
61    offsets: &[usize],
62    idx: &mut usize,
63    diags: &mut Vec<Diagnostic>,
64) {
65    // CTEs are visited before the main body in source order.
66    if let Some(with) = &query.with {
67        for cte in &with.cte_tables {
68            check_query(&cte.query, max, rule, source, offsets, idx, diags);
69        }
70    }
71    check_set_expr(&query.body, max, rule, source, offsets, idx, diags);
72}
73
74fn check_set_expr(
75    expr: &SetExpr,
76    max: usize,
77    rule: &'static str,
78    source: &str,
79    offsets: &[usize],
80    idx: &mut usize,
81    diags: &mut Vec<Diagnostic>,
82) {
83    match expr {
84        SetExpr::Select(sel) => {
85            check_select(sel, max, rule, source, offsets, idx, diags);
86        }
87        SetExpr::Query(inner) => {
88            check_query(inner, max, rule, source, offsets, idx, diags);
89        }
90        SetExpr::SetOperation { left, right, .. } => {
91            check_set_expr(left, max, rule, source, offsets, idx, diags);
92            check_set_expr(right, max, rule, source, offsets, idx, diags);
93        }
94        _ => {}
95    }
96}
97
98fn check_select(
99    sel: &Select,
100    max: usize,
101    rule: &'static str,
102    source: &str,
103    offsets: &[usize],
104    idx: &mut usize,
105    diags: &mut Vec<Diagnostic>,
106) {
107    // Each SELECT in the AST corresponds to one SELECT keyword in source.
108    let offset = offsets.get(*idx).copied().unwrap_or(0);
109    *idx += 1;
110
111    let count = sel.projection.len();
112    if count > max {
113        let (line, col) = line_col(source, offset);
114        diags.push(Diagnostic {
115            rule,
116            message: format!(
117                "SELECT has {count} columns; maximum is {max}",
118                count = count,
119                max = max,
120            ),
121            line,
122            col,
123        });
124    }
125
126    // Recurse into subqueries inside FROM / JOIN clauses.
127    for table in &sel.from {
128        recurse_table_factor(&table.relation, max, rule, source, offsets, idx, diags);
129        for join in &table.joins {
130            recurse_table_factor(&join.relation, max, rule, source, offsets, idx, diags);
131        }
132    }
133
134    // Recurse into subqueries inside the WHERE clause.
135    if let Some(selection) = &sel.selection {
136        recurse_expr(selection, max, rule, source, offsets, idx, diags);
137    }
138
139    // Recurse into scalar subqueries in the projection list.
140    for item in &sel.projection {
141        if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
142            recurse_expr(e, max, rule, source, offsets, idx, diags);
143        }
144    }
145}
146
147fn recurse_table_factor(
148    tf: &TableFactor,
149    max: usize,
150    rule: &'static str,
151    source: &str,
152    offsets: &[usize],
153    idx: &mut usize,
154    diags: &mut Vec<Diagnostic>,
155) {
156    if let TableFactor::Derived { subquery, .. } = tf {
157        check_query(subquery, max, rule, source, offsets, idx, diags);
158    }
159}
160
161fn recurse_expr(
162    expr: &Expr,
163    max: usize,
164    rule: &'static str,
165    source: &str,
166    offsets: &[usize],
167    idx: &mut usize,
168    diags: &mut Vec<Diagnostic>,
169) {
170    match expr {
171        Expr::Subquery(q) => check_query(q, max, rule, source, offsets, idx, diags),
172        Expr::InSubquery { subquery, .. } => {
173            check_query(subquery, max, rule, source, offsets, idx, diags)
174        }
175        Expr::Exists { subquery, .. } => {
176            check_query(subquery, max, rule, source, offsets, idx, diags)
177        }
178        Expr::BinaryOp { left, right, .. } => {
179            recurse_expr(left, max, rule, source, offsets, idx, diags);
180            recurse_expr(right, max, rule, source, offsets, idx, diags);
181        }
182        _ => {}
183    }
184}
185
186// ── helpers ───────────────────────────────────────────────────────────────────
187
188/// Collect byte offsets of every `SELECT` keyword (case-insensitive,
189/// word-boundary, outside strings/comments) in source order.
190fn collect_select_offsets(source: &str) -> Vec<usize> {
191    let bytes = source.as_bytes();
192    let len = bytes.len();
193    let skip_map = SkipMap::build(source);
194    let kw = b"SELECT";
195    let kw_len = kw.len();
196    let mut offsets = Vec::new();
197
198    let mut i = 0;
199    while i + kw_len <= len {
200        if !skip_map.is_code(i) {
201            i += 1;
202            continue;
203        }
204
205        // Word boundary check before.
206        let before_ok = i == 0 || {
207            let b = bytes[i - 1];
208            !is_word_char(b)
209        };
210        if !before_ok {
211            i += 1;
212            continue;
213        }
214
215        // Case-insensitive match.
216        let matches = bytes[i..i + kw_len]
217            .iter()
218            .zip(kw.iter())
219            .all(|(a, b)| a.eq_ignore_ascii_case(b));
220
221        if matches {
222            // Word boundary check after.
223            let after = i + kw_len;
224            let after_ok = after >= len || !is_word_char(bytes[after]);
225
226            // All bytes of SELECT must be real code.
227            let all_code = (i..i + kw_len).all(|k| skip_map.is_code(k));
228
229            if after_ok && all_code {
230                offsets.push(i);
231                i += kw_len;
232                continue;
233            }
234        }
235
236        i += 1;
237    }
238
239    offsets
240}
241
242/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
243fn line_col(source: &str, offset: usize) -> (usize, usize) {
244    let before = &source[..offset];
245    let line = before.chars().filter(|&c| c == '\n').count() + 1;
246    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
247    (line, col)
248}