sqrust_rules/lint/
create_or_replace.rs1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::Statement;
3
4pub struct CreateOrReplace;
5
6impl Rule for CreateOrReplace {
7 fn name(&self) -> &'static str {
8 "Lint/CreateOrReplace"
9 }
10
11 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12 if !ctx.parse_errors.is_empty() {
14 return Vec::new();
15 }
16
17 let mut diags = Vec::new();
18 let source = &ctx.source;
19 let source_upper = source.to_uppercase();
20
21 for stmt in &ctx.statements {
22 let object_type: Option<&'static str> = match stmt {
23 Statement::CreateTable(ct) if ct.or_replace => Some("TABLE"),
24 Statement::CreateView { or_replace, .. } if *or_replace => Some("VIEW"),
25 Statement::CreateFunction(cf) if cf.or_replace => Some("FUNCTION"),
26 _ => None,
27 };
28
29 if let Some(kind) = object_type {
30 let (line, col) = find_keyword_position(source, &source_upper, "CREATE");
31 diags.push(Diagnostic {
32 rule: self.name(),
33 message: format!(
34 "CREATE OR REPLACE {} silently replaces an existing database object",
35 kind
36 ),
37 line,
38 col,
39 });
40 }
41 }
42
43 diags
44 }
45}
46
47fn find_keyword_position(source: &str, source_upper: &str, keyword: &str) -> (usize, usize) {
51 let kw_len = keyword.len();
52 let bytes = source_upper.as_bytes();
53 let text_len = bytes.len();
54
55 let mut search_from = 0usize;
56 while search_from < text_len {
57 let Some(rel) = source_upper[search_from..].find(keyword) else {
58 break;
59 };
60 let abs = search_from + rel;
61
62 let before_ok = abs == 0
63 || {
64 let b = bytes[abs - 1];
65 !b.is_ascii_alphanumeric() && b != b'_'
66 };
67 let after = abs + kw_len;
68 let after_ok = after >= text_len
69 || {
70 let b = bytes[after];
71 !b.is_ascii_alphanumeric() && b != b'_'
72 };
73
74 if before_ok && after_ok {
75 return offset_to_line_col(source, abs);
76 }
77 search_from = abs + 1;
78 }
79
80 (1, 1)
81}
82
83fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
85 let before = &source[..offset];
86 let line = before.chars().filter(|&c| c == '\n').count() + 1;
87 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
88 (line, col)
89}