wtransport_proto_lightyear_patch/
bytes.rs

1use crate::varint::VarInt;
2use octets::Octets;
3use octets::OctetsMut;
4use std::ops::Deref;
5use std::ops::DerefMut;
6
7/// An error indicating write operation was not able to complete because
8/// end of buffer has been reached.
9#[derive(Debug)]
10pub struct EndOfBuffer;
11
12/// Reads bytes or varint from a source.
13pub trait BytesReader<'a> {
14    /// Reads an unsigned variable-length integer in network byte-order from
15    /// the current offset and advances the offset.
16    ///
17    /// Returns [`None`] if not enough capacity (offset is not advanced in that case).
18    fn get_varint(&mut self) -> Option<VarInt>;
19
20    /// Reads `len` bytes from the current offset **without copying** and advances
21    /// the offset.
22    ///
23    /// Returns [`None`] if not enough capacity (offset is not advanced in that case).
24    fn get_bytes(&mut self, len: usize) -> Option<&'a [u8]>;
25}
26
27impl<'a> BytesReader<'a> for &'a [u8] {
28    fn get_varint(&mut self) -> Option<VarInt> {
29        let varint_size = VarInt::parse_size(*self.first()?);
30        let buffer = self.get(..varint_size)?;
31        let varint = BufferReader::new(buffer)
32            .get_varint()
33            .expect("Varint parsable");
34        *self = &self[varint_size..];
35        Some(varint)
36    }
37
38    fn get_bytes(&mut self, len: usize) -> Option<&'a [u8]> {
39        let buffer = self.get(..len)?;
40        *self = &self[len..];
41        Some(buffer)
42    }
43}
44
45/// Writes bytes or varint on a source.
46pub trait BytesWriter {
47    /// Writes an unsigned variable-length integer in network byte-order at the
48    /// current offset and advances the offset.
49    ///
50    /// Returns [`Err`] if source is exhausted and no space is available.
51    fn put_varint(&mut self, varint: VarInt) -> Result<(), EndOfBuffer>;
52
53    /// Writes (by **copy**) all `bytes` at the current offset and advances it.
54    ///
55    /// Returns [`Err`] if source is exhausted and no space is available.
56    fn put_bytes(&mut self, bytes: &[u8]) -> Result<(), EndOfBuffer>;
57}
58
59impl BytesWriter for Vec<u8> {
60    fn put_varint(&mut self, varint: VarInt) -> Result<(), EndOfBuffer> {
61        let offset = self.len();
62
63        self.resize(offset + varint.size(), 0);
64
65        BufferWriter::new(&mut self[offset..])
66            .put_varint(varint)
67            .expect("Enough capacity pre-allocated");
68
69        Ok(())
70    }
71
72    fn put_bytes(&mut self, bytes: &[u8]) -> Result<(), EndOfBuffer> {
73        self.extend_from_slice(bytes);
74        Ok(())
75    }
76}
77
78/// A zero-copy immutable byte-buffer reader.
79///
80/// Internally, it stores an offset that is increased during reading.
81pub struct BufferReader<'a>(Octets<'a>);
82
83impl<'a> BufferReader<'a> {
84    /// Creates a [`BufferReader`] from the given slice, without copying.
85    ///
86    /// Inner offset is initialized to zero.
87    #[inline(always)]
88    pub fn new(buffer: &'a [u8]) -> Self {
89        Self(Octets::with_slice(buffer))
90    }
91
92    /// Returns the remaining capacity in the buffer.
93    #[inline(always)]
94    pub fn capacity(&self) -> usize {
95        self.0.cap()
96    }
97
98    /// Returns the current offset of the buffer.
99    #[inline(always)]
100    pub fn offset(&self) -> usize {
101        self.0.off()
102    }
103
104    /// Advances the offset.
105    ///
106    /// In case of [`Err`] the offset is not advanced.
107    #[inline(always)]
108    pub fn skip(&mut self, len: usize) -> Result<(), EndOfBuffer> {
109        self.0
110            .skip(len)
111            .map_err(|octets::BufferTooShortError| EndOfBuffer)
112    }
113
114    /// Returns a reference to the internal buffer.
115    ///
116    /// **Note**: this is the entire buffer (despite offset).
117    #[inline(always)]
118    pub fn buffer(&self) -> &'a [u8] {
119        self.0.buf()
120    }
121
122    /// Returns the inner buffer starting from the current offset.
123    #[inline(always)]
124    pub fn buffer_remaining(&mut self) -> &'a [u8] {
125        &self.buffer()[self.offset()..]
126    }
127
128    /// Creates a [`BufferReaderChild`] with this parent.
129    #[inline(always)]
130    pub fn child(&mut self) -> BufferReaderChild<'a, '_> {
131        BufferReaderChild::with_parent(self)
132    }
133}
134
135impl<'a> BytesReader<'a> for BufferReader<'a> {
136    #[inline(always)]
137    fn get_varint(&mut self) -> Option<VarInt> {
138        match self.0.get_varint() {
139            Ok(value) => {
140                // SAFETY: octets returns a legit varint
141                Some(unsafe {
142                    debug_assert!(value <= VarInt::MAX.into_inner());
143                    VarInt::from_u64_unchecked(value)
144                })
145            }
146            Err(octets::BufferTooShortError) => None,
147        }
148    }
149
150    #[inline(always)]
151    fn get_bytes(&mut self, len: usize) -> Option<&'a [u8]> {
152        self.0.get_bytes(len).ok().map(|o| o.buf())
153    }
154}
155
156/// It acts like a copy of a parent [`BufferReader`].
157///
158/// You can create this from [`BufferReader::child`]. The child offset will be set
159/// to `0`, but its underlying buffer will start from the current parent's offset.
160///
161/// Having a copy it allows reading the buffer preserving the parent's original offset.
162///
163/// If you want to commit the number of bytes read to the parent, use [`BufferReaderChild::commit`].
164/// Instead, dropping this without committing, it will not alter the parent.
165pub struct BufferReaderChild<'a, 'b> {
166    reader: BufferReader<'a>,
167    parent: &'b mut BufferReader<'a>,
168}
169
170impl<'a, 'b> BufferReaderChild<'a, 'b> {
171    /// Advances the parent [`BufferReader`] offset of the amount read with this child.
172    #[inline(always)]
173    pub fn commit(self) {
174        self.parent
175            .skip(self.reader.offset())
176            .expect("Child offset is bounded to parent");
177    }
178
179    #[inline(always)]
180    fn with_parent(parent: &'b mut BufferReader<'a>) -> Self {
181        Self {
182            reader: BufferReader::new(parent.buffer_remaining()),
183            parent,
184        }
185    }
186}
187
188impl<'a, 'b> Deref for BufferReaderChild<'a, 'b> {
189    type Target = BufferReader<'a>;
190
191    #[inline(always)]
192    fn deref(&self) -> &Self::Target {
193        &self.reader
194    }
195}
196
197impl<'a, 'b> DerefMut for BufferReaderChild<'a, 'b> {
198    #[inline(always)]
199    fn deref_mut(&mut self) -> &mut Self::Target {
200        &mut self.reader
201    }
202}
203
204/// A zero-copy mutable buffer writer.
205pub struct BufferWriter<'a>(OctetsMut<'a>);
206
207impl<'a> BufferWriter<'a> {
208    /// Creates an [`BufferWriter`] by using `bytes` as inner buffer.
209    ///
210    /// Inner offset is initialized to zero.
211    #[inline(always)]
212    pub fn new(bytes: &'a mut [u8]) -> Self {
213        Self(OctetsMut::with_slice(bytes))
214    }
215
216    /// Returns the remaining capacity in the buffer.
217    #[inline(always)]
218    pub fn capacity(&self) -> usize {
219        self.0.cap()
220    }
221
222    /// Returns the current offset of the buffer.
223    #[inline(always)]
224    pub fn offset(&self) -> usize {
225        self.0.off()
226    }
227
228    /// Returns the portion of the inner buffer written so far.
229    #[inline(always)]
230    pub fn buffer_written(&self) -> &[u8] {
231        &self.0.buf()[..self.offset()]
232    }
233}
234
235impl<'a> BytesWriter for BufferWriter<'a> {
236    #[inline(always)]
237    fn put_varint(&mut self, varint: VarInt) -> Result<(), EndOfBuffer> {
238        self.0
239            .put_varint(varint.into_inner())
240            .map_err(|octets::BufferTooShortError| EndOfBuffer)?;
241
242        Ok(())
243    }
244
245    #[inline(always)]
246    fn put_bytes(&mut self, bytes: &[u8]) -> Result<(), EndOfBuffer> {
247        self.0
248            .put_bytes(bytes)
249            .map_err(|octets::BufferTooShortError| EndOfBuffer)
250    }
251}
252
253/// Async operations.
254#[cfg(feature = "async")]
255#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
256pub mod r#async {
257    use super::*;
258    use std::future::Future;
259    use std::io::ErrorKind as IoErrorKind;
260    use std::pin::Pin;
261    use std::task::ready;
262    use std::task::Context;
263    use std::task::Poll;
264
265    /// Error during read operations.
266    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
267    #[derive(Debug)]
268    pub enum IoReadError {
269        /// Read failed because immediate EOF (attempt reading the first byte).
270        ///
271        /// In this case, *zero* bytes have been read during the operation.
272        ImmediateFin,
273
274        /// Read failed because EOF reached in the middle of operation.
275        ///
276        /// In this case, *at least* one byte has been read during the operation.
277        UnexpectedFin,
278
279        /// Read failed because peer interrupted operation (at any point).
280        ///
281        /// In this case, zero or more bytes might be have read during the operation.
282        Reset,
283
284        /// Read failed because peer not is not connected, or disconnected (at any point).
285        ///
286        /// In this case, zero or more bytes might be have read during the operation.
287        NotConnected,
288    }
289
290    /// Error during write operation.
291    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
292    #[derive(Debug)]
293    pub enum IoWriteError {
294        /// Write failed because peer stopped operation.
295        ///
296        /// In this case, zero or more bytes might be have written during the operation.
297        Stopped,
298
299        /// Write failed because peer not is not connected.
300        ///
301        /// In this case, zero or more bytes might be have written during the operation.
302        NotConnected,
303    }
304
305    impl From<std::io::Error> for IoReadError {
306        fn from(error: std::io::Error) -> Self {
307            match error.kind() {
308                IoErrorKind::ConnectionReset => IoReadError::Reset,
309                _ => IoReadError::NotConnected,
310            }
311        }
312    }
313
314    impl From<std::io::Error> for IoWriteError {
315        fn from(error: std::io::Error) -> Self {
316            match error.kind() {
317                IoErrorKind::ConnectionReset => IoWriteError::Stopped,
318                _ => IoWriteError::NotConnected,
319            }
320        }
321    }
322
323    /// Reads bytes from a source.
324    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
325    pub trait AsyncRead {
326        /// Attempt to read from the source into `buf`.
327        ///
328        /// Generally, an implementation will perform a **copy**.
329        ///
330        /// On success, it returns `Ok(num_bytes_read)`, that is the
331        /// length of bytes written into `buf`.
332        ///
333        /// It returns `0` if and only if:
334        ///   * `buf` is empty; or
335        ///   * The source reached its end (the stream is exhausted / EOF).
336        ///
337        /// An implementation SHOULD only generates the following errors:
338        ///   * [`std::io::ErrorKind::ConnectionReset`] if the read operation was explicitly truncated
339        ///      by the source.
340        ///   * [`std::io::ErrorKind::NotConnected`] if the read operation aborted at any point because
341        ///      lack of communication with the source.
342        fn poll_read(
343            self: Pin<&mut Self>,
344            cx: &mut Context<'_>,
345            buf: &mut [u8],
346        ) -> Poll<std::io::Result<usize>>;
347    }
348
349    impl AsyncRead for &[u8] {
350        fn poll_read(
351            mut self: Pin<&mut Self>,
352            _cx: &mut Context<'_>,
353            buf: &mut [u8],
354        ) -> Poll<std::io::Result<usize>> {
355            let amt = std::cmp::min(self.len(), buf.len());
356            let (a, b) = self.split_at(amt);
357            buf[..amt].copy_from_slice(a);
358            *self = b;
359            Poll::Ready(Ok(amt))
360        }
361    }
362
363    /// Writes bytes into a destination.
364    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
365    pub trait AsyncWrite {
366        /// Attempt to write `buf` into the destination.
367        ///
368        /// Generally, an implementation will perform a **copy**.
369        ///
370        /// On success, it returns `Ok(num_bytes_written)`, that is the number
371        /// of bytes written.
372        /// Note that, it is possible that not the entire `buf` will be written (for instance,
373        /// because of a mechanism of flow controller or limited capacity).
374        ///
375        /// An implementation SHOULD never return `Ok(0)` if `buf` is not empty.
376        ///
377        /// An implementation SHOULD only generates the following errors:
378        ///   * [`std::io::ErrorKind::ConnectionReset`] if the write operation was explicitly stopped
379        ///      by the destination.
380        ///   * [`std::io::ErrorKind::NotConnected`] if the write operation aborted at any point because
381        ///      lack of communication with the destination.
382        fn poll_write(
383            self: Pin<&mut Self>,
384            cx: &mut Context<'_>,
385            buf: &[u8],
386        ) -> Poll<std::io::Result<usize>>;
387    }
388
389    impl AsyncWrite for Vec<u8> {
390        fn poll_write(
391            mut self: Pin<&mut Self>,
392            _cx: &mut Context<'_>,
393            buf: &[u8],
394        ) -> Poll<std::io::Result<usize>> {
395            self.extend_from_slice(buf);
396            Poll::Ready(std::io::Result::Ok(buf.len()))
397        }
398    }
399
400    /// Reads bytes or varints asynchronously.
401    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
402    pub trait BytesReaderAsync {
403        /// Reads an unsigned variable-length integer in network byte-order from a source.
404        fn get_varint(&mut self) -> GetVarint<Self>;
405
406        /// Reads the source until `buffer` is completely filled.
407        fn get_buffer<'a>(&'a mut self, buffer: &'a mut [u8]) -> GetBuffer<Self>;
408    }
409
410    impl<T> BytesReaderAsync for T
411    where
412        T: AsyncRead + ?Sized,
413    {
414        fn get_varint(&mut self) -> GetVarint<Self> {
415            GetVarint::new(self)
416        }
417
418        fn get_buffer<'a>(&'a mut self, buffer: &'a mut [u8]) -> GetBuffer<Self> {
419            GetBuffer::new(self, buffer)
420        }
421    }
422
423    impl<T> BytesWriterAsync for T
424    where
425        T: AsyncWrite + ?Sized,
426    {
427        fn put_varint(&mut self, varint: VarInt) -> PutVarint<Self> {
428            PutVarint::new(self, varint)
429        }
430
431        fn put_buffer<'a>(&'a mut self, buffer: &'a [u8]) -> PutBuffer<Self> {
432            PutBuffer::new(self, buffer)
433        }
434    }
435
436    /// Writes bytes or varints asynchronously.
437    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
438    pub trait BytesWriterAsync {
439        /// Writes an unsigned variable-length integer in network byte-order to
440        /// the source advancing the buffer's internal cursor.
441        fn put_varint(&mut self, varint: VarInt) -> PutVarint<Self>;
442
443        /// Pushes some bytes into ths source advancing the buffer’s internal cursor.
444        fn put_buffer<'a>(&'a mut self, buffer: &'a [u8]) -> PutBuffer<Self>;
445    }
446
447    /// [`Future`] for reading a varint.
448    ///
449    /// Created by [`BytesReaderAsync::get_varint`].
450    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
451    pub struct GetVarint<'a, R: ?Sized> {
452        reader: &'a mut R,
453        buffer: [u8; VarInt::MAX_SIZE],
454        offset: usize,
455        varint_size: usize,
456    }
457
458    impl<'a, R> GetVarint<'a, R>
459    where
460        R: AsyncRead + ?Sized,
461    {
462        fn new(reader: &'a mut R) -> Self {
463            Self {
464                reader,
465                buffer: [0; VarInt::MAX_SIZE],
466                offset: 0,
467                varint_size: 0,
468            }
469        }
470    }
471
472    impl<'a, R> Future for GetVarint<'a, R>
473    where
474        R: AsyncRead + Unpin + ?Sized,
475    {
476        type Output = Result<VarInt, IoReadError>;
477
478        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
479            let this = self.get_mut();
480
481            if this.offset == 0 {
482                debug_assert_eq!(this.varint_size, 0);
483
484                let read = ready!(AsyncRead::poll_read(
485                    Pin::new(this.reader),
486                    cx,
487                    &mut this.buffer[0..1]
488                ))?;
489
490                debug_assert!(read == 0 || read == 1);
491
492                if read == 1 {
493                    this.offset = 1;
494                    this.varint_size = VarInt::parse_size(this.buffer[0]);
495                    debug_assert_ne!(this.varint_size, 0);
496                } else {
497                    return Poll::Ready(Err(IoReadError::ImmediateFin));
498                }
499            }
500
501            while this.offset < this.varint_size {
502                let read = ready!(AsyncRead::poll_read(
503                    Pin::new(this.reader),
504                    cx,
505                    &mut this.buffer[this.offset..this.varint_size]
506                ))?;
507
508                debug_assert!(read <= this.varint_size - this.offset);
509
510                if read > 0 {
511                    this.offset += read;
512                } else {
513                    return Poll::Ready(Err(IoReadError::UnexpectedFin));
514                }
515            }
516
517            let varint = BufferReader::new(&this.buffer[..this.varint_size])
518                .get_varint()
519                .expect("Varint is parsable");
520
521            Poll::Ready(Ok(varint))
522        }
523    }
524
525    /// [`Future`] for reading a buffer of bytes.
526    ///
527    /// Created by [`BytesReaderAsync::get_buffer`].
528    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
529    pub struct GetBuffer<'a, R: ?Sized> {
530        reader: &'a mut R,
531        buffer: &'a mut [u8],
532        offset: usize,
533    }
534
535    impl<'a, R> GetBuffer<'a, R>
536    where
537        R: AsyncRead + ?Sized,
538    {
539        fn new(reader: &'a mut R, buffer: &'a mut [u8]) -> Self {
540            Self {
541                reader,
542                buffer,
543                offset: 0,
544            }
545        }
546    }
547
548    impl<'a, R> Future for GetBuffer<'a, R>
549    where
550        R: AsyncRead + Unpin + ?Sized,
551    {
552        type Output = Result<(), IoReadError>;
553
554        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
555            let this = self.get_mut();
556
557            while this.offset < this.buffer.len() {
558                let read = ready!(AsyncRead::poll_read(
559                    Pin::new(this.reader),
560                    cx,
561                    &mut this.buffer[this.offset..],
562                ))?;
563
564                debug_assert!(read <= this.buffer.len() - this.offset);
565
566                if read > 0 {
567                    this.offset += read;
568                } else if this.offset > 0 {
569                    return Poll::Ready(Err(IoReadError::UnexpectedFin));
570                } else {
571                    return Poll::Ready(Err(IoReadError::ImmediateFin));
572                }
573            }
574
575            Poll::Ready(Ok(()))
576        }
577    }
578
579    /// [`Future`] for writing a varint.
580    ///
581    /// Created by [`BytesWriterAsync::put_varint`].
582    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
583    pub struct PutVarint<'a, W: ?Sized> {
584        writer: &'a mut W,
585        buffer: [u8; VarInt::MAX_SIZE],
586        offset: usize,
587        varint_size: usize,
588    }
589
590    impl<'a, W> PutVarint<'a, W>
591    where
592        W: AsyncWrite + ?Sized,
593    {
594        fn new(writer: &'a mut W, varint: VarInt) -> Self {
595            let mut this = Self {
596                writer,
597                buffer: [0; VarInt::MAX_SIZE],
598                offset: 0,
599                varint_size: 0,
600            };
601
602            let mut buffer_writer = BufferWriter::new(&mut this.buffer);
603            buffer_writer
604                .put_varint(varint)
605                .expect("Inner buffer is enough for max varint");
606
607            this.varint_size = buffer_writer.offset();
608
609            this
610        }
611    }
612
613    impl<'a, W> Future for PutVarint<'a, W>
614    where
615        W: AsyncWrite + Unpin + ?Sized,
616    {
617        type Output = Result<(), IoWriteError>;
618
619        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
620            let this = self.get_mut();
621
622            while this.offset < this.varint_size {
623                let written = ready!(AsyncWrite::poll_write(
624                    Pin::new(this.writer),
625                    cx,
626                    &this.buffer[this.offset..this.varint_size]
627                ))?;
628
629                // TODO(bfesta): what if AsyncWrite returns Ok(0)? maybe wake and pending?
630                debug_assert!(written > 0);
631
632                this.offset += written;
633            }
634
635            Poll::Ready(Ok(()))
636        }
637    }
638
639    /// [`Future`] for writing a buffer of bytes.
640    ///
641    /// Created by [`BytesWriterAsync::put_buffer`].
642    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
643    pub struct PutBuffer<'a, W: ?Sized> {
644        writer: &'a mut W,
645        buffer: &'a [u8],
646        offset: usize,
647    }
648
649    impl<'a, W> PutBuffer<'a, W>
650    where
651        W: AsyncWrite + ?Sized,
652    {
653        fn new(writer: &'a mut W, buffer: &'a [u8]) -> Self {
654            Self {
655                writer,
656                buffer,
657                offset: 0,
658            }
659        }
660    }
661
662    impl<'a, W> Future for PutBuffer<'a, W>
663    where
664        W: AsyncWrite + Unpin + ?Sized,
665    {
666        type Output = Result<(), IoWriteError>;
667
668        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
669            let this = self.get_mut();
670
671            while this.offset < this.buffer.len() {
672                let written = ready!(AsyncWrite::poll_write(
673                    Pin::new(this.writer),
674                    cx,
675                    &this.buffer[this.offset..]
676                ))?;
677
678                // TODO(bfesta): what if AsyncWrite returns Ok(0)? maybe wake and pending?
679                debug_assert!(written > 0);
680
681                this.offset += written;
682            }
683
684            Poll::Ready(Ok(()))
685        }
686    }
687}
688
689#[cfg(feature = "async")]
690pub use r#async::AsyncRead;
691
692#[cfg(feature = "async")]
693pub use r#async::AsyncWrite;
694
695#[cfg(feature = "async")]
696pub use r#async::BytesReaderAsync;
697
698#[cfg(feature = "async")]
699pub use r#async::BytesWriterAsync;
700
701#[cfg(feature = "async")]
702pub use r#async::IoReadError;
703
704#[cfg(feature = "async")]
705pub use r#async::IoWriteError;
706
707#[cfg(test)]
708mod tests {
709    use super::*;
710
711    #[test]
712    fn read_varint() {
713        for (varint_buffer, value_expect) in utils::VARINT_TEST_CASES {
714            let mut buffer_reader = BufferReader::new(varint_buffer);
715
716            assert_eq!(buffer_reader.offset(), 0);
717            assert_eq!(buffer_reader.capacity(), varint_buffer.len());
718
719            let value = buffer_reader.get_varint().unwrap();
720
721            assert_eq!(value, value_expect);
722            assert_eq!(buffer_reader.offset(), varint_buffer.len());
723            assert_eq!(buffer_reader.capacity(), 0);
724        }
725    }
726
727    #[tokio::test]
728    async fn read_varint_async() {
729        for (varint_buffer, value_expect) in utils::VARINT_TEST_CASES {
730            let mut reader = utils::StepReader::new(varint_buffer);
731            let value = reader.get_varint().await.unwrap();
732            assert_eq!(value, value_expect);
733        }
734    }
735
736    #[test]
737    fn read_buffer() {
738        let mut buffer_reader = BufferReader::new(utils::BUFFER_TEST);
739        let value = buffer_reader.get_bytes(utils::BUFFER_TEST.len()).unwrap();
740        assert_eq!(value, utils::BUFFER_TEST);
741    }
742
743    #[tokio::test]
744    async fn read_buffer_async() {
745        let mut value = [0; utils::BUFFER_TEST.len()];
746        let mut reader = utils::StepReader::new(utils::BUFFER_TEST);
747        reader.get_buffer(&mut value).await.unwrap();
748        assert_eq!(value, utils::BUFFER_TEST);
749    }
750
751    #[test]
752    fn write_varint() {
753        let mut buffer = [0; VarInt::MAX_SIZE];
754        for (varint_buffer, value) in utils::VARINT_TEST_CASES {
755            let mut buffer_writer = BufferWriter::new(&mut buffer);
756
757            assert_eq!(buffer_writer.offset(), 0);
758            assert_eq!(buffer_writer.capacity(), VarInt::MAX_SIZE);
759
760            buffer_writer.put_varint(value).unwrap();
761
762            assert_eq!(buffer_writer.offset(), varint_buffer.len());
763            assert_eq!(buffer_writer.buffer_written(), varint_buffer);
764        }
765    }
766
767    #[tokio::test]
768    async fn write_varint_async() {
769        for (varint_buffer, value) in utils::VARINT_TEST_CASES {
770            let mut writer = utils::StepWriter::new(Some(8));
771
772            writer.put_varint(value).await.unwrap();
773            assert_eq!(value.size(), writer.written().len());
774            assert_eq!(writer.written(), varint_buffer);
775        }
776    }
777
778    #[test]
779    fn child_commit() {
780        let mut buffer_reader = BufferReader::new(&[0x1, 0x2]);
781
782        buffer_reader.skip(1).unwrap();
783        assert_eq!(buffer_reader.offset(), 1);
784        assert_eq!(buffer_reader.capacity(), 1);
785
786        let mut buffer_reader_child = buffer_reader.child();
787        assert_eq!(buffer_reader_child.offset(), 0);
788        assert_eq!(buffer_reader_child.capacity(), 1);
789
790        assert!(matches!(buffer_reader_child.get_bytes(1), Some([0x2])));
791        assert_eq!(buffer_reader_child.offset(), 1);
792
793        buffer_reader_child.commit();
794
795        assert_eq!(buffer_reader.offset(), 2);
796        assert_eq!(buffer_reader.capacity(), 0);
797    }
798
799    #[test]
800    fn child_drop() {
801        let mut buffer_reader = BufferReader::new(&[0x1, 0x2]);
802
803        buffer_reader.skip(1).unwrap();
804        assert_eq!(buffer_reader.offset(), 1);
805        assert_eq!(buffer_reader.capacity(), 1);
806
807        let mut buffer_reader_child = buffer_reader.child();
808        assert_eq!(buffer_reader_child.offset(), 0);
809        assert_eq!(buffer_reader_child.capacity(), 1);
810
811        assert!(matches!(buffer_reader_child.get_bytes(1), Some([0x2])));
812        assert_eq!(buffer_reader_child.offset(), 1);
813
814        assert_eq!(buffer_reader.offset(), 1);
815        assert_eq!(buffer_reader.capacity(), 1);
816    }
817
818    #[test]
819    fn none() {
820        let mut buffer_reader = BufferReader::new(&[]);
821        assert!(buffer_reader.get_varint().is_none());
822        assert!(buffer_reader.get_bytes(1).is_none());
823
824        let mut buffer_writer = BufferWriter::new(&mut []);
825        assert!(buffer_writer.put_varint(VarInt::from_u32(0)).is_err());
826        assert!(buffer_writer.put_bytes(&[0x0]).is_err());
827    }
828
829    #[tokio::test]
830    async fn none_async() {
831        let mut reader = utils::StepReader::new(vec![]);
832        assert!(reader.get_varint().await.is_err());
833        assert!(reader.get_buffer(&mut [0x0]).await.is_err());
834
835        let mut writer = utils::StepWriter::new(Some(0));
836        assert!(writer.put_varint(VarInt::from_u32(0)).await.is_err());
837        assert!(writer.put_buffer(&[0x0]).await.is_err());
838    }
839
840    #[tokio::test]
841    async fn fin_varint() {
842        for (buffer, _) in utils::VARINT_TEST_CASES {
843            for len in 0..buffer.len() {
844                let result = BytesReaderAsync::get_varint(&mut &buffer[..len]).await;
845
846                match len {
847                    0 => assert!(matches!(result, Err(IoReadError::ImmediateFin))),
848                    _ => assert!(matches!(result, Err(IoReadError::UnexpectedFin))),
849                }
850            }
851        }
852    }
853
854    #[tokio::test]
855    async fn fin_buffer() {
856        let mut buffer = [0; utils::BUFFER_TEST.len()];
857
858        for len in 0..utils::BUFFER_TEST.len() {
859            let result = (&mut &utils::BUFFER_TEST[..len])
860                .get_buffer(&mut buffer)
861                .await;
862
863            match len {
864                0 => assert!(matches!(result, Err(IoReadError::ImmediateFin))),
865                _ => assert!(matches!(result, Err(IoReadError::UnexpectedFin))),
866            }
867        }
868    }
869
870    mod utils {
871        use super::*;
872
873        pub const VARINT_TEST_CASES: [(&[u8], VarInt); 4] = [
874            (&[0xc2, 0x19, 0x7c, 0x5e, 0xff, 0x14, 0xe8, 0x8c], unsafe {
875                VarInt::from_u64_unchecked(151_288_809_941_952_652)
876            }),
877            (&[0x9d, 0x7f, 0x3e, 0x7d], VarInt::from_u32(494_878_333)),
878            (&[0x7b, 0xbd], VarInt::from_u32(15_293)),
879            (&[0x25], VarInt::from_u32(37)),
880        ];
881
882        pub const BUFFER_TEST: &[u8] = b"WebTransport";
883
884        #[cfg(feature = "async")]
885        #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
886        pub mod r#async {
887            use super::*;
888            use std::pin::Pin;
889            use std::task::Context;
890            use std::task::Poll;
891
892            pub struct StepReader {
893                data: Box<[u8]>,
894                offset: usize,
895                to_pending: bool,
896            }
897
898            impl StepReader {
899                pub fn new<T>(data: T) -> Self
900                where
901                    T: Into<Box<[u8]>>,
902                {
903                    Self {
904                        data: data.into(),
905                        offset: 0,
906                        to_pending: true,
907                    }
908                }
909            }
910
911            impl AsyncRead for StepReader {
912                fn poll_read(
913                    mut self: Pin<&mut Self>,
914                    cx: &mut Context<'_>,
915                    buf: &mut [u8],
916                ) -> Poll<std::io::Result<usize>> {
917                    let new_pending = !self.to_pending;
918                    let to_pending = std::mem::replace(&mut self.to_pending, new_pending);
919
920                    if buf.is_empty() {
921                        return Poll::Ready(Ok(0));
922                    }
923
924                    if to_pending {
925                        cx.waker().wake_by_ref();
926                        Poll::Pending
927                    } else if let Some(&byte) = self.data.get(self.offset) {
928                        buf[0] = byte;
929                        self.offset += 1;
930                        Poll::Ready(Ok(1))
931                    } else {
932                        Poll::Ready(Ok(0))
933                    }
934                }
935            }
936
937            pub struct StepWriter {
938                buffer: Vec<u8>,
939                max_len: Option<usize>,
940                to_pending: bool,
941            }
942
943            impl StepWriter {
944                pub fn new(max_len: Option<usize>) -> Self {
945                    Self {
946                        buffer: Vec::new(),
947                        max_len,
948                        to_pending: true,
949                    }
950                }
951
952                pub fn written(&self) -> &[u8] {
953                    &self.buffer
954                }
955            }
956
957            impl AsyncWrite for StepWriter {
958                fn poll_write(
959                    mut self: Pin<&mut Self>,
960                    cx: &mut Context<'_>,
961                    buf: &[u8],
962                ) -> Poll<Result<usize, std::io::Error>> {
963                    let new_pending = !self.to_pending;
964                    let to_pending = std::mem::replace(&mut self.to_pending, new_pending);
965
966                    if buf.is_empty() {
967                        return Poll::Ready(Ok(0));
968                    }
969
970                    if to_pending {
971                        cx.waker().wake_by_ref();
972                        Poll::Pending
973                    } else if self.buffer.len() < self.max_len.unwrap_or(usize::MAX) {
974                        let byte = buf[0];
975                        self.buffer.push(byte);
976                        Poll::Ready(Ok(1))
977                    } else {
978                        Poll::Ready(Err(std::io::Error::new(
979                            std::io::ErrorKind::ConnectionReset,
980                            "Reached max len",
981                        )))
982                    }
983                }
984            }
985        }
986
987        #[cfg(feature = "async")]
988        pub use r#async::StepReader;
989
990        #[cfg(feature = "async")]
991        pub use r#async::StepWriter;
992    }
993}