1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor};
3
4pub struct SubqueryInSelect;
5
6impl Rule for SubqueryInSelect {
7 fn name(&self) -> &'static str {
8 "Structure/SubqueryInSelect"
9 }
10
11 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12 if !ctx.parse_errors.is_empty() {
13 return Vec::new();
14 }
15
16 let mut diags = Vec::new();
17
18 for stmt in &ctx.statements {
19 if let Statement::Query(query) = stmt {
20 check_query(query, self.name(), &ctx.source, &mut diags);
21 }
22 }
23
24 diags
25 }
26}
27
28fn check_query(
31 query: &Query,
32 rule: &'static str,
33 source: &str,
34 diags: &mut Vec<Diagnostic>,
35) {
36 if let Some(with) = &query.with {
38 for cte in &with.cte_tables {
39 check_query(&cte.query, rule, source, diags);
40 }
41 }
42 check_set_expr(&query.body, rule, source, diags);
43}
44
45fn check_set_expr(
46 expr: &SetExpr,
47 rule: &'static str,
48 source: &str,
49 diags: &mut Vec<Diagnostic>,
50) {
51 match expr {
52 SetExpr::Select(sel) => {
53 check_select(sel, rule, source, diags);
54 }
55 SetExpr::Query(inner) => {
56 check_query(inner, rule, source, diags);
57 }
58 SetExpr::SetOperation { left, right, .. } => {
59 check_set_expr(left, rule, source, diags);
60 check_set_expr(right, rule, source, diags);
61 }
62 _ => {}
63 }
64}
65
66fn check_select(
67 sel: &Select,
68 rule: &'static str,
69 source: &str,
70 diags: &mut Vec<Diagnostic>,
71) {
72 for item in &sel.projection {
74 let expr = match item {
75 SelectItem::UnnamedExpr(e) => Some(e),
76 SelectItem::ExprWithAlias { expr, .. } => Some(expr),
77 _ => None,
78 };
79
80 if let Some(Expr::Subquery(subquery)) = expr {
81 let (line, col) = find_subquery_pos(source, subquery);
82 diags.push(Diagnostic {
83 rule,
84 message: "Scalar subquery in SELECT list may cause N+1 query performance issues; consider using a JOIN".to_string(),
85 line,
86 col,
87 });
88 check_query(subquery, rule, source, diags);
90 }
91 }
92
93 for table in &sel.from {
95 recurse_table_factor(&table.relation, rule, source, diags);
96 for join in &table.joins {
97 recurse_table_factor(&join.relation, rule, source, diags);
98 }
99 }
100
101 if let Some(selection) = &sel.selection {
105 recurse_expr_for_queries(selection, rule, source, diags);
106 }
107}
108
109fn recurse_table_factor(
110 tf: &TableFactor,
111 rule: &'static str,
112 source: &str,
113 diags: &mut Vec<Diagnostic>,
114) {
115 if let TableFactor::Derived { subquery, .. } = tf {
116 check_query(subquery, rule, source, diags);
117 }
118}
119
120fn recurse_expr_for_queries(
124 expr: &Expr,
125 rule: &'static str,
126 source: &str,
127 diags: &mut Vec<Diagnostic>,
128) {
129 match expr {
130 Expr::Subquery(q) => check_query(q, rule, source, diags),
131 Expr::InSubquery { subquery, .. } => check_query(subquery, rule, source, diags),
132 Expr::Exists { subquery, .. } => check_query(subquery, rule, source, diags),
133 Expr::BinaryOp { left, right, .. } => {
134 recurse_expr_for_queries(left, rule, source, diags);
135 recurse_expr_for_queries(right, rule, source, diags);
136 }
137 _ => {}
138 }
139}
140
141fn find_subquery_pos(source: &str, _query: &Query) -> (usize, usize) {
147 let bytes = source.as_bytes();
148 let len = bytes.len();
149
150 let mut i = 0;
151 while i < len {
152 if bytes[i] == b'(' {
153 let mut j = i + 1;
155 while j < len
156 && (bytes[j] == b' '
157 || bytes[j] == b'\t'
158 || bytes[j] == b'\n'
159 || bytes[j] == b'\r')
160 {
161 j += 1;
162 }
163
164 let kw = b"SELECT";
166 let kw_len = kw.len();
167 if j + kw_len <= len {
168 let matches = bytes[j..j + kw_len]
169 .iter()
170 .zip(kw.iter())
171 .all(|(a, b)| a.eq_ignore_ascii_case(b));
172
173 let boundary_after = j + kw_len >= len || {
174 let nb = bytes[j + kw_len];
175 !nb.is_ascii_alphanumeric() && nb != b'_'
176 };
177
178 if matches && boundary_after {
179 return line_col(source, i);
180 }
181 }
182 }
183 i += 1;
184 }
185
186 (1, 1)
187}
188
189fn line_col(source: &str, offset: usize) -> (usize, usize) {
191 let before = &source[..offset];
192 let line = before.chars().filter(|&c| c == '\n').count() + 1;
193 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
194 (line, col)
195}