serde_xml/
escape.rs

1//! XML escape and unescape utilities.
2//!
3//! This module provides fast, allocation-minimizing functions for escaping
4//! and unescaping XML special characters.
5
6use memchr::memchr;
7
8/// Escapes XML special characters in a string.
9///
10/// Returns a `Cow<str>` to avoid allocation when no escaping is needed.
11#[inline]
12pub fn escape(s: &str) -> std::borrow::Cow<'_, str> {
13    let bytes = s.as_bytes();
14
15    // Fast path: scan for any character needing escape
16    let needs_escape = bytes.iter().any(|&b| matches!(b, b'<' | b'>' | b'&' | b'"' | b'\''));
17
18    if !needs_escape {
19        return std::borrow::Cow::Borrowed(s);
20    }
21
22    let mut result = String::with_capacity(s.len() + s.len() / 8);
23    escape_to_inner(bytes, &mut result);
24    std::borrow::Cow::Owned(result)
25}
26
27/// Escapes XML special characters and appends to the given string.
28#[inline]
29pub fn escape_to(s: &str, out: &mut String) {
30    escape_to_inner(s.as_bytes(), out);
31}
32
33/// Internal escape implementation - simple byte-by-byte with batching.
34#[inline(always)]
35fn escape_to_inner(bytes: &[u8], out: &mut String) {
36    let mut start = 0;
37
38    for (i, &byte) in bytes.iter().enumerate() {
39        let escaped = match byte {
40            b'<' => "&lt;",
41            b'>' => "&gt;",
42            b'&' => "&amp;",
43            b'"' => "&quot;",
44            b'\'' => "&apos;",
45            _ => continue,
46        };
47
48        // Batch append non-escaped bytes
49        if start < i {
50            // SAFETY: Only escaping ASCII chars, so UTF-8 boundaries are preserved
51            out.push_str(unsafe { std::str::from_utf8_unchecked(&bytes[start..i]) });
52        }
53        out.push_str(escaped);
54        start = i + 1;
55    }
56
57    // Append remaining
58    if start < bytes.len() {
59        out.push_str(unsafe { std::str::from_utf8_unchecked(&bytes[start..]) });
60    }
61}
62
63/// Escapes XML special characters for attribute values.
64#[inline]
65pub fn escape_attr(s: &str) -> std::borrow::Cow<'_, str> {
66    escape(s)
67}
68
69/// Unescapes XML entities in a string.
70///
71/// Returns a `Cow<str>` to avoid allocation when no unescaping is needed.
72#[inline]
73pub fn unescape(s: &str) -> Result<std::borrow::Cow<'_, str>, UnescapeError> {
74    let bytes = s.as_bytes();
75
76    // Fast path: check if any unescaping is needed using memchr
77    match memchr(b'&', bytes) {
78        None => Ok(std::borrow::Cow::Borrowed(s)),
79        Some(first_amp) => {
80            let mut result = String::with_capacity(s.len());
81            // Add everything before the first &
82            if first_amp > 0 {
83                result.push_str(unsafe {
84                    std::str::from_utf8_unchecked(&bytes[..first_amp])
85                });
86            }
87            unescape_from(bytes, first_amp, &mut result)?;
88            Ok(std::borrow::Cow::Owned(result))
89        }
90    }
91}
92
93/// Error type for unescape operations.
94#[derive(Debug, Clone, PartialEq, Eq)]
95pub struct UnescapeError {
96    /// The invalid entity that caused the error.
97    pub entity: String,
98    /// Position in the input where the error occurred.
99    pub position: usize,
100}
101
102impl std::fmt::Display for UnescapeError {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        write!(f, "invalid XML entity '{}' at position {}", self.entity, self.position)
105    }
106}
107
108impl std::error::Error for UnescapeError {}
109
110/// Unescapes XML entities and appends to the given string.
111#[inline]
112pub fn unescape_to(s: &str, out: &mut String) -> Result<(), UnescapeError> {
113    let bytes = s.as_bytes();
114    match memchr(b'&', bytes) {
115        None => {
116            out.push_str(s);
117            Ok(())
118        }
119        Some(first_amp) => {
120            if first_amp > 0 {
121                out.push_str(unsafe {
122                    std::str::from_utf8_unchecked(&bytes[..first_amp])
123                });
124            }
125            unescape_from(bytes, first_amp, out)
126        }
127    }
128}
129
130/// Internal unescape starting from a position known to have '&'.
131#[inline(always)]
132fn unescape_from(bytes: &[u8], start: usize, out: &mut String) -> Result<(), UnescapeError> {
133    let mut i = start;
134
135    while i < bytes.len() {
136        if bytes[i] == b'&' {
137            let entity_start = i;
138            i += 1;
139
140            // Find semicolon using memchr for speed
141            match memchr(b';', &bytes[i..]) {
142                Some(len) if len > 0 && len <= 10 => {
143                    let entity = unsafe {
144                        std::str::from_utf8_unchecked(&bytes[i..i + len])
145                    };
146
147                    if let Some(c) = decode_entity_fast(entity) {
148                        out.push(c);
149                        i += len + 1;
150
151                        // Find and append text until next &
152                        if let Some(next_amp) = memchr(b'&', &bytes[i..]) {
153                            if next_amp > 0 {
154                                out.push_str(unsafe {
155                                    std::str::from_utf8_unchecked(&bytes[i..i + next_amp])
156                                });
157                            }
158                            i += next_amp;
159                        } else {
160                            // No more entities
161                            out.push_str(unsafe {
162                                std::str::from_utf8_unchecked(&bytes[i..])
163                            });
164                            return Ok(());
165                        }
166                    } else {
167                        return Err(UnescapeError {
168                            entity: format!("&{};", entity),
169                            position: entity_start,
170                        });
171                    }
172                }
173                _ => {
174                    return Err(UnescapeError {
175                        entity: String::from("&"),
176                        position: entity_start,
177                    });
178                }
179            }
180        } else {
181            i += 1;
182        }
183    }
184
185    Ok(())
186}
187
188/// Fast entity decoder with common cases first.
189#[inline(always)]
190fn decode_entity_fast(entity: &str) -> Option<char> {
191    // Check length first to avoid string comparisons
192    match entity.len() {
193        2 => match entity {
194            "lt" => Some('<'),
195            "gt" => Some('>'),
196            _ => decode_numeric_entity(entity),
197        },
198        3 => match entity {
199            "amp" => Some('&'),
200            _ => decode_numeric_entity(entity),
201        },
202        4 => match entity {
203            "quot" => Some('"'),
204            "apos" => Some('\''),
205            _ => decode_numeric_entity(entity),
206        },
207        _ => decode_numeric_entity(entity),
208    }
209}
210
211/// Decodes a numeric character reference (&#NNN; or &#xHHH;).
212#[inline]
213fn decode_numeric_entity(entity: &str) -> Option<char> {
214    let bytes = entity.as_bytes();
215    if bytes.is_empty() || bytes[0] != b'#' {
216        return None;
217    }
218
219    let (radix, digits) = if bytes.len() > 1 && (bytes[1] == b'x' || bytes[1] == b'X') {
220        (16, &entity[2..])
221    } else {
222        (10, &entity[1..])
223    };
224
225    if digits.is_empty() {
226        return None;
227    }
228
229    let code = u32::from_str_radix(digits, radix).ok()?;
230    char::from_u32(code)
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn test_escape_no_special_chars() {
239        let s = "Hello, World!";
240        let escaped = escape(s);
241        assert!(matches!(escaped, std::borrow::Cow::Borrowed(_)));
242        assert_eq!(escaped, s);
243    }
244
245    #[test]
246    fn test_escape_lt() {
247        assert_eq!(escape("<"), "&lt;");
248    }
249
250    #[test]
251    fn test_escape_gt() {
252        assert_eq!(escape(">"), "&gt;");
253    }
254
255    #[test]
256    fn test_escape_amp() {
257        assert_eq!(escape("&"), "&amp;");
258    }
259
260    #[test]
261    fn test_escape_quot() {
262        assert_eq!(escape("\""), "&quot;");
263    }
264
265    #[test]
266    fn test_escape_apos() {
267        assert_eq!(escape("'"), "&apos;");
268    }
269
270    #[test]
271    fn test_escape_mixed() {
272        assert_eq!(
273            escape("<div class=\"foo\">Hello & goodbye</div>"),
274            "&lt;div class=&quot;foo&quot;&gt;Hello &amp; goodbye&lt;/div&gt;"
275        );
276    }
277
278    #[test]
279    fn test_unescape_no_entities() {
280        let s = "Hello, World!";
281        let unescaped = unescape(s).unwrap();
282        assert!(matches!(unescaped, std::borrow::Cow::Borrowed(_)));
283        assert_eq!(unescaped, s);
284    }
285
286    #[test]
287    fn test_unescape_lt() {
288        assert_eq!(unescape("&lt;").unwrap(), "<");
289    }
290
291    #[test]
292    fn test_unescape_gt() {
293        assert_eq!(unescape("&gt;").unwrap(), ">");
294    }
295
296    #[test]
297    fn test_unescape_amp() {
298        assert_eq!(unescape("&amp;").unwrap(), "&");
299    }
300
301    #[test]
302    fn test_unescape_quot() {
303        assert_eq!(unescape("&quot;").unwrap(), "\"");
304    }
305
306    #[test]
307    fn test_unescape_apos() {
308        assert_eq!(unescape("&apos;").unwrap(), "'");
309    }
310
311    #[test]
312    fn test_unescape_mixed() {
313        assert_eq!(
314            unescape("&lt;div class=&quot;foo&quot;&gt;Hello &amp; goodbye&lt;/div&gt;").unwrap(),
315            "<div class=\"foo\">Hello & goodbye</div>"
316        );
317    }
318
319    #[test]
320    fn test_unescape_numeric_decimal() {
321        assert_eq!(unescape("&#65;").unwrap(), "A");
322        assert_eq!(unescape("&#97;").unwrap(), "a");
323        assert_eq!(unescape("&#8364;").unwrap(), "€");
324    }
325
326    #[test]
327    fn test_unescape_numeric_hex() {
328        assert_eq!(unescape("&#x41;").unwrap(), "A");
329        assert_eq!(unescape("&#x61;").unwrap(), "a");
330        assert_eq!(unescape("&#x20AC;").unwrap(), "€");
331    }
332
333    #[test]
334    fn test_unescape_invalid_entity() {
335        let result = unescape("&invalid;");
336        assert!(result.is_err());
337        let err = result.unwrap_err();
338        assert_eq!(err.entity, "&invalid;");
339        assert_eq!(err.position, 0);
340    }
341
342    #[test]
343    fn test_unescape_unterminated_entity() {
344        let result = unescape("&lt");
345        assert!(result.is_err());
346    }
347
348    #[test]
349    fn test_escape_to() {
350        let mut out = String::new();
351        escape_to("<test>", &mut out);
352        assert_eq!(out, "&lt;test&gt;");
353    }
354
355    #[test]
356    fn test_roundtrip() {
357        let original = "<div class=\"foo\">Hello & goodbye</div>";
358        let escaped = escape(original);
359        let unescaped = unescape(&escaped).unwrap();
360        assert_eq!(unescaped, original);
361    }
362}