1use std::fmt;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum LintSeverity {
26 Error,
27 Warning,
28}
29
30impl fmt::Display for LintSeverity {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 match self {
33 Self::Error => f.write_str("error"),
34 Self::Warning => f.write_str("warning"),
35 }
36 }
37}
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct LintError {
41 pub line: u32,
42 pub column: u32,
43 pub severity: LintSeverity,
44 pub message: String,
45 pub source: String,
46}
47
48impl fmt::Display for LintError {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 write!(
51 f,
52 "{}:{}:{}: {}: {}",
53 self.source, self.line, self.column, self.severity, self.message
54 )
55 }
56}
57
58pub fn lint_declarative_schema(sql: &str, source: &str) -> Result<(), Vec<LintError>> {
64 let statements = split_top_level_statements(sql, source)?;
65 let mut errors = Vec::new();
66 for stmt in &statements {
67 if let Some(err) = classify(stmt, source) {
68 errors.push(err);
69 }
70 }
71 if errors.iter().any(|e| e.severity == LintSeverity::Error) {
72 return Err(errors);
73 }
74 Ok(())
75}
76
77#[derive(Debug, Clone)]
78struct TopStatement {
79 text: String,
80 start_line: u32,
81 start_column: u32,
82}
83
84enum LexState {
85 Normal,
86 SingleQuote,
87 DollarQuote(String),
88 LineComment,
89 BlockComment(u32),
90}
91
92fn dollar_tag_end(bytes: &[u8], start: usize) -> Option<usize> {
93 debug_assert_eq!(bytes[start], b'$');
94 let mut j = start + 1;
95 while j < bytes.len() {
96 let c = bytes[j];
97 if c == b'$' {
98 return Some(j);
99 }
100 if !(c.is_ascii_alphanumeric() || c == b'_') {
101 return None;
102 }
103 j += 1;
104 }
105 None
106}
107
108fn split_top_level_statements(
109 sql: &str,
110 source: &str,
111) -> Result<Vec<TopStatement>, Vec<LintError>> {
112 let bytes = sql.as_bytes();
113 let mut statements: Vec<TopStatement> = Vec::new();
114 let mut state = LexState::Normal;
115 let mut start = 0usize;
116 let mut i = 0usize;
117 let mut start_line: u32 = 1;
118 let mut start_col: u32 = 1;
119 let mut line: u32 = 1;
120 let mut col: u32 = 1;
121 let mut stmt_line: u32 = 1;
122 let mut stmt_col: u32 = 1;
123 let mut has_content = false;
124
125 while i < bytes.len() {
126 let b = bytes[i];
127 match &mut state {
128 LexState::Normal => match b {
129 b'\'' => {
130 if !has_content {
131 stmt_line = line;
132 stmt_col = col;
133 }
134 has_content = true;
135 state = LexState::SingleQuote;
136 advance(&mut i, &mut line, &mut col, b);
137 },
138 b'-' if bytes.get(i + 1) == Some(&b'-') => {
139 state = LexState::LineComment;
140 advance(&mut i, &mut line, &mut col, b);
141 advance(&mut i, &mut line, &mut col, b'-');
142 },
143 b'/' if bytes.get(i + 1) == Some(&b'*') => {
144 state = LexState::BlockComment(1);
145 advance(&mut i, &mut line, &mut col, b);
146 advance(&mut i, &mut line, &mut col, b'*');
147 },
148 b'$' => {
149 if !has_content {
150 stmt_line = line;
151 stmt_col = col;
152 }
153 has_content = true;
154 if let Some(tag_end) = dollar_tag_end(bytes, i) {
155 let tag = sql[i..=tag_end].to_string();
156 let advance_by = tag_end - i + 1;
157 for _ in 0..advance_by {
158 advance(&mut i, &mut line, &mut col, b'$');
159 }
160 state = LexState::DollarQuote(tag);
161 } else {
162 advance(&mut i, &mut line, &mut col, b);
163 }
164 },
165 b';' => {
166 if has_content {
167 let text = sql[start..i].trim().to_string();
168 if !text.is_empty() {
169 statements.push(TopStatement {
170 text,
171 start_line: stmt_line,
172 start_column: stmt_col,
173 });
174 }
175 }
176 has_content = false;
177 advance(&mut i, &mut line, &mut col, b);
178 start = i;
179 start_line = line;
180 start_col = col;
181 },
182 _ => {
183 if !b.is_ascii_whitespace() {
184 if !has_content {
185 stmt_line = line;
186 stmt_col = col;
187 }
188 has_content = true;
189 }
190 advance(&mut i, &mut line, &mut col, b);
191 },
192 },
193 LexState::SingleQuote => {
194 if b == b'\'' {
195 if bytes.get(i + 1) == Some(&b'\'') {
196 advance(&mut i, &mut line, &mut col, b);
197 advance(&mut i, &mut line, &mut col, b'\'');
198 } else {
199 state = LexState::Normal;
200 advance(&mut i, &mut line, &mut col, b);
201 }
202 } else {
203 advance(&mut i, &mut line, &mut col, b);
204 }
205 },
206 LexState::DollarQuote(tag) => {
207 let tag_bytes = tag.as_bytes();
208 if i + tag_bytes.len() <= bytes.len() && &bytes[i..i + tag_bytes.len()] == tag_bytes
209 {
210 for _ in 0..tag_bytes.len() {
211 advance(&mut i, &mut line, &mut col, b'$');
212 }
213 state = LexState::Normal;
214 } else {
215 advance(&mut i, &mut line, &mut col, b);
216 }
217 },
218 LexState::LineComment => {
219 if b == b'\n' {
220 state = LexState::Normal;
221 }
222 advance(&mut i, &mut line, &mut col, b);
223 },
224 LexState::BlockComment(depth) => {
225 if b == b'/' && bytes.get(i + 1) == Some(&b'*') {
226 *depth += 1;
227 advance(&mut i, &mut line, &mut col, b);
228 advance(&mut i, &mut line, &mut col, b'*');
229 } else if b == b'*' && bytes.get(i + 1) == Some(&b'/') {
230 *depth -= 1;
231 let zero = *depth == 0;
232 advance(&mut i, &mut line, &mut col, b);
233 advance(&mut i, &mut line, &mut col, b'/');
234 if zero {
235 state = LexState::Normal;
236 }
237 } else {
238 advance(&mut i, &mut line, &mut col, b);
239 }
240 },
241 }
242 }
243
244 match state {
245 LexState::Normal | LexState::LineComment => {
246 if has_content {
247 let text = sql[start..].trim().to_string();
248 if !text.is_empty() {
249 statements.push(TopStatement {
250 text,
251 start_line: stmt_line,
252 start_column: stmt_col,
253 });
254 }
255 }
256 Ok(statements)
257 },
258 LexState::SingleQuote => Err(vec![LintError {
259 line: start_line,
260 column: start_col,
261 severity: LintSeverity::Error,
262 message: "unterminated string literal".into(),
263 source: source.to_string(),
264 }]),
265 LexState::DollarQuote(tag) => Err(vec![LintError {
266 line: start_line,
267 column: start_col,
268 severity: LintSeverity::Error,
269 message: format!("unterminated dollar-quoted string: {tag}"),
270 source: source.to_string(),
271 }]),
272 LexState::BlockComment(_) => Err(vec![LintError {
273 line: start_line,
274 column: start_col,
275 severity: LintSeverity::Error,
276 message: "unterminated block comment".into(),
277 source: source.to_string(),
278 }]),
279 }
280}
281
282fn advance(i: &mut usize, line: &mut u32, col: &mut u32, b: u8) {
283 *i += 1;
284 if b == b'\n' {
285 *line += 1;
286 *col = 1;
287 } else {
288 *col += 1;
289 }
290}
291
292fn classify(stmt: &TopStatement, source: &str) -> Option<LintError> {
293 let stripped = strip_sql_comments(&stmt.text);
294 let upper = uppercase_keywords(&stripped);
295 let tokens: Vec<&str> = upper.split_whitespace().collect();
296 if tokens.is_empty() {
297 return None;
298 }
299
300 let leading = tokens[0];
301
302 let reject = |reason: &str| LintError {
303 line: stmt.start_line,
304 column: stmt.start_column,
305 severity: LintSeverity::Error,
306 message: format!(
307 "imperative SQL in declarative schema: {reason} — move to \
308 schema/migrations/NNN_<name>.sql"
309 ),
310 source: source.to_string(),
311 };
312
313 match leading {
314 "ALTER" => return Some(reject("ALTER")),
315 "DROP" => return Some(reject("DROP")),
316 "UPDATE" => return Some(reject("UPDATE")),
317 "INSERT" => return Some(reject("INSERT")),
318 "DELETE" => return Some(reject("DELETE")),
319 "TRUNCATE" => return Some(reject("TRUNCATE")),
320 "GRANT" => return Some(reject("GRANT")),
321 "REVOKE" => return Some(reject("REVOKE")),
322 "DO" => return Some(reject("DO $$ block")),
323 _ => {},
324 }
325
326 if leading == "CREATE" {
327 return classify_create(&tokens, stmt, source);
328 }
329
330 if leading == "COMMENT" && tokens.get(1) == Some(&"ON") {
331 return None;
332 }
333
334 if leading == "SELECT" {
335 return Some(LintError {
336 line: stmt.start_line,
337 column: stmt.start_column,
338 severity: LintSeverity::Error,
339 message: "imperative SQL in declarative schema: SELECT — move to \
340 schema/migrations/NNN_<name>.sql"
341 .into(),
342 source: source.to_string(),
343 });
344 }
345
346 None
347}
348
349fn classify_create(tokens: &[&str], stmt: &TopStatement, source: &str) -> Option<LintError> {
350 let mut idx = 1;
351
352 if tokens.get(idx) == Some(&"OR") && tokens.get(idx + 1) == Some(&"REPLACE") {
353 idx += 2;
354 }
355
356 if tokens.get(idx) == Some(&"UNIQUE") {
357 idx += 1;
358 }
359
360 let kind = match tokens.get(idx) {
361 Some(k) => *k,
362 None => return None,
363 };
364 idx += 1;
365
366 let has_if_not_exists = tokens.get(idx) == Some(&"IF")
367 && tokens.get(idx + 1) == Some(&"NOT")
368 && tokens.get(idx + 2) == Some(&"EXISTS");
369
370 match kind {
371 "TABLE" => {
372 if !has_if_not_exists {
373 return Some(LintError {
374 line: stmt.start_line,
375 column: stmt.start_column,
376 severity: LintSeverity::Warning,
377 message: "CREATE TABLE without IF NOT EXISTS — add IF NOT EXISTS for \
378 idempotency"
379 .into(),
380 source: source.to_string(),
381 });
382 }
383 None
384 },
385 "EXTENSION" => {
386 if !has_if_not_exists {
387 return Some(LintError {
388 line: stmt.start_line,
389 column: stmt.start_column,
390 severity: LintSeverity::Warning,
391 message: "CREATE EXTENSION without IF NOT EXISTS".into(),
392 source: source.to_string(),
393 });
394 }
395 None
396 },
397 _ => None,
398 }
399}
400
401fn strip_sql_comments(text: &str) -> String {
402 let bytes = text.as_bytes();
403 let mut out = String::with_capacity(text.len());
404 let mut i = 0;
405 let mut in_single = false;
406 let mut in_dollar: Option<String> = None;
407 while i < bytes.len() {
408 let b = bytes[i];
409 if let Some(tag) = &in_dollar {
410 let tag_b = tag.as_bytes();
411 if i + tag_b.len() <= bytes.len() && &bytes[i..i + tag_b.len()] == tag_b {
412 out.push_str(tag);
413 i += tag_b.len();
414 in_dollar = None;
415 } else {
416 out.push(b as char);
417 i += 1;
418 }
419 continue;
420 }
421 if in_single {
422 out.push(b as char);
423 if b == b'\'' {
424 if bytes.get(i + 1) == Some(&b'\'') {
425 out.push('\'');
426 i += 2;
427 continue;
428 }
429 in_single = false;
430 }
431 i += 1;
432 continue;
433 }
434 if b == b'\'' {
435 in_single = true;
436 out.push('\'');
437 i += 1;
438 continue;
439 }
440 if b == b'$' {
441 if let Some(end) = dollar_tag_end(bytes, i) {
442 let tag = text[i..=end].to_string();
443 out.push_str(&tag);
444 i = end + 1;
445 in_dollar = Some(tag);
446 continue;
447 }
448 }
449 if b == b'-' && bytes.get(i + 1) == Some(&b'-') {
450 while i < bytes.len() && bytes[i] != b'\n' {
451 i += 1;
452 }
453 continue;
454 }
455 if b == b'/' && bytes.get(i + 1) == Some(&b'*') {
456 let mut depth = 1u32;
457 i += 2;
458 while i < bytes.len() && depth > 0 {
459 if bytes[i] == b'/' && bytes.get(i + 1) == Some(&b'*') {
460 depth += 1;
461 i += 2;
462 } else if bytes[i] == b'*' && bytes.get(i + 1) == Some(&b'/') {
463 depth -= 1;
464 i += 2;
465 } else {
466 i += 1;
467 }
468 }
469 continue;
470 }
471 out.push(b as char);
472 i += 1;
473 }
474 out
475}
476
477fn uppercase_keywords(text: &str) -> String {
478 let mut out = String::with_capacity(text.len());
479 let mut in_string = false;
480 let mut in_dollar = false;
481 let bytes = text.as_bytes();
482 let mut i = 0;
483 while i < bytes.len() {
484 let b = bytes[i];
485 if !in_string && !in_dollar && b == b'$' {
486 if let Some(end) = dollar_tag_end(bytes, i) {
487 out.push_str(&text[i..=end]);
488 i = end + 1;
489 in_dollar = true;
490 continue;
491 }
492 }
493 if in_dollar && b == b'$' {
494 if let Some(end) = dollar_tag_end(bytes, i) {
495 out.push_str(&text[i..=end]);
496 i = end + 1;
497 in_dollar = false;
498 continue;
499 }
500 }
501 if in_dollar {
502 out.push(b as char);
503 i += 1;
504 continue;
505 }
506 if b == b'\'' {
507 in_string = !in_string;
508 out.push('\'');
509 i += 1;
510 continue;
511 }
512 if in_string {
513 out.push(b as char);
514 i += 1;
515 continue;
516 }
517 out.push(b.to_ascii_uppercase() as char);
518 i += 1;
519 }
520 out
521}