Skip to main content

sqrust_rules/ambiguous/
distinct_with_window_function.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3    Expr, FunctionArguments, Query, Select, SelectItem, SetExpr, Statement, TableFactor,
4};
5
6pub struct DistinctWithWindowFunction;
7
8impl Rule for DistinctWithWindowFunction {
9    fn name(&self) -> &'static str {
10        "Ambiguous/DistinctWithWindowFunction"
11    }
12
13    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
14        if !ctx.parse_errors.is_empty() {
15            return Vec::new();
16        }
17
18        let mut diags = Vec::new();
19        for stmt in &ctx.statements {
20            collect_from_statement(stmt, ctx, &mut diags);
21        }
22        diags
23    }
24}
25
26fn collect_from_statement(stmt: &Statement, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
27    if let Statement::Query(query) = stmt {
28        collect_from_query(query, ctx, diags);
29    }
30}
31
32fn collect_from_query(query: &Query, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
33    if let Some(with) = &query.with {
34        for cte in &with.cte_tables {
35            collect_from_query(&cte.query, ctx, diags);
36        }
37    }
38    collect_from_set_expr(&query.body, ctx, diags);
39}
40
41fn collect_from_set_expr(set_expr: &SetExpr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
42    match set_expr {
43        SetExpr::Select(select) => {
44            collect_from_select(select, ctx, diags);
45        }
46        SetExpr::Query(inner) => {
47            collect_from_query(inner, ctx, diags);
48        }
49        SetExpr::SetOperation { left, right, .. } => {
50            collect_from_set_expr(left, ctx, diags);
51            collect_from_set_expr(right, ctx, diags);
52        }
53        _ => {}
54    }
55}
56
57fn collect_from_select(select: &Select, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
58    // Only flag if this SELECT has DISTINCT.
59    if select.distinct.is_some() {
60        // Check if any projection item contains a window function.
61        let has_window_fn = select.projection.iter().any(|item| {
62            let expr = match item {
63                SelectItem::UnnamedExpr(e) => e,
64                SelectItem::ExprWithAlias { expr: e, .. } => e,
65                _ => return false,
66            };
67            expr_contains_window_fn(expr)
68        });
69
70        if has_window_fn {
71            let (line, col) = find_keyword_pos(&ctx.source, "SELECT");
72            diags.push(Diagnostic {
73                rule: "Ambiguous/DistinctWithWindowFunction",
74                message: "DISTINCT with window functions may produce unexpected results — window functions run before DISTINCT is applied".to_string(),
75                line,
76                col,
77            });
78        }
79    }
80
81    // Always recurse into FROM subqueries.
82    for twj in &select.from {
83        collect_from_table_factor(&twj.relation, ctx, diags);
84        for join in &twj.joins {
85            collect_from_table_factor(&join.relation, ctx, diags);
86        }
87    }
88
89    // Recurse into WHERE subqueries.
90    if let Some(selection) = &select.selection {
91        collect_subqueries_from_expr(selection, ctx, diags);
92    }
93
94    // Recurse into HAVING subqueries.
95    if let Some(having) = &select.having {
96        collect_subqueries_from_expr(having, ctx, diags);
97    }
98}
99
100fn collect_from_table_factor(
101    factor: &TableFactor,
102    ctx: &FileContext,
103    diags: &mut Vec<Diagnostic>,
104) {
105    if let TableFactor::Derived { subquery, .. } = factor {
106        collect_from_query(subquery, ctx, diags);
107    }
108}
109
110/// Returns `true` if `expr` or any sub-expression is a window function call
111/// (i.e., a `Function` node with `over: Some(_)`).
112fn expr_contains_window_fn(expr: &Expr) -> bool {
113    match expr {
114        Expr::Function(func) => {
115            if func.over.is_some() {
116                return true;
117            }
118            // Recurse into function arguments.
119            if let FunctionArguments::List(arg_list) = &func.args {
120                use sqlparser::ast::{FunctionArg, FunctionArgExpr};
121                for arg in &arg_list.args {
122                    let expr_arg = match arg {
123                        FunctionArg::Named { arg, .. }
124                        | FunctionArg::ExprNamed { arg, .. }
125                        | FunctionArg::Unnamed(arg) => arg,
126                    };
127                    if let FunctionArgExpr::Expr(e) = expr_arg {
128                        if expr_contains_window_fn(e) {
129                            return true;
130                        }
131                    }
132                }
133            }
134            false
135        }
136        Expr::BinaryOp { left, right, .. } => {
137            expr_contains_window_fn(left) || expr_contains_window_fn(right)
138        }
139        Expr::UnaryOp { expr: inner, .. } => expr_contains_window_fn(inner),
140        Expr::Nested(inner) => expr_contains_window_fn(inner),
141        Expr::Case {
142            operand,
143            conditions,
144            results,
145            else_result,
146        } => {
147            operand.as_deref().map_or(false, expr_contains_window_fn)
148                || conditions.iter().any(expr_contains_window_fn)
149                || results.iter().any(expr_contains_window_fn)
150                || else_result
151                    .as_deref()
152                    .map_or(false, expr_contains_window_fn)
153        }
154        _ => false,
155    }
156}
157
158/// Recurse into subqueries nested inside WHERE / HAVING expressions.
159fn collect_subqueries_from_expr(expr: &Expr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
160    match expr {
161        Expr::Subquery(q) | Expr::InSubquery { subquery: q, .. } | Expr::Exists { subquery: q, .. } => {
162            collect_from_query(q, ctx, diags);
163        }
164        Expr::BinaryOp { left, right, .. } => {
165            collect_subqueries_from_expr(left, ctx, diags);
166            collect_subqueries_from_expr(right, ctx, diags);
167        }
168        Expr::UnaryOp { expr: inner, .. } => {
169            collect_subqueries_from_expr(inner, ctx, diags);
170        }
171        Expr::Nested(inner) => {
172            collect_subqueries_from_expr(inner, ctx, diags);
173        }
174        _ => {}
175    }
176}
177
178/// Finds the first occurrence of `keyword` (case-insensitive, word-boundary)
179/// in `source` and returns a 1-indexed (line, col). Falls back to (1, 1).
180fn find_keyword_pos(source: &str, keyword: &str) -> (usize, usize) {
181    let upper = source.to_uppercase();
182    let kw_upper = keyword.to_uppercase();
183    let bytes = upper.as_bytes();
184    let kw_bytes = kw_upper.as_bytes();
185    let kw_len = kw_bytes.len();
186
187    let mut i = 0;
188    while i + kw_len <= bytes.len() {
189        if bytes[i..i + kw_len] == *kw_bytes {
190            let before_ok =
191                i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
192            let after = i + kw_len;
193            let after_ok = after >= bytes.len()
194                || (!bytes[after].is_ascii_alphanumeric() && bytes[after] != b'_');
195            if before_ok && after_ok {
196                return offset_to_line_col(source, i);
197            }
198        }
199        i += 1;
200    }
201    (1, 1)
202}
203
204/// Converts a byte offset to 1-indexed (line, col).
205fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
206    let before = &source[..offset];
207    let line = before.chars().filter(|&c| c == '\n').count() + 1;
208    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
209    (line, col)
210}