1use std::fmt;
2use std::ops::{Range, RangeInclusive};
3
4pub enum UnicodeEscapeKind {
5 Extended,
6 Short,
7}
8
9impl UnicodeEscapeKind {
10 fn count(&self) -> u32 {
11 match self {
12 UnicodeEscapeKind::Extended => 6,
13 UnicodeEscapeKind::Short => 4,
14 }
15 }
16}
17
18pub enum UnicodeEscError {
19 InvalidEscape,
20 InvalidSurrogatePair,
21 OutOfRange,
22 RequiresHexDigits {
23 kind: UnicodeEscapeKind,
24 escape_char: char,
25 },
26}
27
28impl fmt::Display for UnicodeEscError {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 match self {
31 Self::InvalidEscape => f.write_str("Invalid Unicode escape sequence"),
32 Self::InvalidSurrogatePair => f.write_str("Invalid Unicode surrogate pair"),
33 Self::OutOfRange => f.write_str("Unicode escape value out of range"),
34 Self::RequiresHexDigits { kind, escape_char } => {
35 let required = kind.count();
36 let plus = match kind {
37 UnicodeEscapeKind::Extended => "+",
38 UnicodeEscapeKind::Short => "",
39 };
40 let xs = "X".repeat(required as usize);
41 write!(
42 f,
43 "Unicode escape requires {required} hex digits: {escape_char}{plus}{xs}"
44 )
45 }
46 }
47 }
48}
49
50pub fn escape_unicode_esc_str<F>(text: &str, escape_char: char, mut callback: F)
51where
52 F: FnMut(Range<usize>, Result<char, UnicodeEscError>),
53{
54 const HIGH_SURROGATE: RangeInclusive<u32> = 0xD800..=0xDBFF;
55 const LOW_SURROGATE: RangeInclusive<u32> = 0xDC00..=0xDFFF;
56 const MAX_CODEPOINT: u32 = 0x10FFFF;
57
58 let mut chars = text.char_indices().peekable();
59 let mut high_surrogate: Option<(Range<usize>, u32)> = None;
60
61 while let Some((escape_start, c)) = chars.next() {
62 if c != escape_char {
63 if let Some((hi_range, _)) = high_surrogate.take() {
64 callback(hi_range, Err(UnicodeEscError::InvalidSurrogatePair));
65 }
66 callback(escape_start..escape_start + c.len_utf8(), Ok(c));
67 continue;
68 }
69 let kind = match chars.peek() {
70 Some(&(_, c)) if c == escape_char => {
71 chars.next();
72 if let Some((hi_range, _)) = high_surrogate.take() {
73 callback(hi_range, Err(UnicodeEscError::InvalidSurrogatePair));
74 }
75 let end = escape_start + escape_char.len_utf8() * 2;
76 callback(escape_start..end, Ok(escape_char));
77 continue;
78 }
79 Some(&(_, '+')) => {
80 chars.next();
81 UnicodeEscapeKind::Extended
82 }
83 Some(&(_, c)) if c.is_ascii_hexdigit() => UnicodeEscapeKind::Short,
84 _ => {
85 let end = chars
86 .next()
87 .map(|(i, c)| i + c.len_utf8())
88 .unwrap_or(text.len());
89 if let Some((hi_range, _)) = high_surrogate.take() {
90 callback(hi_range, Err(UnicodeEscError::InvalidSurrogatePair));
91 }
92 callback(escape_start..end, Err(UnicodeEscError::InvalidEscape));
93 continue;
94 }
95 };
96 let mut codepoint: u32 = 0;
97 let mut got_all = true;
98 let mut last_end = chars.peek().map(|&(i, _)| i).unwrap_or(text.len());
99 for _ in 0..kind.count() {
100 let radix = 16;
101 let Some(&(i, ch)) = chars.peek() else {
102 got_all = false;
103 break;
104 };
105 let Some(d) = ch.to_digit(radix) else {
106 got_all = false;
107 break;
108 };
109 chars.next();
110 codepoint = codepoint * radix + d;
111 last_end = i + ch.len_utf8();
112 }
113 if !got_all {
114 if let Some((hi_range, _)) = high_surrogate.take() {
115 callback(hi_range, Err(UnicodeEscError::InvalidSurrogatePair));
116 }
117 callback(
118 escape_start..last_end,
119 Err(UnicodeEscError::RequiresHexDigits { kind, escape_char }),
120 );
121 continue;
122 }
123 if let Some((hi_range, hi_cp)) = high_surrogate.take() {
124 if LOW_SURROGATE.contains(&codepoint) {
125 let combined = 0x10000 + ((hi_cp - 0xD800) << 10) + (codepoint - 0xDC00);
126 let ch = char::from_u32(combined).unwrap();
127 callback(hi_range.start..last_end, Ok(ch));
128 continue;
129 }
130 callback(
131 hi_range.start..last_end,
132 Err(UnicodeEscError::InvalidSurrogatePair),
133 );
134 continue;
135 }
136 if codepoint > MAX_CODEPOINT {
137 callback(escape_start..last_end, Err(UnicodeEscError::OutOfRange));
138 } else if HIGH_SURROGATE.contains(&codepoint) {
139 high_surrogate = Some((escape_start..last_end, codepoint));
140 } else if LOW_SURROGATE.contains(&codepoint) {
141 callback(
142 escape_start..last_end,
143 Err(UnicodeEscError::InvalidSurrogatePair),
144 );
145 } else {
146 let ch = char::from_u32(codepoint).unwrap();
147 callback(escape_start..last_end, Ok(ch));
148 }
149 }
150 if let Some((range, _)) = high_surrogate {
151 callback(range, Err(UnicodeEscError::InvalidSurrogatePair));
152 }
153}
154
155const fn is_valid_uescape_char(byte: u8) -> bool {
157 !byte.is_ascii_hexdigit()
158 && byte != b'+'
159 && byte != b'\''
160 && byte != b'"'
161 && !matches!(
162 byte,
163 b' ' | b'\t' | b'\n' | b'\r' | 0x0B | 0x0C
164 )
165}
166
167pub fn uescape_char(text: &str) -> Option<char> {
168 let inner = text.strip_prefix('\'')?.strip_suffix('\'')?;
169 let &[byte] = inner.as_bytes() else {
170 return None;
171 };
172 is_valid_uescape_char(byte).then(|| char::from(byte))
173}
174
175pub fn decode_plain_string(inner: &str, out: &mut String) {
176 let mut chars = inner.chars().peekable();
177 while let Some(c) = chars.next() {
178 if c == '\'' && chars.peek() == Some(&'\'') {
179 chars.next();
180 out.push('\'');
181 } else {
182 out.push(c);
183 }
184 }
185}
186
187fn push_char_bytes(c: char, bytes: &mut Vec<u8>) {
188 let mut buf = [0; 4];
189 let encoded = c.encode_utf8(&mut buf);
190 bytes.extend_from_slice(encoded.as_bytes());
191}
192
193pub fn decode_esc_string(inner: &str, out: &mut String) {
194 let mut chars = inner.chars().peekable();
195 let mut bytes = vec![];
196
197 while let Some(c) = chars.next() {
198 if c == '\'' && chars.peek() == Some(&'\'') {
199 chars.next();
200 bytes.push(b'\'');
201 continue;
202 }
203 if c != '\\' {
204 push_char_bytes(c, &mut bytes);
205 continue;
206 }
207 let Some(&next) = chars.peek() else {
208 bytes.push(b'\\');
209 break;
210 };
211 match next {
212 'b' => {
213 chars.next();
214 bytes.push(b'\x08');
215 }
216 'f' => {
217 chars.next();
218 bytes.push(b'\x0C');
219 }
220 'n' => {
221 chars.next();
222 bytes.push(b'\n');
223 }
224 'r' => {
225 chars.next();
226 bytes.push(b'\r');
227 }
228 't' => {
229 chars.next();
230 bytes.push(b'\t');
231 }
232 '0'..='7' => {
233 let mut value: u32 = 0;
234 for _ in 0..3 {
235 match chars.peek() {
236 Some(&d) if ('0'..='7').contains(&d) => {
237 chars.next();
238 value = value * 8 + d.to_digit(8).unwrap();
239 }
240 _ => break,
241 }
242 }
243 if value != 0 {
244 bytes.push(value as u8);
245 }
246 }
247 'x' => {
248 chars.next();
249 let mut value: u8 = 0;
250 let mut got_any = false;
251 for _ in 0..2 {
252 match chars.peek() {
253 Some(&d) if d.is_ascii_hexdigit() => {
254 chars.next();
255 value = value * 16 + d.to_digit(16).unwrap() as u8;
256 got_any = true;
257 }
258 _ => break,
259 }
260 }
261 if got_any {
262 if value != 0 {
263 bytes.push(value);
264 }
265 } else {
266 bytes.push(b'x');
267 }
268 }
269 'u' | 'U' => {
270 chars.next();
271 let required = if next == 'u' { 4 } else { 8 };
272 let mut value: u32 = 0;
273 let mut got_all = true;
274 for _ in 0..required {
275 match chars.peek() {
276 Some(&d) if d.is_ascii_hexdigit() => {
277 chars.next();
278 value = value * 16 + d.to_digit(16).unwrap();
279 }
280 _ => {
281 got_all = false;
282 break;
283 }
284 }
285 }
286 if got_all
287 && let Some(ch) = char::from_u32(value)
288 && ch != '\0'
289 {
290 push_char_bytes(ch, &mut bytes);
291 }
292 }
293 _ => {
294 chars.next();
295 push_char_bytes(next, &mut bytes);
296 }
297 }
298 }
299
300 out.push_str(&String::from_utf8_lossy(&bytes));
301}
302
303pub fn decode_unicode_esc_string(inner: &str, escape_char: char, out: &mut String) {
304 let inner = inner.replace("''", "'");
305 escape_unicode_esc_str(&inner, escape_char, |_range, result| {
306 if let Ok(ch) = result {
307 out.push(ch);
308 }
309 });
310}
311
312#[cfg(test)]
313mod tests {
314 use insta::assert_snapshot;
315
316 use super::*;
317
318 fn unicode_escape_events(text: &str, escape_char: char) -> String {
319 let mut events = vec![];
320
321 escape_unicode_esc_str(text, escape_char, |range, result| {
322 let entry = match result {
323 Ok(ch) => format!("{}..{} ok {ch:?}", range.start, range.end),
324 Err(err) => format!("{}..{} err {err}", range.start, range.end),
325 };
326 events.push(entry);
327 });
328
329 events.join("\n")
330 }
331
332 fn decode_escape_string(inner: &str) -> String {
333 let mut out = String::new();
334 decode_esc_string(inner, &mut out);
335 out
336 }
337
338 fn decode_unicode_escape_string(inner: &str, escape_char: char) -> String {
339 let mut out = String::new();
340 decode_unicode_esc_string(inner, escape_char, &mut out);
341 out
342 }
343
344 #[test]
345 fn ok() {
346 assert_snapshot!(unicode_escape_events(r"hello world", '\\'), @"
347 0..1 ok 'h'
348 1..2 ok 'e'
349 2..3 ok 'l'
350 3..4 ok 'l'
351 4..5 ok 'o'
352 5..6 ok ' '
353 6..7 ok 'w'
354 7..8 ok 'o'
355 8..9 ok 'r'
356 9..10 ok 'l'
357 10..11 ok 'd'
358 ");
359 }
360
361 #[test]
362 fn incomplete_unicode_escape_breaks_surrogate_pairing() {
363 assert_snapshot!(unicode_escape_events(r"\D800\006\DC00", '\\'), @r"
364 0..5 err Invalid Unicode surrogate pair
365 5..9 err Unicode escape requires 4 hex digits: \XXXX
366 9..14 err Invalid Unicode surrogate pair
367 ");
368 }
369
370 #[test]
371 fn invalid_unicode_escape_breaks_surrogate_pairing() {
372 assert_snapshot!(unicode_escape_events(r"\D800\Q\DC00", '\\'), @r"
373 0..5 err Invalid Unicode surrogate pair
374 5..7 err Invalid Unicode escape sequence
375 7..12 err Invalid Unicode surrogate pair
376 ");
377 }
378
379 #[test]
380 fn invalid_unicode_escape_does_not_emit_literal_char() {
381 assert_snapshot!(unicode_escape_events(r"\0061\Q\0062", '\\'), @r"
382 0..5 ok 'a'
383 5..7 err Invalid Unicode escape sequence
384 7..12 ok 'b'
385 ");
386 }
387
388 #[test]
389 fn invalid_unicode_escape_works_with_custom_escape_char() {
390 assert_snapshot!(unicode_escape_events("!0061!Q!0062", '!'), @r"
391 0..5 ok 'a'
392 5..7 err Invalid Unicode escape sequence
393 7..12 ok 'b'
394 ");
395 }
396
397 #[test]
398 fn valid_unicode_escape_after_high_surrogate_only_emits_error() {
399 assert_snapshot!(unicode_escape_events(r"\D800\0061", '\\'), @r"
400 0..10 err Invalid Unicode surrogate pair
401 ");
402 }
403
404 #[test]
405 fn decode_escape_string_hex_bytes_as_utf8() {
406 assert_snapshot!(decode_escape_string(r"\xC3\xA9"), @"é");
407 }
408
409 #[test]
410 fn decode_escape_string_skips_nul_byte() {
411 assert_snapshot!(decode_escape_string(r"a\000b"), @"ab");
412 }
413
414 #[test]
415 fn decode_unicode_string_collapses_doubled_quotes() {
416 assert_snapshot!(decode_unicode_escape_string("a''b", '\\'), @"a'b");
417 }
418}