Skip to main content

sqrust_rules/layout/
parenthesis_spacing.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct ParenthesisSpacing;
4
5impl Rule for ParenthesisSpacing {
6    fn name(&self) -> &'static str {
7        "ParenthesisSpacing"
8    }
9
10    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
11        find_violations(&ctx.source, self.name())
12    }
13
14    fn fix(&self, ctx: &FileContext) -> Option<String> {
15        let violations = find_violations(&ctx.source, self.name());
16        if violations.is_empty() {
17            return None;
18        }
19
20        let bytes = ctx.source.as_bytes();
21        let len = bytes.len();
22        let skip = build_skip_set(bytes, len);
23
24        let mut result: Vec<u8> = Vec::with_capacity(len);
25        let mut i = 0;
26
27        while i < len {
28            let ch = bytes[i];
29
30            // Remove space(s) immediately after `(` (but not newlines)
31            if ch == b'(' {
32                result.push(ch);
33                i += 1;
34                // Skip any spaces that follow (not newlines)
35                while i < len && bytes[i] == b' ' && !skip[i] {
36                    i += 1;
37                }
38                continue;
39            }
40
41            // Remove space(s) immediately before `)`
42            // We need to look ahead: if current byte is space and next non-space is `)`, drop spaces
43            if ch == b' ' && !skip[i] {
44                // Scan forward to see if there's a `)` after only spaces
45                let mut j = i;
46                while j < len && bytes[j] == b' ' && !skip[j] {
47                    j += 1;
48                }
49                if j < len && bytes[j] == b')' {
50                    // Suppress all these spaces; let `)` be emitted on next iteration
51                    i = j;
52                    continue;
53                } else {
54                    result.push(ch);
55                    i += 1;
56                    continue;
57                }
58            }
59
60            result.push(ch);
61            i += 1;
62        }
63
64        Some(String::from_utf8(result).expect("source was valid UTF-8"))
65    }
66}
67
68/// Scans the source for paren spacing violations.
69fn find_violations(source: &str, rule_name: &'static str) -> Vec<Diagnostic> {
70    let bytes = source.as_bytes();
71    let len = bytes.len();
72    let skip = build_skip_set(bytes, len);
73
74    let mut diags = Vec::new();
75
76    for i in 0..len {
77        // Space after opening parenthesis: `(` immediately followed by ` `
78        if i + 1 < len
79            && bytes[i] == b'('
80            && bytes[i + 1] == b' '
81            && !skip[i + 1]
82        {
83            let (line, col) = byte_offset_to_line_col(source, i + 1);
84            diags.push(Diagnostic {
85                rule: rule_name,
86                message: "Space after opening parenthesis; remove the space".to_string(),
87                line,
88                col,
89            });
90        }
91
92        // Space before closing parenthesis: ` ` immediately followed by `)`
93        if i + 1 < len
94            && bytes[i] == b' '
95            && bytes[i + 1] == b')'
96            && !skip[i]
97        {
98            let (line, col) = byte_offset_to_line_col(source, i);
99            diags.push(Diagnostic {
100                rule: rule_name,
101                message: "Space before closing parenthesis; remove the space".to_string(),
102                line,
103                col,
104            });
105        }
106    }
107
108    diags
109}
110
111/// Builds a boolean skip-set (indexed by byte offset).
112/// A byte is in the skip set if it lies inside:
113///   - a single-quoted string `'...'` (with `''` escaping)
114///   - a double-quoted identifier `"..."` (with `""` escaping)
115///   - a block comment `/* ... */`
116///   - a line comment `-- ...` (until newline)
117fn build_skip_set(bytes: &[u8], len: usize) -> Vec<bool> {
118    let mut skip = vec![false; len];
119    let mut i = 0;
120
121    while i < len {
122        // Single-quoted string
123        if bytes[i] == b'\'' {
124            let start = i;
125            i += 1;
126            while i < len {
127                if bytes[i] == b'\'' {
128                    if i + 1 < len && bytes[i + 1] == b'\'' {
129                        // Escaped quote `''`
130                        skip[start..=i + 1].fill(true);
131                        i += 2;
132                        continue;
133                    }
134                    // Closing quote
135                    skip[start..=i].fill(true);
136                    i += 1;
137                    break;
138                }
139                i += 1;
140            }
141            continue;
142        }
143
144        // Double-quoted identifier
145        if bytes[i] == b'"' {
146            let start = i;
147            i += 1;
148            while i < len {
149                if bytes[i] == b'"' {
150                    if i + 1 < len && bytes[i + 1] == b'"' {
151                        // Escaped quote `""`
152                        skip[start..=i + 1].fill(true);
153                        i += 2;
154                        continue;
155                    }
156                    // Closing quote
157                    skip[start..=i].fill(true);
158                    i += 1;
159                    break;
160                }
161                i += 1;
162            }
163            continue;
164        }
165
166        // Block comment `/* ... */`
167        if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
168            let start = i;
169            i += 2;
170            while i + 1 < len {
171                if bytes[i] == b'*' && bytes[i + 1] == b'/' {
172                    skip[start..=i + 1].fill(true);
173                    i += 2;
174                    break;
175                }
176                i += 1;
177            }
178            continue;
179        }
180
181        // Line comment `-- ...`
182        if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
183            let start = i;
184            while i < len && bytes[i] != b'\n' {
185                i += 1;
186            }
187            skip[start..i].fill(true);
188            continue;
189        }
190
191        i += 1;
192    }
193
194    skip
195}
196
197/// Converts a byte offset into a 1-indexed (line, col) pair.
198fn byte_offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
199    let mut line = 1usize;
200    let mut line_start = 0usize;
201    for (i, ch) in source.char_indices() {
202        if i == offset {
203            break;
204        }
205        if ch == '\n' {
206            line += 1;
207            line_start = i + 1;
208        }
209    }
210    let col = offset - line_start + 1;
211    (line, col)
212}