Skip to main content

sqrust_rules/lint/
column_alias_in_where.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, With};
3use std::collections::HashSet;
4
5pub struct ColumnAliasInWhere;
6
7impl Rule for ColumnAliasInWhere {
8    fn name(&self) -> &'static str {
9        "Lint/ColumnAliasInWhere"
10    }
11
12    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
13        if !ctx.parse_errors.is_empty() {
14            return Vec::new();
15        }
16        let mut diags = Vec::new();
17        for stmt in &ctx.statements {
18            check_stmt(stmt, &ctx.source, "Lint/ColumnAliasInWhere", &mut diags);
19        }
20        diags
21    }
22}
23
24fn check_stmt(stmt: &Statement, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
25    if let Statement::Query(q) = stmt {
26        check_query(q, src, rule, diags);
27    }
28}
29
30fn check_query(q: &Query, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
31    if let Some(With { cte_tables, .. }) = &q.with {
32        for cte in cte_tables {
33            check_query(&cte.query, src, rule, diags);
34        }
35    }
36    check_set_expr(&q.body, src, rule, diags);
37}
38
39fn check_set_expr(body: &SetExpr, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
40    match body {
41        SetExpr::Select(s) => check_select(s, src, rule, diags),
42        SetExpr::SetOperation { left, right, .. } => {
43            check_set_expr(left, src, rule, diags);
44            check_set_expr(right, src, rule, diags);
45        }
46        SetExpr::Query(q) => check_query(q, src, rule, diags),
47        _ => {}
48    }
49}
50
51fn check_select(sel: &Select, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
52    // Collect SELECT aliases
53    let mut aliases: HashSet<String> = HashSet::new();
54    for item in &sel.projection {
55        if let SelectItem::ExprWithAlias { alias, .. } = item {
56            aliases.insert(alias.value.to_lowercase());
57        }
58    }
59
60    if aliases.is_empty() {
61        return;
62    }
63
64    // Walk WHERE for identifiers matching aliases
65    if let Some(where_expr) = &sel.selection {
66        let start_offset = find_where_offset(src);
67        find_alias_refs(where_expr, &aliases, src, rule, diags, start_offset);
68    }
69}
70
71fn find_where_offset(src: &str) -> usize {
72    let bytes = src.as_bytes();
73    let kw = b"WHERE";
74    let mut i = 0;
75    while i + 5 <= bytes.len() {
76        if bytes[i..i + 5].eq_ignore_ascii_case(kw) {
77            let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
78            let after_ok = i + 5 >= bytes.len() || !is_word_char(bytes[i + 5]);
79            if before_ok && after_ok {
80                return i;
81            }
82        }
83        i += 1;
84    }
85    0
86}
87
88fn find_alias_refs(
89    expr: &Expr,
90    aliases: &HashSet<String>,
91    src: &str,
92    rule: &'static str,
93    diags: &mut Vec<Diagnostic>,
94    start_offset: usize,
95) {
96    match expr {
97        Expr::Identifier(ident) => {
98            let lower = ident.value.to_lowercase();
99            if aliases.contains(&lower) {
100                if let Some(off) = find_word_in_source(src, &ident.value, start_offset) {
101                    let (line, col) = offset_to_line_col(src, off);
102                    diags.push(Diagnostic {
103                        rule,
104                        message: format!(
105                            "Column alias '{}' is used in WHERE clause; aliases are not available in WHERE (evaluated before SELECT)",
106                            ident.value
107                        ),
108                        line,
109                        col,
110                    });
111                }
112            }
113        }
114        Expr::BinaryOp { left, right, .. } => {
115            find_alias_refs(left, aliases, src, rule, diags, start_offset);
116            find_alias_refs(right, aliases, src, rule, diags, start_offset);
117        }
118        Expr::UnaryOp { expr, .. } | Expr::Nested(expr) => {
119            find_alias_refs(expr, aliases, src, rule, diags, start_offset);
120        }
121        Expr::Between { expr, low, high, .. } => {
122            find_alias_refs(expr, aliases, src, rule, diags, start_offset);
123            find_alias_refs(low, aliases, src, rule, diags, start_offset);
124            find_alias_refs(high, aliases, src, rule, diags, start_offset);
125        }
126        Expr::InList { expr, list, .. } => {
127            find_alias_refs(expr, aliases, src, rule, diags, start_offset);
128            for e in list {
129                find_alias_refs(e, aliases, src, rule, diags, start_offset);
130            }
131        }
132        Expr::IsNull(e) | Expr::IsNotNull(e) => {
133            find_alias_refs(e, aliases, src, rule, diags, start_offset);
134        }
135        Expr::Like { expr, pattern, .. } | Expr::ILike { expr, pattern, .. } => {
136            find_alias_refs(expr, aliases, src, rule, diags, start_offset);
137            find_alias_refs(pattern, aliases, src, rule, diags, start_offset);
138        }
139        _ => {}
140    }
141}
142
143fn find_word_in_source(src: &str, word: &str, start: usize) -> Option<usize> {
144    let bytes = src.as_bytes();
145    let wbytes = word.as_bytes();
146    let wlen = wbytes.len();
147    if wlen == 0 {
148        return None;
149    }
150    let mut i = start;
151    while i + wlen <= bytes.len() {
152        if bytes[i..i + wlen].eq_ignore_ascii_case(wbytes) {
153            let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
154            let after_ok = i + wlen >= bytes.len() || !is_word_char(bytes[i + wlen]);
155            if before_ok && after_ok {
156                return Some(i);
157            }
158        }
159        i += 1;
160    }
161    None
162}
163
164fn is_word_char(b: u8) -> bool {
165    b.is_ascii_alphanumeric() || b == b'_'
166}
167
168fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
169    let before = &source[..offset.min(source.len())];
170    let line = before.chars().filter(|&c| c == '\n').count() + 1;
171    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
172    (line, col)
173}