solar_parse/lexer/unescape/
mod.rs

1//! Utilities for validating string and char literals and turning them into values they represent.
2
3use alloy_primitives::hex;
4use solar_data_structures::trustme;
5use std::{borrow::Cow, ops::Range, slice, str::Chars};
6
7mod errors;
8pub(crate) use errors::emit_unescape_error;
9pub use errors::EscapeError;
10
11/// What kind of literal do we parse.
12#[derive(Clone, Copy, Debug, PartialEq, Eq)]
13pub enum Mode {
14    /// Normal string literal (e.g. `"a"`).
15    Str,
16    /// Unicode string literal (e.g. `unicode"๐Ÿ˜€"`).
17    UnicodeStr,
18    /// Hex string literal (e.g. `hex"1234"`).
19    HexStr,
20}
21
22/// Parses a string literal (without quotes) into a byte array.
23pub fn parse_string_literal(src: &str, mode: Mode) -> Cow<'_, [u8]> {
24    try_parse_string_literal(src, mode, |_, _| {})
25}
26
27/// Parses a string literal (without quotes) into a byte array.
28/// `f` is called for each escape error.
29#[instrument(name = "parse_string_literal", level = "debug", skip_all)]
30pub fn try_parse_string_literal<F>(src: &str, mode: Mode, f: F) -> Cow<'_, [u8]>
31where
32    F: FnMut(Range<usize>, EscapeError),
33{
34    let mut bytes = if needs_unescape(src, mode) {
35        Cow::Owned(parse_literal_unescape(src, mode, f))
36    } else {
37        Cow::Borrowed(src.as_bytes())
38    };
39    if mode == Mode::HexStr {
40        // Currently this should never fail, but it's a good idea to check anyway.
41        if let Ok(decoded) = hex::decode(&bytes) {
42            bytes = Cow::Owned(decoded);
43        }
44    }
45    bytes
46}
47
48#[cold]
49fn parse_literal_unescape<F>(src: &str, mode: Mode, f: F) -> Vec<u8>
50where
51    F: FnMut(Range<usize>, EscapeError),
52{
53    let mut bytes = Vec::with_capacity(src.len());
54    parse_literal_unescape_into(src, mode, f, &mut bytes);
55    bytes
56}
57
58fn parse_literal_unescape_into<F>(src: &str, mode: Mode, mut f: F, dst_buf: &mut Vec<u8>)
59where
60    F: FnMut(Range<usize>, EscapeError),
61{
62    // `src.len()` is enough capacity for the unescaped string, so we can just use a slice.
63    // SAFETY: The buffer is never read from.
64    debug_assert!(dst_buf.is_empty());
65    debug_assert!(dst_buf.capacity() >= src.len());
66    let mut dst = unsafe { slice::from_raw_parts_mut(dst_buf.as_mut_ptr(), dst_buf.capacity()) };
67    unescape_literal_unchecked(src, mode, |range, res| match res {
68        Ok(c) => {
69            // NOTE: We can't use `char::encode_utf8` because `c` can be an invalid unicode code.
70            let written = super::utf8::encode_utf8_raw(c, dst).len();
71
72            // SAFETY: Unescaping guarantees that the final unescaped byte array is shorter than
73            // the initial string.
74            debug_assert!(dst.len() >= written);
75            let advanced = unsafe { dst.get_unchecked_mut(written..) };
76
77            // SAFETY: I don't know why this triggers E0521.
78            dst = unsafe { trustme::decouple_lt_mut(advanced) };
79        }
80        Err(e) => f(range, e),
81    });
82    unsafe { dst_buf.set_len(dst_buf.capacity() - dst.len()) };
83}
84
85/// Unescapes the contents of a string literal (without quotes).
86///
87/// The callback is invoked with a range and either a unicode code point or an error.
88#[instrument(level = "debug", skip_all)]
89pub fn unescape_literal<F>(src: &str, mode: Mode, mut callback: F)
90where
91    F: FnMut(Range<usize>, Result<u32, EscapeError>),
92{
93    if needs_unescape(src, mode) {
94        unescape_literal_unchecked(src, mode, callback)
95    } else {
96        for (i, ch) in src.char_indices() {
97            callback(i..i + ch.len_utf8(), Ok(ch as u32));
98        }
99    }
100}
101
102/// Unescapes the contents of a string literal (without quotes).
103///
104/// See [`unescape_literal`] for more details.
105fn unescape_literal_unchecked<F>(src: &str, mode: Mode, callback: F)
106where
107    F: FnMut(Range<usize>, Result<u32, EscapeError>),
108{
109    match mode {
110        Mode::Str | Mode::UnicodeStr => {
111            unescape_str(src, matches!(mode, Mode::UnicodeStr), callback)
112        }
113        Mode::HexStr => unescape_hex_str(src, callback),
114    }
115}
116
117/// Fast-path check for whether a string literal needs to be unescaped or errors need to be
118/// reported.
119fn needs_unescape(src: &str, mode: Mode) -> bool {
120    fn needs_unescape_chars(src: &str) -> bool {
121        memchr::memchr3(b'\\', b'\n', b'\r', src.as_bytes()).is_some()
122    }
123
124    match mode {
125        Mode::Str => needs_unescape_chars(src) || !src.is_ascii(),
126        Mode::UnicodeStr => needs_unescape_chars(src),
127        Mode::HexStr => src.len() % 2 != 0 || !hex::check_raw(src),
128    }
129}
130
131fn scan_escape(chars: &mut Chars<'_>) -> Result<u32, EscapeError> {
132    // Previous character was '\\', unescape what follows.
133    // https://docs.soliditylang.org/en/latest/grammar.html#a4.SolidityLexer.EscapeSequence
134    // Note that hex and unicode escape codes are not validated since string literals are allowed
135    // to contain invalid UTF-8.
136    Ok(match chars.next().ok_or(EscapeError::LoneSlash)? {
137        // Both quotes are always valid escapes.
138        '\'' => '\'' as u32,
139        '"' => '"' as u32,
140
141        '\\' => '\\' as u32,
142        'n' => '\n' as u32,
143        'r' => '\r' as u32,
144        't' => '\t' as u32,
145
146        'x' => {
147            // Parse hexadecimal character code.
148            let mut value = 0;
149            for _ in 0..2 {
150                let d = chars.next().ok_or(EscapeError::HexEscapeTooShort)?;
151                let d = d.to_digit(16).ok_or(EscapeError::InvalidHexEscape)?;
152                value = value * 16 + d;
153            }
154            value
155        }
156
157        'u' => {
158            // Parse hexadecimal unicode character code.
159            let mut value = 0;
160            for _ in 0..4 {
161                let d = chars.next().ok_or(EscapeError::UnicodeEscapeTooShort)?;
162                let d = d.to_digit(16).ok_or(EscapeError::InvalidUnicodeEscape)?;
163                value = value * 16 + d;
164            }
165            value
166        }
167
168        _ => return Err(EscapeError::InvalidEscape),
169    })
170}
171
172/// Unescape characters in a string literal.
173///
174/// See [`unescape_literal`] for more details.
175fn unescape_str<F>(src: &str, is_unicode: bool, mut callback: F)
176where
177    F: FnMut(Range<usize>, Result<u32, EscapeError>),
178{
179    let mut chars = src.chars();
180    // The `start` and `end` computation here is complicated because
181    // `skip_ascii_whitespace` makes us to skip over chars without counting
182    // them in the range computation.
183    while let Some(c) = chars.next() {
184        let start = src.len() - chars.as_str().len() - c.len_utf8();
185        let res = match c {
186            '\\' => match chars.clone().next() {
187                Some('\r') if chars.clone().nth(1) == Some('\n') => {
188                    // +2 for the '\\' and '\r' characters.
189                    skip_ascii_whitespace(&mut chars, start + 2, &mut callback);
190                    continue;
191                }
192                Some('\n') => {
193                    // +1 for the '\\' character.
194                    skip_ascii_whitespace(&mut chars, start + 1, &mut callback);
195                    continue;
196                }
197                _ => scan_escape(&mut chars),
198            },
199            '\n' => Err(EscapeError::StrNewline),
200            '\r' => {
201                if chars.clone().next() == Some('\n') {
202                    continue;
203                }
204                Err(EscapeError::BareCarriageReturn)
205            }
206            c if !is_unicode && !c.is_ascii() => Err(EscapeError::StrNonAsciiChar),
207            c => Ok(c as u32),
208        };
209        let end = src.len() - chars.as_str().len();
210        callback(start..end, res);
211    }
212}
213
214/// Skips over whitespace after a "\\\n" escape sequence.
215///
216/// Reports errors if multiple newlines are encountered.
217fn skip_ascii_whitespace<F>(chars: &mut Chars<'_>, mut start: usize, callback: &mut F)
218where
219    F: FnMut(Range<usize>, Result<u32, EscapeError>),
220{
221    // Skip the first newline.
222    let mut nl = chars.next();
223    if let Some('\r') = nl {
224        nl = chars.next();
225    }
226    debug_assert_eq!(nl, Some('\n'));
227    let mut tail = chars.as_str();
228    start += 1;
229
230    while tail.starts_with(|c: char| c.is_ascii_whitespace()) {
231        let first_non_space =
232            tail.bytes().position(|b| !matches!(b, b' ' | b'\t')).unwrap_or(tail.len());
233        tail = &tail[first_non_space..];
234        start += first_non_space;
235
236        if let Some(tail2) = tail.strip_prefix('\n').or_else(|| tail.strip_prefix("\r\n")) {
237            let skipped = tail.len() - tail2.len();
238            tail = tail2;
239            callback(start..start + skipped, Err(EscapeError::CannotSkipMultipleLines));
240            start += skipped;
241        }
242    }
243    *chars = tail.chars();
244}
245
246/// Unescape characters in a hex string literal.
247///
248/// See [`unescape_literal`] for more details.
249fn unescape_hex_str<F>(src: &str, mut callback: F)
250where
251    F: FnMut(Range<usize>, Result<u32, EscapeError>),
252{
253    let mut chars = src.char_indices();
254    if src.starts_with("0x") || src.starts_with("0X") {
255        chars.next();
256        chars.next();
257        callback(0..2, Err(EscapeError::HexPrefix));
258    }
259
260    let count = chars.clone().filter(|(_, c)| c.is_ascii_hexdigit()).count();
261    if count % 2 != 0 {
262        callback(0..src.len(), Err(EscapeError::HexOddDigits));
263        return;
264    }
265
266    let mut emit_underscore_errors = true;
267    let mut allow_underscore = false;
268    let mut even = true;
269    for (start, c) in chars {
270        let res = match c {
271            '_' => {
272                if emit_underscore_errors && (!allow_underscore || !even) {
273                    // Don't spam errors for multiple underscores.
274                    emit_underscore_errors = false;
275                    Err(EscapeError::HexBadUnderscore)
276                } else {
277                    allow_underscore = false;
278                    continue;
279                }
280            }
281            c if !c.is_ascii_hexdigit() => Err(EscapeError::HexNotHexDigit),
282            c => Ok(c as u32),
283        };
284
285        if res.is_ok() {
286            even = !even;
287            allow_underscore = true;
288        }
289
290        let end = start + c.len_utf8();
291        callback(start..end, res);
292    }
293
294    if emit_underscore_errors && src.len() > 1 && src.ends_with('_') {
295        callback(src.len() - 1..src.len(), Err(EscapeError::HexBadUnderscore));
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use EscapeError::*;
303
304    type ExErr = (Range<usize>, EscapeError);
305
306    fn check(mode: Mode, src: &str, expected_str: &str, expected_errs: &[ExErr]) {
307        let panic_str = format!("{mode:?}: {src:?}");
308
309        let mut ok = String::with_capacity(src.len());
310        let mut errs = Vec::with_capacity(expected_errs.len());
311        unescape_literal(src, mode, |range, c| match c {
312            Ok(c) => ok.push(char::try_from(c).unwrap()),
313            Err(e) => errs.push((range, e)),
314        });
315        assert_eq!(errs, expected_errs, "{panic_str}");
316        assert_eq!(ok, expected_str, "{panic_str}");
317
318        let mut errs2 = Vec::with_capacity(errs.len());
319        let out = try_parse_string_literal(src, mode, |range, e| {
320            errs2.push((range, e));
321        });
322        assert_eq!(errs2, errs, "{panic_str}");
323        if mode == Mode::HexStr {
324            assert_eq!(hex::encode(out), expected_str, "{panic_str}");
325        } else {
326            assert_eq!(hex::encode(out), hex::encode(expected_str), "{panic_str}");
327        }
328    }
329
330    #[test]
331    fn unescape_str() {
332        let cases: &[(&str, &str, &[ExErr])] = &[
333            ("", "", &[]),
334            (" ", " ", &[]),
335            ("\t", "\t", &[]),
336            (" \t ", " \t ", &[]),
337            ("foo", "foo", &[]),
338            ("hello world", "hello world", &[]),
339            (r"\", "", &[(0..1, LoneSlash)]),
340            (r"\\", "\\", &[]),
341            (r"\\\", "\\", &[(2..3, LoneSlash)]),
342            (r"\\\\", "\\\\", &[]),
343            (r"\\ ", "\\ ", &[]),
344            (r"\\ \", "\\ ", &[(3..4, LoneSlash)]),
345            (r"\\ \\", "\\ \\", &[]),
346            (r"\x", "", &[(0..2, HexEscapeTooShort)]),
347            (r"\x1", "", &[(0..3, HexEscapeTooShort)]),
348            (r"\xz", "", &[(0..3, InvalidHexEscape)]),
349            (r"\xzf", "f", &[(0..3, InvalidHexEscape)]),
350            (r"\xzz", "z", &[(0..3, InvalidHexEscape)]),
351            (r"\x69", "\x69", &[]),
352            (r"\xE8", "รจ", &[]),
353            (r"\u", "", &[(0..2, UnicodeEscapeTooShort)]),
354            (r"\u1", "", &[(0..3, UnicodeEscapeTooShort)]),
355            (r"\uz", "", &[(0..3, InvalidUnicodeEscape)]),
356            (r"\uzf", "f", &[(0..3, InvalidUnicodeEscape)]),
357            (r"\u12", "", &[(0..4, UnicodeEscapeTooShort)]),
358            (r"\u123", "", &[(0..5, UnicodeEscapeTooShort)]),
359            (r"\u1234", "\u{1234}", &[]),
360            (r"\u00e8", "รจ", &[]),
361            (r"\r", "\r", &[]),
362            (r"\t", "\t", &[]),
363            (r"\n", "\n", &[]),
364            (r"\n\n", "\n\n", &[]),
365            (r"\ ", "", &[(0..2, InvalidEscape)]),
366            (r"\?", "", &[(0..2, InvalidEscape)]),
367            ("\r\n", "", &[(1..2, StrNewline)]),
368            ("\n", "", &[(0..1, StrNewline)]),
369            ("\\\n", "", &[]),
370            ("\\\na", "a", &[]),
371            ("\\\n  a", "a", &[]),
372            ("a \\\n  b", "a b", &[]),
373            ("a\\n\\\n  b", "a\nb", &[]),
374            ("a\\t\\\n  b", "a\tb", &[]),
375            ("a\\n \\\n  b", "a\n b", &[]),
376            ("a\\n \\\n \tb", "a\n b", &[]),
377            ("a\\t \\\n  b", "a\t b", &[]),
378            ("\\\n \t a", "a", &[]),
379            (" \\\n \t a", " a", &[]),
380            ("\\\n \t a\n", "a", &[(6..7, StrNewline)]),
381            ("\\\n   \t   ", "", &[]),
382            (" \\\n   \t   ", " ", &[]),
383            (" he\\\n \\\nllo \\\n wor\\\nld", " hello world", &[]),
384            ("\\\n\na\\\nb", "ab", &[(2..3, CannotSkipMultipleLines)]),
385            ("\\\n \na\\\nb", "ab", &[(3..4, CannotSkipMultipleLines)]),
386            (
387                "\\\n \n\na\\\nb",
388                "ab",
389                &[(3..4, CannotSkipMultipleLines), (4..5, CannotSkipMultipleLines)],
390            ),
391            (
392                "a\\\n \n \t \nb\\\nc",
393                "abc",
394                &[(4..5, CannotSkipMultipleLines), (8..9, CannotSkipMultipleLines)],
395            ),
396        ];
397        for &(src, expected_str, expected_errs) in cases {
398            check(Mode::Str, src, expected_str, expected_errs);
399            check(Mode::UnicodeStr, src, expected_str, expected_errs);
400        }
401    }
402
403    #[test]
404    fn unescape_unicode_str() {
405        let cases: &[(&str, &str, &[ExErr], &[ExErr])] = &[
406            ("รจ", "รจ", &[], &[(0..2, StrNonAsciiChar)]),
407            ("๐Ÿ˜€", "๐Ÿ˜€", &[], &[(0..4, StrNonAsciiChar)]),
408        ];
409        for &(src, expected_str, e1, e2) in cases {
410            check(Mode::UnicodeStr, src, expected_str, e1);
411            check(Mode::Str, src, "", e2);
412        }
413    }
414
415    #[test]
416    fn unescape_hex_str() {
417        let cases: &[(&str, &str, &[ExErr])] = &[
418            ("", "", &[]),
419            ("z", "", &[(0..1, HexNotHexDigit)]),
420            ("\n", "", &[(0..1, HexNotHexDigit)]),
421            ("  11", "11", &[(0..1, HexNotHexDigit), (1..2, HexNotHexDigit)]),
422            ("0x", "", &[(0..2, HexPrefix)]),
423            ("0X", "", &[(0..2, HexPrefix)]),
424            ("0x11", "11", &[(0..2, HexPrefix)]),
425            ("0X11", "11", &[(0..2, HexPrefix)]),
426            ("1", "", &[(0..1, HexOddDigits)]),
427            ("12", "12", &[]),
428            ("123", "", &[(0..3, HexOddDigits)]),
429            ("1234", "1234", &[]),
430            ("_", "", &[(0..1, HexBadUnderscore)]),
431            ("_11", "11", &[(0..1, HexBadUnderscore)]),
432            ("_11_", "11", &[(0..1, HexBadUnderscore)]),
433            ("11_", "11", &[(2..3, HexBadUnderscore)]),
434            ("11_22", "1122", &[]),
435            ("11__", "11", &[(3..4, HexBadUnderscore)]),
436            ("11__22", "1122", &[(3..4, HexBadUnderscore)]),
437            ("1_2", "12", &[(1..2, HexBadUnderscore)]),
438        ];
439        for &(src, expected_str, expected_errs) in cases {
440            check(Mode::HexStr, src, expected_str, expected_errs);
441        }
442    }
443}