Skip to main content

rustpython_literal/
escape.rs

1use alloc::string::String;
2use rustpython_wtf8::{CodePoint, Wtf8};
3
4#[derive(Debug, PartialEq, Eq, Copy, Clone, Hash, is_macro::Is)]
5pub enum Quote {
6    Single,
7    Double,
8}
9
10impl Quote {
11    #[inline]
12    pub const fn swap(self) -> Self {
13        match self {
14            Self::Single => Self::Double,
15            Self::Double => Self::Single,
16        }
17    }
18
19    #[inline]
20    pub const fn to_byte(&self) -> u8 {
21        match self {
22            Self::Single => b'\'',
23            Self::Double => b'"',
24        }
25    }
26
27    #[inline]
28    pub const fn to_char(&self) -> char {
29        match self {
30            Self::Single => '\'',
31            Self::Double => '"',
32        }
33    }
34}
35
36pub struct EscapeLayout {
37    pub quote: Quote,
38    pub len: Option<usize>,
39}
40
41/// Represents string types that can be escape-printed.
42///
43/// # Safety
44///
45/// `source_len` and `layout` must be accurate, and `layout.len` must not be equal
46/// to `Some(source_len)` if the string contains non-printable characters.
47pub unsafe trait Escape {
48    fn source_len(&self) -> usize;
49    fn layout(&self) -> &EscapeLayout;
50    fn changed(&self) -> bool {
51        self.layout().len != Some(self.source_len())
52    }
53
54    /// Write the body of the string directly to the formatter.
55    ///
56    /// # Safety
57    ///
58    /// This string must only contain printable characters.
59    unsafe fn write_source(&self, formatter: &mut impl core::fmt::Write) -> core::fmt::Result;
60    fn write_body_slow(&self, formatter: &mut impl core::fmt::Write) -> core::fmt::Result;
61    fn write_body(&self, formatter: &mut impl core::fmt::Write) -> core::fmt::Result {
62        if self.changed() {
63            self.write_body_slow(formatter)
64        } else {
65            // SAFETY: verified the string contains only printable characters.
66            unsafe { self.write_source(formatter) }
67        }
68    }
69}
70
71/// Returns the outer quotes to use and the number of quotes that need to be
72/// escaped.
73pub(crate) const fn choose_quote(
74    single_count: usize,
75    double_count: usize,
76    preferred_quote: Quote,
77) -> (Quote, usize) {
78    let (primary_count, secondary_count) = match preferred_quote {
79        Quote::Single => (single_count, double_count),
80        Quote::Double => (double_count, single_count),
81    };
82
83    // always use primary unless we have primary but no secondary
84    let use_secondary = primary_count > 0 && secondary_count == 0;
85    if use_secondary {
86        (preferred_quote.swap(), secondary_count)
87    } else {
88        (preferred_quote, primary_count)
89    }
90}
91
92pub struct UnicodeEscape<'a> {
93    source: &'a Wtf8,
94    layout: EscapeLayout,
95}
96
97impl<'a> UnicodeEscape<'a> {
98    #[inline]
99    pub const fn with_forced_quote(source: &'a Wtf8, quote: Quote) -> Self {
100        let layout = EscapeLayout { quote, len: None };
101        Self { source, layout }
102    }
103    #[inline]
104    pub fn with_preferred_quote(source: &'a Wtf8, quote: Quote) -> Self {
105        let layout = Self::repr_layout(source, quote);
106        Self { source, layout }
107    }
108    #[inline]
109    pub fn new_repr(source: &'a Wtf8) -> Self {
110        Self::with_preferred_quote(source, Quote::Single)
111    }
112    #[inline]
113    pub const fn str_repr<'r>(&'a self) -> StrRepr<'r, 'a> {
114        StrRepr(self)
115    }
116}
117
118pub struct StrRepr<'r, 'a>(&'r UnicodeEscape<'a>);
119
120impl StrRepr<'_, '_> {
121    pub fn write(&self, formatter: &mut impl core::fmt::Write) -> core::fmt::Result {
122        let quote = self.0.layout().quote.to_char();
123        formatter.write_char(quote)?;
124        self.0.write_body(formatter)?;
125        formatter.write_char(quote)
126    }
127
128    pub fn to_string(&self) -> Option<String> {
129        let mut s = String::with_capacity(self.0.layout().len?);
130        self.write(&mut s).unwrap();
131        Some(s)
132    }
133}
134
135impl core::fmt::Display for StrRepr<'_, '_> {
136    fn fmt(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
137        self.write(formatter)
138    }
139}
140
141impl UnicodeEscape<'_> {
142    const REPR_RESERVED_LEN: usize = 2; // for quotes
143
144    pub fn repr_layout(source: &Wtf8, preferred_quote: Quote) -> EscapeLayout {
145        Self::output_layout_with_checker(source, preferred_quote, |a, b| {
146            Some((a as isize).checked_add(b as isize)? as usize)
147        })
148    }
149
150    fn output_layout_with_checker(
151        source: &Wtf8,
152        preferred_quote: Quote,
153        length_add: impl Fn(usize, usize) -> Option<usize>,
154    ) -> EscapeLayout {
155        let mut out_len = Self::REPR_RESERVED_LEN;
156        let mut single_count = 0;
157        let mut double_count = 0;
158
159        for ch in source.code_points() {
160            let incr = match ch.to_char() {
161                Some('\'') => {
162                    single_count += 1;
163                    1
164                }
165                Some('"') => {
166                    double_count += 1;
167                    1
168                }
169                _ => Self::escaped_char_len(ch),
170            };
171            let Some(new_len) = length_add(out_len, incr) else {
172                #[cold]
173                const fn stop(
174                    single_count: usize,
175                    double_count: usize,
176                    preferred_quote: Quote,
177                ) -> EscapeLayout {
178                    EscapeLayout {
179                        quote: choose_quote(single_count, double_count, preferred_quote).0,
180                        len: None,
181                    }
182                }
183                return stop(single_count, double_count, preferred_quote);
184            };
185            out_len = new_len;
186        }
187
188        let (quote, num_escaped_quotes) = choose_quote(single_count, double_count, preferred_quote);
189        // we'll be adding backslashes in front of the existing inner quotes
190        let Some(out_len) = length_add(out_len, num_escaped_quotes) else {
191            return EscapeLayout { quote, len: None };
192        };
193
194        EscapeLayout {
195            quote,
196            len: Some(out_len - Self::REPR_RESERVED_LEN),
197        }
198    }
199
200    fn escaped_char_len(ch: CodePoint) -> usize {
201        // surrogates are \uHHHH
202        let Some(ch) = ch.to_char() else { return 6 };
203        match ch {
204            '\\' | '\t' | '\r' | '\n' => 2,
205            ch if ch < ' ' || ch as u32 == 0x7f => 4, // \xHH
206            ch if ch.is_ascii() => 1,
207            ch if crate::char::is_printable(ch) => {
208                // max = std::cmp::max(ch, max);
209                ch.len_utf8()
210            }
211            ch if (ch as u32) < 0x100 => 4,   // \xHH
212            ch if (ch as u32) < 0x10000 => 6, // \uHHHH
213            _ => 10,                          // \uHHHHHHHH
214        }
215    }
216
217    fn write_char(
218        ch: CodePoint,
219        quote: Quote,
220        formatter: &mut impl core::fmt::Write,
221    ) -> core::fmt::Result {
222        let Some(ch) = ch.to_char() else {
223            return write!(formatter, "\\u{:04x}", ch.to_u32());
224        };
225        match ch {
226            '\n' => formatter.write_str("\\n"),
227            '\t' => formatter.write_str("\\t"),
228            '\r' => formatter.write_str("\\r"),
229            // these 2 branches *would* be handled below, but we shouldn't have to do a
230            // unicodedata lookup just for ascii characters
231            '\x20'..='\x7e' => {
232                // printable ascii range
233                if ch == quote.to_char() || ch == '\\' {
234                    formatter.write_char('\\')?;
235                }
236                formatter.write_char(ch)
237            }
238            ch if ch.is_ascii() => {
239                write!(formatter, "\\x{:02x}", ch as u8)
240            }
241            ch if crate::char::is_printable(ch) => formatter.write_char(ch),
242            '\0'..='\u{ff}' => {
243                write!(formatter, "\\x{:02x}", ch as u32)
244            }
245            '\0'..='\u{ffff}' => {
246                write!(formatter, "\\u{:04x}", ch as u32)
247            }
248            _ => {
249                write!(formatter, "\\U{:08x}", ch as u32)
250            }
251        }
252    }
253}
254
255unsafe impl Escape for UnicodeEscape<'_> {
256    fn source_len(&self) -> usize {
257        self.source.len()
258    }
259
260    fn layout(&self) -> &EscapeLayout {
261        &self.layout
262    }
263
264    unsafe fn write_source(&self, formatter: &mut impl core::fmt::Write) -> core::fmt::Result {
265        formatter.write_str(unsafe {
266            // SAFETY: this function must be called only when source is printable characters (i.e. no surrogates)
267            core::str::from_utf8_unchecked(self.source.as_bytes())
268        })
269    }
270
271    #[cold]
272    fn write_body_slow(&self, formatter: &mut impl core::fmt::Write) -> core::fmt::Result {
273        for ch in self.source.code_points() {
274            Self::write_char(ch, self.layout().quote, formatter)?;
275        }
276        Ok(())
277    }
278}
279
280pub struct AsciiEscape<'a> {
281    source: &'a [u8],
282    layout: EscapeLayout,
283}
284
285impl<'a> AsciiEscape<'a> {
286    #[inline]
287    pub const fn new(source: &'a [u8], layout: EscapeLayout) -> Self {
288        Self { source, layout }
289    }
290    #[inline]
291    pub const fn with_forced_quote(source: &'a [u8], quote: Quote) -> Self {
292        let layout = EscapeLayout { quote, len: None };
293        Self { source, layout }
294    }
295    #[inline]
296    pub fn with_preferred_quote(source: &'a [u8], quote: Quote) -> Self {
297        let layout = Self::repr_layout(source, quote);
298        Self { source, layout }
299    }
300    #[inline]
301    pub fn new_repr(source: &'a [u8]) -> Self {
302        Self::with_preferred_quote(source, Quote::Single)
303    }
304    #[inline]
305    pub const fn bytes_repr<'r>(&'a self) -> BytesRepr<'r, 'a> {
306        BytesRepr(self)
307    }
308}
309
310impl AsciiEscape<'_> {
311    pub fn repr_layout(source: &[u8], preferred_quote: Quote) -> EscapeLayout {
312        Self::output_layout_with_checker(source, preferred_quote, 3, |a, b| {
313            Some((a as isize).checked_add(b as isize)? as usize)
314        })
315    }
316
317    pub fn named_repr_layout(source: &[u8], name: &str) -> EscapeLayout {
318        Self::output_layout_with_checker(source, Quote::Single, name.len() + 2 + 3, |a, b| {
319            Some((a as isize).checked_add(b as isize)? as usize)
320        })
321    }
322
323    fn output_layout_with_checker(
324        source: &[u8],
325        preferred_quote: Quote,
326        reserved_len: usize,
327        length_add: impl Fn(usize, usize) -> Option<usize>,
328    ) -> EscapeLayout {
329        let mut out_len = reserved_len;
330        let mut single_count = 0;
331        let mut double_count = 0;
332
333        for ch in source {
334            let incr = match ch {
335                b'\'' => {
336                    single_count += 1;
337                    1
338                }
339                b'"' => {
340                    double_count += 1;
341                    1
342                }
343                c => Self::escaped_char_len(*c),
344            };
345            let Some(new_len) = length_add(out_len, incr) else {
346                #[cold]
347                const fn stop(
348                    single_count: usize,
349                    double_count: usize,
350                    preferred_quote: Quote,
351                ) -> EscapeLayout {
352                    EscapeLayout {
353                        quote: choose_quote(single_count, double_count, preferred_quote).0,
354                        len: None,
355                    }
356                }
357                return stop(single_count, double_count, preferred_quote);
358            };
359            out_len = new_len;
360        }
361
362        let (quote, num_escaped_quotes) = choose_quote(single_count, double_count, preferred_quote);
363        // we'll be adding backslashes in front of the existing inner quotes
364        let Some(out_len) = length_add(out_len, num_escaped_quotes) else {
365            return EscapeLayout { quote, len: None };
366        };
367
368        EscapeLayout {
369            quote,
370            len: Some(out_len - reserved_len),
371        }
372    }
373
374    const fn escaped_char_len(ch: u8) -> usize {
375        match ch {
376            b'\\' | b'\t' | b'\r' | b'\n' => 2,
377            0x20..=0x7e => 1,
378            _ => 4, // \xHH
379        }
380    }
381
382    fn write_char(
383        ch: u8,
384        quote: Quote,
385        formatter: &mut impl core::fmt::Write,
386    ) -> core::fmt::Result {
387        match ch {
388            b'\t' => formatter.write_str("\\t"),
389            b'\n' => formatter.write_str("\\n"),
390            b'\r' => formatter.write_str("\\r"),
391            0x20..=0x7e => {
392                // printable ascii range
393                if ch == quote.to_byte() || ch == b'\\' {
394                    formatter.write_char('\\')?;
395                }
396                formatter.write_char(ch as char)
397            }
398            ch => write!(formatter, "\\x{ch:02x}"),
399        }
400    }
401}
402
403unsafe impl Escape for AsciiEscape<'_> {
404    fn source_len(&self) -> usize {
405        self.source.len()
406    }
407
408    fn layout(&self) -> &EscapeLayout {
409        &self.layout
410    }
411
412    unsafe fn write_source(&self, formatter: &mut impl core::fmt::Write) -> core::fmt::Result {
413        formatter.write_str(unsafe {
414            // SAFETY: this function must be called only when source is printable ascii characters
415            core::str::from_utf8_unchecked(self.source)
416        })
417    }
418
419    #[cold]
420    fn write_body_slow(&self, formatter: &mut impl core::fmt::Write) -> core::fmt::Result {
421        for ch in self.source {
422            Self::write_char(*ch, self.layout().quote, formatter)?;
423        }
424        Ok(())
425    }
426}
427
428pub struct BytesRepr<'r, 'a>(&'r AsciiEscape<'a>);
429
430impl BytesRepr<'_, '_> {
431    pub fn write(&self, formatter: &mut impl core::fmt::Write) -> core::fmt::Result {
432        let quote = self.0.layout().quote.to_char();
433        formatter.write_char('b')?;
434        formatter.write_char(quote)?;
435        self.0.write_body(formatter)?;
436        formatter.write_char(quote)
437    }
438
439    pub fn to_string(&self) -> Option<String> {
440        let mut s = String::with_capacity(self.0.layout().len?);
441        self.write(&mut s).unwrap();
442        Some(s)
443    }
444}
445
446impl core::fmt::Display for BytesRepr<'_, '_> {
447    fn fmt(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
448        self.write(formatter)
449    }
450}
451
452#[cfg(test)]
453mod unicode_escape_tests {
454    use super::*;
455
456    #[test]
457    fn changed() {
458        fn test(s: &str) -> bool {
459            UnicodeEscape::new_repr(s.as_ref()).changed()
460        }
461        assert!(!test("hello"));
462        assert!(!test("'hello'"));
463        assert!(!test("\"hello\""));
464
465        assert!(test("'\"hello"));
466        assert!(test("hello\n"));
467    }
468}