smtp_message/
data.rs

1use std::{
2    cmp,
3    io::{self, IoSlice, IoSliceMut},
4    ops::Range,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use futures::{pin_mut, AsyncRead, AsyncWrite, AsyncWriteExt};
10use pin_project::pin_project;
11
12// use crate::*;
13
14#[derive(Copy, Clone, Debug, Eq, PartialEq)]
15enum EscapedDataReaderState {
16    Start,
17    Cr,
18    CrLf,
19    CrLfDot,
20    CrLfDotCr,
21    End,
22    Completed,
23}
24
25/// `AsyncRead` instance that returns an unescaped `DATA` stream.
26///
27/// Note that:
28///  - If a line (as defined by b"\r\n" endings) starts with a b'.', it is an
29///    "escaping" dot that is not part of the actual contents of the line.
30///  - If a line is exactly b".\r\n", it is the last line of the stream this
31///    stream will give. It is not part of the actual contents of the message.
32#[pin_project]
33pub struct EscapedDataReader<'a, R> {
34    buf: &'a mut [u8],
35
36    // This should be another &'a mut [u8], but the issue described in [1] makes it not work
37    // [1] https://github.com/rust-lang/rust/issues/72477
38    unhandled: Range<usize>,
39
40    state: EscapedDataReaderState,
41
42    #[pin]
43    read: R,
44}
45
46impl<'a, R> EscapedDataReader<'a, R>
47where
48    R: AsyncRead,
49{
50    #[inline]
51    pub fn new(buf: &'a mut [u8], unhandled: Range<usize>, read: R) -> Self {
52        EscapedDataReader {
53            buf,
54            unhandled,
55            state: EscapedDataReaderState::CrLf,
56            read,
57        }
58    }
59
60    /// Returns `true` iff the message has been successfully streamed
61    /// to completion
62    #[inline]
63    pub fn is_finished(&self) -> bool {
64        self.state == EscapedDataReaderState::End || self.state == EscapedDataReaderState::Completed
65    }
66
67    /// Asserts that the full message has been read (ie.
68    /// [`.is_finished()`](EscapedDataReader::is_finished) would
69    /// return `true`), then marks this reader as complete.
70    ///
71    /// Note that this should be called before saving the stream,
72    /// given that until `.is_finished()` has returned `true` it's not
73    /// yet sure whether the stream ended due to connection loss or
74    /// thanks to the end of data marker being reached.
75    #[inline]
76    pub fn complete(&mut self) {
77        assert!(self.is_finished());
78        self.state = EscapedDataReaderState::Completed;
79    }
80
81    /// Returns the range of data in the `buf` passed to `new` that
82    /// contains data that hasn't been handled yet (ie. what followed
83    /// the end-of-data marker) if `complete()` has been called, and
84    /// `None` otherwise.
85    #[inline]
86    pub fn get_unhandled(&self) -> Option<Range<usize>> {
87        if self.state == EscapedDataReaderState::Completed {
88            Some(self.unhandled.clone())
89        } else {
90            None
91        }
92    }
93}
94
95impl<'a, R> AsyncRead for EscapedDataReader<'a, R>
96where
97    R: AsyncRead,
98{
99    fn poll_read(
100        self: Pin<&mut Self>,
101        cx: &mut Context,
102        buf: &mut [u8],
103    ) -> Poll<io::Result<usize>> {
104        self.poll_read_vectored(cx, &mut [IoSliceMut::new(buf)])
105    }
106
107    fn poll_read_vectored(
108        self: Pin<&mut Self>,
109        cx: &mut Context,
110        bufs: &mut [IoSliceMut],
111    ) -> Poll<io::Result<usize>> {
112        // If we have already finished, return early
113        if self.is_finished() {
114            return Poll::Ready(Ok(0));
115        }
116
117        let this = self.project();
118
119        // First, fill the bufs with incoming data
120        let raw_size = {
121            let unhandled_len_start = this.unhandled.end - this.unhandled.start;
122            if unhandled_len_start > 0 {
123                for buf in bufs.iter_mut() {
124                    let copy_len = cmp::min(buf.len(), this.unhandled.end - this.unhandled.start);
125                    let next_start = this.unhandled.start + copy_len;
126                    buf[..copy_len].copy_from_slice(&this.buf[this.unhandled.start..next_start]);
127                    this.unhandled.start = next_start;
128                }
129                unhandled_len_start - (this.unhandled.end - this.unhandled.start)
130            } else {
131                match this.read.poll_read_vectored(cx, bufs) {
132                    Poll::Ready(Ok(s)) => s,
133                    other => return other,
134                }
135            }
136        };
137
138        // If there was nothing to read, return early
139        if raw_size == 0 {
140            if bufs.iter().map(|b| b.len()).sum::<usize>() == 0 {
141                return Poll::Ready(Ok(0));
142            } else {
143                return Poll::Ready(Err(io::Error::new(
144                    io::ErrorKind::ConnectionAborted,
145                    "connection aborted without finishing the data stream",
146                )));
147            }
148        }
149
150        // Then, look for the end in the bufs
151        let mut size = 0;
152        for b in 0..bufs.len() {
153            for i in 0..cmp::min(bufs[b].len(), raw_size - size) {
154                use EscapedDataReaderState::*;
155                match (*this.state, bufs[b][i]) {
156                    (Cr, b'\n') => *this.state = CrLf,
157                    (CrLf, b'.') => *this.state = CrLfDot,
158                    (CrLfDot, b'\r') => *this.state = CrLfDotCr,
159                    (CrLfDotCr, b'\n') => {
160                        *this.state = End;
161                        size += i + 1;
162
163                        if this.unhandled.start == this.unhandled.end {
164                            // The data (most likely) comes from `this.read` -- or, at least, we
165                            // know that there can be nothing left in `this.unhandled`.
166                            let remaining = cmp::min(bufs[b].len() - (i + 1), raw_size - size);
167                            this.buf[..remaining]
168                                .copy_from_slice(&bufs[b][i + 1..i + 1 + remaining]);
169                            let mut copied = remaining;
170                            for buf in &bufs[b + 1..] {
171                                let remaining = cmp::min(buf.len(), raw_size - size - copied);
172                                this.buf[copied..copied + remaining]
173                                    .copy_from_slice(&buf[..remaining]);
174                                copied += remaining;
175                            }
176                            *this.unhandled = 0..copied;
177                        } else {
178                            // The data comes straight out of `this.unhandled`,
179                            // so let's just reuse it
180                            this.unhandled.start -= raw_size - size;
181                        }
182
183                        return Poll::Ready(Ok(size));
184                    }
185                    (_, b'\r') => *this.state = Cr,
186                    _ => *this.state = Start,
187                }
188            }
189            size += cmp::min(bufs[b].len(), raw_size - size);
190        }
191
192        // Didn't reach the end, let's return everything found
193        Poll::Ready(Ok(size))
194    }
195}
196
197pub struct DataUnescapeRes {
198    pub written: usize,
199    pub unhandled_idx: usize,
200}
201
202/// Helper struct to unescape a data stream.
203///
204/// Note that one unescaper should be used for a single data stream. Creating a
205/// `DataUnescaper` is basically free, and not creating a new one would probably
206/// lead to initial `\r\n` being handled incorrectly.
207pub struct DataUnescaper {
208    is_preceded_by_crlf: bool,
209}
210
211impl DataUnescaper {
212    /// Creates a `DataUnescaper`.
213    ///
214    /// The `is_preceded_by_crlf` argument is used to indicate whether, before
215    /// the first buffer that is fed into `unescape`, the unescaper should
216    /// assume that a `\r\n` was present.
217    ///
218    /// Usually, one will want to set `true` as an argument, as starting a
219    /// `DataUnescaper` mid-line is a rare use case.
220    pub fn new(is_preceded_by_crlf: bool) -> DataUnescaper {
221        DataUnescaper {
222            is_preceded_by_crlf,
223        }
224    }
225
226    /// Unescapes data coming from an [`EscapedDataReader`](EscapedDataReader).
227    ///
228    /// This takes a `data` argument. It will modify the `data` argument,
229    /// removing the escaping that could happen with it, and then returns a
230    /// [`DataUnescapeRes`](DataUnescapeRes).
231    ///
232    /// It is possible that the end of `data` does not land on a boundary that
233    /// allows yet to know whether data should be output or not. This is the
234    /// reason why this returns a [`DataUnescapeRes`](DataUnescapeRes). The
235    /// returned value will contain:
236    ///  - `.written`, which is the number of unescaped bytes that have been
237    ///    written in `data` — that is, `data[..res.written]` is the unescaped
238    ///    data, and
239    ///  - `.unhandled_idx`, which is the number of bytes at the end of `data`
240    ///    that could not be handled yet for lack of more information — that is,
241    ///    `data[res.unhandled_idx..]` is data that should be at the beginning
242    ///    of the next call to `data_unescape`.
243    ///
244    /// Note that the unhandled data's length is never going to be longer than 4
245    /// bytes long ("\r\n.\r", the longest sequence that can't be interpreted
246    /// yet), so it should not be an issue to just copy it to the next
247    /// buffer's start.
248    pub fn unescape(&mut self, data: &mut [u8]) -> DataUnescapeRes {
249        // TODO: this could be optimized by having a state machine we handle ourselves.
250        // Unfortunately, neither regex nor regex_automata provide tooling for
251        // noalloc replacements when the replacement is guaranteed to be shorter than
252        // the match
253
254        let mut written = 0;
255        let mut unhandled_idx = 0;
256
257        if self.is_preceded_by_crlf {
258            if data.len() <= 3 {
259                // Don't have enough information to know whether it's the end or just an escape.
260                // Maybe it's nothing special, but let's not make an effort to check it, as
261                // asking for 4-byte buffers should hopefully not be too much.
262                return DataUnescapeRes {
263                    written: 0,
264                    unhandled_idx: 0,
265                };
266            } else if data.starts_with(b".\r\n") {
267                // It is the end already
268                return DataUnescapeRes {
269                    written: 0,
270                    unhandled_idx: 3,
271                };
272            } else if data[0] == b'.' {
273                // It is just an escape, skip the dot
274                unhandled_idx += 1;
275            } else {
276                // It is nothing special, just go the regular path
277            }
278
279            self.is_preceded_by_crlf = false;
280        }
281
282        // First, look for "\r\n."
283        while let Some(i) = data[unhandled_idx..].windows(3).position(|s| s == b"\r\n.") {
284            if data.len() <= unhandled_idx + i + 4 {
285                // Don't have enough information to know whether it's the end or just an escape
286                if unhandled_idx != written {
287                    data.copy_within(unhandled_idx..unhandled_idx + i, written);
288                }
289                return DataUnescapeRes {
290                    written: written + i,
291                    unhandled_idx: unhandled_idx + i,
292                };
293            } else if &data[unhandled_idx + i + 3..unhandled_idx + i + 5] != b"\r\n" {
294                // It is just an escape
295                if unhandled_idx != written {
296                    data.copy_within(unhandled_idx..unhandled_idx + i + 2, written);
297                }
298                written += i + 2;
299                unhandled_idx += i + 3;
300            } else {
301                // It is the end
302                if unhandled_idx != written {
303                    data.copy_within(unhandled_idx..unhandled_idx + i + 2, written);
304                }
305                return DataUnescapeRes {
306                    written: written + i + 2,
307                    unhandled_idx: unhandled_idx + i + 5,
308                };
309            }
310        }
311
312        // There is no "\r\n." any longer, let's handle the remaining bytes by simply
313        // checking whether they end with something that needs handling.
314        if data.ends_with(b"\r\n") {
315            if unhandled_idx != written {
316                data.copy_within(unhandled_idx..data.len() - 2, written);
317            }
318            DataUnescapeRes {
319                written: written + data.len() - 2 - unhandled_idx,
320                unhandled_idx: data.len() - 2,
321            }
322        } else if data.ends_with(b"\r") {
323            if unhandled_idx != written {
324                data.copy_within(unhandled_idx..data.len() - 1, written);
325            }
326            DataUnescapeRes {
327                written: written + data.len() - 1 - unhandled_idx,
328                unhandled_idx: data.len() - 1,
329            }
330        } else {
331            if unhandled_idx != written {
332                data.copy_within(unhandled_idx..data.len(), written);
333            }
334            DataUnescapeRes {
335                written: written + data.len() - unhandled_idx,
336                unhandled_idx: data.len(),
337            }
338        }
339    }
340}
341
342#[derive(Clone, Copy)]
343enum EscapingDataWriterState {
344    Start,
345    Cr,
346    CrLf,
347}
348
349/// `AsyncWrite` instance that takes an unescaped `DATA` stream and
350/// escapes it.
351#[pin_project]
352pub struct EscapingDataWriter<W> {
353    state: EscapingDataWriterState,
354
355    #[pin]
356    write: W,
357}
358
359impl<W> EscapingDataWriter<W>
360where
361    W: AsyncWrite,
362{
363    #[inline]
364    pub fn new(write: W) -> Self {
365        EscapingDataWriter {
366            state: EscapingDataWriterState::CrLf,
367            write,
368        }
369    }
370
371    #[inline]
372    pub async fn finish(self) -> io::Result<()> {
373        let write = self.write;
374        pin_mut!(write);
375        match self.state {
376            EscapingDataWriterState::CrLf => write.write_all(b".\r\n").await,
377            _ => write.write_all(b"\r\n.\r\n").await,
378        }
379    }
380}
381
382impl<W> AsyncWrite for EscapingDataWriter<W>
383where
384    W: AsyncWrite,
385{
386    #[inline]
387    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
388        self.poll_write_vectored(cx, &[IoSlice::new(buf)])
389    }
390
391    #[inline]
392    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
393        self.project().write.poll_flush(cx)
394    }
395
396    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
397        Poll::Ready(Err(io::Error::new(
398            io::ErrorKind::Other,
399            "tried closing a stream during a message",
400        )))
401    }
402
403    fn poll_write_vectored(
404        self: Pin<&mut Self>,
405        cx: &mut Context,
406        bufs: &[IoSlice],
407    ) -> Poll<io::Result<usize>> {
408        fn set_state_until(state: &mut EscapingDataWriterState, bufs: &[IoSlice], n: usize) {
409            use EscapingDataWriterState::*;
410            let mut n = n;
411            for buf in bufs {
412                if n.saturating_sub(2) > buf.len() {
413                    n -= buf.len();
414                    *state = Start;
415                    continue;
416                }
417                for i in n.saturating_sub(2)..cmp::min(buf.len(), n) {
418                    n -= 1;
419                    match (*state, buf[i]) {
420                        (_, b'\r') => *state = Cr,
421                        (Cr, b'\n') => *state = CrLf,
422                        // We know that this function can't be called with an escape happening
423                        _ => *state = Start,
424                    }
425                }
426                if n == 0 {
427                    return;
428                }
429            }
430        }
431
432        let this = self.project();
433
434        let initial_state = *this.state;
435        for b in 0..bufs.len() {
436            for i in 0..bufs[b].len() {
437                use EscapingDataWriterState::*;
438                match (*this.state, bufs[b][i]) {
439                    (_, b'\r') => *this.state = Cr,
440                    (Cr, b'\n') => *this.state = CrLf,
441                    (CrLf, b'.') => {
442                        let mut v = Vec::with_capacity(b + 1);
443                        let mut writing = 0;
444                        for buf in &bufs[0..b] {
445                            v.push(IoSlice::new(buf));
446                            writing += buf.len();
447                        }
448                        v.push(IoSlice::new(&bufs[b][..=i]));
449                        writing += i + 1;
450                        return match this.write.poll_write_vectored(cx, &v) {
451                            Poll::Ready(Ok(s)) => {
452                                if s == writing {
453                                    *this.state = Start;
454                                    Poll::Ready(Ok(s - 1))
455                                } else {
456                                    *this.state = initial_state;
457                                    set_state_until(this.state, bufs, s);
458                                    Poll::Ready(Ok(s))
459                                }
460                            }
461                            o => o,
462                        };
463                    }
464                    _ => *this.state = Start,
465                }
466            }
467        }
468
469        match this.write.poll_write_vectored(cx, bufs) {
470            Poll::Ready(Ok(s)) => {
471                if s == bufs.iter().map(|b| b.len()).sum::<usize>() {
472                    Poll::Ready(Ok(s))
473                } else {
474                    *this.state = initial_state;
475                    set_state_until(this.state, bufs, s);
476                    Poll::Ready(Ok(s))
477                }
478            }
479            o => o,
480        }
481    }
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487    use crate::*;
488
489    use futures::{
490        executor,
491        io::{AsyncReadExt, Cursor},
492    };
493
494    // TODO: actually test the vectored version of the function
495    #[test]
496    fn escaped_data_reader() {
497        let tests: &[(&[&[u8]], &[u8], &[u8])] = &[
498            (
499                &[b"foo", b" bar", b"\r\n", b".\r", b"\n"],
500                b"foo bar\r\n.\r\n",
501                b"",
502            ),
503            (&[b"\r\n.\r\n", b"\r\n"], b"\r\n.\r\n", b"\r\n"),
504            (&[b".\r\n"], b".\r\n", b""),
505            (&[b".baz\r\n", b".\r\n", b"foo"], b".baz\r\n.\r\n", b"foo"),
506            (&[b" .baz", b"\r\n.", b"\r\nfoo"], b" .baz\r\n.\r\n", b"foo"),
507            (&[b".\r\n", b"MAIL FROM"], b".\r\n", b"MAIL FROM"),
508            (&[b"..\r\n.\r\n"], b"..\r\n.\r\n", b""),
509            (
510                &[b"foo\r\n. ", b"bar\r\n.\r\n"],
511                b"foo\r\n. bar\r\n.\r\n",
512                b"",
513            ),
514            (&[b".\r\nMAIL FROM"], b".\r\n", b"MAIL FROM"),
515            (&[b"..\r\n.\r\nMAIL FROM"], b"..\r\n.\r\n", b"MAIL FROM"),
516        ];
517        let mut surrounding_buf: [u8; 16] = [0; 16];
518        let mut enclosed_buf: [u8; 8] = [0; 8];
519        for (i, &(inp, out, rem)) in tests.iter().enumerate() {
520            println!(
521                "Trying to parse test {} into {:?} with {:?} remaining\n",
522                i,
523                show_bytes(out),
524                show_bytes(rem)
525            );
526
527            let mut reader = inp[1..].iter().map(Cursor::new).fold(
528                Box::pin(futures::io::empty()) as Pin<Box<dyn 'static + AsyncRead>>,
529                |a, b| Box::pin(AsyncReadExt::chain(a, b)),
530            );
531
532            surrounding_buf[..inp[0].len()].copy_from_slice(inp[0]);
533            let mut data_reader =
534                EscapedDataReader::new(&mut surrounding_buf, 0..inp[0].len(), reader.as_mut());
535
536            let mut res_out = Vec::<u8>::new();
537            while let Ok(r) = executor::block_on(data_reader.read(&mut enclosed_buf)) {
538                if r == 0 {
539                    break;
540                }
541                println!(
542                    "got out buf (size {}): {:?}",
543                    r,
544                    show_bytes(&enclosed_buf[..r])
545                );
546                res_out.extend_from_slice(&enclosed_buf[..r]);
547            }
548            data_reader.complete();
549            println!(
550                "total out is: {:?}, hoping for: {:?}",
551                show_bytes(&res_out),
552                show_bytes(out)
553            );
554            assert_eq!(&res_out[..], out);
555
556            let unhandled = data_reader.get_unhandled().unwrap();
557            let mut res_rem = Vec::<u8>::new();
558            res_rem.extend_from_slice(&surrounding_buf[unhandled]);
559
560            while let Ok(r) = executor::block_on(reader.read(&mut surrounding_buf)) {
561                if r == 0 {
562                    break;
563                }
564                println!("got rem buf: {:?}", show_bytes(&surrounding_buf[..r]));
565                res_rem.extend_from_slice(&surrounding_buf[0..r]);
566            }
567            println!(
568                "total rem is: {:?}, hoping for: {:?}",
569                show_bytes(&res_rem),
570                show_bytes(rem)
571            );
572            assert_eq!(&res_rem[..], rem);
573        }
574    }
575
576    #[test]
577    fn data_unescaper() {
578        let tests: &[(&[&[u8]], &[u8])] = &[
579            (&[b"foo", b" bar", b"\r\n", b".\r", b"\n"], b"foo bar\r\n"),
580            (&[b"\r\n.\r\n"], b"\r\n"),
581            (&[b".baz\r\n", b".\r\n"], b"baz\r\n"),
582            (&[b" .baz", b"\r\n.", b"\r\n"], b" .baz\r\n"),
583            (&[b".\r\n"], b""),
584            (&[b"..\r\n.\r\n"], b".\r\n"),
585            (&[b"foo\r\n. ", b"bar\r\n.\r\n"], b"foo\r\n bar\r\n"),
586            (&[b"\r\r\n.\r\n"], b"\r\r\n"),
587        ];
588        let mut buf: [u8; 1024] = [0; 1024];
589        for &(inp, out) in tests {
590            println!(
591                "Test: {:?}",
592                itertools::concat(
593                    inp.iter()
594                        .map(|i| show_bytes(i).chars().collect::<Vec<char>>())
595                )
596                .iter()
597                .collect::<String>()
598            );
599            let mut res = Vec::<u8>::new();
600            let mut end = 0;
601            let mut unescaper = DataUnescaper::new(true);
602            for i in inp {
603                buf[end..end + i.len()].copy_from_slice(i);
604                let r = unescaper.unescape(&mut buf[..end + i.len()]);
605                res.extend_from_slice(&buf[..r.written]);
606                buf.copy_within(r.unhandled_idx..end + i.len(), 0);
607                end = end + i.len() - r.unhandled_idx;
608            }
609            println!("Result: {:?}", show_bytes(&res));
610            assert_eq!(&res[..], out);
611        }
612    }
613
614    #[test]
615    fn escaping_data_writer() {
616        let tests: &[(&[&[&[u8]]], &[u8])] = &[
617            (&[&[b"foo", b" bar"], &[b" baz"]], b"foo bar baz\r\n.\r\n"),
618            (&[&[b"foo\r\n. bar\r\n"]], b"foo\r\n.. bar\r\n.\r\n"),
619            (&[&[b""]], b".\r\n"),
620            (&[&[b"."]], b"..\r\n.\r\n"),
621            (&[&[b"\r"]], b"\r\r\n.\r\n"),
622            (&[&[b"foo\r"]], b"foo\r\r\n.\r\n"),
623            (&[&[b"foo bar\r", b"\n"]], b"foo bar\r\n.\r\n"),
624            (
625                &[&[b"foo bar\r\n"], &[b". baz\n"]],
626                b"foo bar\r\n.. baz\n\r\n.\r\n",
627            ),
628        ];
629        for &(inp, out) in tests {
630            println!("Expected result: {:?}", show_bytes(out));
631            let mut v = Vec::new();
632            let c = Cursor::new(&mut v);
633            let mut w = EscapingDataWriter::new(c);
634            for write in inp {
635                let mut written = 0;
636                let total_to_write = write.iter().map(|b| b.len()).sum::<usize>();
637                while written != total_to_write {
638                    let mut i = Vec::new();
639                    let mut skipped = 0;
640                    for s in *write {
641                        if skipped + s.len() <= written {
642                            skipped += s.len();
643                            println!("(skipping, skipped = {})", skipped);
644                            continue;
645                        }
646                        if written - skipped != 0 {
647                            println!("(skipping first {} chars)", written - skipped);
648                            i.push(IoSlice::new(&s[(written - skipped)..]));
649                            skipped = written;
650                        } else {
651                            println!("(skipping nothing)");
652                            i.push(IoSlice::new(s));
653                        }
654                    }
655                    println!("Writing: {:?}", i);
656                    written += executor::block_on(w.write_vectored(&i)).unwrap();
657                    println!("Written: {:?} (out of {:?})", written, total_to_write);
658                }
659            }
660            executor::block_on(w.finish()).unwrap();
661            assert_eq!(&v, &out);
662        }
663    }
664}