Skip to main content

sqrust_rules/convention/
no_values_function.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct NoValuesFunction;
4
5impl Rule for NoValuesFunction {
6    fn name(&self) -> &'static str {
7        "Convention/NoValuesFunction"
8    }
9
10    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
11        find_violations(&ctx.source, self.name())
12    }
13}
14
15fn build_skip_set(source: &str) -> std::collections::HashSet<usize> {
16    let mut skip = std::collections::HashSet::new();
17    let bytes = source.as_bytes();
18    let len = bytes.len();
19    let mut i = 0;
20    while i < len {
21        if bytes[i] == b'\'' {
22            i += 1;
23            while i < len {
24                if bytes[i] == b'\'' {
25                    if i + 1 < len && bytes[i + 1] == b'\'' {
26                        skip.insert(i);
27                        i += 2;
28                    } else {
29                        i += 1;
30                        break;
31                    }
32                } else {
33                    skip.insert(i);
34                    i += 1;
35                }
36            }
37        } else if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
38            while i < len && bytes[i] != b'\n' {
39                skip.insert(i);
40                i += 1;
41            }
42        } else {
43            i += 1;
44        }
45    }
46    skip
47}
48
49/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
50fn line_col(source: &str, offset: usize) -> (usize, usize) {
51    let before = &source[..offset];
52    let line = before.chars().filter(|&c| c == '\n').count() + 1;
53    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
54    (line, col)
55}
56
57#[inline]
58fn is_word_char(ch: u8) -> bool {
59    ch.is_ascii_alphanumeric() || ch == b'_'
60}
61
62/// Check if the byte slice starting at `pos` (within `bytes`) contains `keyword`
63/// (case-insensitive) as a complete word. Returns true if found.
64fn contains_word_ci(bytes: &[u8], pos: usize, end: usize, keyword: &[u8]) -> bool {
65    let kw_len = keyword.len();
66    if end < kw_len {
67        return false;
68    }
69    let mut i = pos;
70    while i + kw_len <= end {
71        let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
72        if before_ok && bytes[i..i + kw_len].eq_ignore_ascii_case(keyword) {
73            let after = i + kw_len;
74            let after_ok = after >= end || !is_word_char(bytes[after]);
75            if after_ok {
76                return true;
77            }
78        }
79        i += 1;
80    }
81    false
82}
83
84fn find_violations(source: &str, rule_name: &'static str) -> Vec<Diagnostic> {
85    let bytes = source.as_bytes();
86    let len = bytes.len();
87
88    if len == 0 {
89        return Vec::new();
90    }
91
92    let skip = build_skip_set(source);
93    let mut diags = Vec::new();
94
95    // VALUES keyword length
96    let values_kw = b"VALUES";
97    let values_len = values_kw.len();
98
99    let mut i = 0;
100    while i < len {
101        if skip.contains(&i) {
102            i += 1;
103            continue;
104        }
105
106        // Try to match VALUES at position i with word boundary before
107        if i + values_len > len {
108            break;
109        }
110
111        let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
112        if !before_ok {
113            i += 1;
114            continue;
115        }
116
117        if !bytes[i..i + values_len].eq_ignore_ascii_case(values_kw) {
118            i += 1;
119            continue;
120        }
121
122        // Ensure all VALUES chars are code (not in string/comment)
123        let all_code = (0..values_len).all(|k| !skip.contains(&(i + k)));
124        if !all_code {
125            i += 1;
126            continue;
127        }
128
129        let values_end = i + values_len;
130
131        // Word boundary after VALUES must be `(` for it to be a function call
132        if values_end >= len || bytes[values_end] != b'(' {
133            i += 1;
134            continue;
135        }
136
137        // Found `VALUES(`. Now determine if this is a function call (not INSERT clause).
138        // Look back up to 300 characters for expression-context keywords.
139        let window_start = if i >= 300 { i - 300 } else { 0 };
140        let context_slice = &bytes[window_start..i];
141
142        // Expression context keywords that indicate VALUES() function usage:
143        // ON DUPLICATE KEY UPDATE (most common case)
144        // SET (UPDATE ... SET col = VALUES(col))
145        // AND, OR, THEN, ELSE, WHERE (general expression context)
146        let is_expression_context = contains_word_ci(context_slice, 0, context_slice.len(), b"UPDATE")
147            || contains_word_ci(context_slice, 0, context_slice.len(), b"SET");
148
149        if is_expression_context {
150            // Additional check: make sure there's an INSERT in the context which means
151            // this is INSERT...VALUES clause, not the VALUES() function.
152            // If INSERT is present but VALUES( appears AFTER ON DUPLICATE KEY UPDATE,
153            // then it IS the function.
154            // We detect this by checking if ON DUPLICATE KEY UPDATE appears in context.
155            let has_on_duplicate = contains_word_ci(context_slice, 0, context_slice.len(), b"DUPLICATE");
156            let has_insert = contains_word_ci(context_slice, 0, context_slice.len(), b"INSERT");
157
158            // If INSERT is present and DUPLICATE is not, this might be INSERT SET (MySQL extension)
159            // or UPDATE SET. Flag it either way if UPDATE/SET is present.
160            let _ = has_insert; // captured for clarity; we flag based on expression context
161            let _ = has_on_duplicate;
162
163            let (line, col) = line_col(source, i);
164            diags.push(Diagnostic {
165                rule: rule_name,
166                message: "VALUES() function is MySQL-specific (used in ON DUPLICATE KEY UPDATE) — not supported in other databases".to_string(),
167                line,
168                col,
169            });
170        }
171
172        i = values_end + 1;
173    }
174
175    diags
176}