s2n_quic_core/buffer/reader/
checked.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    buffer::{
6        reader::{Reader, Storage},
7        writer,
8    },
9    varint::VarInt,
10};
11
12#[cfg(debug_assertions)]
13use crate::buffer::reader::storage::Infallible;
14
15/// Ensures [`Reader`] invariants are held as each trait function is called
16pub struct Checked<'a, R>
17where
18    R: Reader + ?Sized,
19{
20    inner: &'a mut R,
21    #[cfg(debug_assertions)]
22    chunk: alloc::vec::Vec<u8>,
23}
24
25impl<'a, R> Checked<'a, R>
26where
27    R: Reader + ?Sized,
28{
29    #[inline(always)]
30    pub fn new(inner: &'a mut R) -> Self {
31        Self {
32            inner,
33            #[cfg(debug_assertions)]
34            chunk: Default::default(),
35        }
36    }
37}
38
39/// Forward on to the inner reader when debug_assertions are disabled
40#[cfg(not(debug_assertions))]
41impl<'a, R> Storage for Checked<'a, R>
42where
43    R: Reader + ?Sized,
44{
45    type Error = R::Error;
46
47    #[inline(always)]
48    fn buffered_len(&self) -> usize {
49        self.inner.buffered_len()
50    }
51
52    #[inline(always)]
53    fn buffer_is_empty(&self) -> bool {
54        self.inner.buffer_is_empty()
55    }
56
57    #[inline(always)]
58    fn read_chunk(&mut self, watermark: usize) -> Result<super::storage::Chunk<'_>, Self::Error> {
59        self.inner.read_chunk(watermark)
60    }
61
62    #[inline(always)]
63    fn partial_copy_into<Dest>(
64        &mut self,
65        dest: &mut Dest,
66    ) -> Result<super::storage::Chunk<'_>, Self::Error>
67    where
68        Dest: writer::Storage + ?Sized,
69    {
70        self.inner.partial_copy_into(dest)
71    }
72
73    #[inline(always)]
74    fn copy_into<Dest>(&mut self, dest: &mut Dest) -> Result<(), Self::Error>
75    where
76        Dest: writer::Storage + ?Sized,
77    {
78        self.inner.copy_into(dest)
79    }
80}
81
82#[cfg(debug_assertions)]
83impl<R> Storage for Checked<'_, R>
84where
85    R: Reader + ?Sized,
86{
87    type Error = R::Error;
88
89    #[inline]
90    fn buffered_len(&self) -> usize {
91        self.inner.buffered_len()
92    }
93
94    #[inline]
95    fn buffer_is_empty(&self) -> bool {
96        self.inner.buffer_is_empty()
97    }
98
99    #[inline]
100    fn read_chunk(&mut self, watermark: usize) -> Result<super::storage::Chunk<'_>, Self::Error> {
101        let snapshot = Snapshot::new(self.inner, watermark);
102
103        let mut chunk = self.inner.read_chunk(watermark)?;
104
105        // copy the returned chunk into another buffer so we can read the `inner` state
106        self.chunk.clear();
107        chunk.infallible_copy_into(&mut self.chunk);
108
109        snapshot.check(self.inner, 0, self.chunk.len());
110
111        Ok(self.chunk[..].into())
112    }
113
114    #[inline]
115    fn partial_copy_into<Dest>(
116        &mut self,
117        dest: &mut Dest,
118    ) -> Result<super::storage::Chunk<'_>, Self::Error>
119    where
120        Dest: writer::Storage + ?Sized,
121    {
122        let snapshot = Snapshot::new(self.inner, dest.remaining_capacity());
123        let mut dest = dest.track_write();
124
125        let mut chunk = self.inner.partial_copy_into(&mut dest)?;
126
127        // copy the returned chunk into another buffer so we can read the `inner` state
128        self.chunk.clear();
129        chunk.infallible_copy_into(&mut self.chunk);
130
131        snapshot.check(self.inner, dest.written_len(), self.chunk.len());
132
133        Ok(self.chunk[..].into())
134    }
135
136    #[inline]
137    fn copy_into<Dest>(&mut self, dest: &mut Dest) -> Result<(), Self::Error>
138    where
139        Dest: writer::Storage + ?Sized,
140    {
141        let snapshot = Snapshot::new(self.inner, dest.remaining_capacity());
142        let mut dest = dest.track_write();
143
144        self.inner.copy_into(&mut dest)?;
145
146        snapshot.check(self.inner, dest.written_len(), 0);
147
148        Ok(())
149    }
150}
151
152impl<R> Reader for Checked<'_, R>
153where
154    R: Reader + ?Sized,
155{
156    #[inline(always)]
157    fn current_offset(&self) -> VarInt {
158        self.inner.current_offset()
159    }
160
161    #[inline(always)]
162    fn final_offset(&self) -> Option<VarInt> {
163        self.inner.final_offset()
164    }
165
166    #[inline(always)]
167    fn has_buffered_fin(&self) -> bool {
168        self.inner.has_buffered_fin()
169    }
170
171    #[inline(always)]
172    fn is_consumed(&self) -> bool {
173        self.inner.is_consumed()
174    }
175}
176
177#[cfg(debug_assertions)]
178struct Snapshot {
179    current_offset: VarInt,
180    final_offset: Option<VarInt>,
181    buffered_len: usize,
182    dest_capacity: usize,
183}
184
185#[cfg(debug_assertions)]
186impl Snapshot {
187    #[inline]
188    fn new<R: Reader + ?Sized>(reader: &R, dest_capacity: usize) -> Self {
189        let current_offset = reader.current_offset();
190        let final_offset = reader.final_offset();
191        let buffered_len = reader.buffered_len();
192        Self {
193            current_offset,
194            final_offset,
195            buffered_len,
196            dest_capacity,
197        }
198    }
199
200    #[inline]
201    fn check<R: Reader + ?Sized>(&self, reader: &R, dest_written_len: usize, chunk_len: usize) {
202        assert!(
203            chunk_len <= self.dest_capacity,
204            "chunk exceeded destination"
205        );
206
207        let write_len = reader.current_offset() - self.current_offset;
208
209        assert_eq!(
210            dest_written_len as u64 + chunk_len as u64,
211            write_len.as_u64(),
212            "{} reader misreporting offsets",
213            core::any::type_name::<R>(),
214        );
215
216        assert!(write_len <= self.buffered_len as u64);
217
218        if self.final_offset.is_some() {
219            assert_eq!(
220                reader.final_offset(),
221                self.final_offset,
222                "{} reader changed final offset",
223                core::any::type_name::<R>(),
224            )
225        }
226    }
227}