sqrust_rules/convention/
no_isnull_function.rs1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct NoIsnullFunction;
4
5const MESSAGE: &str =
6 "ISNULL() is SQL Server/Sybase-specific; use COALESCE() for two-argument null replacement or the IS NULL predicate for null checks";
7
8impl Rule for NoIsnullFunction {
9 fn name(&self) -> &'static str {
10 "Convention/NoIsnullFunction"
11 }
12
13 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
14 find_violations(&ctx.source, self.name())
15 }
16}
17
18fn build_skip_set(source: &str) -> std::collections::HashSet<usize> {
19 let mut skip = std::collections::HashSet::new();
20 let bytes = source.as_bytes();
21 let len = bytes.len();
22 let mut i = 0;
23 while i < len {
24 if bytes[i] == b'\'' {
25 i += 1;
26 while i < len {
27 if bytes[i] == b'\'' {
28 if i + 1 < len && bytes[i + 1] == b'\'' {
29 skip.insert(i);
30 i += 2;
31 } else {
32 i += 1;
33 break;
34 }
35 } else {
36 skip.insert(i);
37 i += 1;
38 }
39 }
40 } else if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
41 while i < len && bytes[i] != b'\n' {
42 skip.insert(i);
43 i += 1;
44 }
45 } else {
46 i += 1;
47 }
48 }
49 skip
50}
51
52fn line_col(source: &str, offset: usize) -> (usize, usize) {
54 let before = &source[..offset];
55 let line = before.chars().filter(|&c| c == '\n').count() + 1;
56 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
57 (line, col)
58}
59
60#[inline]
61fn is_word_char(ch: u8) -> bool {
62 ch.is_ascii_alphanumeric() || ch == b'_'
63}
64
65fn find_violations(source: &str, rule_name: &'static str) -> Vec<Diagnostic> {
66 let bytes = source.as_bytes();
67 let len = bytes.len();
68
69 if len == 0 {
70 return Vec::new();
71 }
72
73 let skip = build_skip_set(source);
74 let mut diags = Vec::new();
75
76 let keyword = b"ISNULL";
78 let kw_len = keyword.len();
79
80 let mut i = 0;
81 while i + kw_len <= len {
82 if skip.contains(&i) {
84 i += 1;
85 continue;
86 }
87
88 let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
90 if !before_ok {
91 i += 1;
92 continue;
93 }
94
95 if !bytes[i..i + kw_len].eq_ignore_ascii_case(keyword) {
97 i += 1;
98 continue;
99 }
100
101 let all_code = (0..kw_len).all(|k| !skip.contains(&(i + k)));
103 if !all_code {
104 i += 1;
105 continue;
106 }
107
108 let kw_end = i + kw_len;
109
110 if kw_end >= len || bytes[kw_end] != b'(' {
112 i += 1;
113 continue;
114 }
115
116 let (line, col) = line_col(source, i);
117 diags.push(Diagnostic {
118 rule: rule_name,
119 message: MESSAGE.to_string(),
120 line,
121 col,
122 });
123
124 i = kw_end + 1;
125 }
126
127 diags
128}