solar_parse/lexer/unescape/
mod.rs

1//! Utilities for validating string and char literals and turning them into values they represent.
2
3use alloy_primitives::hex;
4use solar_ast::Span;
5use solar_data_structures::trustme;
6use solar_interface::{BytePos, Session};
7use std::{borrow::Cow, ops::Range, slice, str::Chars};
8
9mod errors;
10pub use errors::EscapeError;
11pub(crate) use errors::emit_unescape_error;
12
13pub use solar_ast::StrKind;
14
15pub fn parse_string_literal<'a>(
16    src: &'a str,
17    kind: StrKind,
18    lit_span: Span,
19    sess: &Session,
20) -> (Cow<'a, [u8]>, bool) {
21    let mut has_error = false;
22    let content_start = lit_span.lo() + BytePos(1) + BytePos(kind.prefix().len() as u32);
23    let bytes = try_parse_string_literal(src, kind, |range, error| {
24        has_error = true;
25        let (start, end) = (range.start as u32, range.end as u32);
26        let lo = content_start + BytePos(start);
27        let hi = lo + BytePos(end - start);
28        let span = Span::new(lo, hi);
29        emit_unescape_error(&sess.dcx, src, span, range, error);
30    });
31    (bytes, has_error)
32}
33
34/// Parses a string literal (without quotes) into a byte array.
35/// `f` is called for each escape error.
36#[instrument(name = "parse_string_literal", level = "trace", skip_all)]
37pub fn try_parse_string_literal<F>(src: &str, kind: StrKind, f: F) -> Cow<'_, [u8]>
38where
39    F: FnMut(Range<usize>, EscapeError),
40{
41    let mut bytes = if needs_unescape(src, kind) {
42        Cow::Owned(parse_literal_unescape(src, kind, f))
43    } else {
44        Cow::Borrowed(src.as_bytes())
45    };
46    if kind == StrKind::Hex {
47        // Currently this should never fail, but it's a good idea to check anyway.
48        if let Ok(decoded) = hex::decode(&bytes) {
49            bytes = Cow::Owned(decoded);
50        }
51    }
52    bytes
53}
54
55#[cold]
56fn parse_literal_unescape<F>(src: &str, kind: StrKind, f: F) -> Vec<u8>
57where
58    F: FnMut(Range<usize>, EscapeError),
59{
60    let mut bytes = Vec::with_capacity(src.len());
61    parse_literal_unescape_into(src, kind, f, &mut bytes);
62    bytes
63}
64
65fn parse_literal_unescape_into<F>(src: &str, kind: StrKind, mut f: F, dst_buf: &mut Vec<u8>)
66where
67    F: FnMut(Range<usize>, EscapeError),
68{
69    // `src.len()` is enough capacity for the unescaped string, so we can just use a slice.
70    // SAFETY: The buffer is never read from.
71    debug_assert!(dst_buf.is_empty());
72    debug_assert!(dst_buf.capacity() >= src.len());
73    let mut dst = unsafe { slice::from_raw_parts_mut(dst_buf.as_mut_ptr(), dst_buf.capacity()) };
74    unescape_literal_unchecked(src, kind, |range, res| match res {
75        Ok(c) => {
76            // NOTE: We can't use `char::encode_utf8` because `c` can be an invalid unicode code.
77            let written = super::utf8::encode_utf8_raw(c, dst).len();
78
79            // SAFETY: Unescaping guarantees that the final unescaped byte array is shorter than
80            // the initial string.
81            debug_assert!(dst.len() >= written);
82            let advanced = unsafe { dst.get_unchecked_mut(written..) };
83
84            // SAFETY: I don't know why this triggers E0521.
85            dst = unsafe { trustme::decouple_lt_mut(advanced) };
86        }
87        Err(e) => f(range, e),
88    });
89    unsafe { dst_buf.set_len(dst_buf.capacity() - dst.len()) };
90}
91
92/// Unescapes the contents of a string literal (without quotes).
93///
94/// The callback is invoked with a range and either a unicode code point or an error.
95#[instrument(level = "debug", skip_all)]
96pub fn unescape_literal<F>(src: &str, kind: StrKind, mut callback: F)
97where
98    F: FnMut(Range<usize>, Result<u32, EscapeError>),
99{
100    if needs_unescape(src, kind) {
101        unescape_literal_unchecked(src, kind, callback)
102    } else {
103        for (i, ch) in src.char_indices() {
104            callback(i..i + ch.len_utf8(), Ok(ch as u32));
105        }
106    }
107}
108
109/// Unescapes the contents of a string literal (without quotes).
110///
111/// See [`unescape_literal`] for more details.
112fn unescape_literal_unchecked<F>(src: &str, kind: StrKind, callback: F)
113where
114    F: FnMut(Range<usize>, Result<u32, EscapeError>),
115{
116    match kind {
117        StrKind::Str | StrKind::Unicode => {
118            unescape_str(src, matches!(kind, StrKind::Unicode), callback)
119        }
120        StrKind::Hex => unescape_hex_str(src, callback),
121    }
122}
123
124/// Fast-path check for whether a string literal needs to be unescaped or errors need to be
125/// reported.
126fn needs_unescape(src: &str, kind: StrKind) -> bool {
127    fn needs_unescape_chars(src: &str) -> bool {
128        memchr::memchr3(b'\\', b'\n', b'\r', src.as_bytes()).is_some()
129    }
130
131    match kind {
132        StrKind::Str => needs_unescape_chars(src) || !src.is_ascii(),
133        StrKind::Unicode => needs_unescape_chars(src),
134        StrKind::Hex => !src.len().is_multiple_of(2) || !hex::check_raw(src),
135    }
136}
137
138fn scan_escape(chars: &mut Chars<'_>) -> Result<u32, EscapeError> {
139    // Previous character was '\\', unescape what follows.
140    // https://docs.soliditylang.org/en/latest/grammar.html#a4.SolidityLexer.EscapeSequence
141    // Note that hex and unicode escape codes are not validated since string literals are allowed
142    // to contain invalid UTF-8.
143    Ok(match chars.next().ok_or(EscapeError::LoneSlash)? {
144        // Both quotes are always valid escapes.
145        '\'' => '\'' as u32,
146        '"' => '"' as u32,
147
148        '\\' => '\\' as u32,
149        'n' => '\n' as u32,
150        'r' => '\r' as u32,
151        't' => '\t' as u32,
152
153        'x' => {
154            // Parse hexadecimal character code.
155            let mut value = 0;
156            for _ in 0..2 {
157                let d = chars.next().ok_or(EscapeError::HexEscapeTooShort)?;
158                let d = d.to_digit(16).ok_or(EscapeError::InvalidHexEscape)?;
159                value = value * 16 + d;
160            }
161            value
162        }
163
164        'u' => {
165            // Parse hexadecimal unicode character code.
166            let mut value = 0;
167            for _ in 0..4 {
168                let d = chars.next().ok_or(EscapeError::UnicodeEscapeTooShort)?;
169                let d = d.to_digit(16).ok_or(EscapeError::InvalidUnicodeEscape)?;
170                value = value * 16 + d;
171            }
172            value
173        }
174
175        _ => return Err(EscapeError::InvalidEscape),
176    })
177}
178
179/// Unescape characters in a string literal.
180///
181/// See [`unescape_literal`] for more details.
182fn unescape_str<F>(src: &str, is_unicode: bool, mut callback: F)
183where
184    F: FnMut(Range<usize>, Result<u32, EscapeError>),
185{
186    let mut chars = src.chars();
187    // The `start` and `end` computation here is complicated because
188    // `skip_ascii_whitespace` makes us to skip over chars without counting
189    // them in the range computation.
190    while let Some(c) = chars.next() {
191        let start = src.len() - chars.as_str().len() - c.len_utf8();
192        let res = match c {
193            '\\' => match chars.clone().next() {
194                Some('\r') if chars.clone().nth(1) == Some('\n') => {
195                    // +2 for the '\\' and '\r' characters.
196                    skip_ascii_whitespace(&mut chars, start + 2, &mut callback);
197                    continue;
198                }
199                Some('\n') => {
200                    // +1 for the '\\' character.
201                    skip_ascii_whitespace(&mut chars, start + 1, &mut callback);
202                    continue;
203                }
204                _ => scan_escape(&mut chars),
205            },
206            '\n' => Err(EscapeError::StrNewline),
207            '\r' => {
208                if chars.clone().next() == Some('\n') {
209                    continue;
210                }
211                Err(EscapeError::BareCarriageReturn)
212            }
213            c if !is_unicode && !c.is_ascii() => Err(EscapeError::StrNonAsciiChar),
214            c => Ok(c as u32),
215        };
216        let end = src.len() - chars.as_str().len();
217        callback(start..end, res);
218    }
219}
220
221/// Skips over whitespace after a "\\\n" escape sequence.
222///
223/// Reports errors if multiple newlines are encountered.
224fn skip_ascii_whitespace<F>(chars: &mut Chars<'_>, mut start: usize, callback: &mut F)
225where
226    F: FnMut(Range<usize>, Result<u32, EscapeError>),
227{
228    // Skip the first newline.
229    let mut nl = chars.next();
230    if let Some('\r') = nl {
231        nl = chars.next();
232    }
233    debug_assert_eq!(nl, Some('\n'));
234    let mut tail = chars.as_str();
235    start += 1;
236
237    while tail.starts_with(|c: char| c.is_ascii_whitespace()) {
238        let first_non_space =
239            tail.bytes().position(|b| !matches!(b, b' ' | b'\t')).unwrap_or(tail.len());
240        tail = &tail[first_non_space..];
241        start += first_non_space;
242
243        if let Some(tail2) = tail.strip_prefix('\n').or_else(|| tail.strip_prefix("\r\n")) {
244            let skipped = tail.len() - tail2.len();
245            tail = tail2;
246            callback(start..start + skipped, Err(EscapeError::CannotSkipMultipleLines));
247            start += skipped;
248        }
249    }
250    *chars = tail.chars();
251}
252
253/// Unescape characters in a hex string literal.
254///
255/// See [`unescape_literal`] for more details.
256fn unescape_hex_str<F>(src: &str, mut callback: F)
257where
258    F: FnMut(Range<usize>, Result<u32, EscapeError>),
259{
260    let mut chars = src.char_indices();
261    if src.starts_with("0x") || src.starts_with("0X") {
262        chars.next();
263        chars.next();
264        callback(0..2, Err(EscapeError::HexPrefix));
265    }
266
267    let count = chars.clone().filter(|(_, c)| c.is_ascii_hexdigit()).count();
268    if !count.is_multiple_of(2) {
269        callback(0..src.len(), Err(EscapeError::HexOddDigits));
270        return;
271    }
272
273    let mut emit_underscore_errors = true;
274    let mut allow_underscore = false;
275    let mut even = true;
276    for (start, c) in chars {
277        let res = match c {
278            '_' => {
279                if emit_underscore_errors && (!allow_underscore || !even) {
280                    // Don't spam errors for multiple underscores.
281                    emit_underscore_errors = false;
282                    Err(EscapeError::HexBadUnderscore)
283                } else {
284                    allow_underscore = false;
285                    continue;
286                }
287            }
288            c if !c.is_ascii_hexdigit() => Err(EscapeError::HexNotHexDigit),
289            c => Ok(c as u32),
290        };
291
292        if res.is_ok() {
293            even = !even;
294            allow_underscore = true;
295        }
296
297        let end = start + c.len_utf8();
298        callback(start..end, res);
299    }
300
301    if emit_underscore_errors && src.len() > 1 && src.ends_with('_') {
302        callback(src.len() - 1..src.len(), Err(EscapeError::HexBadUnderscore));
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309    use EscapeError::*;
310
311    type ExErr = (Range<usize>, EscapeError);
312
313    fn check(kind: StrKind, src: &str, expected_str: &str, expected_errs: &[ExErr]) {
314        let panic_str = format!("{kind:?}: {src:?}");
315
316        let mut ok = String::with_capacity(src.len());
317        let mut errs = Vec::with_capacity(expected_errs.len());
318        unescape_literal(src, kind, |range, c| match c {
319            Ok(c) => ok.push(char::try_from(c).unwrap()),
320            Err(e) => errs.push((range, e)),
321        });
322        assert_eq!(errs, expected_errs, "{panic_str}");
323        assert_eq!(ok, expected_str, "{panic_str}");
324
325        let mut errs2 = Vec::with_capacity(errs.len());
326        let out = try_parse_string_literal(src, kind, |range, e| {
327            errs2.push((range, e));
328        });
329        assert_eq!(errs2, errs, "{panic_str}");
330        if kind == StrKind::Hex {
331            assert_eq!(hex::encode(out), expected_str, "{panic_str}");
332        } else {
333            assert_eq!(hex::encode(out), hex::encode(expected_str), "{panic_str}");
334        }
335    }
336
337    #[test]
338    fn unescape_str() {
339        let cases: &[(&str, &str, &[ExErr])] = &[
340            ("", "", &[]),
341            (" ", " ", &[]),
342            ("\t", "\t", &[]),
343            (" \t ", " \t ", &[]),
344            ("foo", "foo", &[]),
345            ("hello world", "hello world", &[]),
346            (r"\", "", &[(0..1, LoneSlash)]),
347            (r"\\", "\\", &[]),
348            (r"\\\", "\\", &[(2..3, LoneSlash)]),
349            (r"\\\\", "\\\\", &[]),
350            (r"\\ ", "\\ ", &[]),
351            (r"\\ \", "\\ ", &[(3..4, LoneSlash)]),
352            (r"\\ \\", "\\ \\", &[]),
353            (r"\x", "", &[(0..2, HexEscapeTooShort)]),
354            (r"\x1", "", &[(0..3, HexEscapeTooShort)]),
355            (r"\xz", "", &[(0..3, InvalidHexEscape)]),
356            (r"\xzf", "f", &[(0..3, InvalidHexEscape)]),
357            (r"\xzz", "z", &[(0..3, InvalidHexEscape)]),
358            (r"\x69", "\x69", &[]),
359            (r"\xE8", "è", &[]),
360            (r"\u", "", &[(0..2, UnicodeEscapeTooShort)]),
361            (r"\u1", "", &[(0..3, UnicodeEscapeTooShort)]),
362            (r"\uz", "", &[(0..3, InvalidUnicodeEscape)]),
363            (r"\uzf", "f", &[(0..3, InvalidUnicodeEscape)]),
364            (r"\u12", "", &[(0..4, UnicodeEscapeTooShort)]),
365            (r"\u123", "", &[(0..5, UnicodeEscapeTooShort)]),
366            (r"\u1234", "\u{1234}", &[]),
367            (r"\u00e8", "è", &[]),
368            (r"\r", "\r", &[]),
369            (r"\t", "\t", &[]),
370            (r"\n", "\n", &[]),
371            (r"\n\n", "\n\n", &[]),
372            (r"\ ", "", &[(0..2, InvalidEscape)]),
373            (r"\?", "", &[(0..2, InvalidEscape)]),
374            ("\r\n", "", &[(1..2, StrNewline)]),
375            ("\n", "", &[(0..1, StrNewline)]),
376            ("\\\n", "", &[]),
377            ("\\\na", "a", &[]),
378            ("\\\n  a", "a", &[]),
379            ("a \\\n  b", "a b", &[]),
380            ("a\\n\\\n  b", "a\nb", &[]),
381            ("a\\t\\\n  b", "a\tb", &[]),
382            ("a\\n \\\n  b", "a\n b", &[]),
383            ("a\\n \\\n \tb", "a\n b", &[]),
384            ("a\\t \\\n  b", "a\t b", &[]),
385            ("\\\n \t a", "a", &[]),
386            (" \\\n \t a", " a", &[]),
387            ("\\\n \t a\n", "a", &[(6..7, StrNewline)]),
388            ("\\\n   \t   ", "", &[]),
389            (" \\\n   \t   ", " ", &[]),
390            (" he\\\n \\\nllo \\\n wor\\\nld", " hello world", &[]),
391            ("\\\n\na\\\nb", "ab", &[(2..3, CannotSkipMultipleLines)]),
392            ("\\\n \na\\\nb", "ab", &[(3..4, CannotSkipMultipleLines)]),
393            (
394                "\\\n \n\na\\\nb",
395                "ab",
396                &[(3..4, CannotSkipMultipleLines), (4..5, CannotSkipMultipleLines)],
397            ),
398            (
399                "a\\\n \n \t \nb\\\nc",
400                "abc",
401                &[(4..5, CannotSkipMultipleLines), (8..9, CannotSkipMultipleLines)],
402            ),
403        ];
404        for &(src, expected_str, expected_errs) in cases {
405            check(StrKind::Str, src, expected_str, expected_errs);
406            check(StrKind::Unicode, src, expected_str, expected_errs);
407        }
408    }
409
410    #[test]
411    fn unescape_unicode_str() {
412        let cases: &[(&str, &str, &[ExErr], &[ExErr])] = &[
413            ("è", "è", &[], &[(0..2, StrNonAsciiChar)]),
414            ("😀", "😀", &[], &[(0..4, StrNonAsciiChar)]),
415        ];
416        for &(src, expected_str, e1, e2) in cases {
417            check(StrKind::Unicode, src, expected_str, e1);
418            check(StrKind::Str, src, "", e2);
419        }
420    }
421
422    #[test]
423    fn unescape_hex_str() {
424        let cases: &[(&str, &str, &[ExErr])] = &[
425            ("", "", &[]),
426            ("z", "", &[(0..1, HexNotHexDigit)]),
427            ("\n", "", &[(0..1, HexNotHexDigit)]),
428            ("  11", "11", &[(0..1, HexNotHexDigit), (1..2, HexNotHexDigit)]),
429            ("0x", "", &[(0..2, HexPrefix)]),
430            ("0X", "", &[(0..2, HexPrefix)]),
431            ("0x11", "11", &[(0..2, HexPrefix)]),
432            ("0X11", "11", &[(0..2, HexPrefix)]),
433            ("1", "", &[(0..1, HexOddDigits)]),
434            ("12", "12", &[]),
435            ("123", "", &[(0..3, HexOddDigits)]),
436            ("1234", "1234", &[]),
437            ("_", "", &[(0..1, HexBadUnderscore)]),
438            ("_11", "11", &[(0..1, HexBadUnderscore)]),
439            ("_11_", "11", &[(0..1, HexBadUnderscore)]),
440            ("11_", "11", &[(2..3, HexBadUnderscore)]),
441            ("11_22", "1122", &[]),
442            ("11__", "11", &[(3..4, HexBadUnderscore)]),
443            ("11__22", "1122", &[(3..4, HexBadUnderscore)]),
444            ("1_2", "12", &[(1..2, HexBadUnderscore)]),
445        ];
446        for &(src, expected_str, expected_errs) in cases {
447            check(StrKind::Hex, src, expected_str, expected_errs);
448        }
449    }
450}