Skip to main content

sqrust_rules/ambiguous/
select_null_expression.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Query, SelectItem, SetExpr, Statement, TableFactor, Value};
3
4pub struct SelectNullExpression;
5
6impl Rule for SelectNullExpression {
7    fn name(&self) -> &'static str {
8        "Ambiguous/SelectNullExpression"
9    }
10
11    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12        if !ctx.parse_errors.is_empty() {
13            return Vec::new();
14        }
15
16        let mut diags = Vec::new();
17        for stmt in &ctx.statements {
18            if let Statement::Query(query) = stmt {
19                check_query(query, ctx, &mut diags);
20            }
21        }
22        diags
23    }
24}
25
26// ── AST walking ───────────────────────────────────────────────────────────────
27
28fn check_query(query: &Query, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
29    // Visit CTEs.
30    if let Some(with) = &query.with {
31        for cte in &with.cte_tables {
32            check_query(&cte.query, ctx, diags);
33        }
34    }
35
36    check_set_expr(&query.body, ctx, diags);
37}
38
39fn check_set_expr(expr: &SetExpr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
40    match expr {
41        SetExpr::Select(sel) => {
42            // Check projection items for bare NULL (no alias).
43            for item in &sel.projection {
44                if let SelectItem::UnnamedExpr(sqlparser::ast::Expr::Value(Value::Null)) = item {
45                    let (line, col) = find_null_pos(&ctx.source);
46                    diags.push(Diagnostic {
47                        rule: "Ambiguous/SelectNullExpression",
48                        message: "Selecting a literal NULL without an alias; add an alias to clarify intent".to_string(),
49                        line,
50                        col,
51                    });
52                }
53            }
54
55            // Recurse into subqueries in FROM / JOIN clauses.
56            for twj in &sel.from {
57                check_table_factor(&twj.relation, ctx, diags);
58                for join in &twj.joins {
59                    check_table_factor(&join.relation, ctx, diags);
60                }
61            }
62        }
63        SetExpr::Query(inner) => check_query(inner, ctx, diags),
64        SetExpr::SetOperation { left, right, .. } => {
65            check_set_expr(left, ctx, diags);
66            check_set_expr(right, ctx, diags);
67        }
68        _ => {}
69    }
70}
71
72fn check_table_factor(tf: &TableFactor, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
73    if let TableFactor::Derived { subquery, .. } = tf {
74        check_query(subquery, ctx, diags);
75    }
76}
77
78// ── position helper ───────────────────────────────────────────────────────────
79
80/// Scan source for the first word-boundary `NULL` keyword (case-insensitive)
81/// and return its 1-indexed (line, col). Falls back to (1, 1).
82fn find_null_pos(source: &str) -> (usize, usize) {
83    let keyword = "NULL";
84    let upper_src = source.to_uppercase();
85    let kw_len = keyword.len();
86    let bytes = upper_src.as_bytes();
87    let len = bytes.len();
88
89    let mut pos = 0;
90    while pos + kw_len <= len {
91        if let Some(rel) = upper_src[pos..].find(keyword) {
92            let abs = pos + rel;
93
94            let before_ok = abs == 0 || {
95                let b = bytes[abs - 1];
96                !b.is_ascii_alphanumeric() && b != b'_'
97            };
98            let after = abs + kw_len;
99            let after_ok = after >= len || {
100                let b = bytes[after];
101                !b.is_ascii_alphanumeric() && b != b'_'
102            };
103
104            if before_ok && after_ok {
105                return line_col(source, abs);
106            }
107
108            pos = abs + 1;
109        } else {
110            break;
111        }
112    }
113
114    (1, 1)
115}
116
117/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
118fn line_col(source: &str, offset: usize) -> (usize, usize) {
119    let before = &source[..offset];
120    let line = before.chars().filter(|&c| c == '\n').count() + 1;
121    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
122    (line, col)
123}