Skip to main content

sqrust_rules/structure/
subquery_in_select.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor};
3
4pub struct SubqueryInSelect;
5
6impl Rule for SubqueryInSelect {
7    fn name(&self) -> &'static str {
8        "Structure/SubqueryInSelect"
9    }
10
11    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12        if !ctx.parse_errors.is_empty() {
13            return Vec::new();
14        }
15
16        let mut diags = Vec::new();
17
18        for stmt in &ctx.statements {
19            if let Statement::Query(query) = stmt {
20                check_query(query, self.name(), &ctx.source, &mut diags);
21            }
22        }
23
24        diags
25    }
26}
27
28// ── AST walking ───────────────────────────────────────────────────────────────
29
30fn check_query(
31    query: &Query,
32    rule: &'static str,
33    source: &str,
34    diags: &mut Vec<Diagnostic>,
35) {
36    // Visit CTEs first.
37    if let Some(with) = &query.with {
38        for cte in &with.cte_tables {
39            check_query(&cte.query, rule, source, diags);
40        }
41    }
42    check_set_expr(&query.body, rule, source, diags);
43}
44
45fn check_set_expr(
46    expr: &SetExpr,
47    rule: &'static str,
48    source: &str,
49    diags: &mut Vec<Diagnostic>,
50) {
51    match expr {
52        SetExpr::Select(sel) => {
53            check_select(sel, rule, source, diags);
54        }
55        SetExpr::Query(inner) => {
56            check_query(inner, rule, source, diags);
57        }
58        SetExpr::SetOperation { left, right, .. } => {
59            check_set_expr(left, rule, source, diags);
60            check_set_expr(right, rule, source, diags);
61        }
62        _ => {}
63    }
64}
65
66fn check_select(
67    sel: &Select,
68    rule: &'static str,
69    source: &str,
70    diags: &mut Vec<Diagnostic>,
71) {
72    // Check each item in the projection list for a scalar subquery.
73    for item in &sel.projection {
74        let expr = match item {
75            SelectItem::UnnamedExpr(e) => Some(e),
76            SelectItem::ExprWithAlias { expr, .. } => Some(expr),
77            _ => None,
78        };
79
80        if let Some(Expr::Subquery(subquery)) = expr {
81            let (line, col) = find_subquery_pos(source, subquery);
82            diags.push(Diagnostic {
83                rule,
84                message: "Scalar subquery in SELECT list may cause N+1 query performance issues; consider using a JOIN".to_string(),
85                line,
86                col,
87            });
88            // Recurse into the subquery's own SELECT list (nested pattern).
89            check_query(subquery, rule, source, diags);
90        }
91    }
92
93    // Recurse into subqueries in the FROM clause.
94    for table in &sel.from {
95        recurse_table_factor(&table.relation, rule, source, diags);
96        for join in &table.joins {
97            recurse_table_factor(&join.relation, rule, source, diags);
98        }
99    }
100
101    // Recurse into the WHERE clause expressions to catch any scalar subqueries
102    // that appear in nested queries reached through WHERE (not flagged, just
103    // walked so nested SELECT-list subqueries can be found).
104    if let Some(selection) = &sel.selection {
105        recurse_expr_for_queries(selection, rule, source, diags);
106    }
107}
108
109fn recurse_table_factor(
110    tf: &TableFactor,
111    rule: &'static str,
112    source: &str,
113    diags: &mut Vec<Diagnostic>,
114) {
115    if let TableFactor::Derived { subquery, .. } = tf {
116        check_query(subquery, rule, source, diags);
117    }
118}
119
120/// Walk an expression only to find nested Query nodes (e.g. in WHERE/IN/EXISTS)
121/// so that any SELECT-list subqueries inside those are checked. We do NOT flag
122/// the expression itself here — only SELECT-list items are flagged.
123fn recurse_expr_for_queries(
124    expr: &Expr,
125    rule: &'static str,
126    source: &str,
127    diags: &mut Vec<Diagnostic>,
128) {
129    match expr {
130        Expr::Subquery(q) => check_query(q, rule, source, diags),
131        Expr::InSubquery { subquery, .. } => check_query(subquery, rule, source, diags),
132        Expr::Exists { subquery, .. } => check_query(subquery, rule, source, diags),
133        Expr::BinaryOp { left, right, .. } => {
134            recurse_expr_for_queries(left, rule, source, diags);
135            recurse_expr_for_queries(right, rule, source, diags);
136        }
137        _ => {}
138    }
139}
140
141// ── helpers ───────────────────────────────────────────────────────────────────
142
143/// Find the position of the opening `(SELECT` for a scalar subquery.
144/// Scans the source for `(` followed by optional whitespace followed by SELECT.
145/// Falls back to (1, 1) if not found.
146fn find_subquery_pos(source: &str, _query: &Query) -> (usize, usize) {
147    let bytes = source.as_bytes();
148    let len = bytes.len();
149
150    let mut i = 0;
151    while i < len {
152        if bytes[i] == b'(' {
153            // Scan forward past optional whitespace.
154            let mut j = i + 1;
155            while j < len
156                && (bytes[j] == b' '
157                    || bytes[j] == b'\t'
158                    || bytes[j] == b'\n'
159                    || bytes[j] == b'\r')
160            {
161                j += 1;
162            }
163
164            // Check for SELECT keyword (case-insensitive, word-boundary after).
165            let kw = b"SELECT";
166            let kw_len = kw.len();
167            if j + kw_len <= len {
168                let matches = bytes[j..j + kw_len]
169                    .iter()
170                    .zip(kw.iter())
171                    .all(|(a, b)| a.eq_ignore_ascii_case(b));
172
173                let boundary_after = j + kw_len >= len || {
174                    let nb = bytes[j + kw_len];
175                    !nb.is_ascii_alphanumeric() && nb != b'_'
176                };
177
178                if matches && boundary_after {
179                    return line_col(source, i);
180                }
181            }
182        }
183        i += 1;
184    }
185
186    (1, 1)
187}
188
189/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
190fn line_col(source: &str, offset: usize) -> (usize, usize) {
191    let before = &source[..offset];
192    let line = before.chars().filter(|&c| c == '\n').count() + 1;
193    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
194    (line, col)
195}