sqrust_rules/lint/
create_trigger_statement.rs1use sqrust_core::{Diagnostic, FileContext, Rule};
2use std::collections::HashSet;
3
4pub struct CreateTriggerStatement;
5
6impl Rule for CreateTriggerStatement {
7 fn name(&self) -> &'static str {
8 "Lint/CreateTriggerStatement"
9 }
10
11 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12 let source = &ctx.source;
13 let skip = build_skip_set(source);
14 let lower = source.to_lowercase();
15 let bytes = lower.as_bytes();
16 let len = bytes.len();
17 let mut diags = Vec::new();
18 let mut i = 0;
19
20 while i < len {
21 if skip.contains(&i) {
22 i += 1;
23 continue;
24 }
25
26 if !lower[i..].starts_with("create") {
27 i += 1;
28 continue;
29 }
30
31 let create_start = i;
32 let create_end = i + 6; let before_ok = create_start == 0 || {
36 let b = bytes[create_start - 1];
37 !b.is_ascii_alphanumeric() && b != b'_'
38 };
39
40 let after_ok = create_end >= len || {
42 let b = bytes[create_end];
43 !b.is_ascii_alphanumeric() && b != b'_'
44 };
45
46 if !before_ok || !after_ok {
47 i += 1;
48 continue;
49 }
50
51 let mut j = create_end;
53 while j < len
54 && (bytes[j] == b' '
55 || bytes[j] == b'\t'
56 || bytes[j] == b'\r'
57 || bytes[j] == b'\n')
58 && !skip.contains(&j)
59 {
60 j += 1;
61 }
62
63 let stmt_end = find_stmt_end(&lower, &skip, j);
65 let stmt_slice = &lower[j..stmt_end];
66
67 if contains_word_boundary_keyword(stmt_slice, "trigger") {
68 let (line, col) = offset_to_line_col(source, create_start);
69 diags.push(Diagnostic {
70 rule: self.name(),
71 message: "CREATE TRIGGER statements should not appear in SQL files"
72 .to_string(),
73 line,
74 col,
75 });
76
77 i = stmt_end + 1;
78 continue;
79 }
80
81 i += 1;
82 }
83
84 diags
85 }
86}
87
88fn find_stmt_end(lower: &str, skip: &HashSet<usize>, from: usize) -> usize {
89 let bytes = lower.as_bytes();
90 let len = bytes.len();
91 let mut k = from;
92 while k < len {
93 if !skip.contains(&k) && bytes[k] == b';' {
94 return k;
95 }
96 k += 1;
97 }
98 len
99}
100
101fn contains_word_boundary_keyword(text: &str, keyword: &str) -> bool {
102 let kw_len = keyword.len();
103 let bytes = text.as_bytes();
104 let text_len = bytes.len();
105 let mut search_from = 0;
106
107 while search_from < text_len {
108 let Some(rel) = text[search_from..].find(keyword) else {
109 break;
110 };
111 let abs = search_from + rel;
112
113 let before_ok = abs == 0 || {
114 let b = bytes[abs - 1];
115 !b.is_ascii_alphanumeric() && b != b'_'
116 };
117 let after = abs + kw_len;
118 let after_ok = after >= text_len || {
119 let b = bytes[after];
120 !b.is_ascii_alphanumeric() && b != b'_'
121 };
122
123 if before_ok && after_ok {
124 return true;
125 }
126 search_from = abs + 1;
127 }
128
129 false
130}
131
132fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
133 let before = &source[..offset];
134 let line = before.chars().filter(|&c| c == '\n').count() + 1;
135 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
136 (line, col)
137}
138
139fn build_skip_set(source: &str) -> HashSet<usize> {
140 let mut skip = HashSet::new();
141 let bytes = source.as_bytes();
142 let len = bytes.len();
143 let mut i = 0;
144
145 while i < len {
146 if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
148 skip.insert(i);
149 skip.insert(i + 1);
150 i += 2;
151 while i < len && bytes[i] != b'\n' {
152 skip.insert(i);
153 i += 1;
154 }
155 continue;
156 }
157
158 if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
160 skip.insert(i);
161 skip.insert(i + 1);
162 i += 2;
163 while i < len {
164 if i + 1 < len && bytes[i] == b'*' && bytes[i + 1] == b'/' {
165 skip.insert(i);
166 skip.insert(i + 1);
167 i += 2;
168 break;
169 }
170 skip.insert(i);
171 i += 1;
172 }
173 continue;
174 }
175
176 if bytes[i] == b'\'' {
178 skip.insert(i);
179 i += 1;
180 while i < len {
181 if bytes[i] == b'\'' {
182 skip.insert(i);
183 i += 1;
184 if i < len && bytes[i] == b'\'' {
185 skip.insert(i);
186 i += 1;
187 continue;
188 }
189 break;
190 }
191 skip.insert(i);
192 i += 1;
193 }
194 continue;
195 }
196
197 if bytes[i] == b'"' {
199 skip.insert(i);
200 i += 1;
201 while i < len && bytes[i] != b'"' {
202 skip.insert(i);
203 i += 1;
204 }
205 if i < len {
206 skip.insert(i);
207 i += 1;
208 }
209 continue;
210 }
211
212 if bytes[i] == b'`' {
214 skip.insert(i);
215 i += 1;
216 while i < len && bytes[i] != b'`' {
217 skip.insert(i);
218 i += 1;
219 }
220 if i < len {
221 skip.insert(i);
222 i += 1;
223 }
224 continue;
225 }
226
227 i += 1;
228 }
229
230 skip
231}