Skip to main content

sqrust_rules/lint/
create_trigger_statement.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use std::collections::HashSet;
3
4pub struct CreateTriggerStatement;
5
6impl Rule for CreateTriggerStatement {
7    fn name(&self) -> &'static str {
8        "Lint/CreateTriggerStatement"
9    }
10
11    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12        let source = &ctx.source;
13        let skip = build_skip_set(source);
14        let lower = source.to_lowercase();
15        let bytes = lower.as_bytes();
16        let len = bytes.len();
17        let mut diags = Vec::new();
18        let mut i = 0;
19
20        while i < len {
21            if skip.contains(&i) {
22                i += 1;
23                continue;
24            }
25
26            if !lower[i..].starts_with("create") {
27                i += 1;
28                continue;
29            }
30
31            let create_start = i;
32            let create_end = i + 6; // "create".len()
33
34            // Word-boundary check before "create".
35            let before_ok = create_start == 0 || {
36                let b = bytes[create_start - 1];
37                !b.is_ascii_alphanumeric() && b != b'_'
38            };
39
40            // Word-boundary check after "create".
41            let after_ok = create_end >= len || {
42                let b = bytes[create_end];
43                !b.is_ascii_alphanumeric() && b != b'_'
44            };
45
46            if !before_ok || !after_ok {
47                i += 1;
48                continue;
49            }
50
51            // Skip whitespace between CREATE and next keyword.
52            let mut j = create_end;
53            while j < len
54                && (bytes[j] == b' '
55                    || bytes[j] == b'\t'
56                    || bytes[j] == b'\r'
57                    || bytes[j] == b'\n')
58                && !skip.contains(&j)
59            {
60                j += 1;
61            }
62
63            // Scan the rest of the statement for "trigger" keyword.
64            let stmt_end = find_stmt_end(&lower, &skip, j);
65            let stmt_slice = &lower[j..stmt_end];
66
67            if contains_word_boundary_keyword(stmt_slice, "trigger") {
68                let (line, col) = offset_to_line_col(source, create_start);
69                diags.push(Diagnostic {
70                    rule: self.name(),
71                    message: "CREATE TRIGGER statements should not appear in SQL files"
72                        .to_string(),
73                    line,
74                    col,
75                });
76
77                i = stmt_end + 1;
78                continue;
79            }
80
81            i += 1;
82        }
83
84        diags
85    }
86}
87
88fn find_stmt_end(lower: &str, skip: &HashSet<usize>, from: usize) -> usize {
89    let bytes = lower.as_bytes();
90    let len = bytes.len();
91    let mut k = from;
92    while k < len {
93        if !skip.contains(&k) && bytes[k] == b';' {
94            return k;
95        }
96        k += 1;
97    }
98    len
99}
100
101fn contains_word_boundary_keyword(text: &str, keyword: &str) -> bool {
102    let kw_len = keyword.len();
103    let bytes = text.as_bytes();
104    let text_len = bytes.len();
105    let mut search_from = 0;
106
107    while search_from < text_len {
108        let Some(rel) = text[search_from..].find(keyword) else {
109            break;
110        };
111        let abs = search_from + rel;
112
113        let before_ok = abs == 0 || {
114            let b = bytes[abs - 1];
115            !b.is_ascii_alphanumeric() && b != b'_'
116        };
117        let after = abs + kw_len;
118        let after_ok = after >= text_len || {
119            let b = bytes[after];
120            !b.is_ascii_alphanumeric() && b != b'_'
121        };
122
123        if before_ok && after_ok {
124            return true;
125        }
126        search_from = abs + 1;
127    }
128
129    false
130}
131
132fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
133    let before = &source[..offset];
134    let line = before.chars().filter(|&c| c == '\n').count() + 1;
135    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
136    (line, col)
137}
138
139fn build_skip_set(source: &str) -> HashSet<usize> {
140    let mut skip = HashSet::new();
141    let bytes = source.as_bytes();
142    let len = bytes.len();
143    let mut i = 0;
144
145    while i < len {
146        // Line comment: -- ... end-of-line
147        if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
148            skip.insert(i);
149            skip.insert(i + 1);
150            i += 2;
151            while i < len && bytes[i] != b'\n' {
152                skip.insert(i);
153                i += 1;
154            }
155            continue;
156        }
157
158        // Block comment: /* ... */
159        if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
160            skip.insert(i);
161            skip.insert(i + 1);
162            i += 2;
163            while i < len {
164                if i + 1 < len && bytes[i] == b'*' && bytes[i + 1] == b'/' {
165                    skip.insert(i);
166                    skip.insert(i + 1);
167                    i += 2;
168                    break;
169                }
170                skip.insert(i);
171                i += 1;
172            }
173            continue;
174        }
175
176        // Single-quoted string: '...' with '' as escaped quote
177        if bytes[i] == b'\'' {
178            skip.insert(i);
179            i += 1;
180            while i < len {
181                if bytes[i] == b'\'' {
182                    skip.insert(i);
183                    i += 1;
184                    if i < len && bytes[i] == b'\'' {
185                        skip.insert(i);
186                        i += 1;
187                        continue;
188                    }
189                    break;
190                }
191                skip.insert(i);
192                i += 1;
193            }
194            continue;
195        }
196
197        // Double-quoted identifier: "..."
198        if bytes[i] == b'"' {
199            skip.insert(i);
200            i += 1;
201            while i < len && bytes[i] != b'"' {
202                skip.insert(i);
203                i += 1;
204            }
205            if i < len {
206                skip.insert(i);
207                i += 1;
208            }
209            continue;
210        }
211
212        // Backtick identifier: `...`
213        if bytes[i] == b'`' {
214            skip.insert(i);
215            i += 1;
216            while i < len && bytes[i] != b'`' {
217                skip.insert(i);
218                i += 1;
219            }
220            if i < len {
221                skip.insert(i);
222                i += 1;
223            }
224            continue;
225        }
226
227        i += 1;
228    }
229
230    skip
231}