Skip to main content

sqrust_rules/ambiguous/
subquery_in_group_by.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, GroupByExpr, Query, SetExpr, Statement, TableFactor};
3
4pub struct SubqueryInGroupBy;
5
6impl Rule for SubqueryInGroupBy {
7    fn name(&self) -> &'static str {
8        "Ambiguous/SubqueryInGroupBy"
9    }
10
11    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12        if !ctx.parse_errors.is_empty() {
13            return Vec::new();
14        }
15
16        let mut diags = Vec::new();
17        for stmt in &ctx.statements {
18            collect_from_statement(stmt, ctx, &mut diags);
19        }
20        diags
21    }
22}
23
24fn collect_from_statement(stmt: &Statement, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
25    if let Statement::Query(query) = stmt {
26        collect_from_query(query, ctx, diags);
27    }
28}
29
30fn collect_from_query(query: &Query, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
31    if let Some(with) = &query.with {
32        for cte in &with.cte_tables {
33            collect_from_query(&cte.query, ctx, diags);
34        }
35    }
36    collect_from_set_expr(&query.body, ctx, diags);
37}
38
39fn collect_from_set_expr(expr: &SetExpr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
40    match expr {
41        SetExpr::Select(select) => {
42            // Check GROUP BY expressions for subqueries.
43            if let GroupByExpr::Expressions(exprs, _) = &select.group_by {
44                for group_expr in exprs {
45                    if contains_subquery(group_expr) {
46                        let (line, col) = find_keyword_position(&ctx.source, "GROUP BY")
47                            .unwrap_or((1, 1));
48                        diags.push(Diagnostic {
49                            rule: "Ambiguous/SubqueryInGroupBy",
50                            message: "Subquery in GROUP BY is non-standard and unsupported by most databases"
51                                .to_string(),
52                            line,
53                            col,
54                        });
55                    }
56                }
57            }
58            // GroupByExpr::All(_) => no flag — GROUP BY ALL is not a subquery.
59
60            // Recurse into FROM subqueries so we catch violations in CTEs/derived tables.
61            for twj in &select.from {
62                collect_from_table_factor(&twj.relation, ctx, diags);
63                for join in &twj.joins {
64                    collect_from_table_factor(&join.relation, ctx, diags);
65                }
66            }
67        }
68        SetExpr::Query(inner) => {
69            collect_from_query(inner, ctx, diags);
70        }
71        SetExpr::SetOperation { left, right, .. } => {
72            collect_from_set_expr(left, ctx, diags);
73            collect_from_set_expr(right, ctx, diags);
74        }
75        _ => {}
76    }
77}
78
79fn collect_from_table_factor(
80    factor: &TableFactor,
81    ctx: &FileContext,
82    diags: &mut Vec<Diagnostic>,
83) {
84    if let TableFactor::Derived { subquery, .. } = factor {
85        collect_from_query(subquery, ctx, diags);
86    }
87}
88
89/// Returns `true` if `expr` is or contains a subquery (`Subquery`, `InSubquery`,
90/// or `Exists`). Recurses into `BinaryOp` and `Nested` to catch subqueries
91/// embedded inside larger expressions.
92fn contains_subquery(expr: &Expr) -> bool {
93    match expr {
94        Expr::Subquery(_) | Expr::Exists { .. } | Expr::InSubquery { .. } => true,
95        Expr::BinaryOp { left, right, .. } => {
96            contains_subquery(left) || contains_subquery(right)
97        }
98        Expr::Nested(inner) => contains_subquery(inner),
99        Expr::UnaryOp { expr: inner, .. } => contains_subquery(inner),
100        Expr::Case {
101            operand,
102            conditions,
103            results,
104            else_result,
105        } => {
106            operand.as_deref().is_some_and(contains_subquery)
107                || conditions.iter().any(contains_subquery)
108                || results.iter().any(contains_subquery)
109                || else_result.as_deref().is_some_and(contains_subquery)
110        }
111        _ => false,
112    }
113}
114
115/// Finds the first occurrence of the two-word phrase `GROUP BY` (case-insensitive,
116/// with any whitespace between) in `source` and returns `Some((line, col))`.
117/// Falls back to `None` if not found.
118fn find_keyword_position(source: &str, _keyword: &str) -> Option<(usize, usize)> {
119    let bytes = source.as_bytes();
120    let len = bytes.len();
121    let group = b"GROUP";
122    let by = b"BY";
123
124    let mut i = 0;
125    while i < len {
126        // Try to match GROUP at a word boundary.
127        if i + group.len() <= len
128            && bytes[i..i + group.len()].eq_ignore_ascii_case(group)
129            && (i == 0 || !is_word_char(bytes[i - 1]))
130            && (i + group.len() >= len || !is_word_char(bytes[i + group.len()]))
131        {
132            // Skip whitespace after GROUP.
133            let mut j = i + group.len();
134            while j < len && (bytes[j] == b' ' || bytes[j] == b'\t' || bytes[j] == b'\n' || bytes[j] == b'\r') {
135                j += 1;
136            }
137            // Try to match BY at a word boundary.
138            if j + by.len() <= len
139                && bytes[j..j + by.len()].eq_ignore_ascii_case(by)
140                && (j + by.len() >= len || !is_word_char(bytes[j + by.len()]))
141            {
142                return Some(offset_to_line_col(source, i));
143            }
144        }
145        i += 1;
146    }
147    None
148}
149
150#[inline]
151fn is_word_char(ch: u8) -> bool {
152    ch.is_ascii_alphanumeric() || ch == b'_'
153}
154
155/// Converts a byte offset to 1-indexed (line, col).
156fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
157    let before = &source[..offset];
158    let line = before.chars().filter(|&c| c == '\n').count() + 1;
159    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
160    (line, col)
161}