strict_encoding/
reader.rs

1// Strict encoding library for deterministic binary serialization.
2//
3// SPDX-License-Identifier: Apache-2.0
4//
5// Written in 2019-2024 by
6//     Dr. Maxim Orlovsky <orlovsky@ubideco.org>
7//
8// Copyright 2022-2024 UBIDECO Labs
9//
10// Licensed under the Apache License, Version 2.0 (the "License");
11// you may not use this file except in compliance with the License.
12// You may obtain a copy of the License at
13//
14//     http://www.apache.org/licenses/LICENSE-2.0
15//
16// Unless required by applicable law or agreed to in writing, software
17// distributed under the License is distributed on an "AS IS" BASIS,
18// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19// See the License for the specific language governing permissions and
20// limitations under the License.
21
22use std::io;
23
24use crate::{
25    DecodeError, FieldName, ReadRaw, ReadStruct, ReadTuple, ReadUnion, StrictDecode, StrictEnum,
26    StrictStruct, StrictSum, StrictTuple, StrictUnion, TypedRead, VariantName,
27};
28
29// TODO: Move to amplify crate
30/// A simple way to count bytes read through [`io::Read`].
31#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Default, Debug)]
32pub struct ReadCounter {
33    /// Count of bytes which passed through this reader
34    pub count: usize,
35}
36
37impl io::Read for ReadCounter {
38    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
39        let count = buf.len();
40        self.count += count;
41        Ok(count)
42    }
43}
44
45// TODO: Move to amplify crate
46#[derive(Clone, Debug)]
47pub struct ConfinedReader<R: io::Read> {
48    count: usize,
49    limit: usize,
50    reader: R,
51}
52
53impl<R: io::Read> From<R> for ConfinedReader<R> {
54    fn from(reader: R) -> Self {
55        Self {
56            count: 0,
57            limit: usize::MAX,
58            reader,
59        }
60    }
61}
62
63impl<R: io::Read> ConfinedReader<R> {
64    pub fn with(limit: usize, reader: R) -> Self {
65        Self {
66            count: 0,
67            limit,
68            reader,
69        }
70    }
71
72    pub fn count(&self) -> usize { self.count }
73
74    pub fn unconfine(self) -> R { self.reader }
75}
76
77impl<R: io::Read> io::Read for ConfinedReader<R> {
78    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
79        let len = self.reader.read(buf)?;
80        match self.count.checked_add(len) {
81            None => return Err(io::ErrorKind::OutOfMemory.into()),
82            Some(len) if len > self.limit => return Err(io::ErrorKind::InvalidInput.into()),
83            Some(len) => self.count = len,
84        };
85        Ok(len)
86    }
87}
88
89#[derive(Clone, Debug)]
90pub struct StreamReader<R: io::Read>(ConfinedReader<R>);
91
92impl<R: io::Read> StreamReader<R> {
93    pub fn new<const MAX: usize>(inner: R) -> Self { Self(ConfinedReader::with(MAX, inner)) }
94    pub fn unconfine(self) -> R { self.0.unconfine() }
95}
96
97impl<T: AsRef<[u8]>> StreamReader<io::Cursor<T>> {
98    pub fn cursor<const MAX: usize>(inner: T) -> Self {
99        Self(ConfinedReader::with(MAX, io::Cursor::new(inner)))
100    }
101}
102
103impl<R: io::Read> ReadRaw for StreamReader<R> {
104    fn read_raw<const MAX_LEN: usize>(&mut self, len: usize) -> io::Result<Vec<u8>> {
105        use io::Read;
106        let mut buf = vec![0u8; len];
107        self.0.read_exact(&mut buf)?;
108        Ok(buf)
109    }
110
111    fn read_raw_array<const LEN: usize>(&mut self) -> io::Result<[u8; LEN]> {
112        use io::Read;
113        let mut buf = [0u8; LEN];
114        self.0.read_exact(&mut buf)?;
115        Ok(buf)
116    }
117}
118
119impl<T: AsRef<[u8]>> StreamReader<io::Cursor<T>> {
120    pub fn in_memory<const MAX: usize>(data: T) -> Self { Self::new::<MAX>(io::Cursor::new(data)) }
121    pub fn into_cursor(self) -> io::Cursor<T> { self.0.unconfine() }
122}
123
124impl StreamReader<ReadCounter> {
125    pub fn counter<const MAX: usize>() -> Self { Self::new::<MAX>(ReadCounter::default()) }
126}
127
128#[derive(Clone, Debug, From)]
129pub struct StrictReader<R: ReadRaw>(R);
130
131impl<T: AsRef<[u8]>> StrictReader<StreamReader<io::Cursor<T>>> {
132    pub fn in_memory<const MAX: usize>(data: T) -> Self {
133        Self(StreamReader::in_memory::<MAX>(data))
134    }
135    pub fn into_cursor(self) -> io::Cursor<T> { self.0.into_cursor() }
136}
137
138impl StrictReader<StreamReader<ReadCounter>> {
139    pub fn counter<const MAX: usize>() -> Self { Self(StreamReader::counter::<MAX>()) }
140}
141
142impl<R: ReadRaw> StrictReader<R> {
143    pub fn with(reader: R) -> Self { Self(reader) }
144
145    pub fn unbox(self) -> R { self.0 }
146}
147
148impl<R: ReadRaw> TypedRead for StrictReader<R> {
149    type TupleReader<'parent>
150        = TupleReader<'parent, R>
151    where Self: 'parent;
152    type StructReader<'parent>
153        = StructReader<'parent, R>
154    where Self: 'parent;
155    type UnionReader = Self;
156    type RawReader = R;
157
158    unsafe fn raw_reader(&mut self) -> &mut Self::RawReader { &mut self.0 }
159
160    fn read_union<T: StrictUnion>(
161        &mut self,
162        inner: impl FnOnce(VariantName, &mut Self::UnionReader) -> Result<T, DecodeError>,
163    ) -> Result<T, DecodeError> {
164        let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
165        let tag = u8::strict_decode(self)?;
166        let variant_name = T::variant_name_by_tag(tag)
167            .ok_or(DecodeError::UnionTagNotKnown(name.to_string(), tag))?;
168        inner(variant_name, self)
169    }
170
171    fn read_enum<T: StrictEnum>(&mut self) -> Result<T, DecodeError>
172    where u8: From<T> {
173        let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
174        let tag = u8::strict_decode(self)?;
175        T::try_from(tag).map_err(|_| DecodeError::EnumTagNotKnown(name.to_string(), tag))
176    }
177
178    fn read_tuple<'parent, 'me, T: StrictTuple>(
179        &'me mut self,
180        inner: impl FnOnce(&mut Self::TupleReader<'parent>) -> Result<T, DecodeError>,
181    ) -> Result<T, DecodeError>
182    where
183        Self: 'parent,
184        'me: 'parent,
185    {
186        let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
187        let mut reader = TupleReader {
188            read_fields: 0,
189            parent: self,
190        };
191        let res = inner(&mut reader)?;
192        assert_ne!(reader.read_fields, 0, "you forget to read fields for a tuple {}", name);
193        assert_eq!(
194            reader.read_fields,
195            T::FIELD_COUNT,
196            "the number of fields read for a tuple {} doesn't match type declaration",
197            name
198        );
199        Ok(res)
200    }
201
202    fn read_struct<'parent, 'me, T: StrictStruct>(
203        &'me mut self,
204        inner: impl FnOnce(&mut Self::StructReader<'parent>) -> Result<T, DecodeError>,
205    ) -> Result<T, DecodeError>
206    where
207        Self: 'parent,
208        'me: 'parent,
209    {
210        let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
211        let mut reader = StructReader {
212            named_fields: empty!(),
213            parent: self,
214        };
215        let res = inner(&mut reader)?;
216        assert!(!reader.named_fields.is_empty(), "you forget to read fields for a tuple {}", name);
217
218        for field in T::ALL_FIELDS {
219            let pos = reader
220                .named_fields
221                .iter()
222                .position(|f| f.as_str() == *field)
223                .unwrap_or_else(|| panic!("field {} is not read for {}", field, name));
224            reader.named_fields.remove(pos);
225        }
226        assert!(reader.named_fields.is_empty(), "excessive fields are read for {}", name);
227        Ok(res)
228    }
229}
230
231#[derive(Debug)]
232pub struct TupleReader<'parent, R: ReadRaw> {
233    read_fields: u8,
234    parent: &'parent mut StrictReader<R>,
235}
236
237impl<'parent, R: ReadRaw> ReadTuple for TupleReader<'parent, R> {
238    fn read_field<T: StrictDecode>(&mut self) -> Result<T, DecodeError> {
239        self.read_fields += 1;
240        T::strict_decode(self.parent)
241    }
242}
243
244#[derive(Debug)]
245pub struct StructReader<'parent, R: ReadRaw> {
246    named_fields: Vec<FieldName>,
247    parent: &'parent mut StrictReader<R>,
248}
249
250impl<'parent, R: ReadRaw> ReadStruct for StructReader<'parent, R> {
251    fn read_field<T: StrictDecode>(&mut self, field: FieldName) -> Result<T, DecodeError> {
252        self.named_fields.push(field);
253        T::strict_decode(self.parent)
254    }
255}
256
257impl<R: ReadRaw> ReadUnion for StrictReader<R> {
258    type TupleReader<'parent>
259        = TupleReader<'parent, R>
260    where Self: 'parent;
261    type StructReader<'parent>
262        = StructReader<'parent, R>
263    where Self: 'parent;
264
265    fn read_tuple<'parent, 'me, T: StrictSum>(
266        &'me mut self,
267        inner: impl FnOnce(&mut Self::TupleReader<'parent>) -> Result<T, DecodeError>,
268    ) -> Result<T, DecodeError>
269    where
270        Self: 'parent,
271        'me: 'parent,
272    {
273        let mut reader = TupleReader {
274            read_fields: 0,
275            parent: self,
276        };
277        inner(&mut reader)
278    }
279
280    fn read_struct<'parent, 'me, T: StrictSum>(
281        &'me mut self,
282        inner: impl FnOnce(&mut Self::StructReader<'parent>) -> Result<T, DecodeError>,
283    ) -> Result<T, DecodeError>
284    where
285        Self: 'parent,
286        'me: 'parent,
287    {
288        let mut reader = StructReader {
289            named_fields: empty!(),
290            parent: self,
291        };
292        inner(&mut reader)
293    }
294}