Skip to main content

sqrust_rules/lint/
non_deterministic_function.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3    Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Query, Select, SelectItem, SetExpr,
4    Statement, TableFactor,
5};
6
7pub struct NonDeterministicFunction;
8
9/// Function names (uppercased) that are considered non-deterministic.
10const NON_DETERMINISTIC: &[&str] = &[
11    "RAND",
12    "RANDOM",
13    "UUID",
14    "NEWID",
15    "NEWSEQUENTIALID",
16    "GEN_RANDOM_UUID",
17];
18
19impl Rule for NonDeterministicFunction {
20    fn name(&self) -> &'static str {
21        "Lint/NonDeterministicFunction"
22    }
23
24    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
25        // Skip files that failed to parse — AST may be incomplete.
26        if !ctx.parse_errors.is_empty() {
27            return Vec::new();
28        }
29
30        let mut diags = Vec::new();
31        // Per-function-name occurrence counter so that find_occurrence can
32        // locate the correct position in source when a name appears multiple times.
33        let mut occurrence_counters: std::collections::HashMap<String, usize> =
34            std::collections::HashMap::new();
35
36        for stmt in &ctx.statements {
37            walk_statement(stmt, &ctx.source, &mut occurrence_counters, &mut diags);
38        }
39
40        diags
41    }
42}
43
44// ── Statement walker ──────────────────────────────────────────────────────────
45
46fn walk_statement(
47    stmt: &Statement,
48    source: &str,
49    counters: &mut std::collections::HashMap<String, usize>,
50    diags: &mut Vec<Diagnostic>,
51) {
52    match stmt {
53        Statement::Query(q) => walk_query(q, source, counters, diags),
54        Statement::Insert(insert) => {
55            if let Some(src) = &insert.source {
56                walk_query(src, source, counters, diags);
57            }
58        }
59        Statement::Update {
60            selection, assignments, ..
61        } => {
62            if let Some(expr) = selection {
63                walk_expr(expr, source, counters, diags);
64            }
65            for assign in assignments {
66                walk_expr(&assign.value, source, counters, diags);
67            }
68        }
69        Statement::Delete(delete) => {
70            if let Some(expr) = &delete.selection {
71                walk_expr(expr, source, counters, diags);
72            }
73        }
74        _ => {}
75    }
76}
77
78// ── Query / SET-expression walker ─────────────────────────────────────────────
79
80fn walk_query(
81    query: &Query,
82    source: &str,
83    counters: &mut std::collections::HashMap<String, usize>,
84    diags: &mut Vec<Diagnostic>,
85) {
86    if let Some(with) = &query.with {
87        for cte in &with.cte_tables {
88            walk_query(&cte.query, source, counters, diags);
89        }
90    }
91    walk_set_expr(&query.body, source, counters, diags);
92}
93
94fn walk_set_expr(
95    expr: &SetExpr,
96    source: &str,
97    counters: &mut std::collections::HashMap<String, usize>,
98    diags: &mut Vec<Diagnostic>,
99) {
100    match expr {
101        SetExpr::Select(sel) => walk_select(sel, source, counters, diags),
102        SetExpr::Query(inner) => walk_query(inner, source, counters, diags),
103        SetExpr::SetOperation { left, right, .. } => {
104            walk_set_expr(left, source, counters, diags);
105            walk_set_expr(right, source, counters, diags);
106        }
107        _ => {}
108    }
109}
110
111fn walk_select(
112    sel: &Select,
113    source: &str,
114    counters: &mut std::collections::HashMap<String, usize>,
115    diags: &mut Vec<Diagnostic>,
116) {
117    // Projection
118    for item in &sel.projection {
119        match item {
120            SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
121                walk_expr(e, source, counters, diags);
122            }
123            _ => {}
124        }
125    }
126
127    // WHERE clause
128    if let Some(expr) = &sel.selection {
129        walk_expr(expr, source, counters, diags);
130    }
131
132    // HAVING clause
133    if let Some(expr) = &sel.having {
134        walk_expr(expr, source, counters, diags);
135    }
136
137    // GROUP BY expressions
138    if let sqlparser::ast::GroupByExpr::Expressions(exprs, _) = &sel.group_by {
139        for e in exprs {
140            walk_expr(e, source, counters, diags);
141        }
142    }
143
144    // Subqueries inside FROM
145    for twj in &sel.from {
146        walk_table_factor(&twj.relation, source, counters, diags);
147        for join in &twj.joins {
148            walk_table_factor(&join.relation, source, counters, diags);
149        }
150    }
151}
152
153fn walk_table_factor(
154    tf: &TableFactor,
155    source: &str,
156    counters: &mut std::collections::HashMap<String, usize>,
157    diags: &mut Vec<Diagnostic>,
158) {
159    if let TableFactor::Derived { subquery, .. } = tf {
160        walk_query(subquery, source, counters, diags);
161    }
162}
163
164// ── Expression walker ─────────────────────────────────────────────────────────
165
166fn walk_expr(
167    expr: &Expr,
168    source: &str,
169    counters: &mut std::collections::HashMap<String, usize>,
170    diags: &mut Vec<Diagnostic>,
171) {
172    match expr {
173        Expr::Function(func) => {
174            // Extract the last ident in the function name (handles schema-qualified calls)
175            let name_upper = func
176                .name
177                .0
178                .last()
179                .map(|ident| ident.value.to_uppercase())
180                .unwrap_or_default();
181
182            if NON_DETERMINISTIC.contains(&name_upper.as_str()) {
183                let occ = counters.entry(name_upper.clone()).or_insert(0);
184                let occurrence = *occ;
185                *occ += 1;
186
187                let offset = find_occurrence(source, &name_upper, occurrence);
188                let (line, col) = offset_to_line_col(source, offset);
189
190                diags.push(Diagnostic {
191                    rule: "Lint/NonDeterministicFunction",
192                    message: format!(
193                        "Non-deterministic function {}() produces different results on each call",
194                        name_upper
195                    ),
196                    line,
197                    col,
198                });
199            }
200
201            // Recurse into function arguments
202            if let FunctionArguments::List(list) = &func.args {
203                for arg in &list.args {
204                    let inner_expr = match arg {
205                        FunctionArg::Named { arg, .. }
206                        | FunctionArg::Unnamed(arg)
207                        | FunctionArg::ExprNamed { arg, .. } => match arg {
208                            FunctionArgExpr::Expr(e) => Some(e),
209                            _ => None,
210                        },
211                    };
212                    if let Some(e) = inner_expr {
213                        walk_expr(e, source, counters, diags);
214                    }
215                }
216            }
217        }
218
219        Expr::BinaryOp { left, right, .. } => {
220            walk_expr(left, source, counters, diags);
221            walk_expr(right, source, counters, diags);
222        }
223
224        Expr::UnaryOp { expr: inner, .. } => {
225            walk_expr(inner, source, counters, diags);
226        }
227
228        Expr::Nested(inner) => walk_expr(inner, source, counters, diags),
229
230        Expr::Case {
231            operand,
232            conditions,
233            results,
234            else_result,
235        } => {
236            if let Some(op) = operand {
237                walk_expr(op, source, counters, diags);
238            }
239            for c in conditions {
240                walk_expr(c, source, counters, diags);
241            }
242            for r in results {
243                walk_expr(r, source, counters, diags);
244            }
245            if let Some(e) = else_result {
246                walk_expr(e, source, counters, diags);
247            }
248        }
249
250        Expr::InList {
251            expr: inner,
252            list,
253            ..
254        } => {
255            walk_expr(inner, source, counters, diags);
256            for e in list {
257                walk_expr(e, source, counters, diags);
258            }
259        }
260
261        Expr::InSubquery {
262            expr: inner,
263            subquery,
264            ..
265        } => {
266            walk_expr(inner, source, counters, diags);
267            walk_query(subquery, source, counters, diags);
268        }
269
270        Expr::Exists { subquery, .. } => walk_query(subquery, source, counters, diags),
271
272        Expr::Subquery(q) => walk_query(q, source, counters, diags),
273
274        Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
275            walk_expr(inner, source, counters, diags);
276        }
277
278        Expr::Between {
279            expr: inner,
280            low,
281            high,
282            ..
283        } => {
284            walk_expr(inner, source, counters, diags);
285            walk_expr(low, source, counters, diags);
286            walk_expr(high, source, counters, diags);
287        }
288
289        Expr::Like {
290            expr: inner,
291            pattern,
292            ..
293        }
294        | Expr::ILike {
295            expr: inner,
296            pattern,
297            ..
298        } => {
299            walk_expr(inner, source, counters, diags);
300            walk_expr(pattern, source, counters, diags);
301        }
302
303        // Literals, identifiers, wildcards, etc. — nothing to recurse into
304        _ => {}
305    }
306}
307
308// ── Source-text helpers ───────────────────────────────────────────────────────
309
310/// Finds the byte offset of the `nth` (0-indexed) whole-word, case-insensitive
311/// occurrence of `name` (uppercased) in `source`. Returns 0 if not found.
312fn find_occurrence(source: &str, name: &str, nth: usize) -> usize {
313    let bytes = source.as_bytes();
314    let name_bytes: Vec<u8> = name.bytes().map(|b| b.to_ascii_uppercase()).collect();
315    let name_len = name_bytes.len();
316    let src_len = bytes.len();
317
318    let mut count = 0usize;
319    let mut i = 0usize;
320
321    while i + name_len <= src_len {
322        let before_ok = i == 0
323            || {
324                let b = bytes[i - 1];
325                !b.is_ascii_alphanumeric() && b != b'_'
326            };
327
328        if before_ok {
329            let matches = bytes[i..i + name_len]
330                .iter()
331                .zip(name_bytes.iter())
332                .all(|(&a, &b)| a.to_ascii_uppercase() == b);
333
334            if matches {
335                let after = i + name_len;
336                let after_ok = after >= src_len
337                    || {
338                        let b = bytes[after];
339                        !b.is_ascii_alphanumeric() && b != b'_'
340                    };
341
342                if after_ok {
343                    if count == nth {
344                        return i;
345                    }
346                    count += 1;
347                }
348            }
349        }
350
351        i += 1;
352    }
353
354    0
355}
356
357/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
358fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
359    let before = &source[..offset];
360    let line = before.chars().filter(|&c| c == '\n').count() + 1;
361    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
362    (line, col)
363}