1use alloy_primitives::hex;
4use solar_ast::Span;
5use solar_data_structures::trustme;
6use solar_interface::{BytePos, Session};
7use std::{borrow::Cow, ops::Range, slice, str::Chars};
8
9mod errors;
10pub use errors::EscapeError;
11pub(crate) use errors::emit_unescape_error;
12
13pub use solar_ast::StrKind;
14
15pub fn parse_string_literal<'a>(
16 src: &'a str,
17 kind: StrKind,
18 lit_span: Span,
19 sess: &Session,
20) -> (Cow<'a, [u8]>, bool) {
21 let mut has_error = false;
22 let content_start = lit_span.lo() + BytePos(1) + BytePos(kind.prefix().len() as u32);
23 let bytes = try_parse_string_literal(src, kind, |range, error| {
24 has_error = true;
25 let (start, end) = (range.start as u32, range.end as u32);
26 let lo = content_start + BytePos(start);
27 let hi = lo + BytePos(end - start);
28 let span = Span::new(lo, hi);
29 emit_unescape_error(&sess.dcx, src, span, range, error);
30 });
31 (bytes, has_error)
32}
33
34#[instrument(name = "parse_string_literal", level = "trace", skip_all)]
37pub fn try_parse_string_literal<F>(src: &str, kind: StrKind, f: F) -> Cow<'_, [u8]>
38where
39 F: FnMut(Range<usize>, EscapeError),
40{
41 let mut bytes = if needs_unescape(src, kind) {
42 Cow::Owned(parse_literal_unescape(src, kind, f))
43 } else {
44 Cow::Borrowed(src.as_bytes())
45 };
46 if kind == StrKind::Hex {
47 if let Ok(decoded) = hex::decode(&bytes) {
49 bytes = Cow::Owned(decoded);
50 }
51 }
52 bytes
53}
54
55#[cold]
56fn parse_literal_unescape<F>(src: &str, kind: StrKind, f: F) -> Vec<u8>
57where
58 F: FnMut(Range<usize>, EscapeError),
59{
60 let mut bytes = Vec::with_capacity(src.len());
61 parse_literal_unescape_into(src, kind, f, &mut bytes);
62 bytes
63}
64
65fn parse_literal_unescape_into<F>(src: &str, kind: StrKind, mut f: F, dst_buf: &mut Vec<u8>)
66where
67 F: FnMut(Range<usize>, EscapeError),
68{
69 debug_assert!(dst_buf.is_empty());
72 debug_assert!(dst_buf.capacity() >= src.len());
73 let mut dst = unsafe { slice::from_raw_parts_mut(dst_buf.as_mut_ptr(), dst_buf.capacity()) };
74 unescape_literal_unchecked(src, kind, |range, res| match res {
75 Ok(c) => {
76 let written = super::utf8::encode_utf8_raw(c, dst).len();
78
79 debug_assert!(dst.len() >= written);
82 let advanced = unsafe { dst.get_unchecked_mut(written..) };
83
84 dst = unsafe { trustme::decouple_lt_mut(advanced) };
86 }
87 Err(e) => f(range, e),
88 });
89 unsafe { dst_buf.set_len(dst_buf.capacity() - dst.len()) };
90}
91
92#[instrument(level = "debug", skip_all)]
96pub fn unescape_literal<F>(src: &str, kind: StrKind, mut callback: F)
97where
98 F: FnMut(Range<usize>, Result<u32, EscapeError>),
99{
100 if needs_unescape(src, kind) {
101 unescape_literal_unchecked(src, kind, callback)
102 } else {
103 for (i, ch) in src.char_indices() {
104 callback(i..i + ch.len_utf8(), Ok(ch as u32));
105 }
106 }
107}
108
109fn unescape_literal_unchecked<F>(src: &str, kind: StrKind, callback: F)
113where
114 F: FnMut(Range<usize>, Result<u32, EscapeError>),
115{
116 match kind {
117 StrKind::Str | StrKind::Unicode => {
118 unescape_str(src, matches!(kind, StrKind::Unicode), callback)
119 }
120 StrKind::Hex => unescape_hex_str(src, callback),
121 }
122}
123
124fn needs_unescape(src: &str, kind: StrKind) -> bool {
127 fn needs_unescape_chars(src: &str) -> bool {
128 memchr::memchr3(b'\\', b'\n', b'\r', src.as_bytes()).is_some()
129 }
130
131 match kind {
132 StrKind::Str => needs_unescape_chars(src) || !src.is_ascii(),
133 StrKind::Unicode => needs_unescape_chars(src),
134 StrKind::Hex => !src.len().is_multiple_of(2) || !hex::check_raw(src),
135 }
136}
137
138fn scan_escape(chars: &mut Chars<'_>) -> Result<u32, EscapeError> {
139 Ok(match chars.next().ok_or(EscapeError::LoneSlash)? {
144 '\'' => '\'' as u32,
146 '"' => '"' as u32,
147
148 '\\' => '\\' as u32,
149 'n' => '\n' as u32,
150 'r' => '\r' as u32,
151 't' => '\t' as u32,
152
153 'x' => {
154 let mut value = 0;
156 for _ in 0..2 {
157 let d = chars.next().ok_or(EscapeError::HexEscapeTooShort)?;
158 let d = d.to_digit(16).ok_or(EscapeError::InvalidHexEscape)?;
159 value = value * 16 + d;
160 }
161 value
162 }
163
164 'u' => {
165 let mut value = 0;
167 for _ in 0..4 {
168 let d = chars.next().ok_or(EscapeError::UnicodeEscapeTooShort)?;
169 let d = d.to_digit(16).ok_or(EscapeError::InvalidUnicodeEscape)?;
170 value = value * 16 + d;
171 }
172 value
173 }
174
175 _ => return Err(EscapeError::InvalidEscape),
176 })
177}
178
179fn unescape_str<F>(src: &str, is_unicode: bool, mut callback: F)
183where
184 F: FnMut(Range<usize>, Result<u32, EscapeError>),
185{
186 let mut chars = src.chars();
187 while let Some(c) = chars.next() {
191 let start = src.len() - chars.as_str().len() - c.len_utf8();
192 let res = match c {
193 '\\' => match chars.clone().next() {
194 Some('\r') if chars.clone().nth(1) == Some('\n') => {
195 skip_ascii_whitespace(&mut chars, start + 2, &mut callback);
197 continue;
198 }
199 Some('\n') => {
200 skip_ascii_whitespace(&mut chars, start + 1, &mut callback);
202 continue;
203 }
204 _ => scan_escape(&mut chars),
205 },
206 '\n' => Err(EscapeError::StrNewline),
207 '\r' => {
208 if chars.clone().next() == Some('\n') {
209 continue;
210 }
211 Err(EscapeError::BareCarriageReturn)
212 }
213 c if !is_unicode && !c.is_ascii() => Err(EscapeError::StrNonAsciiChar),
214 c => Ok(c as u32),
215 };
216 let end = src.len() - chars.as_str().len();
217 callback(start..end, res);
218 }
219}
220
221fn skip_ascii_whitespace<F>(chars: &mut Chars<'_>, mut start: usize, callback: &mut F)
225where
226 F: FnMut(Range<usize>, Result<u32, EscapeError>),
227{
228 let mut nl = chars.next();
230 if let Some('\r') = nl {
231 nl = chars.next();
232 }
233 debug_assert_eq!(nl, Some('\n'));
234 let mut tail = chars.as_str();
235 start += 1;
236
237 while tail.starts_with(|c: char| c.is_ascii_whitespace()) {
238 let first_non_space =
239 tail.bytes().position(|b| !matches!(b, b' ' | b'\t')).unwrap_or(tail.len());
240 tail = &tail[first_non_space..];
241 start += first_non_space;
242
243 if let Some(tail2) = tail.strip_prefix('\n').or_else(|| tail.strip_prefix("\r\n")) {
244 let skipped = tail.len() - tail2.len();
245 tail = tail2;
246 callback(start..start + skipped, Err(EscapeError::CannotSkipMultipleLines));
247 start += skipped;
248 }
249 }
250 *chars = tail.chars();
251}
252
253fn unescape_hex_str<F>(src: &str, mut callback: F)
257where
258 F: FnMut(Range<usize>, Result<u32, EscapeError>),
259{
260 let mut chars = src.char_indices();
261 if src.starts_with("0x") || src.starts_with("0X") {
262 chars.next();
263 chars.next();
264 callback(0..2, Err(EscapeError::HexPrefix));
265 }
266
267 let count = chars.clone().filter(|(_, c)| c.is_ascii_hexdigit()).count();
268 if !count.is_multiple_of(2) {
269 callback(0..src.len(), Err(EscapeError::HexOddDigits));
270 return;
271 }
272
273 let mut emit_underscore_errors = true;
274 let mut allow_underscore = false;
275 let mut even = true;
276 for (start, c) in chars {
277 let res = match c {
278 '_' => {
279 if emit_underscore_errors && (!allow_underscore || !even) {
280 emit_underscore_errors = false;
282 Err(EscapeError::HexBadUnderscore)
283 } else {
284 allow_underscore = false;
285 continue;
286 }
287 }
288 c if !c.is_ascii_hexdigit() => Err(EscapeError::HexNotHexDigit),
289 c => Ok(c as u32),
290 };
291
292 if res.is_ok() {
293 even = !even;
294 allow_underscore = true;
295 }
296
297 let end = start + c.len_utf8();
298 callback(start..end, res);
299 }
300
301 if emit_underscore_errors && src.len() > 1 && src.ends_with('_') {
302 callback(src.len() - 1..src.len(), Err(EscapeError::HexBadUnderscore));
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309 use EscapeError::*;
310
311 type ExErr = (Range<usize>, EscapeError);
312
313 fn check(kind: StrKind, src: &str, expected_str: &str, expected_errs: &[ExErr]) {
314 let panic_str = format!("{kind:?}: {src:?}");
315
316 let mut ok = String::with_capacity(src.len());
317 let mut errs = Vec::with_capacity(expected_errs.len());
318 unescape_literal(src, kind, |range, c| match c {
319 Ok(c) => ok.push(char::try_from(c).unwrap()),
320 Err(e) => errs.push((range, e)),
321 });
322 assert_eq!(errs, expected_errs, "{panic_str}");
323 assert_eq!(ok, expected_str, "{panic_str}");
324
325 let mut errs2 = Vec::with_capacity(errs.len());
326 let out = try_parse_string_literal(src, kind, |range, e| {
327 errs2.push((range, e));
328 });
329 assert_eq!(errs2, errs, "{panic_str}");
330 if kind == StrKind::Hex {
331 assert_eq!(hex::encode(out), expected_str, "{panic_str}");
332 } else {
333 assert_eq!(hex::encode(out), hex::encode(expected_str), "{panic_str}");
334 }
335 }
336
337 #[test]
338 fn unescape_str() {
339 let cases: &[(&str, &str, &[ExErr])] = &[
340 ("", "", &[]),
341 (" ", " ", &[]),
342 ("\t", "\t", &[]),
343 (" \t ", " \t ", &[]),
344 ("foo", "foo", &[]),
345 ("hello world", "hello world", &[]),
346 (r"\", "", &[(0..1, LoneSlash)]),
347 (r"\\", "\\", &[]),
348 (r"\\\", "\\", &[(2..3, LoneSlash)]),
349 (r"\\\\", "\\\\", &[]),
350 (r"\\ ", "\\ ", &[]),
351 (r"\\ \", "\\ ", &[(3..4, LoneSlash)]),
352 (r"\\ \\", "\\ \\", &[]),
353 (r"\x", "", &[(0..2, HexEscapeTooShort)]),
354 (r"\x1", "", &[(0..3, HexEscapeTooShort)]),
355 (r"\xz", "", &[(0..3, InvalidHexEscape)]),
356 (r"\xzf", "f", &[(0..3, InvalidHexEscape)]),
357 (r"\xzz", "z", &[(0..3, InvalidHexEscape)]),
358 (r"\x69", "\x69", &[]),
359 (r"\xE8", "è", &[]),
360 (r"\u", "", &[(0..2, UnicodeEscapeTooShort)]),
361 (r"\u1", "", &[(0..3, UnicodeEscapeTooShort)]),
362 (r"\uz", "", &[(0..3, InvalidUnicodeEscape)]),
363 (r"\uzf", "f", &[(0..3, InvalidUnicodeEscape)]),
364 (r"\u12", "", &[(0..4, UnicodeEscapeTooShort)]),
365 (r"\u123", "", &[(0..5, UnicodeEscapeTooShort)]),
366 (r"\u1234", "\u{1234}", &[]),
367 (r"\u00e8", "è", &[]),
368 (r"\r", "\r", &[]),
369 (r"\t", "\t", &[]),
370 (r"\n", "\n", &[]),
371 (r"\n\n", "\n\n", &[]),
372 (r"\ ", "", &[(0..2, InvalidEscape)]),
373 (r"\?", "", &[(0..2, InvalidEscape)]),
374 ("\r\n", "", &[(1..2, StrNewline)]),
375 ("\n", "", &[(0..1, StrNewline)]),
376 ("\\\n", "", &[]),
377 ("\\\na", "a", &[]),
378 ("\\\n a", "a", &[]),
379 ("a \\\n b", "a b", &[]),
380 ("a\\n\\\n b", "a\nb", &[]),
381 ("a\\t\\\n b", "a\tb", &[]),
382 ("a\\n \\\n b", "a\n b", &[]),
383 ("a\\n \\\n \tb", "a\n b", &[]),
384 ("a\\t \\\n b", "a\t b", &[]),
385 ("\\\n \t a", "a", &[]),
386 (" \\\n \t a", " a", &[]),
387 ("\\\n \t a\n", "a", &[(6..7, StrNewline)]),
388 ("\\\n \t ", "", &[]),
389 (" \\\n \t ", " ", &[]),
390 (" he\\\n \\\nllo \\\n wor\\\nld", " hello world", &[]),
391 ("\\\n\na\\\nb", "ab", &[(2..3, CannotSkipMultipleLines)]),
392 ("\\\n \na\\\nb", "ab", &[(3..4, CannotSkipMultipleLines)]),
393 (
394 "\\\n \n\na\\\nb",
395 "ab",
396 &[(3..4, CannotSkipMultipleLines), (4..5, CannotSkipMultipleLines)],
397 ),
398 (
399 "a\\\n \n \t \nb\\\nc",
400 "abc",
401 &[(4..5, CannotSkipMultipleLines), (8..9, CannotSkipMultipleLines)],
402 ),
403 ];
404 for &(src, expected_str, expected_errs) in cases {
405 check(StrKind::Str, src, expected_str, expected_errs);
406 check(StrKind::Unicode, src, expected_str, expected_errs);
407 }
408 }
409
410 #[test]
411 fn unescape_unicode_str() {
412 let cases: &[(&str, &str, &[ExErr], &[ExErr])] = &[
413 ("è", "è", &[], &[(0..2, StrNonAsciiChar)]),
414 ("😀", "😀", &[], &[(0..4, StrNonAsciiChar)]),
415 ];
416 for &(src, expected_str, e1, e2) in cases {
417 check(StrKind::Unicode, src, expected_str, e1);
418 check(StrKind::Str, src, "", e2);
419 }
420 }
421
422 #[test]
423 fn unescape_hex_str() {
424 let cases: &[(&str, &str, &[ExErr])] = &[
425 ("", "", &[]),
426 ("z", "", &[(0..1, HexNotHexDigit)]),
427 ("\n", "", &[(0..1, HexNotHexDigit)]),
428 (" 11", "11", &[(0..1, HexNotHexDigit), (1..2, HexNotHexDigit)]),
429 ("0x", "", &[(0..2, HexPrefix)]),
430 ("0X", "", &[(0..2, HexPrefix)]),
431 ("0x11", "11", &[(0..2, HexPrefix)]),
432 ("0X11", "11", &[(0..2, HexPrefix)]),
433 ("1", "", &[(0..1, HexOddDigits)]),
434 ("12", "12", &[]),
435 ("123", "", &[(0..3, HexOddDigits)]),
436 ("1234", "1234", &[]),
437 ("_", "", &[(0..1, HexBadUnderscore)]),
438 ("_11", "11", &[(0..1, HexBadUnderscore)]),
439 ("_11_", "11", &[(0..1, HexBadUnderscore)]),
440 ("11_", "11", &[(2..3, HexBadUnderscore)]),
441 ("11_22", "1122", &[]),
442 ("11__", "11", &[(3..4, HexBadUnderscore)]),
443 ("11__22", "1122", &[(3..4, HexBadUnderscore)]),
444 ("1_2", "12", &[(1..2, HexBadUnderscore)]),
445 ];
446 for &(src, expected_str, expected_errs) in cases {
447 check(StrKind::Hex, src, expected_str, expected_errs);
448 }
449 }
450}