sqrust_rules/lint/
set_variable_statement.rs1use sqrust_core::{Diagnostic, FileContext, Rule};
2use std::collections::HashSet;
3
4pub struct SetVariableStatement;
5
6impl Rule for SetVariableStatement {
7 fn name(&self) -> &'static str {
8 "Lint/SetVariableStatement"
9 }
10
11 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12 let source = &ctx.source;
13 let skip = build_skip_set(source);
14 let mut diags = Vec::new();
15
16 let lower = source.to_lowercase();
17 let bytes = lower.as_bytes();
18 let src_bytes = source.as_bytes();
19 let len = bytes.len();
20 let keyword = b"set";
21 let kw_len = keyword.len();
22
23 let mut i = 0;
24 while i + kw_len <= len {
25 if !skip.contains(&i) && bytes[i..i + kw_len] == *keyword {
26 let before_ok = i == 0
27 || {
28 let b = bytes[i - 1];
29 !b.is_ascii_alphanumeric() && b != b'_'
30 };
31 let after_pos = i + kw_len;
32 let after_ok = after_pos >= len
33 || {
34 let b = bytes[after_pos];
35 !b.is_ascii_alphanumeric() && b != b'_'
36 };
37
38 if before_ok && after_ok {
39 let mut j = after_pos;
41 while j < len
42 && (src_bytes[j] == b' '
43 || src_bytes[j] == b'\t'
44 || src_bytes[j] == b'\r'
45 || src_bytes[j] == b'\n')
46 {
47 j += 1;
48 }
49 if j < len && src_bytes[j] == b'@' {
51 let (line, col) = offset_to_line_col(source, i);
52 diags.push(Diagnostic {
53 rule: self.name(),
54 message: "SET @variable is a dialect-specific variable assignment \
55 (MySQL/SQL Server); not supported in standard SQL or \
56 analytical databases"
57 .to_string(),
58 line,
59 col,
60 });
61 }
62 }
63 }
64 i += 1;
65 }
66
67 diags
68 }
69}
70
71fn build_skip_set(source: &str) -> HashSet<usize> {
72 let mut skip = HashSet::new();
73 let bytes = source.as_bytes();
74 let len = bytes.len();
75 let mut i = 0;
76 while i < len {
77 if bytes[i] == b'\'' {
78 i += 1;
79 while i < len {
80 if bytes[i] == b'\'' {
81 if i + 1 < len && bytes[i + 1] == b'\'' {
82 skip.insert(i);
83 i += 2;
84 } else {
85 i += 1;
86 break;
87 }
88 } else {
89 skip.insert(i);
90 i += 1;
91 }
92 }
93 } else if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
94 while i < len && bytes[i] != b'\n' {
95 skip.insert(i);
96 i += 1;
97 }
98 } else {
99 i += 1;
100 }
101 }
102 skip
103}
104
105fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
106 let before = &source[..offset];
107 let line = before.chars().filter(|&c| c == '\n').count() + 1;
108 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
109 (line, col)
110}