sqrust_rules/ambiguous/
subquery_in_group_by.rs1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, GroupByExpr, Query, SetExpr, Statement, TableFactor};
3
4pub struct SubqueryInGroupBy;
5
6impl Rule for SubqueryInGroupBy {
7 fn name(&self) -> &'static str {
8 "Ambiguous/SubqueryInGroupBy"
9 }
10
11 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12 if !ctx.parse_errors.is_empty() {
13 return Vec::new();
14 }
15
16 let mut diags = Vec::new();
17 for stmt in &ctx.statements {
18 collect_from_statement(stmt, ctx, &mut diags);
19 }
20 diags
21 }
22}
23
24fn collect_from_statement(stmt: &Statement, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
25 if let Statement::Query(query) = stmt {
26 collect_from_query(query, ctx, diags);
27 }
28}
29
30fn collect_from_query(query: &Query, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
31 if let Some(with) = &query.with {
32 for cte in &with.cte_tables {
33 collect_from_query(&cte.query, ctx, diags);
34 }
35 }
36 collect_from_set_expr(&query.body, ctx, diags);
37}
38
39fn collect_from_set_expr(expr: &SetExpr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
40 match expr {
41 SetExpr::Select(select) => {
42 if let GroupByExpr::Expressions(exprs, _) = &select.group_by {
44 for group_expr in exprs {
45 if contains_subquery(group_expr) {
46 let (line, col) = find_keyword_position(&ctx.source, "GROUP BY")
47 .unwrap_or((1, 1));
48 diags.push(Diagnostic {
49 rule: "Ambiguous/SubqueryInGroupBy",
50 message: "Subquery in GROUP BY is non-standard and unsupported by most databases"
51 .to_string(),
52 line,
53 col,
54 });
55 }
56 }
57 }
58 for twj in &select.from {
62 collect_from_table_factor(&twj.relation, ctx, diags);
63 for join in &twj.joins {
64 collect_from_table_factor(&join.relation, ctx, diags);
65 }
66 }
67 }
68 SetExpr::Query(inner) => {
69 collect_from_query(inner, ctx, diags);
70 }
71 SetExpr::SetOperation { left, right, .. } => {
72 collect_from_set_expr(left, ctx, diags);
73 collect_from_set_expr(right, ctx, diags);
74 }
75 _ => {}
76 }
77}
78
79fn collect_from_table_factor(
80 factor: &TableFactor,
81 ctx: &FileContext,
82 diags: &mut Vec<Diagnostic>,
83) {
84 if let TableFactor::Derived { subquery, .. } = factor {
85 collect_from_query(subquery, ctx, diags);
86 }
87}
88
89fn contains_subquery(expr: &Expr) -> bool {
93 match expr {
94 Expr::Subquery(_) | Expr::Exists { .. } | Expr::InSubquery { .. } => true,
95 Expr::BinaryOp { left, right, .. } => {
96 contains_subquery(left) || contains_subquery(right)
97 }
98 Expr::Nested(inner) => contains_subquery(inner),
99 Expr::UnaryOp { expr: inner, .. } => contains_subquery(inner),
100 Expr::Case {
101 operand,
102 conditions,
103 results,
104 else_result,
105 } => {
106 operand.as_deref().is_some_and(contains_subquery)
107 || conditions.iter().any(contains_subquery)
108 || results.iter().any(contains_subquery)
109 || else_result.as_deref().is_some_and(contains_subquery)
110 }
111 _ => false,
112 }
113}
114
115fn find_keyword_position(source: &str, _keyword: &str) -> Option<(usize, usize)> {
119 let bytes = source.as_bytes();
120 let len = bytes.len();
121 let group = b"GROUP";
122 let by = b"BY";
123
124 let mut i = 0;
125 while i < len {
126 if i + group.len() <= len
128 && bytes[i..i + group.len()].eq_ignore_ascii_case(group)
129 && (i == 0 || !is_word_char(bytes[i - 1]))
130 && (i + group.len() >= len || !is_word_char(bytes[i + group.len()]))
131 {
132 let mut j = i + group.len();
134 while j < len && (bytes[j] == b' ' || bytes[j] == b'\t' || bytes[j] == b'\n' || bytes[j] == b'\r') {
135 j += 1;
136 }
137 if j + by.len() <= len
139 && bytes[j..j + by.len()].eq_ignore_ascii_case(by)
140 && (j + by.len() >= len || !is_word_char(bytes[j + by.len()]))
141 {
142 return Some(offset_to_line_col(source, i));
143 }
144 }
145 i += 1;
146 }
147 None
148}
149
150#[inline]
151fn is_word_char(ch: u8) -> bool {
152 ch.is_ascii_alphanumeric() || ch == b'_'
153}
154
155fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
157 let before = &source[..offset];
158 let line = before.chars().filter(|&c| c == '\n').count() + 1;
159 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
160 (line, col)
161}