1use bytes::{BufMut, BytesMut};
6
7use crate::{i256, u256};
8use core::cell::Cell;
9use core::fmt;
10use core::str::Utf8Error;
11
12#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
14pub enum DecodeError {
15 BufferLength {
17 for_type: &'static str,
18 expected: usize,
19 given: usize,
20 },
21 InvalidLen { expected: usize, given: usize },
23 InvalidTag { tag: u8, sum_name: Option<String> },
25 InvalidUtf8,
27 InvalidBool(u8),
29 Other(String),
31}
32
33impl fmt::Display for DecodeError {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 match self {
36 DecodeError::BufferLength {
37 for_type,
38 expected,
39 given,
40 } => write!(f, "data too short for {for_type}: Expected {expected}, given {given}"),
41 DecodeError::InvalidLen { expected, given } => {
42 write!(f, "unexpected data length: Expected {expected}, given {given}")
43 }
44 DecodeError::InvalidTag { tag, sum_name } => {
45 write!(
46 f,
47 "unknown tag {tag:#x} for sum type {}",
48 sum_name.as_deref().unwrap_or("<unknown>")
49 )
50 }
51 DecodeError::InvalidUtf8 => f.write_str("invalid utf8"),
52 DecodeError::InvalidBool(byte) => write!(f, "byte {byte} not valid as `bool` (must be 0 or 1)"),
53 DecodeError::Other(err) => f.write_str(err),
54 }
55 }
56}
57impl From<DecodeError> for String {
58 fn from(err: DecodeError) -> Self {
59 err.to_string()
60 }
61}
62impl std::error::Error for DecodeError {}
63
64impl From<Utf8Error> for DecodeError {
65 fn from(_: Utf8Error) -> Self {
66 DecodeError::InvalidUtf8
67 }
68}
69
70pub trait BufWriter {
72 fn put_slice(&mut self, slice: &[u8]);
77
78 fn put_u8(&mut self, val: u8) {
80 self.put_slice(&val.to_le_bytes())
81 }
82
83 fn put_u16(&mut self, val: u16) {
85 self.put_slice(&val.to_le_bytes())
86 }
87
88 fn put_u32(&mut self, val: u32) {
90 self.put_slice(&val.to_le_bytes())
91 }
92
93 fn put_u64(&mut self, val: u64) {
95 self.put_slice(&val.to_le_bytes())
96 }
97
98 fn put_u128(&mut self, val: u128) {
100 self.put_slice(&val.to_le_bytes())
101 }
102
103 fn put_u256(&mut self, val: u256) {
105 self.put_slice(&val.to_le_bytes())
106 }
107
108 fn put_i8(&mut self, val: i8) {
110 self.put_slice(&val.to_le_bytes())
111 }
112
113 fn put_i16(&mut self, val: i16) {
115 self.put_slice(&val.to_le_bytes())
116 }
117
118 fn put_i32(&mut self, val: i32) {
120 self.put_slice(&val.to_le_bytes())
121 }
122
123 fn put_i64(&mut self, val: i64) {
125 self.put_slice(&val.to_le_bytes())
126 }
127
128 fn put_i128(&mut self, val: i128) {
130 self.put_slice(&val.to_le_bytes())
131 }
132
133 fn put_i256(&mut self, val: i256) {
135 self.put_slice(&val.to_le_bytes())
136 }
137}
138
139macro_rules! get_int {
140 ($self:ident, $int:ident) => {
141 match $self.get_array_chunk() {
142 Some(&arr) => Ok($int::from_le_bytes(arr)),
143 None => Err(DecodeError::BufferLength {
144 for_type: stringify!($int),
145 expected: std::mem::size_of::<$int>(),
146 given: $self.remaining(),
147 }),
148 }
149 };
150}
151
152pub trait BufReader<'de> {
156 fn get_chunk(&mut self, size: usize) -> Option<&'de [u8]>;
158
159 fn remaining(&self) -> usize;
161
162 #[inline]
164 fn get_array_chunk<const N: usize>(&mut self) -> Option<&'de [u8; N]> {
165 self.get_chunk(N)?.try_into().ok()
166 }
167
168 #[inline]
170 fn get_slice(&mut self, size: usize) -> Result<&'de [u8], DecodeError> {
171 self.get_chunk(size).ok_or_else(|| DecodeError::BufferLength {
172 for_type: "[u8]",
173 expected: size,
174 given: self.remaining(),
175 })
176 }
177
178 #[inline]
180 fn get_array<const N: usize>(&mut self) -> Result<&'de [u8; N], DecodeError> {
181 self.get_array_chunk().ok_or_else(|| DecodeError::BufferLength {
182 for_type: "[u8; _]",
183 expected: N,
184 given: self.remaining(),
185 })
186 }
187
188 #[inline]
193 fn get_u8(&mut self) -> Result<u8, DecodeError> {
194 get_int!(self, u8)
195 }
196
197 #[inline]
202 fn get_u16(&mut self) -> Result<u16, DecodeError> {
203 get_int!(self, u16)
204 }
205
206 #[inline]
211 fn get_u32(&mut self) -> Result<u32, DecodeError> {
212 get_int!(self, u32)
213 }
214
215 #[inline]
220 fn get_u64(&mut self) -> Result<u64, DecodeError> {
221 get_int!(self, u64)
222 }
223
224 #[inline]
229 fn get_u128(&mut self) -> Result<u128, DecodeError> {
230 get_int!(self, u128)
231 }
232
233 #[inline]
238 fn get_u256(&mut self) -> Result<u256, DecodeError> {
239 get_int!(self, u256)
240 }
241
242 #[inline]
247 fn get_i8(&mut self) -> Result<i8, DecodeError> {
248 get_int!(self, i8)
249 }
250
251 #[inline]
256 fn get_i16(&mut self) -> Result<i16, DecodeError> {
257 get_int!(self, i16)
258 }
259
260 #[inline]
265 fn get_i32(&mut self) -> Result<i32, DecodeError> {
266 get_int!(self, i32)
267 }
268
269 #[inline]
274 fn get_i64(&mut self) -> Result<i64, DecodeError> {
275 get_int!(self, i64)
276 }
277
278 #[inline]
283 fn get_i128(&mut self) -> Result<i128, DecodeError> {
284 get_int!(self, i128)
285 }
286
287 #[inline]
292 fn get_i256(&mut self) -> Result<i256, DecodeError> {
293 get_int!(self, i256)
294 }
295}
296
297impl BufWriter for Vec<u8> {
298 fn put_slice(&mut self, slice: &[u8]) {
299 self.extend_from_slice(slice);
300 }
301}
302
303impl BufWriter for &mut [u8] {
304 fn put_slice(&mut self, slice: &[u8]) {
305 if self.len() < slice.len() {
306 panic!("not enough buffer space")
307 }
308 let (buf, rest) = std::mem::take(self).split_at_mut(slice.len());
309 buf.copy_from_slice(slice);
310 *self = rest;
311 }
312}
313
314impl BufWriter for BytesMut {
315 fn put_slice(&mut self, slice: &[u8]) {
316 BufMut::put_slice(self, slice);
317 }
318}
319
320#[derive(Default)]
322pub struct CountWriter {
323 num_bytes: usize,
325}
326
327impl CountWriter {
328 pub fn finish(self) -> usize {
330 self.num_bytes
331 }
332}
333
334impl BufWriter for CountWriter {
335 fn put_slice(&mut self, slice: &[u8]) {
336 self.num_bytes += slice.len();
337 }
338}
339
340pub struct TeeWriter<W1, W2> {
342 pub w1: W1,
343 pub w2: W2,
344}
345
346impl<W1: BufWriter, W2: BufWriter> TeeWriter<W1, W2> {
347 pub fn new(w1: W1, w2: W2) -> Self {
348 Self { w1, w2 }
349 }
350}
351
352impl<W1: BufWriter, W2: BufWriter> BufWriter for TeeWriter<W1, W2> {
353 fn put_slice(&mut self, slice: &[u8]) {
354 self.w1.put_slice(slice);
355 self.w2.put_slice(slice);
356 }
357}
358
359impl<'de> BufReader<'de> for &'de [u8] {
360 #[inline]
361 fn get_chunk(&mut self, size: usize) -> Option<&'de [u8]> {
362 let (ret, rest) = self.split_at_checked(size)?;
363 *self = rest;
364 Some(ret)
365 }
366
367 #[inline]
368 fn get_array_chunk<const N: usize>(&mut self) -> Option<&'de [u8; N]> {
369 let (ret, rest) = self.split_first_chunk()?;
370 *self = rest;
371 Some(ret)
372 }
373
374 #[inline(always)]
375 fn remaining(&self) -> usize {
376 self.len()
377 }
378}
379
380#[derive(Debug)]
382pub struct Cursor<I> {
383 pub buf: I,
385 pub pos: Cell<usize>,
387}
388
389impl<I> Cursor<I> {
390 pub fn new(buf: I) -> Self {
394 Self { buf, pos: 0.into() }
395 }
396}
397
398impl<'de, I: AsRef<[u8]>> BufReader<'de> for &'de Cursor<I> {
399 #[inline]
400 fn get_chunk(&mut self, size: usize) -> Option<&'de [u8]> {
401 let buf = &self.buf.as_ref()[self.pos.get()..];
403 let ret = buf.get(..size)?;
404
405 self.pos.set(self.pos.get() + size);
407
408 Some(ret)
409 }
410
411 #[inline]
412 fn get_array_chunk<const N: usize>(&mut self) -> Option<&'de [u8; N]> {
413 let buf = &self.buf.as_ref()[self.pos.get()..];
415 let ret = buf.first_chunk()?;
416
417 self.pos.set(self.pos.get() + N);
419
420 Some(ret)
421 }
422
423 fn remaining(&self) -> usize {
424 self.buf.as_ref().len() - self.pos.get()
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use crate::buffer::{BufReader, BufWriter};
431
432 #[test]
433 fn test_simple_encode_decode() {
434 let mut writer: Vec<u8> = vec![];
435 writer.put_u8(5);
436 writer.put_u32(6);
437 writer.put_u64(7);
438
439 let arr_val = [1, 2, 3, 4];
440 writer.put_slice(&arr_val[..]);
441
442 let mut reader = writer.as_slice();
443 assert_eq!(reader.get_u8().unwrap(), 5);
444 assert_eq!(reader.get_u32().unwrap(), 6);
445 assert_eq!(reader.get_u64().unwrap(), 7);
446
447 let slice = reader.get_slice(4).unwrap();
448 assert_eq!(slice, arr_val);
449
450 assert!(reader.get_slice(1).is_err());
452 assert!(reader.get_slice(123).is_err());
453 assert!(reader.get_u64().is_err());
454 assert!(reader.get_u32().is_err());
455 assert!(reader.get_u8().is_err());
456 }
457}