1use memchr::memchr;
7
8#[inline]
12pub fn escape(s: &str) -> std::borrow::Cow<'_, str> {
13 let bytes = s.as_bytes();
14
15 let needs_escape = bytes.iter().any(|&b| matches!(b, b'<' | b'>' | b'&' | b'"' | b'\''));
17
18 if !needs_escape {
19 return std::borrow::Cow::Borrowed(s);
20 }
21
22 let mut result = String::with_capacity(s.len() + s.len() / 8);
23 escape_to_inner(bytes, &mut result);
24 std::borrow::Cow::Owned(result)
25}
26
27#[inline]
29pub fn escape_to(s: &str, out: &mut String) {
30 escape_to_inner(s.as_bytes(), out);
31}
32
33#[inline(always)]
35fn escape_to_inner(bytes: &[u8], out: &mut String) {
36 let mut start = 0;
37
38 for (i, &byte) in bytes.iter().enumerate() {
39 let escaped = match byte {
40 b'<' => "<",
41 b'>' => ">",
42 b'&' => "&",
43 b'"' => """,
44 b'\'' => "'",
45 _ => continue,
46 };
47
48 if start < i {
50 out.push_str(unsafe { std::str::from_utf8_unchecked(&bytes[start..i]) });
52 }
53 out.push_str(escaped);
54 start = i + 1;
55 }
56
57 if start < bytes.len() {
59 out.push_str(unsafe { std::str::from_utf8_unchecked(&bytes[start..]) });
60 }
61}
62
63#[inline]
65pub fn escape_attr(s: &str) -> std::borrow::Cow<'_, str> {
66 escape(s)
67}
68
69#[inline]
73pub fn unescape(s: &str) -> Result<std::borrow::Cow<'_, str>, UnescapeError> {
74 let bytes = s.as_bytes();
75
76 match memchr(b'&', bytes) {
78 None => Ok(std::borrow::Cow::Borrowed(s)),
79 Some(first_amp) => {
80 let mut result = String::with_capacity(s.len());
81 if first_amp > 0 {
83 result.push_str(unsafe {
84 std::str::from_utf8_unchecked(&bytes[..first_amp])
85 });
86 }
87 unescape_from(bytes, first_amp, &mut result)?;
88 Ok(std::borrow::Cow::Owned(result))
89 }
90 }
91}
92
93#[derive(Debug, Clone, PartialEq, Eq)]
95pub struct UnescapeError {
96 pub entity: String,
98 pub position: usize,
100}
101
102impl std::fmt::Display for UnescapeError {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 write!(f, "invalid XML entity '{}' at position {}", self.entity, self.position)
105 }
106}
107
108impl std::error::Error for UnescapeError {}
109
110#[inline]
112pub fn unescape_to(s: &str, out: &mut String) -> Result<(), UnescapeError> {
113 let bytes = s.as_bytes();
114 match memchr(b'&', bytes) {
115 None => {
116 out.push_str(s);
117 Ok(())
118 }
119 Some(first_amp) => {
120 if first_amp > 0 {
121 out.push_str(unsafe {
122 std::str::from_utf8_unchecked(&bytes[..first_amp])
123 });
124 }
125 unescape_from(bytes, first_amp, out)
126 }
127 }
128}
129
130#[inline(always)]
132fn unescape_from(bytes: &[u8], start: usize, out: &mut String) -> Result<(), UnescapeError> {
133 let mut i = start;
134
135 while i < bytes.len() {
136 if bytes[i] == b'&' {
137 let entity_start = i;
138 i += 1;
139
140 match memchr(b';', &bytes[i..]) {
142 Some(len) if len > 0 && len <= 10 => {
143 let entity = unsafe {
144 std::str::from_utf8_unchecked(&bytes[i..i + len])
145 };
146
147 if let Some(c) = decode_entity_fast(entity) {
148 out.push(c);
149 i += len + 1;
150
151 if let Some(next_amp) = memchr(b'&', &bytes[i..]) {
153 if next_amp > 0 {
154 out.push_str(unsafe {
155 std::str::from_utf8_unchecked(&bytes[i..i + next_amp])
156 });
157 }
158 i += next_amp;
159 } else {
160 out.push_str(unsafe {
162 std::str::from_utf8_unchecked(&bytes[i..])
163 });
164 return Ok(());
165 }
166 } else {
167 return Err(UnescapeError {
168 entity: format!("&{};", entity),
169 position: entity_start,
170 });
171 }
172 }
173 _ => {
174 return Err(UnescapeError {
175 entity: String::from("&"),
176 position: entity_start,
177 });
178 }
179 }
180 } else {
181 i += 1;
182 }
183 }
184
185 Ok(())
186}
187
188#[inline(always)]
190fn decode_entity_fast(entity: &str) -> Option<char> {
191 match entity.len() {
193 2 => match entity {
194 "lt" => Some('<'),
195 "gt" => Some('>'),
196 _ => decode_numeric_entity(entity),
197 },
198 3 => match entity {
199 "amp" => Some('&'),
200 _ => decode_numeric_entity(entity),
201 },
202 4 => match entity {
203 "quot" => Some('"'),
204 "apos" => Some('\''),
205 _ => decode_numeric_entity(entity),
206 },
207 _ => decode_numeric_entity(entity),
208 }
209}
210
211#[inline]
213fn decode_numeric_entity(entity: &str) -> Option<char> {
214 let bytes = entity.as_bytes();
215 if bytes.is_empty() || bytes[0] != b'#' {
216 return None;
217 }
218
219 let (radix, digits) = if bytes.len() > 1 && (bytes[1] == b'x' || bytes[1] == b'X') {
220 (16, &entity[2..])
221 } else {
222 (10, &entity[1..])
223 };
224
225 if digits.is_empty() {
226 return None;
227 }
228
229 let code = u32::from_str_radix(digits, radix).ok()?;
230 char::from_u32(code)
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn test_escape_no_special_chars() {
239 let s = "Hello, World!";
240 let escaped = escape(s);
241 assert!(matches!(escaped, std::borrow::Cow::Borrowed(_)));
242 assert_eq!(escaped, s);
243 }
244
245 #[test]
246 fn test_escape_lt() {
247 assert_eq!(escape("<"), "<");
248 }
249
250 #[test]
251 fn test_escape_gt() {
252 assert_eq!(escape(">"), ">");
253 }
254
255 #[test]
256 fn test_escape_amp() {
257 assert_eq!(escape("&"), "&");
258 }
259
260 #[test]
261 fn test_escape_quot() {
262 assert_eq!(escape("\""), """);
263 }
264
265 #[test]
266 fn test_escape_apos() {
267 assert_eq!(escape("'"), "'");
268 }
269
270 #[test]
271 fn test_escape_mixed() {
272 assert_eq!(
273 escape("<div class=\"foo\">Hello & goodbye</div>"),
274 "<div class="foo">Hello & goodbye</div>"
275 );
276 }
277
278 #[test]
279 fn test_unescape_no_entities() {
280 let s = "Hello, World!";
281 let unescaped = unescape(s).unwrap();
282 assert!(matches!(unescaped, std::borrow::Cow::Borrowed(_)));
283 assert_eq!(unescaped, s);
284 }
285
286 #[test]
287 fn test_unescape_lt() {
288 assert_eq!(unescape("<").unwrap(), "<");
289 }
290
291 #[test]
292 fn test_unescape_gt() {
293 assert_eq!(unescape(">").unwrap(), ">");
294 }
295
296 #[test]
297 fn test_unescape_amp() {
298 assert_eq!(unescape("&").unwrap(), "&");
299 }
300
301 #[test]
302 fn test_unescape_quot() {
303 assert_eq!(unescape(""").unwrap(), "\"");
304 }
305
306 #[test]
307 fn test_unescape_apos() {
308 assert_eq!(unescape("'").unwrap(), "'");
309 }
310
311 #[test]
312 fn test_unescape_mixed() {
313 assert_eq!(
314 unescape("<div class="foo">Hello & goodbye</div>").unwrap(),
315 "<div class=\"foo\">Hello & goodbye</div>"
316 );
317 }
318
319 #[test]
320 fn test_unescape_numeric_decimal() {
321 assert_eq!(unescape("A").unwrap(), "A");
322 assert_eq!(unescape("a").unwrap(), "a");
323 assert_eq!(unescape("€").unwrap(), "€");
324 }
325
326 #[test]
327 fn test_unescape_numeric_hex() {
328 assert_eq!(unescape("A").unwrap(), "A");
329 assert_eq!(unescape("a").unwrap(), "a");
330 assert_eq!(unescape("€").unwrap(), "€");
331 }
332
333 #[test]
334 fn test_unescape_invalid_entity() {
335 let result = unescape("&invalid;");
336 assert!(result.is_err());
337 let err = result.unwrap_err();
338 assert_eq!(err.entity, "&invalid;");
339 assert_eq!(err.position, 0);
340 }
341
342 #[test]
343 fn test_unescape_unterminated_entity() {
344 let result = unescape("<");
345 assert!(result.is_err());
346 }
347
348 #[test]
349 fn test_escape_to() {
350 let mut out = String::new();
351 escape_to("<test>", &mut out);
352 assert_eq!(out, "<test>");
353 }
354
355 #[test]
356 fn test_roundtrip() {
357 let original = "<div class=\"foo\">Hello & goodbye</div>";
358 let escaped = escape(original);
359 let unescaped = unescape(&escaped).unwrap();
360 assert_eq!(unescaped, original);
361 }
362}