1use std::{
4 fmt::{Debug, Display},
5 io::Write,
6};
7
8use crate::{
9 escape::{comment_escape, content_escape},
10 lut::{is_invalid_attribute_name, is_invalid_name},
11 reader::{
12 self, AttributeEvent, AttributeQuote, CDataEvent, CommentEvent, DoctypeEvent, TextEvent,
13 },
14};
15
16#[non_exhaustive]
17#[derive(Default, Clone)]
18pub struct Options {
20 pub omit_comments: bool,
22}
23
24pub struct Writer<W: Write> {
26 writer: W,
27 options: Options,
28 depth_and_flags: u32,
29}
30
31pub enum Error {
36 InvalidElementPrefix,
38 InvalidElementName,
40 InvalidAttributeName,
42 InvalidAttributeValue,
44 AttributeOutsideTag,
46 ImproperlyEscaped,
48 InvalidCData,
50 InvalidValue,
52 Io(std::io::Error),
54}
55
56impl std::error::Error for Error {
57 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
58 match self {
59 Error::Io(error) => Some(error),
60 _ => None,
61 }
62 }
63}
64
65impl Debug for Error {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 <Self as Display>::fmt(self, f)
68 }
69}
70
71impl Display for Error {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 f.write_str(match self {
74 Error::InvalidElementPrefix => "invalid element prefix",
75 Error::InvalidElementName => "invalid element name",
76 Error::InvalidAttributeName => "invalid attribute name",
77 Error::InvalidAttributeValue => "invalid attribute value",
78 Error::AttributeOutsideTag => "attributes are only allowed inside tags",
79 Error::ImproperlyEscaped => "improperly escaped content",
80 Error::InvalidCData => "cdata content cannot contain `]]>`",
81 Error::InvalidValue => "value contains null byte",
82 Error::Io(error) => return <std::io::Error as Display>::fmt(error, f),
83 })
84 }
85}
86
87impl From<std::io::Error> for Error {
88 fn from(value: std::io::Error) -> Self {
89 Self::Io(value)
90 }
91}
92
93impl<W: Write> Writer<W> {
94 #[inline]
96 pub fn new(writer: W) -> Self {
97 Self::with_options(writer, Options::default())
98 }
99
100 #[inline]
102 pub fn with_options(writer: W, options: Options) -> Self {
103 Self {
104 writer,
105 options,
106 depth_and_flags: 0,
107 }
108 }
109
110 fn in_empty_tag(&self) -> bool {
111 self.depth_and_flags & 0b10 > 0
112 }
113
114 fn ensure_tag_closed(&mut self) -> Result<(), std::io::Error> {
115 if self.depth_and_flags & 1 > 0 {
116 if self.in_empty_tag() {
117 self.writer.write_all(b"/>")?;
118 self.depth_and_flags += 0b001;
119 } else {
120 self.writer.write_all(b">")?;
121 self.depth_and_flags += 0b011;
122 }
123 }
124
125 Ok(())
126 }
127
128 pub fn write_start(&mut self, prefix: Option<&str>, name: &str) -> Result<(), Error> {
134 if prefix.is_some_and(|pfx| pfx.bytes().any(is_invalid_name)) {
135 return Err(Error::InvalidElementPrefix);
136 }
137
138 if name.bytes().any(is_invalid_name) {
139 return Err(Error::InvalidElementName);
140 }
141
142 self.ensure_tag_closed()?;
143
144 self.depth_and_flags += 0b1;
145 self.writer.write_all(b"<")?;
147 if let Some(prefix) = prefix {
148 self.writer.write_all(prefix.as_bytes())?;
149 self.writer.write_all(b":")?;
150 }
151 self.writer.write_all(name.as_bytes())?;
152
153 Ok(())
154 }
155
156 pub fn write_empty(&mut self, prefix: Option<&str>, name: &str) -> Result<(), Error> {
162 if name.bytes().any(is_invalid_name) {
163 return Err(Error::InvalidElementName);
164 }
165
166 self.ensure_tag_closed()?;
167
168 self.depth_and_flags += 0b11;
169 self.writer.write_all(b"<")?;
171 if let Some(prefix) = prefix {
172 self.writer.write_all(prefix.as_bytes())?;
173 self.writer.write_all(b":")?;
174 }
175 self.writer.write_all(name.as_bytes())?;
176
177 Ok(())
178 }
179
180 pub fn write_raw_attribute(
191 &mut self,
192 name: &str,
193 quote: AttributeQuote,
194 value: &str,
195 ) -> Result<(), Error> {
196 if self.depth_and_flags & 1 == 0 {
197 return Err(Error::AttributeOutsideTag);
198 }
199
200 if name.bytes().any(is_invalid_attribute_name) {
201 return Err(Error::InvalidAttributeName);
202 }
203
204 let quote = quote as u8;
205 if name.bytes().any(|b| [b'\0', quote].contains(&b)) {
206 return Err(Error::InvalidAttributeValue);
207 }
208
209 self.writer.write_all(b" ")?;
210 self.writer.write_all(name.as_bytes())?;
211 self.writer.write_all(b"=")?;
212 self.writer.write_all(&[quote])?;
213 self.writer.write_all(value.as_bytes())?;
214 self.writer.write_all(&[quote])?;
215
216 Ok(())
217 }
218
219 pub fn write_attribute(&mut self, name: &str, value: &str) -> Result<(), Error> {
229 let escaped = content_escape(value);
230 self.write_raw_attribute(name, AttributeQuote::Double, &escaped)
231 }
232
233 pub fn write_end(&mut self, prefix: Option<&str>, name: &str) -> Result<(), Error> {
239 if prefix.is_some_and(|pfx| pfx.bytes().any(is_invalid_name)) {
240 return Err(Error::InvalidElementPrefix);
241 }
242
243 if name.bytes().any(is_invalid_name) {
244 return Err(Error::InvalidElementName);
245 }
246
247 self.ensure_tag_closed()?;
248
249 self.writer.write_all(b"</")?;
251 if let Some(prefix) = prefix {
252 self.writer.write_all(prefix.as_bytes())?;
253 self.writer.write_all(b":")?;
254 }
255 self.writer.write_all(name.as_bytes())?;
256 self.writer.write_all(b">")?;
257
258 self.depth_and_flags -= 0b100;
259
260 Ok(())
261 }
262
263 fn write_raw_text_unchecked(&mut self, text: &str) -> std::io::Result<()> {
264 self.ensure_tag_closed()?;
265
266 self.writer.write_all(text.as_bytes())
267 }
268
269 pub fn write_raw_text(&mut self, text: &str) -> Result<(), Error> {
275 if let Some(idx) = memchr::memchr2(b'\0', b'<', text.as_bytes()) {
276 return Err(if text.as_bytes()[idx] == b'<' {
277 Error::ImproperlyEscaped
278 } else {
279 Error::InvalidValue
280 });
281 }
282
283 self.write_raw_text_unchecked(text).map_err(Into::into)
284 }
285
286 pub fn write_text(&mut self, content: &str) -> Result<(), Error> {
296 let escaped = content_escape(content);
297 self.write_raw_text_unchecked(&escaped).map_err(Into::into)
298 }
299
300 fn write_cdata_unchecked(&mut self, text: &str) -> std::io::Result<()> {
301 self.ensure_tag_closed()?;
302
303 self.writer.write_all(b"<![CDATA[")?;
304 self.writer.write_all(text.as_bytes())?;
305 self.writer.write_all(b"]]>")
306 }
307
308 pub fn write_cdata(&mut self, text: &str) -> Result<(), Error> {
318 if memchr::memmem::find(text.as_bytes(), b"]]>").is_some() {
319 return Err(Error::InvalidCData);
320 }
321
322 self.write_cdata_unchecked(text).map_err(Into::into)
323 }
324
325 fn write_raw_comment_unchecked(&mut self, text: &str) -> std::io::Result<()> {
326 self.ensure_tag_closed()?;
327
328 self.writer.write_all(b"<!--")?;
329 self.writer.write_all(text.as_bytes())?;
330 self.writer.write_all(b"-->")?;
331
332 Ok(())
333 }
334
335 pub fn write_raw_comment(&mut self, text: &str) -> Result<(), Error> {
345 if memchr::memmem::find(text.as_bytes(), b"-->").is_some() {
346 return Err(Error::ImproperlyEscaped);
347 }
348
349 if !self.options.omit_comments {
350 self.write_raw_comment_unchecked(text)?
351 }
352
353 Ok(())
354 }
355
356 pub fn write_comment(&mut self, content: &str) -> Result<(), Error> {
366 if !self.options.omit_comments {
367 let escaped = comment_escape(content);
368 self.write_raw_comment_unchecked(&escaped)?
369 }
370
371 Ok(())
372 }
373
374 pub fn write_attribute_event(&mut self, attr: &AttributeEvent) -> Result<(), Error> {
384 if self.depth_and_flags & 1 == 0 {
385 return Err(Error::AttributeOutsideTag);
386 }
387
388 self.writer.write_all(b" ")?;
389 self.writer.write_all(attr.name().as_bytes())?;
390 self.writer.write_all(b"=")?;
391 self.writer.write_all(&[attr.quote() as u8])?;
392 self.writer.write_all(attr.raw_value().as_bytes())?;
393 self.writer.write_all(&[attr.quote() as u8])?;
394
395 Ok(())
396 }
397
398 pub fn write_event(&mut self, event: &reader::Event) -> Result<(), Error> {
404 match event {
405 reader::Event::Start(start) | reader::Event::Empty(start) => {
406 if start.is_empty() {
407 self.write_empty(start.prefix(), start.name())?;
408 } else {
409 self.write_start(start.prefix(), start.name())?;
410 }
411
412 for attr in start.attributes() {
413 self.write_attribute_event(&attr)?;
414 }
415
416 Ok(())
417 }
418 reader::Event::End(end) => self.write_end(end.prefix(), end.name()),
419 &reader::Event::Comment(CommentEvent { text })
420 | &reader::Event::CData(CDataEvent { text })
421 | &reader::Event::Doctype(DoctypeEvent { text })
422 | &reader::Event::Text(TextEvent { text }) => {
423 self.ensure_tag_closed()?;
424
425 self.writer.write_all(text.as_bytes())?;
426
427 Ok(())
428 }
429 }
430 }
431
432 pub fn inner_ref(&self) -> &W {
434 &self.writer
435 }
436
437 pub fn inner_mut(&mut self) -> &mut W {
439 &mut self.writer
440 }
441
442 pub fn finish(mut self) -> std::io::Result<W> {
448 self.ensure_tag_closed()?;
449
450 Ok(self.writer)
451 }
452
453 pub fn flush(&mut self) -> std::io::Result<()> {
459 self.ensure_tag_closed()?;
460
461 self.writer.flush()
462 }
463}
464
465#[test]
466fn reader_writer_roundtrip() {
467 const CASES: &[&str] = &[
468 "hello world",
469 "<some xml='text'/>",
470 r#"more stuff<then a_tag="here">with content and <![CDATA[value]]></end>"#,
471 "text <!-- something with comments --> text text",
472 ];
473
474 for &input in CASES {
475 let mut writer = Writer::new(std::io::Cursor::new(Vec::new()));
476 let mut reader = reader::Reader::with_options(
477 input,
478 reader::Options::default().allow_top_level_text(true),
479 );
480
481 while let Some(event) = reader.next().transpose().unwrap() {
482 dbg!(event);
483 writer.write_event(&event).unwrap();
484 }
485
486 let result = writer.finish().unwrap().into_inner();
487
488 assert_eq!(std::str::from_utf8(&result).unwrap(), input)
489 }
490}