Skip to main content

sqrust_rules/ambiguous/
convert_function.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct ConvertFunction;
4
5const CONVERT_MESSAGE: &str = "CONVERT() argument order varies by dialect \
6(SQL Server: CONVERT(type, value), MySQL: CONVERT(value, type)); \
7use CAST(value AS type) for standard SQL";
8
9impl Rule for ConvertFunction {
10    fn name(&self) -> &'static str {
11        "Ambiguous/ConvertFunction"
12    }
13
14    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
15        find_violations(&ctx.source, self.name())
16    }
17}
18
19fn find_violations(source: &str, rule_name: &'static str) -> Vec<Diagnostic> {
20    let bytes = source.as_bytes();
21    let len = bytes.len();
22
23    if len == 0 {
24        return Vec::new();
25    }
26
27    let skip = build_skip_set(bytes, len);
28    let mut diags = Vec::new();
29
30    scan_for_function(source, bytes, len, &skip, "CONVERT", CONVERT_MESSAGE, rule_name, &mut diags);
31
32    diags
33}
34
35/// Scan for `func_name(` (case-insensitive) with a word boundary before.
36fn scan_for_function(
37    source: &str,
38    bytes: &[u8],
39    len: usize,
40    skip: &[bool],
41    func_name: &str,
42    message: &str,
43    rule_name: &'static str,
44    diags: &mut Vec<Diagnostic>,
45) {
46    let kw = func_name.as_bytes();
47    let kw_len = kw.len();
48    let mut i = 0;
49
50    while i + kw_len <= len {
51        if skip[i] {
52            i += 1;
53            continue;
54        }
55
56        let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
57        if before_ok && bytes[i..i + kw_len].eq_ignore_ascii_case(kw) {
58            let after = i + kw_len;
59            // Word boundary after: next char must not be a word char
60            let after_ok = after >= len || !is_word_char(bytes[after]);
61            if after_ok {
62                // Skip optional whitespace then check for '('
63                let mut j = after;
64                while j < len && (bytes[j] == b' ' || bytes[j] == b'\t') {
65                    j += 1;
66                }
67                if j < len && bytes[j] == b'(' {
68                    let (line, col) = line_col(source, i);
69                    diags.push(Diagnostic {
70                        rule: rule_name,
71                        message: message.to_string(),
72                        line,
73                        col,
74                    });
75                    i += kw_len;
76                    continue;
77                }
78            }
79        }
80
81        i += 1;
82    }
83}
84
85#[inline]
86fn is_word_char(ch: u8) -> bool {
87    ch.is_ascii_alphanumeric() || ch == b'_'
88}
89
90fn line_col(source: &str, offset: usize) -> (usize, usize) {
91    let before = &source[..offset.min(source.len())];
92    let line = before.chars().filter(|&c| c == '\n').count() + 1;
93    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
94    (line, col)
95}
96
97/// Build a boolean skip-set: `skip[i] == true` means byte `i` is inside a
98/// single-quoted string, double-quoted identifier, block comment, or line comment.
99fn build_skip_set(bytes: &[u8], len: usize) -> Vec<bool> {
100    let mut skip = vec![false; len];
101    let mut i = 0;
102
103    while i < len {
104        // Single-quoted string: '...' with '' escape.
105        if bytes[i] == b'\'' {
106            skip[i] = true;
107            i += 1;
108            while i < len {
109                skip[i] = true;
110                if bytes[i] == b'\'' {
111                    if i + 1 < len && bytes[i + 1] == b'\'' {
112                        i += 1;
113                        skip[i] = true;
114                        i += 1;
115                        continue;
116                    }
117                    i += 1;
118                    break;
119                }
120                i += 1;
121            }
122            continue;
123        }
124
125        // Double-quoted identifier: "..." with "" escape.
126        if bytes[i] == b'"' {
127            skip[i] = true;
128            i += 1;
129            while i < len {
130                skip[i] = true;
131                if bytes[i] == b'"' {
132                    if i + 1 < len && bytes[i + 1] == b'"' {
133                        i += 1;
134                        skip[i] = true;
135                        i += 1;
136                        continue;
137                    }
138                    i += 1;
139                    break;
140                }
141                i += 1;
142            }
143            continue;
144        }
145
146        // Block comment: /* ... */
147        if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
148            skip[i] = true;
149            skip[i + 1] = true;
150            i += 2;
151            while i < len {
152                skip[i] = true;
153                if i + 1 < len && bytes[i] == b'*' && bytes[i + 1] == b'/' {
154                    skip[i + 1] = true;
155                    i += 2;
156                    break;
157                }
158                i += 1;
159            }
160            continue;
161        }
162
163        // Line comment: -- to end of line.
164        if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
165            skip[i] = true;
166            skip[i + 1] = true;
167            i += 2;
168            while i < len && bytes[i] != b'\n' {
169                skip[i] = true;
170                i += 1;
171            }
172            continue;
173        }
174
175        i += 1;
176    }
177
178    skip
179}