Skip to main content

sqrust_rules/structure/
too_many_subqueries.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3    Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor,
4};
5
6use crate::capitalisation::{is_word_char, SkipMap};
7
8pub struct TooManySubqueries {
9    /// Maximum number of subqueries allowed in a single SQL statement.
10    /// `Expr::Subquery`, `Expr::InSubquery`, `Expr::Exists`, and each CTE
11    /// all count as one subquery each.
12    pub max_subqueries: usize,
13}
14
15impl Default for TooManySubqueries {
16    fn default() -> Self {
17        TooManySubqueries { max_subqueries: 3 }
18    }
19}
20
21impl Rule for TooManySubqueries {
22    fn name(&self) -> &'static str {
23        "Structure/TooManySubqueries"
24    }
25
26    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
27        if !ctx.parse_errors.is_empty() {
28            return Vec::new();
29        }
30
31        let mut diags = Vec::new();
32
33        for stmt in &ctx.statements {
34            if let Statement::Query(query) = stmt {
35                let n = count_subqueries_in_query(query);
36                if n > self.max_subqueries {
37                    let (line, col) = find_nth_select_pos(&ctx.source, 1);
38                    diags.push(Diagnostic {
39                        rule: "Structure/TooManySubqueries",
40                        message: format!(
41                            "Statement contains {n} subqueries, exceeding the maximum of {max}",
42                            max = self.max_subqueries,
43                        ),
44                        line,
45                        col,
46                    });
47                }
48            }
49        }
50
51        diags
52    }
53}
54
55// ── subquery counting ─────────────────────────────────────────────────────────
56
57/// Count all subqueries in a Query node, including its CTEs.
58/// Each CTE body counts as one subquery, plus any inline subquery expressions.
59fn count_subqueries_in_query(query: &Query) -> usize {
60    let mut count = 0;
61
62    // Each CTE counts as one subquery.
63    if let Some(with) = &query.with {
64        count += with.cte_tables.len();
65        // Also count subqueries nested inside each CTE body.
66        for cte in &with.cte_tables {
67            count += count_subqueries_in_set_expr(&cte.query.body);
68            if let Some(with2) = &cte.query.with {
69                count += with2.cte_tables.len();
70            }
71        }
72    }
73
74    // Count subqueries in the main query body.
75    count += count_subqueries_in_set_expr(&query.body);
76
77    count
78}
79
80fn count_subqueries_in_set_expr(expr: &SetExpr) -> usize {
81    match expr {
82        SetExpr::Select(sel) => count_subqueries_in_select(sel),
83        SetExpr::Query(inner) => count_subqueries_in_query(inner),
84        SetExpr::SetOperation { left, right, .. } => {
85            count_subqueries_in_set_expr(left) + count_subqueries_in_set_expr(right)
86        }
87        _ => 0,
88    }
89}
90
91fn count_subqueries_in_select(sel: &Select) -> usize {
92    let mut count = 0;
93
94    // Projection (SELECT list).
95    for item in &sel.projection {
96        let expr = match item {
97            SelectItem::UnnamedExpr(e) => Some(e),
98            SelectItem::ExprWithAlias { expr, .. } => Some(expr),
99            _ => None,
100        };
101        if let Some(e) = expr {
102            count += count_subqueries_in_expr(e);
103        }
104    }
105
106    // FROM clause (derived tables / subqueries).
107    for twj in &sel.from {
108        count += count_subqueries_in_table_factor(&twj.relation);
109        for join in &twj.joins {
110            count += count_subqueries_in_table_factor(&join.relation);
111        }
112    }
113
114    // WHERE clause.
115    if let Some(selection) = &sel.selection {
116        count += count_subqueries_in_expr(selection);
117    }
118
119    // HAVING clause.
120    if let Some(having) = &sel.having {
121        count += count_subqueries_in_expr(having);
122    }
123
124    count
125}
126
127fn count_subqueries_in_table_factor(tf: &TableFactor) -> usize {
128    if let TableFactor::Derived { subquery, .. } = tf {
129        // A derived table (subquery in FROM) counts as one subquery plus any
130        // subqueries nested within it.
131        1 + count_subqueries_in_query(subquery)
132    } else {
133        0
134    }
135}
136
137/// Recursively count all subquery expressions within an expression tree.
138/// Counted variants: `Subquery`, `InSubquery`, `Exists`.
139fn count_subqueries_in_expr(expr: &Expr) -> usize {
140    match expr {
141        Expr::Subquery(q) => {
142            // The subquery itself counts as 1; recurse into it to count nested ones.
143            1 + count_subqueries_in_query(q)
144        }
145        Expr::InSubquery { subquery, expr: e, .. } => {
146            1 + count_subqueries_in_query(subquery) + count_subqueries_in_expr(e)
147        }
148        Expr::Exists { subquery, .. } => {
149            1 + count_subqueries_in_query(subquery)
150        }
151        Expr::BinaryOp { left, right, .. } => {
152            count_subqueries_in_expr(left) + count_subqueries_in_expr(right)
153        }
154        Expr::UnaryOp { expr: inner, .. } => count_subqueries_in_expr(inner),
155        Expr::Nested(inner) => count_subqueries_in_expr(inner),
156        Expr::Between { expr: e, low, high, .. } => {
157            count_subqueries_in_expr(e)
158                + count_subqueries_in_expr(low)
159                + count_subqueries_in_expr(high)
160        }
161        Expr::Case {
162            operand,
163            conditions,
164            results,
165            else_result,
166        } => {
167            operand.as_ref().map_or(0, |e| count_subqueries_in_expr(e))
168                + conditions.iter().map(|e| count_subqueries_in_expr(e)).sum::<usize>()
169                + results.iter().map(|e| count_subqueries_in_expr(e)).sum::<usize>()
170                + else_result
171                    .as_ref()
172                    .map_or(0, |e| count_subqueries_in_expr(e))
173        }
174        _ => 0,
175    }
176}
177
178// ── keyword position helpers ──────────────────────────────────────────────────
179
180/// Find the `nth` (0-indexed) occurrence of a keyword (case-insensitive,
181/// word-boundary, outside strings/comments) in `source`. Returns a 1-indexed
182/// (line, col) pair. Falls back to (1, 1) if not found.
183fn find_nth_keyword_pos(source: &str, keyword: &str, nth: usize) -> (usize, usize) {
184    let bytes = source.as_bytes();
185    let len = bytes.len();
186    let skip_map = SkipMap::build(source);
187    let kw_upper: Vec<u8> = keyword.bytes().map(|b| b.to_ascii_uppercase()).collect();
188    let kw_len = kw_upper.len();
189
190    let mut count = 0;
191    let mut i = 0;
192    while i + kw_len <= len {
193        if !skip_map.is_code(i) {
194            i += 1;
195            continue;
196        }
197
198        // Word boundary before.
199        let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
200        if !before_ok {
201            i += 1;
202            continue;
203        }
204
205        // Case-insensitive match.
206        let matches = bytes[i..i + kw_len]
207            .iter()
208            .zip(kw_upper.iter())
209            .all(|(a, b)| a.eq_ignore_ascii_case(b));
210
211        if matches {
212            // Word boundary after.
213            let after = i + kw_len;
214            let after_ok = after >= len || !is_word_char(bytes[after]);
215            let all_code = (i..i + kw_len).all(|k| skip_map.is_code(k));
216
217            if after_ok && all_code {
218                if count == nth {
219                    return line_col(source, i);
220                }
221                count += 1;
222            }
223        }
224
225        i += 1;
226    }
227
228    (1, 1)
229}
230
231/// Find the position of the `nth` (0-indexed) SELECT keyword.
232/// The outer query is SELECT #0; the first subquery is SELECT #1.
233fn find_nth_select_pos(source: &str, nth: usize) -> (usize, usize) {
234    find_nth_keyword_pos(source, "SELECT", nth)
235}
236
237/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
238fn line_col(source: &str, offset: usize) -> (usize, usize) {
239    let before = &source[..offset];
240    let line = before.chars().filter(|&c| c == '\n').count() + 1;
241    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
242    (line, col)
243}