1use std::{
2 cmp,
3 io::{self, IoSlice, IoSliceMut},
4 ops::Range,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use futures::{pin_mut, AsyncRead, AsyncWrite, AsyncWriteExt};
10use pin_project::pin_project;
11
12#[derive(Copy, Clone, Debug, Eq, PartialEq)]
15enum EscapedDataReaderState {
16 Start,
17 Cr,
18 CrLf,
19 CrLfDot,
20 CrLfDotCr,
21 End,
22 Completed,
23}
24
25#[pin_project]
33pub struct EscapedDataReader<'a, R> {
34 buf: &'a mut [u8],
35
36 unhandled: Range<usize>,
39
40 state: EscapedDataReaderState,
41
42 #[pin]
43 read: R,
44}
45
46impl<'a, R> EscapedDataReader<'a, R>
47where
48 R: AsyncRead,
49{
50 #[inline]
51 pub fn new(buf: &'a mut [u8], unhandled: Range<usize>, read: R) -> Self {
52 EscapedDataReader {
53 buf,
54 unhandled,
55 state: EscapedDataReaderState::CrLf,
56 read,
57 }
58 }
59
60 #[inline]
63 pub fn is_finished(&self) -> bool {
64 self.state == EscapedDataReaderState::End || self.state == EscapedDataReaderState::Completed
65 }
66
67 #[inline]
76 pub fn complete(&mut self) {
77 assert!(self.is_finished());
78 self.state = EscapedDataReaderState::Completed;
79 }
80
81 #[inline]
86 pub fn get_unhandled(&self) -> Option<Range<usize>> {
87 if self.state == EscapedDataReaderState::Completed {
88 Some(self.unhandled.clone())
89 } else {
90 None
91 }
92 }
93}
94
95impl<'a, R> AsyncRead for EscapedDataReader<'a, R>
96where
97 R: AsyncRead,
98{
99 fn poll_read(
100 self: Pin<&mut Self>,
101 cx: &mut Context,
102 buf: &mut [u8],
103 ) -> Poll<io::Result<usize>> {
104 self.poll_read_vectored(cx, &mut [IoSliceMut::new(buf)])
105 }
106
107 fn poll_read_vectored(
108 self: Pin<&mut Self>,
109 cx: &mut Context,
110 bufs: &mut [IoSliceMut],
111 ) -> Poll<io::Result<usize>> {
112 if self.is_finished() {
114 return Poll::Ready(Ok(0));
115 }
116
117 let this = self.project();
118
119 let raw_size = {
121 let unhandled_len_start = this.unhandled.end - this.unhandled.start;
122 if unhandled_len_start > 0 {
123 for buf in bufs.iter_mut() {
124 let copy_len = cmp::min(buf.len(), this.unhandled.end - this.unhandled.start);
125 let next_start = this.unhandled.start + copy_len;
126 buf[..copy_len].copy_from_slice(&this.buf[this.unhandled.start..next_start]);
127 this.unhandled.start = next_start;
128 }
129 unhandled_len_start - (this.unhandled.end - this.unhandled.start)
130 } else {
131 match this.read.poll_read_vectored(cx, bufs) {
132 Poll::Ready(Ok(s)) => s,
133 other => return other,
134 }
135 }
136 };
137
138 if raw_size == 0 {
140 if bufs.iter().map(|b| b.len()).sum::<usize>() == 0 {
141 return Poll::Ready(Ok(0));
142 } else {
143 return Poll::Ready(Err(io::Error::new(
144 io::ErrorKind::ConnectionAborted,
145 "connection aborted without finishing the data stream",
146 )));
147 }
148 }
149
150 let mut size = 0;
152 for b in 0..bufs.len() {
153 for i in 0..cmp::min(bufs[b].len(), raw_size - size) {
154 use EscapedDataReaderState::*;
155 match (*this.state, bufs[b][i]) {
156 (Cr, b'\n') => *this.state = CrLf,
157 (CrLf, b'.') => *this.state = CrLfDot,
158 (CrLfDot, b'\r') => *this.state = CrLfDotCr,
159 (CrLfDotCr, b'\n') => {
160 *this.state = End;
161 size += i + 1;
162
163 if this.unhandled.start == this.unhandled.end {
164 let remaining = cmp::min(bufs[b].len() - (i + 1), raw_size - size);
167 this.buf[..remaining]
168 .copy_from_slice(&bufs[b][i + 1..i + 1 + remaining]);
169 let mut copied = remaining;
170 for buf in &bufs[b + 1..] {
171 let remaining = cmp::min(buf.len(), raw_size - size - copied);
172 this.buf[copied..copied + remaining]
173 .copy_from_slice(&buf[..remaining]);
174 copied += remaining;
175 }
176 *this.unhandled = 0..copied;
177 } else {
178 this.unhandled.start -= raw_size - size;
181 }
182
183 return Poll::Ready(Ok(size));
184 }
185 (_, b'\r') => *this.state = Cr,
186 _ => *this.state = Start,
187 }
188 }
189 size += cmp::min(bufs[b].len(), raw_size - size);
190 }
191
192 Poll::Ready(Ok(size))
194 }
195}
196
197pub struct DataUnescapeRes {
198 pub written: usize,
199 pub unhandled_idx: usize,
200}
201
202pub struct DataUnescaper {
208 is_preceded_by_crlf: bool,
209}
210
211impl DataUnescaper {
212 pub fn new(is_preceded_by_crlf: bool) -> DataUnescaper {
221 DataUnescaper {
222 is_preceded_by_crlf,
223 }
224 }
225
226 pub fn unescape(&mut self, data: &mut [u8]) -> DataUnescapeRes {
249 let mut written = 0;
255 let mut unhandled_idx = 0;
256
257 if self.is_preceded_by_crlf {
258 if data.len() <= 3 {
259 return DataUnescapeRes {
263 written: 0,
264 unhandled_idx: 0,
265 };
266 } else if data.starts_with(b".\r\n") {
267 return DataUnescapeRes {
269 written: 0,
270 unhandled_idx: 3,
271 };
272 } else if data[0] == b'.' {
273 unhandled_idx += 1;
275 } else {
276 }
278
279 self.is_preceded_by_crlf = false;
280 }
281
282 while let Some(i) = data[unhandled_idx..].windows(3).position(|s| s == b"\r\n.") {
284 if data.len() <= unhandled_idx + i + 4 {
285 if unhandled_idx != written {
287 data.copy_within(unhandled_idx..unhandled_idx + i, written);
288 }
289 return DataUnescapeRes {
290 written: written + i,
291 unhandled_idx: unhandled_idx + i,
292 };
293 } else if &data[unhandled_idx + i + 3..unhandled_idx + i + 5] != b"\r\n" {
294 if unhandled_idx != written {
296 data.copy_within(unhandled_idx..unhandled_idx + i + 2, written);
297 }
298 written += i + 2;
299 unhandled_idx += i + 3;
300 } else {
301 if unhandled_idx != written {
303 data.copy_within(unhandled_idx..unhandled_idx + i + 2, written);
304 }
305 return DataUnescapeRes {
306 written: written + i + 2,
307 unhandled_idx: unhandled_idx + i + 5,
308 };
309 }
310 }
311
312 if data.ends_with(b"\r\n") {
315 if unhandled_idx != written {
316 data.copy_within(unhandled_idx..data.len() - 2, written);
317 }
318 DataUnescapeRes {
319 written: written + data.len() - 2 - unhandled_idx,
320 unhandled_idx: data.len() - 2,
321 }
322 } else if data.ends_with(b"\r") {
323 if unhandled_idx != written {
324 data.copy_within(unhandled_idx..data.len() - 1, written);
325 }
326 DataUnescapeRes {
327 written: written + data.len() - 1 - unhandled_idx,
328 unhandled_idx: data.len() - 1,
329 }
330 } else {
331 if unhandled_idx != written {
332 data.copy_within(unhandled_idx..data.len(), written);
333 }
334 DataUnescapeRes {
335 written: written + data.len() - unhandled_idx,
336 unhandled_idx: data.len(),
337 }
338 }
339 }
340}
341
342#[derive(Clone, Copy)]
343enum EscapingDataWriterState {
344 Start,
345 Cr,
346 CrLf,
347}
348
349#[pin_project]
352pub struct EscapingDataWriter<W> {
353 state: EscapingDataWriterState,
354
355 #[pin]
356 write: W,
357}
358
359impl<W> EscapingDataWriter<W>
360where
361 W: AsyncWrite,
362{
363 #[inline]
364 pub fn new(write: W) -> Self {
365 EscapingDataWriter {
366 state: EscapingDataWriterState::CrLf,
367 write,
368 }
369 }
370
371 #[inline]
372 pub async fn finish(self) -> io::Result<()> {
373 let write = self.write;
374 pin_mut!(write);
375 match self.state {
376 EscapingDataWriterState::CrLf => write.write_all(b".\r\n").await,
377 _ => write.write_all(b"\r\n.\r\n").await,
378 }
379 }
380}
381
382impl<W> AsyncWrite for EscapingDataWriter<W>
383where
384 W: AsyncWrite,
385{
386 #[inline]
387 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
388 self.poll_write_vectored(cx, &[IoSlice::new(buf)])
389 }
390
391 #[inline]
392 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
393 self.project().write.poll_flush(cx)
394 }
395
396 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
397 Poll::Ready(Err(io::Error::new(
398 io::ErrorKind::Other,
399 "tried closing a stream during a message",
400 )))
401 }
402
403 fn poll_write_vectored(
404 self: Pin<&mut Self>,
405 cx: &mut Context,
406 bufs: &[IoSlice],
407 ) -> Poll<io::Result<usize>> {
408 fn set_state_until(state: &mut EscapingDataWriterState, bufs: &[IoSlice], n: usize) {
409 use EscapingDataWriterState::*;
410 let mut n = n;
411 for buf in bufs {
412 if n.saturating_sub(2) > buf.len() {
413 n -= buf.len();
414 *state = Start;
415 continue;
416 }
417 for i in n.saturating_sub(2)..cmp::min(buf.len(), n) {
418 n -= 1;
419 match (*state, buf[i]) {
420 (_, b'\r') => *state = Cr,
421 (Cr, b'\n') => *state = CrLf,
422 _ => *state = Start,
424 }
425 }
426 if n == 0 {
427 return;
428 }
429 }
430 }
431
432 let this = self.project();
433
434 let initial_state = *this.state;
435 for b in 0..bufs.len() {
436 for i in 0..bufs[b].len() {
437 use EscapingDataWriterState::*;
438 match (*this.state, bufs[b][i]) {
439 (_, b'\r') => *this.state = Cr,
440 (Cr, b'\n') => *this.state = CrLf,
441 (CrLf, b'.') => {
442 let mut v = Vec::with_capacity(b + 1);
443 let mut writing = 0;
444 for buf in &bufs[0..b] {
445 v.push(IoSlice::new(buf));
446 writing += buf.len();
447 }
448 v.push(IoSlice::new(&bufs[b][..=i]));
449 writing += i + 1;
450 return match this.write.poll_write_vectored(cx, &v) {
451 Poll::Ready(Ok(s)) => {
452 if s == writing {
453 *this.state = Start;
454 Poll::Ready(Ok(s - 1))
455 } else {
456 *this.state = initial_state;
457 set_state_until(this.state, bufs, s);
458 Poll::Ready(Ok(s))
459 }
460 }
461 o => o,
462 };
463 }
464 _ => *this.state = Start,
465 }
466 }
467 }
468
469 match this.write.poll_write_vectored(cx, bufs) {
470 Poll::Ready(Ok(s)) => {
471 if s == bufs.iter().map(|b| b.len()).sum::<usize>() {
472 Poll::Ready(Ok(s))
473 } else {
474 *this.state = initial_state;
475 set_state_until(this.state, bufs, s);
476 Poll::Ready(Ok(s))
477 }
478 }
479 o => o,
480 }
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487 use crate::*;
488
489 use futures::{
490 executor,
491 io::{AsyncReadExt, Cursor},
492 };
493
494 #[test]
496 fn escaped_data_reader() {
497 let tests: &[(&[&[u8]], &[u8], &[u8])] = &[
498 (
499 &[b"foo", b" bar", b"\r\n", b".\r", b"\n"],
500 b"foo bar\r\n.\r\n",
501 b"",
502 ),
503 (&[b"\r\n.\r\n", b"\r\n"], b"\r\n.\r\n", b"\r\n"),
504 (&[b".\r\n"], b".\r\n", b""),
505 (&[b".baz\r\n", b".\r\n", b"foo"], b".baz\r\n.\r\n", b"foo"),
506 (&[b" .baz", b"\r\n.", b"\r\nfoo"], b" .baz\r\n.\r\n", b"foo"),
507 (&[b".\r\n", b"MAIL FROM"], b".\r\n", b"MAIL FROM"),
508 (&[b"..\r\n.\r\n"], b"..\r\n.\r\n", b""),
509 (
510 &[b"foo\r\n. ", b"bar\r\n.\r\n"],
511 b"foo\r\n. bar\r\n.\r\n",
512 b"",
513 ),
514 (&[b".\r\nMAIL FROM"], b".\r\n", b"MAIL FROM"),
515 (&[b"..\r\n.\r\nMAIL FROM"], b"..\r\n.\r\n", b"MAIL FROM"),
516 ];
517 let mut surrounding_buf: [u8; 16] = [0; 16];
518 let mut enclosed_buf: [u8; 8] = [0; 8];
519 for (i, &(inp, out, rem)) in tests.iter().enumerate() {
520 println!(
521 "Trying to parse test {} into {:?} with {:?} remaining\n",
522 i,
523 show_bytes(out),
524 show_bytes(rem)
525 );
526
527 let mut reader = inp[1..].iter().map(Cursor::new).fold(
528 Box::pin(futures::io::empty()) as Pin<Box<dyn 'static + AsyncRead>>,
529 |a, b| Box::pin(AsyncReadExt::chain(a, b)),
530 );
531
532 surrounding_buf[..inp[0].len()].copy_from_slice(inp[0]);
533 let mut data_reader =
534 EscapedDataReader::new(&mut surrounding_buf, 0..inp[0].len(), reader.as_mut());
535
536 let mut res_out = Vec::<u8>::new();
537 while let Ok(r) = executor::block_on(data_reader.read(&mut enclosed_buf)) {
538 if r == 0 {
539 break;
540 }
541 println!(
542 "got out buf (size {}): {:?}",
543 r,
544 show_bytes(&enclosed_buf[..r])
545 );
546 res_out.extend_from_slice(&enclosed_buf[..r]);
547 }
548 data_reader.complete();
549 println!(
550 "total out is: {:?}, hoping for: {:?}",
551 show_bytes(&res_out),
552 show_bytes(out)
553 );
554 assert_eq!(&res_out[..], out);
555
556 let unhandled = data_reader.get_unhandled().unwrap();
557 let mut res_rem = Vec::<u8>::new();
558 res_rem.extend_from_slice(&surrounding_buf[unhandled]);
559
560 while let Ok(r) = executor::block_on(reader.read(&mut surrounding_buf)) {
561 if r == 0 {
562 break;
563 }
564 println!("got rem buf: {:?}", show_bytes(&surrounding_buf[..r]));
565 res_rem.extend_from_slice(&surrounding_buf[0..r]);
566 }
567 println!(
568 "total rem is: {:?}, hoping for: {:?}",
569 show_bytes(&res_rem),
570 show_bytes(rem)
571 );
572 assert_eq!(&res_rem[..], rem);
573 }
574 }
575
576 #[test]
577 fn data_unescaper() {
578 let tests: &[(&[&[u8]], &[u8])] = &[
579 (&[b"foo", b" bar", b"\r\n", b".\r", b"\n"], b"foo bar\r\n"),
580 (&[b"\r\n.\r\n"], b"\r\n"),
581 (&[b".baz\r\n", b".\r\n"], b"baz\r\n"),
582 (&[b" .baz", b"\r\n.", b"\r\n"], b" .baz\r\n"),
583 (&[b".\r\n"], b""),
584 (&[b"..\r\n.\r\n"], b".\r\n"),
585 (&[b"foo\r\n. ", b"bar\r\n.\r\n"], b"foo\r\n bar\r\n"),
586 (&[b"\r\r\n.\r\n"], b"\r\r\n"),
587 ];
588 let mut buf: [u8; 1024] = [0; 1024];
589 for &(inp, out) in tests {
590 println!(
591 "Test: {:?}",
592 itertools::concat(
593 inp.iter()
594 .map(|i| show_bytes(i).chars().collect::<Vec<char>>())
595 )
596 .iter()
597 .collect::<String>()
598 );
599 let mut res = Vec::<u8>::new();
600 let mut end = 0;
601 let mut unescaper = DataUnescaper::new(true);
602 for i in inp {
603 buf[end..end + i.len()].copy_from_slice(i);
604 let r = unescaper.unescape(&mut buf[..end + i.len()]);
605 res.extend_from_slice(&buf[..r.written]);
606 buf.copy_within(r.unhandled_idx..end + i.len(), 0);
607 end = end + i.len() - r.unhandled_idx;
608 }
609 println!("Result: {:?}", show_bytes(&res));
610 assert_eq!(&res[..], out);
611 }
612 }
613
614 #[test]
615 fn escaping_data_writer() {
616 let tests: &[(&[&[&[u8]]], &[u8])] = &[
617 (&[&[b"foo", b" bar"], &[b" baz"]], b"foo bar baz\r\n.\r\n"),
618 (&[&[b"foo\r\n. bar\r\n"]], b"foo\r\n.. bar\r\n.\r\n"),
619 (&[&[b""]], b".\r\n"),
620 (&[&[b"."]], b"..\r\n.\r\n"),
621 (&[&[b"\r"]], b"\r\r\n.\r\n"),
622 (&[&[b"foo\r"]], b"foo\r\r\n.\r\n"),
623 (&[&[b"foo bar\r", b"\n"]], b"foo bar\r\n.\r\n"),
624 (
625 &[&[b"foo bar\r\n"], &[b". baz\n"]],
626 b"foo bar\r\n.. baz\n\r\n.\r\n",
627 ),
628 ];
629 for &(inp, out) in tests {
630 println!("Expected result: {:?}", show_bytes(out));
631 let mut v = Vec::new();
632 let c = Cursor::new(&mut v);
633 let mut w = EscapingDataWriter::new(c);
634 for write in inp {
635 let mut written = 0;
636 let total_to_write = write.iter().map(|b| b.len()).sum::<usize>();
637 while written != total_to_write {
638 let mut i = Vec::new();
639 let mut skipped = 0;
640 for s in *write {
641 if skipped + s.len() <= written {
642 skipped += s.len();
643 println!("(skipping, skipped = {})", skipped);
644 continue;
645 }
646 if written - skipped != 0 {
647 println!("(skipping first {} chars)", written - skipped);
648 i.push(IoSlice::new(&s[(written - skipped)..]));
649 skipped = written;
650 } else {
651 println!("(skipping nothing)");
652 i.push(IoSlice::new(s));
653 }
654 }
655 println!("Writing: {:?}", i);
656 written += executor::block_on(w.write_vectored(&i)).unwrap();
657 println!("Written: {:?} (out of {:?})", written, total_to_write);
658 }
659 }
660 executor::block_on(w.finish()).unwrap();
661 assert_eq!(&v, &out);
662 }
663 }
664}