Skip to main content

sqrust_rules/convention/
string_agg_separator.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct StringAggSeparator;
4
5impl Rule for StringAggSeparator {
6    fn name(&self) -> &'static str {
7        "Convention/StringAggSeparator"
8    }
9
10    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
11        find_violations(&ctx.source, self.name())
12    }
13}
14
15/// Builds a set of byte offsets that should be skipped (inside string literals or
16/// line comments).
17fn build_skip_set(source: &str) -> std::collections::HashSet<usize> {
18    let mut skip = std::collections::HashSet::new();
19    let bytes = source.as_bytes();
20    let len = bytes.len();
21    let mut i = 0;
22    while i < len {
23        if bytes[i] == b'\'' {
24            i += 1;
25            while i < len {
26                if bytes[i] == b'\'' {
27                    if i + 1 < len && bytes[i + 1] == b'\'' {
28                        skip.insert(i);
29                        i += 2;
30                    } else {
31                        i += 1;
32                        break;
33                    }
34                } else {
35                    skip.insert(i);
36                    i += 1;
37                }
38            }
39        } else if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
40            while i < len && bytes[i] != b'\n' {
41                skip.insert(i);
42                i += 1;
43            }
44        } else {
45            i += 1;
46        }
47    }
48    skip
49}
50
51/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
52fn line_col(source: &str, offset: usize) -> (usize, usize) {
53    let before = &source[..offset];
54    let line = before.chars().filter(|&c| c == '\n').count() + 1;
55    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
56    (line, col)
57}
58
59fn find_violations(source: &str, rule_name: &'static str) -> Vec<Diagnostic> {
60    let bytes = source.as_bytes();
61    let len = bytes.len();
62
63    if len == 0 {
64        return Vec::new();
65    }
66
67    let skip = build_skip_set(source);
68    let mut diags = Vec::new();
69
70    let group_concat_kw = b"GROUP_CONCAT";
71    let group_concat_len = group_concat_kw.len();
72    let listagg_kw = b"LISTAGG";
73    let listagg_len = listagg_kw.len();
74
75    let mut i = 0;
76    while i < len {
77        if skip.contains(&i) {
78            i += 1;
79            continue;
80        }
81
82        // Try to match GROUP_CONCAT
83        if i + group_concat_len <= len {
84            let before_ok = i == 0 || {
85                let b = bytes[i - 1];
86                !b.is_ascii_alphanumeric() && b != b'_'
87            };
88            if before_ok && bytes[i..i + group_concat_len].eq_ignore_ascii_case(group_concat_kw) {
89                let after = i + group_concat_len;
90                let after_ok = after >= len || {
91                    let b = bytes[after];
92                    !b.is_ascii_alphanumeric() && b != b'_'
93                };
94                if after_ok {
95                    let (line, col) = line_col(source, i);
96                    diags.push(Diagnostic {
97                        rule: rule_name,
98                        message: "GROUP_CONCAT() is MySQL-specific — use STRING_AGG(col, separator) for portable string aggregation (PostgreSQL, SQL Server, BigQuery)".to_string(),
99                        line,
100                        col,
101                    });
102                    i += group_concat_len;
103                    continue;
104                }
105            }
106        }
107
108        // Try to match LISTAGG
109        if i + listagg_len <= len {
110            let before_ok = i == 0 || {
111                let b = bytes[i - 1];
112                !b.is_ascii_alphanumeric() && b != b'_'
113            };
114            if before_ok && bytes[i..i + listagg_len].eq_ignore_ascii_case(listagg_kw) {
115                let after = i + listagg_len;
116                let after_ok = after >= len || {
117                    let b = bytes[after];
118                    !b.is_ascii_alphanumeric() && b != b'_'
119                };
120                if after_ok {
121                    let (line, col) = line_col(source, i);
122                    diags.push(Diagnostic {
123                        rule: rule_name,
124                        message: "LISTAGG() is Oracle/Snowflake-specific — use STRING_AGG(col, separator) for portable string aggregation".to_string(),
125                        line,
126                        col,
127                    });
128                    i += listagg_len;
129                    continue;
130                }
131            }
132        }
133
134        i += 1;
135    }
136
137    diags
138}