sqrust_rules/ambiguous/
convert_function.rs1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct ConvertFunction;
4
5const CONVERT_MESSAGE: &str = "CONVERT() argument order varies by dialect \
6(SQL Server: CONVERT(type, value), MySQL: CONVERT(value, type)); \
7use CAST(value AS type) for standard SQL";
8
9impl Rule for ConvertFunction {
10 fn name(&self) -> &'static str {
11 "Ambiguous/ConvertFunction"
12 }
13
14 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
15 find_violations(&ctx.source, self.name())
16 }
17}
18
19fn find_violations(source: &str, rule_name: &'static str) -> Vec<Diagnostic> {
20 let bytes = source.as_bytes();
21 let len = bytes.len();
22
23 if len == 0 {
24 return Vec::new();
25 }
26
27 let skip = build_skip_set(bytes, len);
28 let mut diags = Vec::new();
29
30 scan_for_function(source, bytes, len, &skip, "CONVERT", CONVERT_MESSAGE, rule_name, &mut diags);
31
32 diags
33}
34
35fn scan_for_function(
37 source: &str,
38 bytes: &[u8],
39 len: usize,
40 skip: &[bool],
41 func_name: &str,
42 message: &str,
43 rule_name: &'static str,
44 diags: &mut Vec<Diagnostic>,
45) {
46 let kw = func_name.as_bytes();
47 let kw_len = kw.len();
48 let mut i = 0;
49
50 while i + kw_len <= len {
51 if skip[i] {
52 i += 1;
53 continue;
54 }
55
56 let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
57 if before_ok && bytes[i..i + kw_len].eq_ignore_ascii_case(kw) {
58 let after = i + kw_len;
59 let after_ok = after >= len || !is_word_char(bytes[after]);
61 if after_ok {
62 let mut j = after;
64 while j < len && (bytes[j] == b' ' || bytes[j] == b'\t') {
65 j += 1;
66 }
67 if j < len && bytes[j] == b'(' {
68 let (line, col) = line_col(source, i);
69 diags.push(Diagnostic {
70 rule: rule_name,
71 message: message.to_string(),
72 line,
73 col,
74 });
75 i += kw_len;
76 continue;
77 }
78 }
79 }
80
81 i += 1;
82 }
83}
84
85#[inline]
86fn is_word_char(ch: u8) -> bool {
87 ch.is_ascii_alphanumeric() || ch == b'_'
88}
89
90fn line_col(source: &str, offset: usize) -> (usize, usize) {
91 let before = &source[..offset.min(source.len())];
92 let line = before.chars().filter(|&c| c == '\n').count() + 1;
93 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
94 (line, col)
95}
96
97fn build_skip_set(bytes: &[u8], len: usize) -> Vec<bool> {
100 let mut skip = vec![false; len];
101 let mut i = 0;
102
103 while i < len {
104 if bytes[i] == b'\'' {
106 skip[i] = true;
107 i += 1;
108 while i < len {
109 skip[i] = true;
110 if bytes[i] == b'\'' {
111 if i + 1 < len && bytes[i + 1] == b'\'' {
112 i += 1;
113 skip[i] = true;
114 i += 1;
115 continue;
116 }
117 i += 1;
118 break;
119 }
120 i += 1;
121 }
122 continue;
123 }
124
125 if bytes[i] == b'"' {
127 skip[i] = true;
128 i += 1;
129 while i < len {
130 skip[i] = true;
131 if bytes[i] == b'"' {
132 if i + 1 < len && bytes[i + 1] == b'"' {
133 i += 1;
134 skip[i] = true;
135 i += 1;
136 continue;
137 }
138 i += 1;
139 break;
140 }
141 i += 1;
142 }
143 continue;
144 }
145
146 if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
148 skip[i] = true;
149 skip[i + 1] = true;
150 i += 2;
151 while i < len {
152 skip[i] = true;
153 if i + 1 < len && bytes[i] == b'*' && bytes[i + 1] == b'/' {
154 skip[i + 1] = true;
155 i += 2;
156 break;
157 }
158 i += 1;
159 }
160 continue;
161 }
162
163 if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
165 skip[i] = true;
166 skip[i + 1] = true;
167 i += 2;
168 while i < len && bytes[i] != b'\n' {
169 skip[i] = true;
170 i += 1;
171 }
172 continue;
173 }
174
175 i += 1;
176 }
177
178 skip
179}