1use alloy_primitives::hex;
4use solar_data_structures::trustme;
5use std::{borrow::Cow, ops::Range, slice, str::Chars};
6
7mod errors;
8pub(crate) use errors::emit_unescape_error;
9pub use errors::EscapeError;
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq)]
13pub enum Mode {
14 Str,
16 UnicodeStr,
18 HexStr,
20}
21
22pub fn parse_string_literal(src: &str, mode: Mode) -> Cow<'_, [u8]> {
24 try_parse_string_literal(src, mode, |_, _| {})
25}
26
27#[instrument(name = "parse_string_literal", level = "debug", skip_all)]
30pub fn try_parse_string_literal<F>(src: &str, mode: Mode, f: F) -> Cow<'_, [u8]>
31where
32 F: FnMut(Range<usize>, EscapeError),
33{
34 let mut bytes = if needs_unescape(src, mode) {
35 Cow::Owned(parse_literal_unescape(src, mode, f))
36 } else {
37 Cow::Borrowed(src.as_bytes())
38 };
39 if mode == Mode::HexStr {
40 if let Ok(decoded) = hex::decode(&bytes) {
42 bytes = Cow::Owned(decoded);
43 }
44 }
45 bytes
46}
47
48#[cold]
49fn parse_literal_unescape<F>(src: &str, mode: Mode, f: F) -> Vec<u8>
50where
51 F: FnMut(Range<usize>, EscapeError),
52{
53 let mut bytes = Vec::with_capacity(src.len());
54 parse_literal_unescape_into(src, mode, f, &mut bytes);
55 bytes
56}
57
58fn parse_literal_unescape_into<F>(src: &str, mode: Mode, mut f: F, dst_buf: &mut Vec<u8>)
59where
60 F: FnMut(Range<usize>, EscapeError),
61{
62 debug_assert!(dst_buf.is_empty());
65 debug_assert!(dst_buf.capacity() >= src.len());
66 let mut dst = unsafe { slice::from_raw_parts_mut(dst_buf.as_mut_ptr(), dst_buf.capacity()) };
67 unescape_literal_unchecked(src, mode, |range, res| match res {
68 Ok(c) => {
69 let written = super::utf8::encode_utf8_raw(c, dst).len();
71
72 debug_assert!(dst.len() >= written);
75 let advanced = unsafe { dst.get_unchecked_mut(written..) };
76
77 dst = unsafe { trustme::decouple_lt_mut(advanced) };
79 }
80 Err(e) => f(range, e),
81 });
82 unsafe { dst_buf.set_len(dst_buf.capacity() - dst.len()) };
83}
84
85#[instrument(level = "debug", skip_all)]
89pub fn unescape_literal<F>(src: &str, mode: Mode, mut callback: F)
90where
91 F: FnMut(Range<usize>, Result<u32, EscapeError>),
92{
93 if needs_unescape(src, mode) {
94 unescape_literal_unchecked(src, mode, callback)
95 } else {
96 for (i, ch) in src.char_indices() {
97 callback(i..i + ch.len_utf8(), Ok(ch as u32));
98 }
99 }
100}
101
102fn unescape_literal_unchecked<F>(src: &str, mode: Mode, callback: F)
106where
107 F: FnMut(Range<usize>, Result<u32, EscapeError>),
108{
109 match mode {
110 Mode::Str | Mode::UnicodeStr => {
111 unescape_str(src, matches!(mode, Mode::UnicodeStr), callback)
112 }
113 Mode::HexStr => unescape_hex_str(src, callback),
114 }
115}
116
117fn needs_unescape(src: &str, mode: Mode) -> bool {
120 fn needs_unescape_chars(src: &str) -> bool {
121 memchr::memchr3(b'\\', b'\n', b'\r', src.as_bytes()).is_some()
122 }
123
124 match mode {
125 Mode::Str => needs_unescape_chars(src) || !src.is_ascii(),
126 Mode::UnicodeStr => needs_unescape_chars(src),
127 Mode::HexStr => src.len() % 2 != 0 || !hex::check_raw(src),
128 }
129}
130
131fn scan_escape(chars: &mut Chars<'_>) -> Result<u32, EscapeError> {
132 Ok(match chars.next().ok_or(EscapeError::LoneSlash)? {
137 '\'' => '\'' as u32,
139 '"' => '"' as u32,
140
141 '\\' => '\\' as u32,
142 'n' => '\n' as u32,
143 'r' => '\r' as u32,
144 't' => '\t' as u32,
145
146 'x' => {
147 let mut value = 0;
149 for _ in 0..2 {
150 let d = chars.next().ok_or(EscapeError::HexEscapeTooShort)?;
151 let d = d.to_digit(16).ok_or(EscapeError::InvalidHexEscape)?;
152 value = value * 16 + d;
153 }
154 value
155 }
156
157 'u' => {
158 let mut value = 0;
160 for _ in 0..4 {
161 let d = chars.next().ok_or(EscapeError::UnicodeEscapeTooShort)?;
162 let d = d.to_digit(16).ok_or(EscapeError::InvalidUnicodeEscape)?;
163 value = value * 16 + d;
164 }
165 value
166 }
167
168 _ => return Err(EscapeError::InvalidEscape),
169 })
170}
171
172fn unescape_str<F>(src: &str, is_unicode: bool, mut callback: F)
176where
177 F: FnMut(Range<usize>, Result<u32, EscapeError>),
178{
179 let mut chars = src.chars();
180 while let Some(c) = chars.next() {
184 let start = src.len() - chars.as_str().len() - c.len_utf8();
185 let res = match c {
186 '\\' => match chars.clone().next() {
187 Some('\r') if chars.clone().nth(1) == Some('\n') => {
188 skip_ascii_whitespace(&mut chars, start + 2, &mut callback);
190 continue;
191 }
192 Some('\n') => {
193 skip_ascii_whitespace(&mut chars, start + 1, &mut callback);
195 continue;
196 }
197 _ => scan_escape(&mut chars),
198 },
199 '\n' => Err(EscapeError::StrNewline),
200 '\r' => {
201 if chars.clone().next() == Some('\n') {
202 continue;
203 }
204 Err(EscapeError::BareCarriageReturn)
205 }
206 c if !is_unicode && !c.is_ascii() => Err(EscapeError::StrNonAsciiChar),
207 c => Ok(c as u32),
208 };
209 let end = src.len() - chars.as_str().len();
210 callback(start..end, res);
211 }
212}
213
214fn skip_ascii_whitespace<F>(chars: &mut Chars<'_>, mut start: usize, callback: &mut F)
218where
219 F: FnMut(Range<usize>, Result<u32, EscapeError>),
220{
221 let mut nl = chars.next();
223 if let Some('\r') = nl {
224 nl = chars.next();
225 }
226 debug_assert_eq!(nl, Some('\n'));
227 let mut tail = chars.as_str();
228 start += 1;
229
230 while tail.starts_with(|c: char| c.is_ascii_whitespace()) {
231 let first_non_space =
232 tail.bytes().position(|b| !matches!(b, b' ' | b'\t')).unwrap_or(tail.len());
233 tail = &tail[first_non_space..];
234 start += first_non_space;
235
236 if let Some(tail2) = tail.strip_prefix('\n').or_else(|| tail.strip_prefix("\r\n")) {
237 let skipped = tail.len() - tail2.len();
238 tail = tail2;
239 callback(start..start + skipped, Err(EscapeError::CannotSkipMultipleLines));
240 start += skipped;
241 }
242 }
243 *chars = tail.chars();
244}
245
246fn unescape_hex_str<F>(src: &str, mut callback: F)
250where
251 F: FnMut(Range<usize>, Result<u32, EscapeError>),
252{
253 let mut chars = src.char_indices();
254 if src.starts_with("0x") || src.starts_with("0X") {
255 chars.next();
256 chars.next();
257 callback(0..2, Err(EscapeError::HexPrefix));
258 }
259
260 let count = chars.clone().filter(|(_, c)| c.is_ascii_hexdigit()).count();
261 if count % 2 != 0 {
262 callback(0..src.len(), Err(EscapeError::HexOddDigits));
263 return;
264 }
265
266 let mut emit_underscore_errors = true;
267 let mut allow_underscore = false;
268 let mut even = true;
269 for (start, c) in chars {
270 let res = match c {
271 '_' => {
272 if emit_underscore_errors && (!allow_underscore || !even) {
273 emit_underscore_errors = false;
275 Err(EscapeError::HexBadUnderscore)
276 } else {
277 allow_underscore = false;
278 continue;
279 }
280 }
281 c if !c.is_ascii_hexdigit() => Err(EscapeError::HexNotHexDigit),
282 c => Ok(c as u32),
283 };
284
285 if res.is_ok() {
286 even = !even;
287 allow_underscore = true;
288 }
289
290 let end = start + c.len_utf8();
291 callback(start..end, res);
292 }
293
294 if emit_underscore_errors && src.len() > 1 && src.ends_with('_') {
295 callback(src.len() - 1..src.len(), Err(EscapeError::HexBadUnderscore));
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use EscapeError::*;
303
304 type ExErr = (Range<usize>, EscapeError);
305
306 fn check(mode: Mode, src: &str, expected_str: &str, expected_errs: &[ExErr]) {
307 let panic_str = format!("{mode:?}: {src:?}");
308
309 let mut ok = String::with_capacity(src.len());
310 let mut errs = Vec::with_capacity(expected_errs.len());
311 unescape_literal(src, mode, |range, c| match c {
312 Ok(c) => ok.push(char::try_from(c).unwrap()),
313 Err(e) => errs.push((range, e)),
314 });
315 assert_eq!(errs, expected_errs, "{panic_str}");
316 assert_eq!(ok, expected_str, "{panic_str}");
317
318 let mut errs2 = Vec::with_capacity(errs.len());
319 let out = try_parse_string_literal(src, mode, |range, e| {
320 errs2.push((range, e));
321 });
322 assert_eq!(errs2, errs, "{panic_str}");
323 if mode == Mode::HexStr {
324 assert_eq!(hex::encode(out), expected_str, "{panic_str}");
325 } else {
326 assert_eq!(hex::encode(out), hex::encode(expected_str), "{panic_str}");
327 }
328 }
329
330 #[test]
331 fn unescape_str() {
332 let cases: &[(&str, &str, &[ExErr])] = &[
333 ("", "", &[]),
334 (" ", " ", &[]),
335 ("\t", "\t", &[]),
336 (" \t ", " \t ", &[]),
337 ("foo", "foo", &[]),
338 ("hello world", "hello world", &[]),
339 (r"\", "", &[(0..1, LoneSlash)]),
340 (r"\\", "\\", &[]),
341 (r"\\\", "\\", &[(2..3, LoneSlash)]),
342 (r"\\\\", "\\\\", &[]),
343 (r"\\ ", "\\ ", &[]),
344 (r"\\ \", "\\ ", &[(3..4, LoneSlash)]),
345 (r"\\ \\", "\\ \\", &[]),
346 (r"\x", "", &[(0..2, HexEscapeTooShort)]),
347 (r"\x1", "", &[(0..3, HexEscapeTooShort)]),
348 (r"\xz", "", &[(0..3, InvalidHexEscape)]),
349 (r"\xzf", "f", &[(0..3, InvalidHexEscape)]),
350 (r"\xzz", "z", &[(0..3, InvalidHexEscape)]),
351 (r"\x69", "\x69", &[]),
352 (r"\xE8", "รจ", &[]),
353 (r"\u", "", &[(0..2, UnicodeEscapeTooShort)]),
354 (r"\u1", "", &[(0..3, UnicodeEscapeTooShort)]),
355 (r"\uz", "", &[(0..3, InvalidUnicodeEscape)]),
356 (r"\uzf", "f", &[(0..3, InvalidUnicodeEscape)]),
357 (r"\u12", "", &[(0..4, UnicodeEscapeTooShort)]),
358 (r"\u123", "", &[(0..5, UnicodeEscapeTooShort)]),
359 (r"\u1234", "\u{1234}", &[]),
360 (r"\u00e8", "รจ", &[]),
361 (r"\r", "\r", &[]),
362 (r"\t", "\t", &[]),
363 (r"\n", "\n", &[]),
364 (r"\n\n", "\n\n", &[]),
365 (r"\ ", "", &[(0..2, InvalidEscape)]),
366 (r"\?", "", &[(0..2, InvalidEscape)]),
367 ("\r\n", "", &[(1..2, StrNewline)]),
368 ("\n", "", &[(0..1, StrNewline)]),
369 ("\\\n", "", &[]),
370 ("\\\na", "a", &[]),
371 ("\\\n a", "a", &[]),
372 ("a \\\n b", "a b", &[]),
373 ("a\\n\\\n b", "a\nb", &[]),
374 ("a\\t\\\n b", "a\tb", &[]),
375 ("a\\n \\\n b", "a\n b", &[]),
376 ("a\\n \\\n \tb", "a\n b", &[]),
377 ("a\\t \\\n b", "a\t b", &[]),
378 ("\\\n \t a", "a", &[]),
379 (" \\\n \t a", " a", &[]),
380 ("\\\n \t a\n", "a", &[(6..7, StrNewline)]),
381 ("\\\n \t ", "", &[]),
382 (" \\\n \t ", " ", &[]),
383 (" he\\\n \\\nllo \\\n wor\\\nld", " hello world", &[]),
384 ("\\\n\na\\\nb", "ab", &[(2..3, CannotSkipMultipleLines)]),
385 ("\\\n \na\\\nb", "ab", &[(3..4, CannotSkipMultipleLines)]),
386 (
387 "\\\n \n\na\\\nb",
388 "ab",
389 &[(3..4, CannotSkipMultipleLines), (4..5, CannotSkipMultipleLines)],
390 ),
391 (
392 "a\\\n \n \t \nb\\\nc",
393 "abc",
394 &[(4..5, CannotSkipMultipleLines), (8..9, CannotSkipMultipleLines)],
395 ),
396 ];
397 for &(src, expected_str, expected_errs) in cases {
398 check(Mode::Str, src, expected_str, expected_errs);
399 check(Mode::UnicodeStr, src, expected_str, expected_errs);
400 }
401 }
402
403 #[test]
404 fn unescape_unicode_str() {
405 let cases: &[(&str, &str, &[ExErr], &[ExErr])] = &[
406 ("รจ", "รจ", &[], &[(0..2, StrNonAsciiChar)]),
407 ("๐", "๐", &[], &[(0..4, StrNonAsciiChar)]),
408 ];
409 for &(src, expected_str, e1, e2) in cases {
410 check(Mode::UnicodeStr, src, expected_str, e1);
411 check(Mode::Str, src, "", e2);
412 }
413 }
414
415 #[test]
416 fn unescape_hex_str() {
417 let cases: &[(&str, &str, &[ExErr])] = &[
418 ("", "", &[]),
419 ("z", "", &[(0..1, HexNotHexDigit)]),
420 ("\n", "", &[(0..1, HexNotHexDigit)]),
421 (" 11", "11", &[(0..1, HexNotHexDigit), (1..2, HexNotHexDigit)]),
422 ("0x", "", &[(0..2, HexPrefix)]),
423 ("0X", "", &[(0..2, HexPrefix)]),
424 ("0x11", "11", &[(0..2, HexPrefix)]),
425 ("0X11", "11", &[(0..2, HexPrefix)]),
426 ("1", "", &[(0..1, HexOddDigits)]),
427 ("12", "12", &[]),
428 ("123", "", &[(0..3, HexOddDigits)]),
429 ("1234", "1234", &[]),
430 ("_", "", &[(0..1, HexBadUnderscore)]),
431 ("_11", "11", &[(0..1, HexBadUnderscore)]),
432 ("_11_", "11", &[(0..1, HexBadUnderscore)]),
433 ("11_", "11", &[(2..3, HexBadUnderscore)]),
434 ("11_22", "1122", &[]),
435 ("11__", "11", &[(3..4, HexBadUnderscore)]),
436 ("11__22", "1122", &[(3..4, HexBadUnderscore)]),
437 ("1_2", "12", &[(1..2, HexBadUnderscore)]),
438 ];
439 for &(src, expected_str, expected_errs) in cases {
440 check(Mode::HexStr, src, expected_str, expected_errs);
441 }
442 }
443}