1use crate::varint::VarInt;
2use octets::Octets;
3use octets::OctetsMut;
4use std::ops::Deref;
5use std::ops::DerefMut;
6
7#[derive(Debug)]
10pub struct EndOfBuffer;
11
12pub trait BytesReader<'a> {
14 fn get_varint(&mut self) -> Option<VarInt>;
19
20 fn get_bytes(&mut self, len: usize) -> Option<&'a [u8]>;
25}
26
27impl<'a> BytesReader<'a> for &'a [u8] {
28 fn get_varint(&mut self) -> Option<VarInt> {
29 let varint_size = VarInt::parse_size(*self.first()?);
30 let buffer = self.get(..varint_size)?;
31 let varint = BufferReader::new(buffer)
32 .get_varint()
33 .expect("Varint parsable");
34 *self = &self[varint_size..];
35 Some(varint)
36 }
37
38 fn get_bytes(&mut self, len: usize) -> Option<&'a [u8]> {
39 let buffer = self.get(..len)?;
40 *self = &self[len..];
41 Some(buffer)
42 }
43}
44
45pub trait BytesWriter {
47 fn put_varint(&mut self, varint: VarInt) -> Result<(), EndOfBuffer>;
52
53 fn put_bytes(&mut self, bytes: &[u8]) -> Result<(), EndOfBuffer>;
57}
58
59impl BytesWriter for Vec<u8> {
60 fn put_varint(&mut self, varint: VarInt) -> Result<(), EndOfBuffer> {
61 let offset = self.len();
62
63 self.resize(offset + varint.size(), 0);
64
65 BufferWriter::new(&mut self[offset..])
66 .put_varint(varint)
67 .expect("Enough capacity pre-allocated");
68
69 Ok(())
70 }
71
72 fn put_bytes(&mut self, bytes: &[u8]) -> Result<(), EndOfBuffer> {
73 self.extend_from_slice(bytes);
74 Ok(())
75 }
76}
77
78pub struct BufferReader<'a>(Octets<'a>);
82
83impl<'a> BufferReader<'a> {
84 #[inline(always)]
88 pub fn new(buffer: &'a [u8]) -> Self {
89 Self(Octets::with_slice(buffer))
90 }
91
92 #[inline(always)]
94 pub fn capacity(&self) -> usize {
95 self.0.cap()
96 }
97
98 #[inline(always)]
100 pub fn offset(&self) -> usize {
101 self.0.off()
102 }
103
104 #[inline(always)]
108 pub fn skip(&mut self, len: usize) -> Result<(), EndOfBuffer> {
109 self.0
110 .skip(len)
111 .map_err(|octets::BufferTooShortError| EndOfBuffer)
112 }
113
114 #[inline(always)]
118 pub fn buffer(&self) -> &'a [u8] {
119 self.0.buf()
120 }
121
122 #[inline(always)]
124 pub fn buffer_remaining(&mut self) -> &'a [u8] {
125 &self.buffer()[self.offset()..]
126 }
127
128 #[inline(always)]
130 pub fn child(&mut self) -> BufferReaderChild<'a, '_> {
131 BufferReaderChild::with_parent(self)
132 }
133}
134
135impl<'a> BytesReader<'a> for BufferReader<'a> {
136 #[inline(always)]
137 fn get_varint(&mut self) -> Option<VarInt> {
138 match self.0.get_varint() {
139 Ok(value) => {
140 Some(unsafe {
142 debug_assert!(value <= VarInt::MAX.into_inner());
143 VarInt::from_u64_unchecked(value)
144 })
145 }
146 Err(octets::BufferTooShortError) => None,
147 }
148 }
149
150 #[inline(always)]
151 fn get_bytes(&mut self, len: usize) -> Option<&'a [u8]> {
152 self.0.get_bytes(len).ok().map(|o| o.buf())
153 }
154}
155
156pub struct BufferReaderChild<'a, 'b> {
166 reader: BufferReader<'a>,
167 parent: &'b mut BufferReader<'a>,
168}
169
170impl<'a, 'b> BufferReaderChild<'a, 'b> {
171 #[inline(always)]
173 pub fn commit(self) {
174 self.parent
175 .skip(self.reader.offset())
176 .expect("Child offset is bounded to parent");
177 }
178
179 #[inline(always)]
180 fn with_parent(parent: &'b mut BufferReader<'a>) -> Self {
181 Self {
182 reader: BufferReader::new(parent.buffer_remaining()),
183 parent,
184 }
185 }
186}
187
188impl<'a, 'b> Deref for BufferReaderChild<'a, 'b> {
189 type Target = BufferReader<'a>;
190
191 #[inline(always)]
192 fn deref(&self) -> &Self::Target {
193 &self.reader
194 }
195}
196
197impl<'a, 'b> DerefMut for BufferReaderChild<'a, 'b> {
198 #[inline(always)]
199 fn deref_mut(&mut self) -> &mut Self::Target {
200 &mut self.reader
201 }
202}
203
204pub struct BufferWriter<'a>(OctetsMut<'a>);
206
207impl<'a> BufferWriter<'a> {
208 #[inline(always)]
212 pub fn new(bytes: &'a mut [u8]) -> Self {
213 Self(OctetsMut::with_slice(bytes))
214 }
215
216 #[inline(always)]
218 pub fn capacity(&self) -> usize {
219 self.0.cap()
220 }
221
222 #[inline(always)]
224 pub fn offset(&self) -> usize {
225 self.0.off()
226 }
227
228 #[inline(always)]
230 pub fn buffer_written(&self) -> &[u8] {
231 &self.0.buf()[..self.offset()]
232 }
233}
234
235impl<'a> BytesWriter for BufferWriter<'a> {
236 #[inline(always)]
237 fn put_varint(&mut self, varint: VarInt) -> Result<(), EndOfBuffer> {
238 self.0
239 .put_varint(varint.into_inner())
240 .map_err(|octets::BufferTooShortError| EndOfBuffer)?;
241
242 Ok(())
243 }
244
245 #[inline(always)]
246 fn put_bytes(&mut self, bytes: &[u8]) -> Result<(), EndOfBuffer> {
247 self.0
248 .put_bytes(bytes)
249 .map_err(|octets::BufferTooShortError| EndOfBuffer)
250 }
251}
252
253#[cfg(feature = "async")]
255#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
256pub mod r#async {
257 use super::*;
258 use std::future::Future;
259 use std::io::ErrorKind as IoErrorKind;
260 use std::pin::Pin;
261 use std::task::ready;
262 use std::task::Context;
263 use std::task::Poll;
264
265 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
267 #[derive(Debug)]
268 pub enum IoReadError {
269 ImmediateFin,
273
274 UnexpectedFin,
278
279 Reset,
283
284 NotConnected,
288 }
289
290 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
292 #[derive(Debug)]
293 pub enum IoWriteError {
294 Stopped,
298
299 NotConnected,
303 }
304
305 impl From<std::io::Error> for IoReadError {
306 fn from(error: std::io::Error) -> Self {
307 match error.kind() {
308 IoErrorKind::ConnectionReset => IoReadError::Reset,
309 _ => IoReadError::NotConnected,
310 }
311 }
312 }
313
314 impl From<std::io::Error> for IoWriteError {
315 fn from(error: std::io::Error) -> Self {
316 match error.kind() {
317 IoErrorKind::ConnectionReset => IoWriteError::Stopped,
318 _ => IoWriteError::NotConnected,
319 }
320 }
321 }
322
323 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
325 pub trait AsyncRead {
326 fn poll_read(
343 self: Pin<&mut Self>,
344 cx: &mut Context<'_>,
345 buf: &mut [u8],
346 ) -> Poll<std::io::Result<usize>>;
347 }
348
349 impl AsyncRead for &[u8] {
350 fn poll_read(
351 mut self: Pin<&mut Self>,
352 _cx: &mut Context<'_>,
353 buf: &mut [u8],
354 ) -> Poll<std::io::Result<usize>> {
355 let amt = std::cmp::min(self.len(), buf.len());
356 let (a, b) = self.split_at(amt);
357 buf[..amt].copy_from_slice(a);
358 *self = b;
359 Poll::Ready(Ok(amt))
360 }
361 }
362
363 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
365 pub trait AsyncWrite {
366 fn poll_write(
383 self: Pin<&mut Self>,
384 cx: &mut Context<'_>,
385 buf: &[u8],
386 ) -> Poll<std::io::Result<usize>>;
387 }
388
389 impl AsyncWrite for Vec<u8> {
390 fn poll_write(
391 mut self: Pin<&mut Self>,
392 _cx: &mut Context<'_>,
393 buf: &[u8],
394 ) -> Poll<std::io::Result<usize>> {
395 self.extend_from_slice(buf);
396 Poll::Ready(std::io::Result::Ok(buf.len()))
397 }
398 }
399
400 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
402 pub trait BytesReaderAsync {
403 fn get_varint(&mut self) -> GetVarint<Self>;
405
406 fn get_buffer<'a>(&'a mut self, buffer: &'a mut [u8]) -> GetBuffer<Self>;
408 }
409
410 impl<T> BytesReaderAsync for T
411 where
412 T: AsyncRead + ?Sized,
413 {
414 fn get_varint(&mut self) -> GetVarint<Self> {
415 GetVarint::new(self)
416 }
417
418 fn get_buffer<'a>(&'a mut self, buffer: &'a mut [u8]) -> GetBuffer<Self> {
419 GetBuffer::new(self, buffer)
420 }
421 }
422
423 impl<T> BytesWriterAsync for T
424 where
425 T: AsyncWrite + ?Sized,
426 {
427 fn put_varint(&mut self, varint: VarInt) -> PutVarint<Self> {
428 PutVarint::new(self, varint)
429 }
430
431 fn put_buffer<'a>(&'a mut self, buffer: &'a [u8]) -> PutBuffer<Self> {
432 PutBuffer::new(self, buffer)
433 }
434 }
435
436 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
438 pub trait BytesWriterAsync {
439 fn put_varint(&mut self, varint: VarInt) -> PutVarint<Self>;
442
443 fn put_buffer<'a>(&'a mut self, buffer: &'a [u8]) -> PutBuffer<Self>;
445 }
446
447 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
451 pub struct GetVarint<'a, R: ?Sized> {
452 reader: &'a mut R,
453 buffer: [u8; VarInt::MAX_SIZE],
454 offset: usize,
455 varint_size: usize,
456 }
457
458 impl<'a, R> GetVarint<'a, R>
459 where
460 R: AsyncRead + ?Sized,
461 {
462 fn new(reader: &'a mut R) -> Self {
463 Self {
464 reader,
465 buffer: [0; VarInt::MAX_SIZE],
466 offset: 0,
467 varint_size: 0,
468 }
469 }
470 }
471
472 impl<'a, R> Future for GetVarint<'a, R>
473 where
474 R: AsyncRead + Unpin + ?Sized,
475 {
476 type Output = Result<VarInt, IoReadError>;
477
478 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
479 let this = self.get_mut();
480
481 if this.offset == 0 {
482 debug_assert_eq!(this.varint_size, 0);
483
484 let read = ready!(AsyncRead::poll_read(
485 Pin::new(this.reader),
486 cx,
487 &mut this.buffer[0..1]
488 ))?;
489
490 debug_assert!(read == 0 || read == 1);
491
492 if read == 1 {
493 this.offset = 1;
494 this.varint_size = VarInt::parse_size(this.buffer[0]);
495 debug_assert_ne!(this.varint_size, 0);
496 } else {
497 return Poll::Ready(Err(IoReadError::ImmediateFin));
498 }
499 }
500
501 while this.offset < this.varint_size {
502 let read = ready!(AsyncRead::poll_read(
503 Pin::new(this.reader),
504 cx,
505 &mut this.buffer[this.offset..this.varint_size]
506 ))?;
507
508 debug_assert!(read <= this.varint_size - this.offset);
509
510 if read > 0 {
511 this.offset += read;
512 } else {
513 return Poll::Ready(Err(IoReadError::UnexpectedFin));
514 }
515 }
516
517 let varint = BufferReader::new(&this.buffer[..this.varint_size])
518 .get_varint()
519 .expect("Varint is parsable");
520
521 Poll::Ready(Ok(varint))
522 }
523 }
524
525 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
529 pub struct GetBuffer<'a, R: ?Sized> {
530 reader: &'a mut R,
531 buffer: &'a mut [u8],
532 offset: usize,
533 }
534
535 impl<'a, R> GetBuffer<'a, R>
536 where
537 R: AsyncRead + ?Sized,
538 {
539 fn new(reader: &'a mut R, buffer: &'a mut [u8]) -> Self {
540 Self {
541 reader,
542 buffer,
543 offset: 0,
544 }
545 }
546 }
547
548 impl<'a, R> Future for GetBuffer<'a, R>
549 where
550 R: AsyncRead + Unpin + ?Sized,
551 {
552 type Output = Result<(), IoReadError>;
553
554 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
555 let this = self.get_mut();
556
557 while this.offset < this.buffer.len() {
558 let read = ready!(AsyncRead::poll_read(
559 Pin::new(this.reader),
560 cx,
561 &mut this.buffer[this.offset..],
562 ))?;
563
564 debug_assert!(read <= this.buffer.len() - this.offset);
565
566 if read > 0 {
567 this.offset += read;
568 } else if this.offset > 0 {
569 return Poll::Ready(Err(IoReadError::UnexpectedFin));
570 } else {
571 return Poll::Ready(Err(IoReadError::ImmediateFin));
572 }
573 }
574
575 Poll::Ready(Ok(()))
576 }
577 }
578
579 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
583 pub struct PutVarint<'a, W: ?Sized> {
584 writer: &'a mut W,
585 buffer: [u8; VarInt::MAX_SIZE],
586 offset: usize,
587 varint_size: usize,
588 }
589
590 impl<'a, W> PutVarint<'a, W>
591 where
592 W: AsyncWrite + ?Sized,
593 {
594 fn new(writer: &'a mut W, varint: VarInt) -> Self {
595 let mut this = Self {
596 writer,
597 buffer: [0; VarInt::MAX_SIZE],
598 offset: 0,
599 varint_size: 0,
600 };
601
602 let mut buffer_writer = BufferWriter::new(&mut this.buffer);
603 buffer_writer
604 .put_varint(varint)
605 .expect("Inner buffer is enough for max varint");
606
607 this.varint_size = buffer_writer.offset();
608
609 this
610 }
611 }
612
613 impl<'a, W> Future for PutVarint<'a, W>
614 where
615 W: AsyncWrite + Unpin + ?Sized,
616 {
617 type Output = Result<(), IoWriteError>;
618
619 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
620 let this = self.get_mut();
621
622 while this.offset < this.varint_size {
623 let written = ready!(AsyncWrite::poll_write(
624 Pin::new(this.writer),
625 cx,
626 &this.buffer[this.offset..this.varint_size]
627 ))?;
628
629 debug_assert!(written > 0);
631
632 this.offset += written;
633 }
634
635 Poll::Ready(Ok(()))
636 }
637 }
638
639 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
643 pub struct PutBuffer<'a, W: ?Sized> {
644 writer: &'a mut W,
645 buffer: &'a [u8],
646 offset: usize,
647 }
648
649 impl<'a, W> PutBuffer<'a, W>
650 where
651 W: AsyncWrite + ?Sized,
652 {
653 fn new(writer: &'a mut W, buffer: &'a [u8]) -> Self {
654 Self {
655 writer,
656 buffer,
657 offset: 0,
658 }
659 }
660 }
661
662 impl<'a, W> Future for PutBuffer<'a, W>
663 where
664 W: AsyncWrite + Unpin + ?Sized,
665 {
666 type Output = Result<(), IoWriteError>;
667
668 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
669 let this = self.get_mut();
670
671 while this.offset < this.buffer.len() {
672 let written = ready!(AsyncWrite::poll_write(
673 Pin::new(this.writer),
674 cx,
675 &this.buffer[this.offset..]
676 ))?;
677
678 debug_assert!(written > 0);
680
681 this.offset += written;
682 }
683
684 Poll::Ready(Ok(()))
685 }
686 }
687}
688
689#[cfg(feature = "async")]
690pub use r#async::AsyncRead;
691
692#[cfg(feature = "async")]
693pub use r#async::AsyncWrite;
694
695#[cfg(feature = "async")]
696pub use r#async::BytesReaderAsync;
697
698#[cfg(feature = "async")]
699pub use r#async::BytesWriterAsync;
700
701#[cfg(feature = "async")]
702pub use r#async::IoReadError;
703
704#[cfg(feature = "async")]
705pub use r#async::IoWriteError;
706
707#[cfg(test)]
708mod tests {
709 use super::*;
710
711 #[test]
712 fn read_varint() {
713 for (varint_buffer, value_expect) in utils::VARINT_TEST_CASES {
714 let mut buffer_reader = BufferReader::new(varint_buffer);
715
716 assert_eq!(buffer_reader.offset(), 0);
717 assert_eq!(buffer_reader.capacity(), varint_buffer.len());
718
719 let value = buffer_reader.get_varint().unwrap();
720
721 assert_eq!(value, value_expect);
722 assert_eq!(buffer_reader.offset(), varint_buffer.len());
723 assert_eq!(buffer_reader.capacity(), 0);
724 }
725 }
726
727 #[tokio::test]
728 async fn read_varint_async() {
729 for (varint_buffer, value_expect) in utils::VARINT_TEST_CASES {
730 let mut reader = utils::StepReader::new(varint_buffer);
731 let value = reader.get_varint().await.unwrap();
732 assert_eq!(value, value_expect);
733 }
734 }
735
736 #[test]
737 fn read_buffer() {
738 let mut buffer_reader = BufferReader::new(utils::BUFFER_TEST);
739 let value = buffer_reader.get_bytes(utils::BUFFER_TEST.len()).unwrap();
740 assert_eq!(value, utils::BUFFER_TEST);
741 }
742
743 #[tokio::test]
744 async fn read_buffer_async() {
745 let mut value = [0; utils::BUFFER_TEST.len()];
746 let mut reader = utils::StepReader::new(utils::BUFFER_TEST);
747 reader.get_buffer(&mut value).await.unwrap();
748 assert_eq!(value, utils::BUFFER_TEST);
749 }
750
751 #[test]
752 fn write_varint() {
753 let mut buffer = [0; VarInt::MAX_SIZE];
754 for (varint_buffer, value) in utils::VARINT_TEST_CASES {
755 let mut buffer_writer = BufferWriter::new(&mut buffer);
756
757 assert_eq!(buffer_writer.offset(), 0);
758 assert_eq!(buffer_writer.capacity(), VarInt::MAX_SIZE);
759
760 buffer_writer.put_varint(value).unwrap();
761
762 assert_eq!(buffer_writer.offset(), varint_buffer.len());
763 assert_eq!(buffer_writer.buffer_written(), varint_buffer);
764 }
765 }
766
767 #[tokio::test]
768 async fn write_varint_async() {
769 for (varint_buffer, value) in utils::VARINT_TEST_CASES {
770 let mut writer = utils::StepWriter::new(Some(8));
771
772 writer.put_varint(value).await.unwrap();
773 assert_eq!(value.size(), writer.written().len());
774 assert_eq!(writer.written(), varint_buffer);
775 }
776 }
777
778 #[test]
779 fn child_commit() {
780 let mut buffer_reader = BufferReader::new(&[0x1, 0x2]);
781
782 buffer_reader.skip(1).unwrap();
783 assert_eq!(buffer_reader.offset(), 1);
784 assert_eq!(buffer_reader.capacity(), 1);
785
786 let mut buffer_reader_child = buffer_reader.child();
787 assert_eq!(buffer_reader_child.offset(), 0);
788 assert_eq!(buffer_reader_child.capacity(), 1);
789
790 assert!(matches!(buffer_reader_child.get_bytes(1), Some([0x2])));
791 assert_eq!(buffer_reader_child.offset(), 1);
792
793 buffer_reader_child.commit();
794
795 assert_eq!(buffer_reader.offset(), 2);
796 assert_eq!(buffer_reader.capacity(), 0);
797 }
798
799 #[test]
800 fn child_drop() {
801 let mut buffer_reader = BufferReader::new(&[0x1, 0x2]);
802
803 buffer_reader.skip(1).unwrap();
804 assert_eq!(buffer_reader.offset(), 1);
805 assert_eq!(buffer_reader.capacity(), 1);
806
807 let mut buffer_reader_child = buffer_reader.child();
808 assert_eq!(buffer_reader_child.offset(), 0);
809 assert_eq!(buffer_reader_child.capacity(), 1);
810
811 assert!(matches!(buffer_reader_child.get_bytes(1), Some([0x2])));
812 assert_eq!(buffer_reader_child.offset(), 1);
813
814 assert_eq!(buffer_reader.offset(), 1);
815 assert_eq!(buffer_reader.capacity(), 1);
816 }
817
818 #[test]
819 fn none() {
820 let mut buffer_reader = BufferReader::new(&[]);
821 assert!(buffer_reader.get_varint().is_none());
822 assert!(buffer_reader.get_bytes(1).is_none());
823
824 let mut buffer_writer = BufferWriter::new(&mut []);
825 assert!(buffer_writer.put_varint(VarInt::from_u32(0)).is_err());
826 assert!(buffer_writer.put_bytes(&[0x0]).is_err());
827 }
828
829 #[tokio::test]
830 async fn none_async() {
831 let mut reader = utils::StepReader::new(vec![]);
832 assert!(reader.get_varint().await.is_err());
833 assert!(reader.get_buffer(&mut [0x0]).await.is_err());
834
835 let mut writer = utils::StepWriter::new(Some(0));
836 assert!(writer.put_varint(VarInt::from_u32(0)).await.is_err());
837 assert!(writer.put_buffer(&[0x0]).await.is_err());
838 }
839
840 #[tokio::test]
841 async fn fin_varint() {
842 for (buffer, _) in utils::VARINT_TEST_CASES {
843 for len in 0..buffer.len() {
844 let result = BytesReaderAsync::get_varint(&mut &buffer[..len]).await;
845
846 match len {
847 0 => assert!(matches!(result, Err(IoReadError::ImmediateFin))),
848 _ => assert!(matches!(result, Err(IoReadError::UnexpectedFin))),
849 }
850 }
851 }
852 }
853
854 #[tokio::test]
855 async fn fin_buffer() {
856 let mut buffer = [0; utils::BUFFER_TEST.len()];
857
858 for len in 0..utils::BUFFER_TEST.len() {
859 let result = (&mut &utils::BUFFER_TEST[..len])
860 .get_buffer(&mut buffer)
861 .await;
862
863 match len {
864 0 => assert!(matches!(result, Err(IoReadError::ImmediateFin))),
865 _ => assert!(matches!(result, Err(IoReadError::UnexpectedFin))),
866 }
867 }
868 }
869
870 mod utils {
871 use super::*;
872
873 pub const VARINT_TEST_CASES: [(&[u8], VarInt); 4] = [
874 (&[0xc2, 0x19, 0x7c, 0x5e, 0xff, 0x14, 0xe8, 0x8c], unsafe {
875 VarInt::from_u64_unchecked(151_288_809_941_952_652)
876 }),
877 (&[0x9d, 0x7f, 0x3e, 0x7d], VarInt::from_u32(494_878_333)),
878 (&[0x7b, 0xbd], VarInt::from_u32(15_293)),
879 (&[0x25], VarInt::from_u32(37)),
880 ];
881
882 pub const BUFFER_TEST: &[u8] = b"WebTransport";
883
884 #[cfg(feature = "async")]
885 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
886 pub mod r#async {
887 use super::*;
888 use std::pin::Pin;
889 use std::task::Context;
890 use std::task::Poll;
891
892 pub struct StepReader {
893 data: Box<[u8]>,
894 offset: usize,
895 to_pending: bool,
896 }
897
898 impl StepReader {
899 pub fn new<T>(data: T) -> Self
900 where
901 T: Into<Box<[u8]>>,
902 {
903 Self {
904 data: data.into(),
905 offset: 0,
906 to_pending: true,
907 }
908 }
909 }
910
911 impl AsyncRead for StepReader {
912 fn poll_read(
913 mut self: Pin<&mut Self>,
914 cx: &mut Context<'_>,
915 buf: &mut [u8],
916 ) -> Poll<std::io::Result<usize>> {
917 let new_pending = !self.to_pending;
918 let to_pending = std::mem::replace(&mut self.to_pending, new_pending);
919
920 if buf.is_empty() {
921 return Poll::Ready(Ok(0));
922 }
923
924 if to_pending {
925 cx.waker().wake_by_ref();
926 Poll::Pending
927 } else if let Some(&byte) = self.data.get(self.offset) {
928 buf[0] = byte;
929 self.offset += 1;
930 Poll::Ready(Ok(1))
931 } else {
932 Poll::Ready(Ok(0))
933 }
934 }
935 }
936
937 pub struct StepWriter {
938 buffer: Vec<u8>,
939 max_len: Option<usize>,
940 to_pending: bool,
941 }
942
943 impl StepWriter {
944 pub fn new(max_len: Option<usize>) -> Self {
945 Self {
946 buffer: Vec::new(),
947 max_len,
948 to_pending: true,
949 }
950 }
951
952 pub fn written(&self) -> &[u8] {
953 &self.buffer
954 }
955 }
956
957 impl AsyncWrite for StepWriter {
958 fn poll_write(
959 mut self: Pin<&mut Self>,
960 cx: &mut Context<'_>,
961 buf: &[u8],
962 ) -> Poll<Result<usize, std::io::Error>> {
963 let new_pending = !self.to_pending;
964 let to_pending = std::mem::replace(&mut self.to_pending, new_pending);
965
966 if buf.is_empty() {
967 return Poll::Ready(Ok(0));
968 }
969
970 if to_pending {
971 cx.waker().wake_by_ref();
972 Poll::Pending
973 } else if self.buffer.len() < self.max_len.unwrap_or(usize::MAX) {
974 let byte = buf[0];
975 self.buffer.push(byte);
976 Poll::Ready(Ok(1))
977 } else {
978 Poll::Ready(Err(std::io::Error::new(
979 std::io::ErrorKind::ConnectionReset,
980 "Reached max len",
981 )))
982 }
983 }
984 }
985 }
986
987 #[cfg(feature = "async")]
988 pub use r#async::StepReader;
989
990 #[cfg(feature = "async")]
991 pub use r#async::StepWriter;
992 }
993}