sqrust_rules/lint/
insert_overwrite.rs1use sqrust_core::{Diagnostic, FileContext, Rule};
2use std::collections::HashSet;
3
4pub struct InsertOverwrite;
5
6impl Rule for InsertOverwrite {
7 fn name(&self) -> &'static str {
8 "Lint/InsertOverwrite"
9 }
10
11 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12 let source = &ctx.source;
13 let skip = build_skip_set(source);
14 let mut diags = Vec::new();
15
16 for (line, col) in find_two_word_keyword(source, "insert", "overwrite", &skip) {
17 diags.push(Diagnostic {
18 rule: self.name(),
19 message: "INSERT OVERWRITE is Hive/Spark SQL-specific syntax; use standard INSERT INTO or CREATE TABLE AS SELECT"
20 .to_string(),
21 line,
22 col,
23 });
24 }
25
26 diags.sort_by_key(|d| (d.line, d.col));
27 diags
28 }
29}
30
31fn build_skip_set(source: &str) -> HashSet<usize> {
33 let mut skip = HashSet::new();
34 let bytes = source.as_bytes();
35 let len = bytes.len();
36 let mut i = 0;
37 while i < len {
38 if bytes[i] == b'\'' {
39 i += 1;
40 while i < len {
41 if bytes[i] == b'\'' {
42 if i + 1 < len && bytes[i + 1] == b'\'' {
43 skip.insert(i);
44 i += 2;
45 } else {
46 i += 1;
47 break;
48 }
49 } else {
50 skip.insert(i);
51 i += 1;
52 }
53 }
54 } else if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
55 while i < len && bytes[i] != b'\n' {
56 skip.insert(i);
57 i += 1;
58 }
59 } else {
60 i += 1;
61 }
62 }
63 skip
64}
65
66fn find_two_word_keyword(
70 source: &str,
71 word1: &str,
72 word2: &str,
73 skip: &HashSet<usize>,
74) -> Vec<(usize, usize)> {
75 let lower = source.to_lowercase();
76 let w1_len = word1.len();
77 let w2_len = word2.len();
78 let bytes = lower.as_bytes();
79 let len = bytes.len();
80 let mut results = Vec::new();
81 let mut i = 0;
82
83 while i + w1_len <= len {
84 if skip.contains(&i) {
85 i += 1;
86 continue;
87 }
88
89 if !lower[i..].starts_with(word1) {
90 i += 1;
91 continue;
92 }
93
94 let before_ok = i == 0
96 || {
97 let b = bytes[i - 1];
98 !b.is_ascii_alphanumeric() && b != b'_'
99 };
100
101 let after_w1 = i + w1_len;
103 let after_w1_ok = after_w1 < len && {
104 let b = bytes[after_w1];
105 !b.is_ascii_alphanumeric() && b != b'_'
106 };
107
108 if before_ok && after_w1_ok {
109 let mut j = after_w1;
111 while j < len && (bytes[j] == b' ' || bytes[j] == b'\t' || bytes[j] == b'\n' || bytes[j] == b'\r') {
112 j += 1;
113 }
114
115 if j + w2_len <= len && !skip.contains(&j) && lower[j..].starts_with(word2) {
116 let after_w2 = j + w2_len;
118 let after_w2_ok = after_w2 >= len
119 || {
120 let b = bytes[after_w2];
121 !b.is_ascii_alphanumeric() && b != b'_'
122 };
123
124 if after_w2_ok {
125 let (line, col) = offset_to_line_col(source, i);
126 results.push((line, col));
127 }
128 }
129 }
130
131 i += 1;
132 }
133
134 results
135}
136
137fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
138 let before = &source[..offset];
139 let line = before.chars().filter(|&c| c == '\n').count() + 1;
140 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
141 (line, col)
142}