Skip to main content

sqrust_rules/structure/
window_frame_all_rows.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3    Expr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, OrderByExpr, Query, Select,
4    SelectItem, SetExpr, Statement, TableFactor, WindowFrameBound, WindowFrameUnits, WindowType,
5};
6
7pub struct WindowFrameAllRows;
8
9impl Rule for WindowFrameAllRows {
10    fn name(&self) -> &'static str {
11        "Structure/WindowFrameAllRows"
12    }
13
14    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
15        if !ctx.parse_errors.is_empty() {
16            return Vec::new();
17        }
18
19        let mut diags = Vec::new();
20
21        for stmt in &ctx.statements {
22            if let Statement::Query(query) = stmt {
23                check_query(query, self.name(), ctx, &mut diags);
24            }
25        }
26
27        diags
28    }
29}
30
31// ── AST walking ───────────────────────────────────────────────────────────────
32
33fn check_query(
34    query: &Query,
35    rule: &'static str,
36    ctx: &FileContext,
37    diags: &mut Vec<Diagnostic>,
38) {
39    // Visit CTEs.
40    if let Some(with) = &query.with {
41        for cte in &with.cte_tables {
42            check_query(&cte.query, rule, ctx, diags);
43        }
44    }
45
46    check_set_expr(&query.body, rule, ctx, diags);
47
48    // Check ORDER BY expressions at query level.
49    if let Some(order_by) = &query.order_by {
50        for ob in &order_by.exprs {
51            check_order_by_expr(ob, rule, ctx, diags);
52        }
53    }
54}
55
56fn check_set_expr(
57    expr: &SetExpr,
58    rule: &'static str,
59    ctx: &FileContext,
60    diags: &mut Vec<Diagnostic>,
61) {
62    match expr {
63        SetExpr::Select(sel) => {
64            check_select(sel, rule, ctx, diags);
65        }
66        SetExpr::Query(inner) => {
67            check_query(inner, rule, ctx, diags);
68        }
69        SetExpr::SetOperation { left, right, .. } => {
70            check_set_expr(left, rule, ctx, diags);
71            check_set_expr(right, rule, ctx, diags);
72        }
73        _ => {}
74    }
75}
76
77fn check_select(
78    sel: &Select,
79    rule: &'static str,
80    ctx: &FileContext,
81    diags: &mut Vec<Diagnostic>,
82) {
83    for item in &sel.projection {
84        if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
85            check_expr(e, rule, ctx, diags);
86        }
87    }
88
89    if let Some(selection) = &sel.selection {
90        check_expr(selection, rule, ctx, diags);
91    }
92
93    if let Some(having) = &sel.having {
94        check_expr(having, rule, ctx, diags);
95    }
96
97    for twj in &sel.from {
98        check_table_factor(&twj.relation, rule, ctx, diags);
99        for join in &twj.joins {
100            check_table_factor(&join.relation, rule, ctx, diags);
101        }
102    }
103}
104
105fn check_table_factor(
106    tf: &TableFactor,
107    rule: &'static str,
108    ctx: &FileContext,
109    diags: &mut Vec<Diagnostic>,
110) {
111    if let TableFactor::Derived { subquery, .. } = tf {
112        check_query(subquery, rule, ctx, diags);
113    }
114}
115
116fn check_order_by_expr(
117    ob: &OrderByExpr,
118    rule: &'static str,
119    ctx: &FileContext,
120    diags: &mut Vec<Diagnostic>,
121) {
122    check_expr(&ob.expr, rule, ctx, diags);
123}
124
125fn check_expr(
126    expr: &Expr,
127    rule: &'static str,
128    ctx: &FileContext,
129    diags: &mut Vec<Diagnostic>,
130) {
131    match expr {
132        Expr::Function(func) => {
133            check_function(func, rule, ctx, diags);
134        }
135        Expr::BinaryOp { left, right, .. } => {
136            check_expr(left, rule, ctx, diags);
137            check_expr(right, rule, ctx, diags);
138        }
139        Expr::UnaryOp { expr, .. } => {
140            check_expr(expr, rule, ctx, diags);
141        }
142        Expr::Nested(e) => {
143            check_expr(e, rule, ctx, diags);
144        }
145        Expr::IsNull(e) => {
146            check_expr(e, rule, ctx, diags);
147        }
148        Expr::IsNotNull(e) => {
149            check_expr(e, rule, ctx, diags);
150        }
151        Expr::Case {
152            operand,
153            conditions,
154            results,
155            else_result,
156        } => {
157            if let Some(op) = operand {
158                check_expr(op, rule, ctx, diags);
159            }
160            for c in conditions {
161                check_expr(c, rule, ctx, diags);
162            }
163            for r in results {
164                check_expr(r, rule, ctx, diags);
165            }
166            if let Some(el) = else_result {
167                check_expr(el, rule, ctx, diags);
168            }
169        }
170        Expr::Subquery(q) => {
171            check_query(q, rule, ctx, diags);
172        }
173        Expr::InSubquery { subquery, .. } => {
174            check_query(subquery, rule, ctx, diags);
175        }
176        Expr::Exists { subquery, .. } => {
177            check_query(subquery, rule, ctx, diags);
178        }
179        _ => {}
180    }
181}
182
183/// Check a Function node: flag window functions with ROWS BETWEEN UNBOUNDED
184/// PRECEDING AND UNBOUNDED FOLLOWING but no PARTITION BY.
185fn check_function(
186    func: &Function,
187    rule: &'static str,
188    ctx: &FileContext,
189    diags: &mut Vec<Diagnostic>,
190) {
191    if let Some(WindowType::WindowSpec(spec)) = &func.over {
192        if spec.partition_by.is_empty() {
193            if let Some(frame) = &spec.window_frame {
194                if is_rows_unbounded_all(frame) {
195                    let (line, col) = find_over_pos(&ctx.source);
196                    diags.push(Diagnostic {
197                        rule,
198                        message: "Window function with ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING and no PARTITION BY processes the entire table — verify this is intentional".to_string(),
199                        line,
200                        col,
201                    });
202                }
203            }
204        }
205    }
206
207    // Recurse into function arguments.
208    if let FunctionArguments::List(list) = &func.args {
209        for arg in &list.args {
210            let fae = match arg {
211                FunctionArg::Named { arg, .. }
212                | FunctionArg::ExprNamed { arg, .. }
213                | FunctionArg::Unnamed(arg) => arg,
214            };
215            if let FunctionArgExpr::Expr(e) = fae {
216                check_expr(e, rule, ctx, diags);
217            }
218        }
219    }
220}
221
222/// Returns true when the frame is exactly `ROWS BETWEEN UNBOUNDED PRECEDING
223/// AND UNBOUNDED FOLLOWING`.
224fn is_rows_unbounded_all(frame: &sqlparser::ast::WindowFrame) -> bool {
225    if frame.units != WindowFrameUnits::Rows {
226        return false;
227    }
228    let start_ok = matches!(frame.start_bound, WindowFrameBound::Preceding(None));
229    let end_ok = matches!(
230        frame.end_bound,
231        Some(WindowFrameBound::Following(None))
232    );
233    start_ok && end_ok
234}
235
236// ── keyword position helper ───────────────────────────────────────────────────
237
238/// Scan source for the first `OVER` keyword (case-insensitive, word-boundary)
239/// and return its 1-indexed (line, col). Falls back to (1, 1).
240fn find_over_pos(source: &str) -> (usize, usize) {
241    let keyword = "OVER";
242    let upper = source.to_uppercase();
243    let kw_len = keyword.len();
244    let bytes = upper.as_bytes();
245    let len = bytes.len();
246
247    let mut pos = 0;
248    while pos + kw_len <= len {
249        if let Some(rel) = upper[pos..].find(keyword) {
250            let abs = pos + rel;
251
252            let before_ok = abs == 0 || {
253                let b = bytes[abs - 1];
254                !b.is_ascii_alphanumeric() && b != b'_'
255            };
256            let after = abs + kw_len;
257            let after_ok = after >= len || {
258                let b = bytes[after];
259                !b.is_ascii_alphanumeric() && b != b'_'
260            };
261
262            if before_ok && after_ok {
263                return line_col(source, abs);
264            }
265
266            pos = abs + 1;
267        } else {
268            break;
269        }
270    }
271
272    (1, 1)
273}
274
275/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
276fn line_col(source: &str, offset: usize) -> (usize, usize) {
277    let before = &source[..offset];
278    let line = before.chars().filter(|&c| c == '\n').count() + 1;
279    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
280    (line, col)
281}