Skip to main content

sqrust_rules/lint/
set_variable_statement.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use std::collections::HashSet;
3
4pub struct SetVariableStatement;
5
6impl Rule for SetVariableStatement {
7    fn name(&self) -> &'static str {
8        "Lint/SetVariableStatement"
9    }
10
11    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12        let source = &ctx.source;
13        let skip = build_skip_set(source);
14        let mut diags = Vec::new();
15
16        let lower = source.to_lowercase();
17        let bytes = lower.as_bytes();
18        let src_bytes = source.as_bytes();
19        let len = bytes.len();
20        let keyword = b"set";
21        let kw_len = keyword.len();
22
23        let mut i = 0;
24        while i + kw_len <= len {
25            if !skip.contains(&i) && bytes[i..i + kw_len] == *keyword {
26                let before_ok = i == 0
27                    || {
28                        let b = bytes[i - 1];
29                        !b.is_ascii_alphanumeric() && b != b'_'
30                    };
31                let after_pos = i + kw_len;
32                let after_ok = after_pos >= len
33                    || {
34                        let b = bytes[after_pos];
35                        !b.is_ascii_alphanumeric() && b != b'_'
36                    };
37
38                if before_ok && after_ok {
39                    // Skip whitespace after SET keyword
40                    let mut j = after_pos;
41                    while j < len
42                        && (src_bytes[j] == b' '
43                            || src_bytes[j] == b'\t'
44                            || src_bytes[j] == b'\r'
45                            || src_bytes[j] == b'\n')
46                    {
47                        j += 1;
48                    }
49                    // Check if next non-whitespace character is '@'
50                    if j < len && src_bytes[j] == b'@' {
51                        let (line, col) = offset_to_line_col(source, i);
52                        diags.push(Diagnostic {
53                            rule: self.name(),
54                            message: "SET @variable is a dialect-specific variable assignment \
55                                      (MySQL/SQL Server); not supported in standard SQL or \
56                                      analytical databases"
57                                .to_string(),
58                            line,
59                            col,
60                        });
61                    }
62                }
63            }
64            i += 1;
65        }
66
67        diags
68    }
69}
70
71fn build_skip_set(source: &str) -> HashSet<usize> {
72    let mut skip = HashSet::new();
73    let bytes = source.as_bytes();
74    let len = bytes.len();
75    let mut i = 0;
76    while i < len {
77        if bytes[i] == b'\'' {
78            i += 1;
79            while i < len {
80                if bytes[i] == b'\'' {
81                    if i + 1 < len && bytes[i + 1] == b'\'' {
82                        skip.insert(i);
83                        i += 2;
84                    } else {
85                        i += 1;
86                        break;
87                    }
88                } else {
89                    skip.insert(i);
90                    i += 1;
91                }
92            }
93        } else if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
94            while i < len && bytes[i] != b'\n' {
95                skip.insert(i);
96                i += 1;
97            }
98        } else {
99            i += 1;
100        }
101    }
102    skip
103}
104
105fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
106    let before = &source[..offset];
107    let line = before.chars().filter(|&c| c == '\n').count() + 1;
108    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
109    (line, col)
110}