sqrust_rules/convention/
nvl_function.rs1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct NvlFunction;
4
5const MESSAGE_NVL: &str =
6 "NVL() is Oracle-specific; use COALESCE() for standard SQL";
7
8const MESSAGE_NVL2: &str =
9 "NVL2() is Oracle-specific; use CASE WHEN col IS NOT NULL THEN ... ELSE ... END instead";
10
11impl Rule for NvlFunction {
12 fn name(&self) -> &'static str {
13 "Convention/NvlFunction"
14 }
15
16 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
17 find_violations(&ctx.source, self.name())
18 }
19}
20
21fn build_skip_set(source: &str) -> std::collections::HashSet<usize> {
22 let mut skip = std::collections::HashSet::new();
23 let bytes = source.as_bytes();
24 let len = bytes.len();
25 let mut i = 0;
26 while i < len {
27 if bytes[i] == b'\'' {
28 i += 1;
29 while i < len {
30 if bytes[i] == b'\'' {
31 if i + 1 < len && bytes[i + 1] == b'\'' {
32 skip.insert(i);
33 i += 2;
34 } else {
35 i += 1;
36 break;
37 }
38 } else {
39 skip.insert(i);
40 i += 1;
41 }
42 }
43 } else if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
44 while i < len && bytes[i] != b'\n' {
45 skip.insert(i);
46 i += 1;
47 }
48 } else {
49 i += 1;
50 }
51 }
52 skip
53}
54
55fn line_col(source: &str, offset: usize) -> (usize, usize) {
57 let before = &source[..offset];
58 let line = before.chars().filter(|&c| c == '\n').count() + 1;
59 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
60 (line, col)
61}
62
63#[inline]
64fn is_word_char(ch: u8) -> bool {
65 ch.is_ascii_alphanumeric() || ch == b'_'
66}
67
68fn find_violations(source: &str, rule_name: &'static str) -> Vec<Diagnostic> {
69 let bytes = source.as_bytes();
70 let len = bytes.len();
71
72 if len == 0 {
73 return Vec::new();
74 }
75
76 let skip = build_skip_set(source);
77 let mut diags = Vec::new();
78
79 let nvl2 = b"NVL2";
82 let nvl = b"NVL";
83
84 let mut i = 0;
85 while i < len {
86 if skip.contains(&i) {
87 i += 1;
88 continue;
89 }
90
91 let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
93 if !before_ok {
94 i += 1;
95 continue;
96 }
97
98 if i + nvl2.len() <= len
100 && bytes[i..i + nvl2.len()].eq_ignore_ascii_case(nvl2)
101 {
102 let all_code = (0..nvl2.len()).all(|k| !skip.contains(&(i + k)));
104 if all_code {
105 let kw_end = i + nvl2.len();
106 let after_ok = kw_end < len
108 && bytes[kw_end] == b'('
109 && (kw_end + 1 >= len || !is_word_char(bytes[kw_end]));
110 if kw_end < len && bytes[kw_end] == b'(' {
112 let (line, col) = line_col(source, i);
113 diags.push(Diagnostic {
114 rule: rule_name,
115 message: MESSAGE_NVL2.to_string(),
116 line,
117 col,
118 });
119 i = kw_end + 1;
120 let _ = after_ok;
121 continue;
122 }
123 }
124 }
125
126 if i + nvl.len() <= len
128 && bytes[i..i + nvl.len()].eq_ignore_ascii_case(nvl)
129 {
130 let all_code = (0..nvl.len()).all(|k| !skip.contains(&(i + k)));
132 if all_code {
133 let kw_end = i + nvl.len();
134 if kw_end < len && bytes[kw_end] == b'(' {
136 let (line, col) = line_col(source, i);
137 diags.push(Diagnostic {
138 rule: rule_name,
139 message: MESSAGE_NVL.to_string(),
140 line,
141 col,
142 });
143 i = kw_end + 1;
144 continue;
145 }
146 }
147 }
148
149 i += 1;
150 }
151
152 diags
153}