spacetimedb_sats/
buffer.rs

1//! Minimal utility for reading & writing the types we need to internal storage,
2//! without relying on types in third party libraries like `bytes::Bytes`, etc.
3//! Meant to be kept slim and trim for use across both native and WASM.
4
5use bytes::{BufMut, BytesMut};
6
7use crate::{i256, u256};
8use core::cell::Cell;
9use core::fmt;
10use core::str::Utf8Error;
11
12/// An error that occurred when decoding.
13#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
14pub enum DecodeError {
15    /// Not enough data was provided in the input.
16    BufferLength {
17        for_type: &'static str,
18        expected: usize,
19        given: usize,
20    },
21    /// Length did not match the statically expected length.
22    InvalidLen { expected: usize, given: usize },
23    /// The tag does not exist for the sum.
24    InvalidTag { tag: u8, sum_name: Option<String> },
25    /// Expected data to be UTF-8 but it wasn't.
26    InvalidUtf8,
27    /// Expected the byte to be 0 or 1 to be a valid bool.
28    InvalidBool(u8),
29    /// Custom error not in the other variants of `DecodeError`.
30    Other(String),
31}
32
33impl fmt::Display for DecodeError {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        match self {
36            DecodeError::BufferLength {
37                for_type,
38                expected,
39                given,
40            } => write!(f, "data too short for {for_type}: Expected {expected}, given {given}"),
41            DecodeError::InvalidLen { expected, given } => {
42                write!(f, "unexpected data length: Expected {expected}, given {given}")
43            }
44            DecodeError::InvalidTag { tag, sum_name } => {
45                write!(
46                    f,
47                    "unknown tag {tag:#x} for sum type {}",
48                    sum_name.as_deref().unwrap_or("<unknown>")
49                )
50            }
51            DecodeError::InvalidUtf8 => f.write_str("invalid utf8"),
52            DecodeError::InvalidBool(byte) => write!(f, "byte {byte} not valid as `bool` (must be 0 or 1)"),
53            DecodeError::Other(err) => f.write_str(err),
54        }
55    }
56}
57impl From<DecodeError> for String {
58    fn from(err: DecodeError) -> Self {
59        err.to_string()
60    }
61}
62impl std::error::Error for DecodeError {}
63
64impl From<Utf8Error> for DecodeError {
65    fn from(_: Utf8Error) -> Self {
66        DecodeError::InvalidUtf8
67    }
68}
69
70/// A buffered writer of some kind.
71pub trait BufWriter {
72    /// Writes the `slice` to the buffer.
73    ///
74    /// This is the only method implementations are required to define.
75    /// All other methods are provided.
76    fn put_slice(&mut self, slice: &[u8]);
77
78    /// Writes a `u8` to the buffer in little-endian (LE) encoding.
79    fn put_u8(&mut self, val: u8) {
80        self.put_slice(&val.to_le_bytes())
81    }
82
83    /// Writes a `u16` to the buffer in little-endian (LE) encoding.
84    fn put_u16(&mut self, val: u16) {
85        self.put_slice(&val.to_le_bytes())
86    }
87
88    /// Writes a `u32` to the buffer in little-endian (LE) encoding.
89    fn put_u32(&mut self, val: u32) {
90        self.put_slice(&val.to_le_bytes())
91    }
92
93    /// Writes a `u64` to the buffer in little-endian (LE) encoding.
94    fn put_u64(&mut self, val: u64) {
95        self.put_slice(&val.to_le_bytes())
96    }
97
98    /// Writes a `u128` to the buffer in little-endian (LE) encoding.
99    fn put_u128(&mut self, val: u128) {
100        self.put_slice(&val.to_le_bytes())
101    }
102
103    /// Writes a `u256` to the buffer in little-endian (LE) encoding.
104    fn put_u256(&mut self, val: u256) {
105        self.put_slice(&val.to_le_bytes())
106    }
107
108    /// Writes an `i8` to the buffer in little-endian (LE) encoding.
109    fn put_i8(&mut self, val: i8) {
110        self.put_slice(&val.to_le_bytes())
111    }
112
113    /// Writes an `i16` to the buffer in little-endian (LE) encoding.
114    fn put_i16(&mut self, val: i16) {
115        self.put_slice(&val.to_le_bytes())
116    }
117
118    /// Writes an `i32` to the buffer in little-endian (LE) encoding.
119    fn put_i32(&mut self, val: i32) {
120        self.put_slice(&val.to_le_bytes())
121    }
122
123    /// Writes an `i64` to the buffer in little-endian (LE) encoding.
124    fn put_i64(&mut self, val: i64) {
125        self.put_slice(&val.to_le_bytes())
126    }
127
128    /// Writes an `i128` to the buffer in little-endian (LE) encoding.
129    fn put_i128(&mut self, val: i128) {
130        self.put_slice(&val.to_le_bytes())
131    }
132
133    /// Writes an `i256` to the buffer in little-endian (LE) encoding.
134    fn put_i256(&mut self, val: i256) {
135        self.put_slice(&val.to_le_bytes())
136    }
137}
138
139macro_rules! get_int {
140    ($self:ident, $int:ident) => {
141        match $self.get_array_chunk() {
142            Some(&arr) => Ok($int::from_le_bytes(arr)),
143            None => Err(DecodeError::BufferLength {
144                for_type: stringify!($int),
145                expected: std::mem::size_of::<$int>(),
146                given: $self.remaining(),
147            }),
148        }
149    };
150}
151
152/// A buffered reader of some kind.
153///
154/// The lifetime `'de` allows the output of deserialization to borrow from the input.
155pub trait BufReader<'de> {
156    /// Reads and returns a chunk of `.len() = size` advancing the cursor iff `self.remaining() >= size`.
157    fn get_chunk(&mut self, size: usize) -> Option<&'de [u8]>;
158
159    /// Returns the number of bytes left to read in the input.
160    fn remaining(&self) -> usize;
161
162    /// Reads and returns a chunk of `.len() = N` as an array, advancing the cursor.
163    #[inline]
164    fn get_array_chunk<const N: usize>(&mut self) -> Option<&'de [u8; N]> {
165        self.get_chunk(N)?.try_into().ok()
166    }
167
168    /// Reads and returns a byte slice of `.len() = size` advancing the cursor.
169    #[inline]
170    fn get_slice(&mut self, size: usize) -> Result<&'de [u8], DecodeError> {
171        self.get_chunk(size).ok_or_else(|| DecodeError::BufferLength {
172            for_type: "[u8]",
173            expected: size,
174            given: self.remaining(),
175        })
176    }
177
178    /// Reads an array of type `[u8; N]` from the input.
179    #[inline]
180    fn get_array<const N: usize>(&mut self) -> Result<&'de [u8; N], DecodeError> {
181        self.get_array_chunk().ok_or_else(|| DecodeError::BufferLength {
182            for_type: "[u8; _]",
183            expected: N,
184            given: self.remaining(),
185        })
186    }
187
188    /// Reads a `u8` in little endian (LE) encoding from the input.
189    ///
190    /// This method is provided for convenience
191    /// and is derived from [`get_chunk`](BufReader::get_chunk)'s definition.
192    #[inline]
193    fn get_u8(&mut self) -> Result<u8, DecodeError> {
194        get_int!(self, u8)
195    }
196
197    /// Reads a `u16` in little endian (LE) encoding from the input.
198    ///
199    /// This method is provided for convenience
200    /// and is derived from [`get_chunk`](BufReader::get_chunk)'s definition.
201    #[inline]
202    fn get_u16(&mut self) -> Result<u16, DecodeError> {
203        get_int!(self, u16)
204    }
205
206    /// Reads a `u32` in little endian (LE) encoding from the input.
207    ///
208    /// This method is provided for convenience
209    /// and is derived from [`get_chunk`](BufReader::get_chunk)'s definition.
210    #[inline]
211    fn get_u32(&mut self) -> Result<u32, DecodeError> {
212        get_int!(self, u32)
213    }
214
215    /// Reads a `u64` in little endian (LE) encoding from the input.
216    ///
217    /// This method is provided for convenience
218    /// and is derived from [`get_chunk`](BufReader::get_chunk)'s definition.
219    #[inline]
220    fn get_u64(&mut self) -> Result<u64, DecodeError> {
221        get_int!(self, u64)
222    }
223
224    /// Reads a `u128` in little endian (LE) encoding from the input.
225    ///
226    /// This method is provided for convenience
227    /// and is derived from [`get_chunk`](BufReader::get_chunk)'s definition.
228    #[inline]
229    fn get_u128(&mut self) -> Result<u128, DecodeError> {
230        get_int!(self, u128)
231    }
232
233    /// Reads a `u256` in little endian (LE) encoding from the input.
234    ///
235    /// This method is provided for convenience
236    /// and is derived from [`get_chunk`](BufReader::get_chunk)'s definition.
237    #[inline]
238    fn get_u256(&mut self) -> Result<u256, DecodeError> {
239        get_int!(self, u256)
240    }
241
242    /// Reads an `i8` in little endian (LE) encoding from the input.
243    ///
244    /// This method is provided for convenience
245    /// and is derived from [`get_chunk`](BufReader::get_chunk)'s definition.
246    #[inline]
247    fn get_i8(&mut self) -> Result<i8, DecodeError> {
248        get_int!(self, i8)
249    }
250
251    /// Reads an `i16` in little endian (LE) encoding from the input.
252    ///
253    /// This method is provided for convenience
254    /// and is derived from [`get_chunk`](BufReader::get_chunk)'s definition.
255    #[inline]
256    fn get_i16(&mut self) -> Result<i16, DecodeError> {
257        get_int!(self, i16)
258    }
259
260    /// Reads an `i32` in little endian (LE) encoding from the input.
261    ///
262    /// This method is provided for convenience
263    /// and is derived from [`get_chunk`](BufReader::get_chunk)'s definition.
264    #[inline]
265    fn get_i32(&mut self) -> Result<i32, DecodeError> {
266        get_int!(self, i32)
267    }
268
269    /// Reads an `i64` in little endian (LE) encoding from the input.
270    ///
271    /// This method is provided for convenience
272    /// and is derived from [`get_chunk`](BufReader::get_chunk)'s definition.
273    #[inline]
274    fn get_i64(&mut self) -> Result<i64, DecodeError> {
275        get_int!(self, i64)
276    }
277
278    /// Reads an `i128` in little endian (LE) encoding from the input.
279    ///
280    /// This method is provided for convenience
281    /// and is derived from [`get_chunk`](BufReader::get_chunk)'s definition.
282    #[inline]
283    fn get_i128(&mut self) -> Result<i128, DecodeError> {
284        get_int!(self, i128)
285    }
286
287    /// Reads an `i256` in little endian (LE) encoding from the input.
288    ///
289    /// This method is provided for convenience
290    /// and is derived from [`get_chunk`](BufReader::get_chunk)'s definition.
291    #[inline]
292    fn get_i256(&mut self) -> Result<i256, DecodeError> {
293        get_int!(self, i256)
294    }
295}
296
297impl BufWriter for Vec<u8> {
298    fn put_slice(&mut self, slice: &[u8]) {
299        self.extend_from_slice(slice);
300    }
301}
302
303impl BufWriter for &mut [u8] {
304    fn put_slice(&mut self, slice: &[u8]) {
305        if self.len() < slice.len() {
306            panic!("not enough buffer space")
307        }
308        let (buf, rest) = std::mem::take(self).split_at_mut(slice.len());
309        buf.copy_from_slice(slice);
310        *self = rest;
311    }
312}
313
314impl BufWriter for BytesMut {
315    fn put_slice(&mut self, slice: &[u8]) {
316        BufMut::put_slice(self, slice);
317    }
318}
319
320/// A [`BufWriter`] that only counts the bytes.
321#[derive(Default)]
322pub struct CountWriter {
323    /// The number of bytes counted thus far.
324    num_bytes: usize,
325}
326
327impl CountWriter {
328    /// Consumes the counter and returns the final count.
329    pub fn finish(self) -> usize {
330        self.num_bytes
331    }
332}
333
334impl BufWriter for CountWriter {
335    fn put_slice(&mut self, slice: &[u8]) {
336        self.num_bytes += slice.len();
337    }
338}
339
340/// A [`BufWriter`] that writes the bytes to two writers `W1` and `W2`.
341pub struct TeeWriter<W1, W2> {
342    pub w1: W1,
343    pub w2: W2,
344}
345
346impl<W1: BufWriter, W2: BufWriter> TeeWriter<W1, W2> {
347    pub fn new(w1: W1, w2: W2) -> Self {
348        Self { w1, w2 }
349    }
350}
351
352impl<W1: BufWriter, W2: BufWriter> BufWriter for TeeWriter<W1, W2> {
353    fn put_slice(&mut self, slice: &[u8]) {
354        self.w1.put_slice(slice);
355        self.w2.put_slice(slice);
356    }
357}
358
359impl<'de> BufReader<'de> for &'de [u8] {
360    #[inline]
361    fn get_chunk(&mut self, size: usize) -> Option<&'de [u8]> {
362        let (ret, rest) = self.split_at_checked(size)?;
363        *self = rest;
364        Some(ret)
365    }
366
367    #[inline]
368    fn get_array_chunk<const N: usize>(&mut self) -> Option<&'de [u8; N]> {
369        let (ret, rest) = self.split_first_chunk()?;
370        *self = rest;
371        Some(ret)
372    }
373
374    #[inline(always)]
375    fn remaining(&self) -> usize {
376        self.len()
377    }
378}
379
380/// A cursor based [`BufReader<'de>`] implementation.
381#[derive(Debug)]
382pub struct Cursor<I> {
383    /// The underlying input read from.
384    pub buf: I,
385    /// The position within the reader.
386    pub pos: Cell<usize>,
387}
388
389impl<I> Cursor<I> {
390    /// Returns a new cursor on the provided `buf` input.
391    ///
392    /// The cursor starts at the beginning of `buf`.
393    pub fn new(buf: I) -> Self {
394        Self { buf, pos: 0.into() }
395    }
396}
397
398impl<'de, I: AsRef<[u8]>> BufReader<'de> for &'de Cursor<I> {
399    #[inline]
400    fn get_chunk(&mut self, size: usize) -> Option<&'de [u8]> {
401        // "Read" the slice `buf[pos..size]`.
402        let buf = &self.buf.as_ref()[self.pos.get()..];
403        let ret = buf.get(..size)?;
404
405        // Advance the cursor by `size` bytes.
406        self.pos.set(self.pos.get() + size);
407
408        Some(ret)
409    }
410
411    #[inline]
412    fn get_array_chunk<const N: usize>(&mut self) -> Option<&'de [u8; N]> {
413        // "Read" the slice `buf[pos..size]`.
414        let buf = &self.buf.as_ref()[self.pos.get()..];
415        let ret = buf.first_chunk()?;
416
417        // Advance the cursor by `size` bytes.
418        self.pos.set(self.pos.get() + N);
419
420        Some(ret)
421    }
422
423    fn remaining(&self) -> usize {
424        self.buf.as_ref().len() - self.pos.get()
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use crate::buffer::{BufReader, BufWriter};
431
432    #[test]
433    fn test_simple_encode_decode() {
434        let mut writer: Vec<u8> = vec![];
435        writer.put_u8(5);
436        writer.put_u32(6);
437        writer.put_u64(7);
438
439        let arr_val = [1, 2, 3, 4];
440        writer.put_slice(&arr_val[..]);
441
442        let mut reader = writer.as_slice();
443        assert_eq!(reader.get_u8().unwrap(), 5);
444        assert_eq!(reader.get_u32().unwrap(), 6);
445        assert_eq!(reader.get_u64().unwrap(), 7);
446
447        let slice = reader.get_slice(4).unwrap();
448        assert_eq!(slice, arr_val);
449
450        // reading beyond buffer end should fail
451        assert!(reader.get_slice(1).is_err());
452        assert!(reader.get_slice(123).is_err());
453        assert!(reader.get_u64().is_err());
454        assert!(reader.get_u32().is_err());
455        assert!(reader.get_u8().is_err());
456    }
457}