1use std::io;
23
24use crate::{
25 DecodeError, FieldName, ReadRaw, ReadStruct, ReadTuple, ReadUnion, StrictDecode, StrictEnum,
26 StrictStruct, StrictSum, StrictTuple, StrictUnion, TypedRead, VariantName,
27};
28
29#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Default, Debug)]
32pub struct ReadCounter {
33 pub count: usize,
35}
36
37impl io::Read for ReadCounter {
38 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
39 let count = buf.len();
40 self.count += count;
41 Ok(count)
42 }
43}
44
45#[derive(Clone, Debug)]
47pub struct ConfinedReader<R: io::Read> {
48 count: usize,
49 limit: usize,
50 reader: R,
51}
52
53impl<R: io::Read> From<R> for ConfinedReader<R> {
54 fn from(reader: R) -> Self {
55 Self {
56 count: 0,
57 limit: usize::MAX,
58 reader,
59 }
60 }
61}
62
63impl<R: io::Read> ConfinedReader<R> {
64 pub fn with(limit: usize, reader: R) -> Self {
65 Self {
66 count: 0,
67 limit,
68 reader,
69 }
70 }
71
72 pub fn count(&self) -> usize { self.count }
73
74 pub fn unconfine(self) -> R { self.reader }
75}
76
77impl<R: io::Read> io::Read for ConfinedReader<R> {
78 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
79 let len = self.reader.read(buf)?;
80 match self.count.checked_add(len) {
81 None => return Err(io::ErrorKind::OutOfMemory.into()),
82 Some(len) if len > self.limit => return Err(io::ErrorKind::InvalidInput.into()),
83 Some(len) => self.count = len,
84 };
85 Ok(len)
86 }
87}
88
89#[derive(Clone, Debug)]
90pub struct StreamReader<R: io::Read>(ConfinedReader<R>);
91
92impl<R: io::Read> StreamReader<R> {
93 pub fn new<const MAX: usize>(inner: R) -> Self { Self(ConfinedReader::with(MAX, inner)) }
94 pub fn unconfine(self) -> R { self.0.unconfine() }
95}
96
97impl<T: AsRef<[u8]>> StreamReader<io::Cursor<T>> {
98 pub fn cursor<const MAX: usize>(inner: T) -> Self {
99 Self(ConfinedReader::with(MAX, io::Cursor::new(inner)))
100 }
101}
102
103impl<R: io::Read> ReadRaw for StreamReader<R> {
104 fn read_raw<const MAX_LEN: usize>(&mut self, len: usize) -> io::Result<Vec<u8>> {
105 use io::Read;
106 let mut buf = vec![0u8; len];
107 self.0.read_exact(&mut buf)?;
108 Ok(buf)
109 }
110
111 fn read_raw_array<const LEN: usize>(&mut self) -> io::Result<[u8; LEN]> {
112 use io::Read;
113 let mut buf = [0u8; LEN];
114 self.0.read_exact(&mut buf)?;
115 Ok(buf)
116 }
117}
118
119impl<T: AsRef<[u8]>> StreamReader<io::Cursor<T>> {
120 pub fn in_memory<const MAX: usize>(data: T) -> Self { Self::new::<MAX>(io::Cursor::new(data)) }
121 pub fn into_cursor(self) -> io::Cursor<T> { self.0.unconfine() }
122}
123
124impl StreamReader<ReadCounter> {
125 pub fn counter<const MAX: usize>() -> Self { Self::new::<MAX>(ReadCounter::default()) }
126}
127
128#[derive(Clone, Debug, From)]
129pub struct StrictReader<R: ReadRaw>(R);
130
131impl<T: AsRef<[u8]>> StrictReader<StreamReader<io::Cursor<T>>> {
132 pub fn in_memory<const MAX: usize>(data: T) -> Self {
133 Self(StreamReader::in_memory::<MAX>(data))
134 }
135 pub fn into_cursor(self) -> io::Cursor<T> { self.0.into_cursor() }
136}
137
138impl StrictReader<StreamReader<ReadCounter>> {
139 pub fn counter<const MAX: usize>() -> Self { Self(StreamReader::counter::<MAX>()) }
140}
141
142impl<R: ReadRaw> StrictReader<R> {
143 pub fn with(reader: R) -> Self { Self(reader) }
144
145 pub fn unbox(self) -> R { self.0 }
146}
147
148impl<R: ReadRaw> TypedRead for StrictReader<R> {
149 type TupleReader<'parent>
150 = TupleReader<'parent, R>
151 where Self: 'parent;
152 type StructReader<'parent>
153 = StructReader<'parent, R>
154 where Self: 'parent;
155 type UnionReader = Self;
156 type RawReader = R;
157
158 unsafe fn raw_reader(&mut self) -> &mut Self::RawReader { &mut self.0 }
159
160 fn read_union<T: StrictUnion>(
161 &mut self,
162 inner: impl FnOnce(VariantName, &mut Self::UnionReader) -> Result<T, DecodeError>,
163 ) -> Result<T, DecodeError> {
164 let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
165 let tag = u8::strict_decode(self)?;
166 let variant_name = T::variant_name_by_tag(tag)
167 .ok_or(DecodeError::UnionTagNotKnown(name.to_string(), tag))?;
168 inner(variant_name, self)
169 }
170
171 fn read_enum<T: StrictEnum>(&mut self) -> Result<T, DecodeError>
172 where u8: From<T> {
173 let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
174 let tag = u8::strict_decode(self)?;
175 T::try_from(tag).map_err(|_| DecodeError::EnumTagNotKnown(name.to_string(), tag))
176 }
177
178 fn read_tuple<'parent, 'me, T: StrictTuple>(
179 &'me mut self,
180 inner: impl FnOnce(&mut Self::TupleReader<'parent>) -> Result<T, DecodeError>,
181 ) -> Result<T, DecodeError>
182 where
183 Self: 'parent,
184 'me: 'parent,
185 {
186 let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
187 let mut reader = TupleReader {
188 read_fields: 0,
189 parent: self,
190 };
191 let res = inner(&mut reader)?;
192 assert_ne!(reader.read_fields, 0, "you forget to read fields for a tuple {}", name);
193 assert_eq!(
194 reader.read_fields,
195 T::FIELD_COUNT,
196 "the number of fields read for a tuple {} doesn't match type declaration",
197 name
198 );
199 Ok(res)
200 }
201
202 fn read_struct<'parent, 'me, T: StrictStruct>(
203 &'me mut self,
204 inner: impl FnOnce(&mut Self::StructReader<'parent>) -> Result<T, DecodeError>,
205 ) -> Result<T, DecodeError>
206 where
207 Self: 'parent,
208 'me: 'parent,
209 {
210 let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
211 let mut reader = StructReader {
212 named_fields: empty!(),
213 parent: self,
214 };
215 let res = inner(&mut reader)?;
216 assert!(!reader.named_fields.is_empty(), "you forget to read fields for a tuple {}", name);
217
218 for field in T::ALL_FIELDS {
219 let pos = reader
220 .named_fields
221 .iter()
222 .position(|f| f.as_str() == *field)
223 .unwrap_or_else(|| panic!("field {} is not read for {}", field, name));
224 reader.named_fields.remove(pos);
225 }
226 assert!(reader.named_fields.is_empty(), "excessive fields are read for {}", name);
227 Ok(res)
228 }
229}
230
231#[derive(Debug)]
232pub struct TupleReader<'parent, R: ReadRaw> {
233 read_fields: u8,
234 parent: &'parent mut StrictReader<R>,
235}
236
237impl<'parent, R: ReadRaw> ReadTuple for TupleReader<'parent, R> {
238 fn read_field<T: StrictDecode>(&mut self) -> Result<T, DecodeError> {
239 self.read_fields += 1;
240 T::strict_decode(self.parent)
241 }
242}
243
244#[derive(Debug)]
245pub struct StructReader<'parent, R: ReadRaw> {
246 named_fields: Vec<FieldName>,
247 parent: &'parent mut StrictReader<R>,
248}
249
250impl<'parent, R: ReadRaw> ReadStruct for StructReader<'parent, R> {
251 fn read_field<T: StrictDecode>(&mut self, field: FieldName) -> Result<T, DecodeError> {
252 self.named_fields.push(field);
253 T::strict_decode(self.parent)
254 }
255}
256
257impl<R: ReadRaw> ReadUnion for StrictReader<R> {
258 type TupleReader<'parent>
259 = TupleReader<'parent, R>
260 where Self: 'parent;
261 type StructReader<'parent>
262 = StructReader<'parent, R>
263 where Self: 'parent;
264
265 fn read_tuple<'parent, 'me, T: StrictSum>(
266 &'me mut self,
267 inner: impl FnOnce(&mut Self::TupleReader<'parent>) -> Result<T, DecodeError>,
268 ) -> Result<T, DecodeError>
269 where
270 Self: 'parent,
271 'me: 'parent,
272 {
273 let mut reader = TupleReader {
274 read_fields: 0,
275 parent: self,
276 };
277 inner(&mut reader)
278 }
279
280 fn read_struct<'parent, 'me, T: StrictSum>(
281 &'me mut self,
282 inner: impl FnOnce(&mut Self::StructReader<'parent>) -> Result<T, DecodeError>,
283 ) -> Result<T, DecodeError>
284 where
285 Self: 'parent,
286 'me: 'parent,
287 {
288 let mut reader = StructReader {
289 named_fields: empty!(),
290 parent: self,
291 };
292 inner(&mut reader)
293 }
294}