Skip to main content

squawk_syntax/
unescape.rs

1use std::fmt;
2use std::ops::{Range, RangeInclusive};
3
4pub enum UnicodeEscapeKind {
5    Extended,
6    Short,
7}
8
9impl UnicodeEscapeKind {
10    fn count(&self) -> u32 {
11        match self {
12            UnicodeEscapeKind::Extended => 6,
13            UnicodeEscapeKind::Short => 4,
14        }
15    }
16}
17
18pub enum UnicodeEscError {
19    InvalidEscape,
20    InvalidSurrogatePair,
21    OutOfRange,
22    RequiresHexDigits {
23        kind: UnicodeEscapeKind,
24        escape_char: char,
25    },
26}
27
28impl fmt::Display for UnicodeEscError {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        match self {
31            Self::InvalidEscape => f.write_str("Invalid Unicode escape sequence"),
32            Self::InvalidSurrogatePair => f.write_str("Invalid Unicode surrogate pair"),
33            Self::OutOfRange => f.write_str("Unicode escape value out of range"),
34            Self::RequiresHexDigits { kind, escape_char } => {
35                let required = kind.count();
36                let plus = match kind {
37                    UnicodeEscapeKind::Extended => "+",
38                    UnicodeEscapeKind::Short => "",
39                };
40                let xs = "X".repeat(required as usize);
41                write!(
42                    f,
43                    "Unicode escape requires {required} hex digits: {escape_char}{plus}{xs}"
44                )
45            }
46        }
47    }
48}
49
50pub fn escape_unicode_esc_str<F>(text: &str, escape_char: char, mut callback: F)
51where
52    F: FnMut(Range<usize>, Result<char, UnicodeEscError>),
53{
54    const HIGH_SURROGATE: RangeInclusive<u32> = 0xD800..=0xDBFF;
55    const LOW_SURROGATE: RangeInclusive<u32> = 0xDC00..=0xDFFF;
56    const MAX_CODEPOINT: u32 = 0x10FFFF;
57
58    let mut chars = text.char_indices().peekable();
59    let mut high_surrogate: Option<(Range<usize>, u32)> = None;
60
61    while let Some((escape_start, c)) = chars.next() {
62        if c != escape_char {
63            if let Some((hi_range, _)) = high_surrogate.take() {
64                callback(hi_range, Err(UnicodeEscError::InvalidSurrogatePair));
65            }
66            callback(escape_start..escape_start + c.len_utf8(), Ok(c));
67            continue;
68        }
69        let kind = match chars.peek() {
70            Some(&(_, c)) if c == escape_char => {
71                chars.next();
72                if let Some((hi_range, _)) = high_surrogate.take() {
73                    callback(hi_range, Err(UnicodeEscError::InvalidSurrogatePair));
74                }
75                let end = escape_start + escape_char.len_utf8() * 2;
76                callback(escape_start..end, Ok(escape_char));
77                continue;
78            }
79            Some(&(_, '+')) => {
80                chars.next();
81                UnicodeEscapeKind::Extended
82            }
83            Some(&(_, c)) if c.is_ascii_hexdigit() => UnicodeEscapeKind::Short,
84            _ => {
85                let end = chars
86                    .next()
87                    .map(|(i, c)| i + c.len_utf8())
88                    .unwrap_or(text.len());
89                if let Some((hi_range, _)) = high_surrogate.take() {
90                    callback(hi_range, Err(UnicodeEscError::InvalidSurrogatePair));
91                }
92                callback(escape_start..end, Err(UnicodeEscError::InvalidEscape));
93                continue;
94            }
95        };
96        let mut codepoint: u32 = 0;
97        let mut got_all = true;
98        let mut last_end = chars.peek().map(|&(i, _)| i).unwrap_or(text.len());
99        for _ in 0..kind.count() {
100            let radix = 16;
101            let Some(&(i, ch)) = chars.peek() else {
102                got_all = false;
103                break;
104            };
105            let Some(d) = ch.to_digit(radix) else {
106                got_all = false;
107                break;
108            };
109            chars.next();
110            codepoint = codepoint * radix + d;
111            last_end = i + ch.len_utf8();
112        }
113        if !got_all {
114            if let Some((hi_range, _)) = high_surrogate.take() {
115                callback(hi_range, Err(UnicodeEscError::InvalidSurrogatePair));
116            }
117            callback(
118                escape_start..last_end,
119                Err(UnicodeEscError::RequiresHexDigits { kind, escape_char }),
120            );
121            continue;
122        }
123        if let Some((hi_range, hi_cp)) = high_surrogate.take() {
124            if LOW_SURROGATE.contains(&codepoint) {
125                let combined = 0x10000 + ((hi_cp - 0xD800) << 10) + (codepoint - 0xDC00);
126                let ch = char::from_u32(combined).unwrap();
127                callback(hi_range.start..last_end, Ok(ch));
128                continue;
129            }
130            callback(
131                hi_range.start..last_end,
132                Err(UnicodeEscError::InvalidSurrogatePair),
133            );
134            continue;
135        }
136        if codepoint > MAX_CODEPOINT {
137            callback(escape_start..last_end, Err(UnicodeEscError::OutOfRange));
138        } else if HIGH_SURROGATE.contains(&codepoint) {
139            high_surrogate = Some((escape_start..last_end, codepoint));
140        } else if LOW_SURROGATE.contains(&codepoint) {
141            callback(
142                escape_start..last_end,
143                Err(UnicodeEscError::InvalidSurrogatePair),
144            );
145        } else {
146            let ch = char::from_u32(codepoint).unwrap();
147            callback(escape_start..last_end, Ok(ch));
148        }
149    }
150    if let Some((range, _)) = high_surrogate {
151        callback(range, Err(UnicodeEscError::InvalidSurrogatePair));
152    }
153}
154
155// https://github.com/postgres/postgres/blob/228a1f9542792c6533ef74c2e7aefad0da1d9a7a/src/backend/parser/parser.c#L350
156const fn is_valid_uescape_char(byte: u8) -> bool {
157    !byte.is_ascii_hexdigit()
158        && byte != b'+'
159        && byte != b'\''
160        && byte != b'"'
161        && !matches!(
162            byte,
163            b' ' | b'\t' | b'\n' | b'\r' | /* b'\v' */ 0x0B | /* b'\f' */ 0x0C
164        )
165}
166
167pub fn uescape_char(text: &str) -> Option<char> {
168    let inner = text.strip_prefix('\'')?.strip_suffix('\'')?;
169    let &[byte] = inner.as_bytes() else {
170        return None;
171    };
172    is_valid_uescape_char(byte).then(|| char::from(byte))
173}
174
175pub fn decode_plain_string(inner: &str, out: &mut String) {
176    let mut chars = inner.chars().peekable();
177    while let Some(c) = chars.next() {
178        if c == '\'' && chars.peek() == Some(&'\'') {
179            chars.next();
180            out.push('\'');
181        } else {
182            out.push(c);
183        }
184    }
185}
186
187fn push_char_bytes(c: char, bytes: &mut Vec<u8>) {
188    let mut buf = [0; 4];
189    let encoded = c.encode_utf8(&mut buf);
190    bytes.extend_from_slice(encoded.as_bytes());
191}
192
193pub fn decode_esc_string(inner: &str, out: &mut String) {
194    let mut chars = inner.chars().peekable();
195    let mut bytes = vec![];
196
197    while let Some(c) = chars.next() {
198        if c == '\'' && chars.peek() == Some(&'\'') {
199            chars.next();
200            bytes.push(b'\'');
201            continue;
202        }
203        if c != '\\' {
204            push_char_bytes(c, &mut bytes);
205            continue;
206        }
207        let Some(&next) = chars.peek() else {
208            bytes.push(b'\\');
209            break;
210        };
211        match next {
212            'b' => {
213                chars.next();
214                bytes.push(b'\x08');
215            }
216            'f' => {
217                chars.next();
218                bytes.push(b'\x0C');
219            }
220            'n' => {
221                chars.next();
222                bytes.push(b'\n');
223            }
224            'r' => {
225                chars.next();
226                bytes.push(b'\r');
227            }
228            't' => {
229                chars.next();
230                bytes.push(b'\t');
231            }
232            '0'..='7' => {
233                let mut value: u32 = 0;
234                for _ in 0..3 {
235                    match chars.peek() {
236                        Some(&d) if ('0'..='7').contains(&d) => {
237                            chars.next();
238                            value = value * 8 + d.to_digit(8).unwrap();
239                        }
240                        _ => break,
241                    }
242                }
243                if value != 0 {
244                    bytes.push(value as u8);
245                }
246            }
247            'x' => {
248                chars.next();
249                let mut value: u8 = 0;
250                let mut got_any = false;
251                for _ in 0..2 {
252                    match chars.peek() {
253                        Some(&d) if d.is_ascii_hexdigit() => {
254                            chars.next();
255                            value = value * 16 + d.to_digit(16).unwrap() as u8;
256                            got_any = true;
257                        }
258                        _ => break,
259                    }
260                }
261                if got_any {
262                    if value != 0 {
263                        bytes.push(value);
264                    }
265                } else {
266                    bytes.push(b'x');
267                }
268            }
269            'u' | 'U' => {
270                chars.next();
271                let required = if next == 'u' { 4 } else { 8 };
272                let mut value: u32 = 0;
273                let mut got_all = true;
274                for _ in 0..required {
275                    match chars.peek() {
276                        Some(&d) if d.is_ascii_hexdigit() => {
277                            chars.next();
278                            value = value * 16 + d.to_digit(16).unwrap();
279                        }
280                        _ => {
281                            got_all = false;
282                            break;
283                        }
284                    }
285                }
286                if got_all
287                    && let Some(ch) = char::from_u32(value)
288                    && ch != '\0'
289                {
290                    push_char_bytes(ch, &mut bytes);
291                }
292            }
293            _ => {
294                chars.next();
295                push_char_bytes(next, &mut bytes);
296            }
297        }
298    }
299
300    out.push_str(&String::from_utf8_lossy(&bytes));
301}
302
303pub fn decode_unicode_esc_string(inner: &str, escape_char: char, out: &mut String) {
304    let inner = inner.replace("''", "'");
305    escape_unicode_esc_str(&inner, escape_char, |_range, result| {
306        if let Ok(ch) = result {
307            out.push(ch);
308        }
309    });
310}
311
312#[cfg(test)]
313mod tests {
314    use insta::assert_snapshot;
315
316    use super::*;
317
318    fn unicode_escape_events(text: &str, escape_char: char) -> String {
319        let mut events = vec![];
320
321        escape_unicode_esc_str(text, escape_char, |range, result| {
322            let entry = match result {
323                Ok(ch) => format!("{}..{} ok {ch:?}", range.start, range.end),
324                Err(err) => format!("{}..{} err {err}", range.start, range.end),
325            };
326            events.push(entry);
327        });
328
329        events.join("\n")
330    }
331
332    fn decode_escape_string(inner: &str) -> String {
333        let mut out = String::new();
334        decode_esc_string(inner, &mut out);
335        out
336    }
337
338    fn decode_unicode_escape_string(inner: &str, escape_char: char) -> String {
339        let mut out = String::new();
340        decode_unicode_esc_string(inner, escape_char, &mut out);
341        out
342    }
343
344    #[test]
345    fn ok() {
346        assert_snapshot!(unicode_escape_events(r"hello world", '\\'), @"
347        0..1 ok 'h'
348        1..2 ok 'e'
349        2..3 ok 'l'
350        3..4 ok 'l'
351        4..5 ok 'o'
352        5..6 ok ' '
353        6..7 ok 'w'
354        7..8 ok 'o'
355        8..9 ok 'r'
356        9..10 ok 'l'
357        10..11 ok 'd'
358        ");
359    }
360
361    #[test]
362    fn incomplete_unicode_escape_breaks_surrogate_pairing() {
363        assert_snapshot!(unicode_escape_events(r"\D800\006\DC00", '\\'), @r"
364        0..5 err Invalid Unicode surrogate pair
365        5..9 err Unicode escape requires 4 hex digits: \XXXX
366        9..14 err Invalid Unicode surrogate pair
367        ");
368    }
369
370    #[test]
371    fn invalid_unicode_escape_breaks_surrogate_pairing() {
372        assert_snapshot!(unicode_escape_events(r"\D800\Q\DC00", '\\'), @r"
373        0..5 err Invalid Unicode surrogate pair
374        5..7 err Invalid Unicode escape sequence
375        7..12 err Invalid Unicode surrogate pair
376        ");
377    }
378
379    #[test]
380    fn invalid_unicode_escape_does_not_emit_literal_char() {
381        assert_snapshot!(unicode_escape_events(r"\0061\Q\0062", '\\'), @r"
382        0..5 ok 'a'
383        5..7 err Invalid Unicode escape sequence
384        7..12 ok 'b'
385        ");
386    }
387
388    #[test]
389    fn invalid_unicode_escape_works_with_custom_escape_char() {
390        assert_snapshot!(unicode_escape_events("!0061!Q!0062", '!'), @r"
391        0..5 ok 'a'
392        5..7 err Invalid Unicode escape sequence
393        7..12 ok 'b'
394        ");
395    }
396
397    #[test]
398    fn valid_unicode_escape_after_high_surrogate_only_emits_error() {
399        assert_snapshot!(unicode_escape_events(r"\D800\0061", '\\'), @r"
400        0..10 err Invalid Unicode surrogate pair
401        ");
402    }
403
404    #[test]
405    fn decode_escape_string_hex_bytes_as_utf8() {
406        assert_snapshot!(decode_escape_string(r"\xC3\xA9"), @"é");
407    }
408
409    #[test]
410    fn decode_escape_string_skips_nul_byte() {
411        assert_snapshot!(decode_escape_string(r"a\000b"), @"ab");
412    }
413
414    #[test]
415    fn decode_unicode_string_collapses_doubled_quotes() {
416        assert_snapshot!(decode_unicode_escape_string("a''b", '\\'), @"a'b");
417    }
418}