sqrust_rules/convention/
string_agg_separator.rs1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct StringAggSeparator;
4
5impl Rule for StringAggSeparator {
6 fn name(&self) -> &'static str {
7 "Convention/StringAggSeparator"
8 }
9
10 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
11 find_violations(&ctx.source, self.name())
12 }
13}
14
15fn build_skip_set(source: &str) -> std::collections::HashSet<usize> {
18 let mut skip = std::collections::HashSet::new();
19 let bytes = source.as_bytes();
20 let len = bytes.len();
21 let mut i = 0;
22 while i < len {
23 if bytes[i] == b'\'' {
24 i += 1;
25 while i < len {
26 if bytes[i] == b'\'' {
27 if i + 1 < len && bytes[i + 1] == b'\'' {
28 skip.insert(i);
29 i += 2;
30 } else {
31 i += 1;
32 break;
33 }
34 } else {
35 skip.insert(i);
36 i += 1;
37 }
38 }
39 } else if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
40 while i < len && bytes[i] != b'\n' {
41 skip.insert(i);
42 i += 1;
43 }
44 } else {
45 i += 1;
46 }
47 }
48 skip
49}
50
51fn line_col(source: &str, offset: usize) -> (usize, usize) {
53 let before = &source[..offset];
54 let line = before.chars().filter(|&c| c == '\n').count() + 1;
55 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
56 (line, col)
57}
58
59fn find_violations(source: &str, rule_name: &'static str) -> Vec<Diagnostic> {
60 let bytes = source.as_bytes();
61 let len = bytes.len();
62
63 if len == 0 {
64 return Vec::new();
65 }
66
67 let skip = build_skip_set(source);
68 let mut diags = Vec::new();
69
70 let group_concat_kw = b"GROUP_CONCAT";
71 let group_concat_len = group_concat_kw.len();
72 let listagg_kw = b"LISTAGG";
73 let listagg_len = listagg_kw.len();
74
75 let mut i = 0;
76 while i < len {
77 if skip.contains(&i) {
78 i += 1;
79 continue;
80 }
81
82 if i + group_concat_len <= len {
84 let before_ok = i == 0 || {
85 let b = bytes[i - 1];
86 !b.is_ascii_alphanumeric() && b != b'_'
87 };
88 if before_ok && bytes[i..i + group_concat_len].eq_ignore_ascii_case(group_concat_kw) {
89 let after = i + group_concat_len;
90 let after_ok = after >= len || {
91 let b = bytes[after];
92 !b.is_ascii_alphanumeric() && b != b'_'
93 };
94 if after_ok {
95 let (line, col) = line_col(source, i);
96 diags.push(Diagnostic {
97 rule: rule_name,
98 message: "GROUP_CONCAT() is MySQL-specific — use STRING_AGG(col, separator) for portable string aggregation (PostgreSQL, SQL Server, BigQuery)".to_string(),
99 line,
100 col,
101 });
102 i += group_concat_len;
103 continue;
104 }
105 }
106 }
107
108 if i + listagg_len <= len {
110 let before_ok = i == 0 || {
111 let b = bytes[i - 1];
112 !b.is_ascii_alphanumeric() && b != b'_'
113 };
114 if before_ok && bytes[i..i + listagg_len].eq_ignore_ascii_case(listagg_kw) {
115 let after = i + listagg_len;
116 let after_ok = after >= len || {
117 let b = bytes[after];
118 !b.is_ascii_alphanumeric() && b != b'_'
119 };
120 if after_ok {
121 let (line, col) = line_col(source, i);
122 diags.push(Diagnostic {
123 rule: rule_name,
124 message: "LISTAGG() is Oracle/Snowflake-specific — use STRING_AGG(col, separator) for portable string aggregation".to_string(),
125 line,
126 col,
127 });
128 i += listagg_len;
129 continue;
130 }
131 }
132 }
133
134 i += 1;
135 }
136
137 diags
138}