1use std::{
17 cell::Cell,
18 fmt,
19 io::{Read, Result as IoResult, Write},
20 marker::PhantomData,
21 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
22};
23
24use serde::{
25 Deserializer,
26 Serializer,
27 de::{self, Error, SeqAccess, Visitor},
28 ser::{self, SerializeTuple},
29};
30use smol_str::SmolStr;
31
32thread_local! {
33 static UNCHECKED_DESERIALIZE: Cell<bool> = const { Cell::new(false) };
34}
35
36#[inline(always)]
40pub fn unchecked_deserialize<T: de::DeserializeOwned>(data: &[u8]) -> Result<T, bincode::Error> {
41 UNCHECKED_DESERIALIZE.set(true);
42 let result = bincode::deserialize(data);
43 UNCHECKED_DESERIALIZE.set(false);
44 result
45}
46
47#[macro_export]
50macro_rules! to_bytes_le {
51 ($($x:expr),*) => ({
52 let mut buffer = vec![];
53 buffer.reserve(64);
54 {$crate::push_bytes_to_vec!(buffer, $($x),*)}.map(|_| buffer)
55 });
56}
57
58#[macro_export]
59macro_rules! push_bytes_to_vec {
60 ($buffer:expr, $y:expr, $($x:expr),*) => ({
61 {ToBytes::write_le(&$y, &mut $buffer)}.and({$crate::push_bytes_to_vec!($buffer, $($x),*)})
62 });
63
64 ($buffer:expr, $x:expr) => ({
65 ToBytes::write_le(&$x, &mut $buffer)
66 })
67}
68
69pub trait ToBytes {
70 fn write_le<W: Write>(&self, writer: W) -> IoResult<()>
72 where
73 Self: Sized;
74
75 fn to_bytes_le(&self) -> anyhow::Result<Vec<u8>>
77 where
78 Self: Sized,
79 {
80 Ok(to_bytes_le![self]?)
81 }
82}
83
84pub trait FromBytes {
85 fn read_le<R: Read>(reader: R) -> IoResult<Self>
87 where
88 Self: Sized;
89
90 fn from_bytes_le(bytes: &[u8]) -> anyhow::Result<Self>
92 where
93 Self: Sized,
94 {
95 Ok(Self::read_le(bytes)?)
96 }
97
98 fn from_bytes_le_unchecked(bytes: &[u8]) -> anyhow::Result<Self>
103 where
104 Self: Sized,
105 {
106 Ok(Self::read_le_unchecked(bytes)?)
107 }
108
109 fn read_le_unchecked<R: Read>(reader: R) -> IoResult<Self>
114 where
115 Self: Sized,
116 {
117 Self::read_le(reader)
118 }
119
120 fn read_le_with_unchecked<R: Read>(reader: R, unchecked: bool) -> IoResult<Self>
122 where
123 Self: Sized,
124 {
125 if unchecked { Self::read_le_unchecked(reader) } else { Self::read_le(reader) }
126 }
127}
128
129pub struct ToBytesSerializer<T: ToBytes>(PhantomData<T>);
131
132impl<T: ToBytes> ToBytesSerializer<T> {
133 pub fn serialize<S: Serializer>(object: &T, serializer: S) -> Result<S::Ok, S::Error> {
135 let bytes = object.to_bytes_le().map_err(ser::Error::custom)?;
136 let mut tuple = serializer.serialize_tuple(bytes.len())?;
137 for byte in &bytes {
138 tuple.serialize_element(byte)?;
139 }
140 tuple.end()
141 }
142
143 pub fn serialize_with_size_encoding<S: Serializer>(object: &T, serializer: S) -> Result<S::Ok, S::Error> {
145 let bytes = object.to_bytes_le().map_err(ser::Error::custom)?;
146 serializer.serialize_bytes(&bytes)
147 }
148}
149
150pub struct FromBytesDeserializer<T: FromBytes>(PhantomData<T>);
151
152impl<'de, T: FromBytes> FromBytesDeserializer<T> {
153 pub fn deserialize<D: Deserializer<'de>>(deserializer: D, name: &str, size: usize) -> Result<T, D::Error> {
157 let mut buffer = Vec::with_capacity(size);
158 deserializer.deserialize_tuple(size, FromBytesVisitor::new(&mut buffer, name))?;
159 FromBytes::read_le(&*buffer).map_err(de::Error::custom)
160 }
161
162 pub fn deserialize_with_u8<D: Deserializer<'de>>(deserializer: D, name: &str) -> Result<T, D::Error> {
164 deserializer.deserialize_tuple(1usize << 8usize, FromBytesWithU8Visitor::<T>::new(name))
165 }
166
167 pub fn deserialize_with_u16<D: Deserializer<'de>>(deserializer: D, name: &str) -> Result<T, D::Error> {
169 deserializer.deserialize_tuple(1usize << 16usize, FromBytesWithU16Visitor::<T>::new(name))
170 }
171
172 pub fn deserialize_with_size_encoding<D: Deserializer<'de>>(deserializer: D, name: &str) -> Result<T, D::Error> {
174 let mut buffer = Vec::with_capacity(32);
175 deserializer.deserialize_bytes(FromBytesVisitor::new(&mut buffer, name))?;
176 FromBytes::read_le(&*buffer).map_err(de::Error::custom)
177 }
178
179 pub fn deserialize_extended<D: Deserializer<'de>>(
184 deserializer: D,
185 name: &str,
186 size_a: usize,
187 size_b: usize,
188 ) -> Result<T, D::Error> {
189 let (size_a, size_b) = match size_a < size_b {
191 true => (size_a, size_b),
192 false => (size_b, size_a),
193 };
194
195 if size_b > i32::MAX as usize {
197 return Err(D::Error::custom(format!("size_b ({size_b}) exceeds maximum")));
198 }
199
200 let mut buffer = Vec::with_capacity(size_b);
202
203 match deserializer.deserialize_tuple(size_b, FromBytesVisitor::new(&mut buffer, name)) {
205 Ok(()) => FromBytes::read_le(&buffer[..size_b]).map_err(de::Error::custom),
207 Err(error) => match buffer.len() == size_a {
209 true => FromBytes::read_le(&buffer[..size_a]).map_err(de::Error::custom),
210 false => Err(error),
211 },
212 }
213 }
214}
215
216pub struct FromBytesUncheckedDeserializer<T: FromBytes>(PhantomData<T>);
217
218impl<'de, T: FromBytes> FromBytesUncheckedDeserializer<T> {
219 pub fn deserialize_with_size_encoding<D: Deserializer<'de>>(deserializer: D, name: &str) -> Result<T, D::Error> {
221 let mut buffer = Vec::with_capacity(32);
222 deserializer.deserialize_bytes(FromBytesVisitor::new(&mut buffer, name))?;
223
224 if UNCHECKED_DESERIALIZE.get() {
225 FromBytes::read_le_unchecked(&*buffer).map_err(de::Error::custom)
226 } else {
227 FromBytes::read_le(&*buffer).map_err(de::Error::custom)
228 }
229 }
230}
231
232pub struct FromBytesVisitor<'a>(&'a mut Vec<u8>, SmolStr);
233
234impl<'a> FromBytesVisitor<'a> {
235 pub fn new(buffer: &'a mut Vec<u8>, name: &str) -> Self {
237 Self(buffer, SmolStr::new(name))
238 }
239}
240
241impl<'de> Visitor<'de> for FromBytesVisitor<'_> {
242 type Value = ();
243
244 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
245 formatter.write_str(&format!("a valid {} ", self.1))
246 }
247
248 fn visit_borrowed_bytes<E: serde::de::Error>(self, bytes: &'de [u8]) -> Result<Self::Value, E> {
249 self.0.extend_from_slice(bytes);
250 Ok(())
251 }
252
253 fn visit_seq<S: SeqAccess<'de>>(self, mut seq: S) -> Result<Self::Value, S::Error> {
254 while let Some(byte) = seq.next_element()? {
255 self.0.push(byte);
256 }
257 Ok(())
258 }
259}
260
261struct FromBytesWithU8Visitor<T: FromBytes>(String, PhantomData<T>);
262
263impl<T: FromBytes> FromBytesWithU8Visitor<T> {
264 pub fn new(name: &str) -> Self {
266 Self(name.to_string(), PhantomData)
267 }
268}
269
270impl<'de, T: FromBytes> Visitor<'de> for FromBytesWithU8Visitor<T> {
271 type Value = T;
272
273 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
274 formatter.write_str(&format!("a valid {} ", self.0))
275 }
276
277 fn visit_seq<V: SeqAccess<'de>>(self, mut seq: V) -> Result<Self::Value, V::Error> {
278 let length: u8 = seq.next_element()?.ok_or_else(|| Error::invalid_length(0, &self))?;
280
281 let mut bytes: Vec<u8> = Vec::with_capacity((length as usize) + 1);
283 bytes.push(length);
285 for i in 0..length {
287 bytes.push(seq.next_element()?.ok_or_else(|| Error::invalid_length((i as usize) + 1, &self))?);
289 }
290 FromBytes::read_le(&*bytes).map_err(de::Error::custom)
292 }
293}
294
295struct FromBytesWithU16Visitor<T: FromBytes>(String, PhantomData<T>);
296
297impl<T: FromBytes> FromBytesWithU16Visitor<T> {
298 pub fn new(name: &str) -> Self {
300 Self(name.to_string(), PhantomData)
301 }
302}
303
304impl<'de, T: FromBytes> Visitor<'de> for FromBytesWithU16Visitor<T> {
305 type Value = T;
306
307 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
308 formatter.write_str(&format!("a valid {} ", self.0))
309 }
310
311 fn visit_seq<V: SeqAccess<'de>>(self, mut seq: V) -> Result<Self::Value, V::Error> {
312 let length: u16 = seq.next_element()?.ok_or_else(|| Error::invalid_length(0, &self))?;
314
315 let mut bytes: Vec<u8> = Vec::with_capacity((length as usize) + 2);
317 bytes.extend(length.to_le_bytes());
319 for i in 0..length {
321 bytes.push(seq.next_element()?.ok_or_else(|| Error::invalid_length((i as usize) + 2, &self))?);
323 }
324 FromBytes::read_le(&*bytes).map_err(de::Error::custom)
326 }
327}
328
329impl ToBytes for () {
330 #[inline]
331 fn write_le<W: Write>(&self, _writer: W) -> IoResult<()> {
332 Ok(())
333 }
334}
335
336impl FromBytes for () {
337 #[inline]
338 fn read_le<R: Read>(_bytes: R) -> IoResult<Self> {
339 Ok(())
340 }
341}
342
343impl ToBytes for bool {
344 #[inline]
345 fn write_le<W: Write>(&self, writer: W) -> IoResult<()> {
346 u8::write_le(&(*self as u8), writer)
347 }
348}
349
350impl FromBytes for bool {
351 #[inline]
352 fn read_le<R: Read>(reader: R) -> IoResult<Self> {
353 match u8::read_le(reader) {
354 Ok(0) => Ok(false),
355 Ok(1) => Ok(true),
356 Ok(_) => Err(std::io::Error::other("FromBytes::read failed")),
357 Err(err) => Err(err),
358 }
359 }
360}
361
362impl ToBytes for SocketAddr {
363 #[inline]
364 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
365 match self.ip() {
367 IpAddr::V4(ipv4) => {
368 0u8.write_le(&mut writer)?;
369 u32::from(ipv4).write_le(&mut writer)?;
370 }
371 IpAddr::V6(ipv6) => {
372 1u8.write_le(&mut writer)?;
373 u128::from(ipv6).write_le(&mut writer)?;
374 }
375 }
376 self.port().write_le(&mut writer)?;
378 Ok(())
379 }
380}
381
382impl FromBytes for SocketAddr {
383 #[inline]
384 fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
385 let ip = match u8::read_le(&mut reader)? {
387 0 => IpAddr::V4(Ipv4Addr::from(u32::read_le(&mut reader)?)),
388 1 => IpAddr::V6(Ipv6Addr::from(u128::read_le(&mut reader)?)),
389 _ => return Err(std::io::Error::other("Invalid IP address")),
390 };
391 let port = u16::read_le(&mut reader)?;
393 Ok(SocketAddr::new(ip, port))
394 }
395}
396
397macro_rules! impl_bytes_for_integer {
398 ($int:ty) => {
399 impl ToBytes for $int {
400 #[inline]
401 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
402 writer.write_all(&self.to_le_bytes())
403 }
404 }
405
406 impl FromBytes for $int {
407 #[inline]
408 fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
409 let mut bytes = [0u8; core::mem::size_of::<$int>()];
410 reader.read_exact(&mut bytes)?;
411 Ok(<$int>::from_le_bytes(bytes))
412 }
413 }
414 };
415}
416
417impl_bytes_for_integer!(u8);
418impl_bytes_for_integer!(u16);
419impl_bytes_for_integer!(u32);
420impl_bytes_for_integer!(u64);
421impl_bytes_for_integer!(u128);
422
423impl_bytes_for_integer!(i8);
424impl_bytes_for_integer!(i16);
425impl_bytes_for_integer!(i32);
426impl_bytes_for_integer!(i64);
427impl_bytes_for_integer!(i128);
428
429impl<const N: usize> ToBytes for [u8; N] {
430 #[inline]
431 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
432 writer.write_all(self)
433 }
434}
435
436impl<const N: usize> FromBytes for [u8; N] {
437 #[inline]
438 fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
439 let mut arr = [0u8; N];
440 reader.read_exact(&mut arr)?;
441 Ok(arr)
442 }
443}
444
445macro_rules! impl_bytes_for_integer_array {
446 ($int:ty) => {
447 impl<const N: usize> ToBytes for [$int; N] {
448 #[inline]
449 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
450 for num in self {
451 writer.write_all(&num.to_le_bytes())?;
452 }
453 Ok(())
454 }
455 }
456
457 impl<const N: usize> FromBytes for [$int; N] {
458 #[inline]
459 fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
460 let mut res: [$int; N] = [0; N];
461 for num in res.iter_mut() {
462 let mut bytes = [0u8; core::mem::size_of::<$int>()];
463 reader.read_exact(&mut bytes)?;
464 *num = <$int>::from_le_bytes(bytes);
465 }
466 Ok(res)
467 }
468 }
469 };
470}
471
472impl_bytes_for_integer_array!(u16);
474impl_bytes_for_integer_array!(u32);
475impl_bytes_for_integer_array!(u64);
476
477impl<L: ToBytes, R: ToBytes> ToBytes for (L, R) {
478 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
479 self.0.write_le(&mut writer)?;
480 self.1.write_le(&mut writer)?;
481 Ok(())
482 }
483}
484
485impl<L: FromBytes, R: FromBytes> FromBytes for (L, R) {
486 #[inline]
487 fn read_le<Reader: Read>(mut reader: Reader) -> IoResult<Self> {
488 let left: L = FromBytes::read_le(&mut reader)?;
489 let right: R = FromBytes::read_le(&mut reader)?;
490 Ok((left, right))
491 }
492}
493
494impl<T: ToBytes> ToBytes for Vec<T> {
495 #[inline]
496 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
497 for item in self {
498 item.write_le(&mut writer)?;
499 }
500 Ok(())
501 }
502}
503
504impl<'a, T: 'a + ToBytes> ToBytes for &'a [T] {
505 #[inline]
506 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
507 for item in *self {
508 item.write_le(&mut writer)?;
509 }
510 Ok(())
511 }
512}
513
514impl<'a, T: 'a + ToBytes> ToBytes for &'a T {
515 #[inline]
516 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
517 (*self).write_le(&mut writer)
518 }
519}
520
521#[inline]
522pub fn bits_from_bytes_le(bytes: &[u8]) -> impl DoubleEndedIterator<Item = bool> + '_ {
523 bytes.iter().flat_map(|byte| (0..8).map(move |i| (*byte >> i) & 1 == 1))
524}
525
526#[inline]
527pub fn bytes_from_bits_le(bits: &[bool]) -> Vec<u8> {
528 let desired_size = if bits.len() % 8 == 0 { bits.len() / 8 } else { bits.len() / 8 + 1 };
529
530 let mut bytes = Vec::with_capacity(desired_size);
531 for bits in bits.chunks(8) {
532 let mut result = 0u8;
533 for (i, bit) in bits.iter().enumerate() {
534 let bit_value = *bit as u8;
535 result += bit_value << i as u8;
536 }
537
538 bytes.push(result);
539 }
540
541 bytes
542}
543
544pub struct LimitedWriter<W: Write> {
546 writer: W,
547 limit: usize,
548 remaining: usize,
549}
550
551impl<W: Write> LimitedWriter<W> {
552 pub fn new(writer: W, limit: usize) -> Self {
553 Self { writer, limit, remaining: limit }
554 }
555}
556
557impl<W: Write> Write for LimitedWriter<W> {
558 fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
559 if self.remaining == 0 && !buf.is_empty() {
560 return Err(std::io::Error::other(format!("Byte limit exceeded: {}", self.limit)));
561 }
562
563 let max_write = std::cmp::min(buf.len(), self.remaining);
564 match self.writer.write(&buf[..max_write]) {
565 Ok(n) => {
566 self.remaining -= n;
567 Ok(n)
568 }
569 Err(e) => Err(e),
570 }
571 }
572
573 fn flush(&mut self) -> IoResult<()> {
574 self.writer.flush()
575 }
576}
577
578#[cfg(test)]
579mod test {
580 use super::*;
581 use crate::TestRng;
582
583 use rand::Rng;
584
585 const ITERATIONS: usize = 1000;
586
587 #[test]
588 fn test_macro_empty() {
589 let array: Vec<u8> = vec![];
590 let bytes_a: Vec<u8> = to_bytes_le![array].unwrap();
591 assert_eq!(&array, &bytes_a);
592 assert_eq!(0, bytes_a.len());
593
594 let bytes_b: Vec<u8> = array.to_bytes_le().unwrap();
595 assert_eq!(&array, &bytes_b);
596 assert_eq!(0, bytes_b.len());
597 }
598
599 #[test]
600 fn test_macro() {
601 let array1 = [1u8; 32];
602 let array2 = [2u8; 16];
603 let array3 = [3u8; 8];
604 let bytes = to_bytes_le![array1, array2, array3].unwrap();
605 assert_eq!(bytes.len(), 56);
606
607 let mut actual_bytes = Vec::new();
608 actual_bytes.extend_from_slice(&array1);
609 actual_bytes.extend_from_slice(&array2);
610 actual_bytes.extend_from_slice(&array3);
611 assert_eq!(bytes, actual_bytes);
612 }
613
614 #[test]
615 fn test_bits_from_bytes_le() {
616 assert_eq!(bits_from_bytes_le(&[204, 76]).collect::<Vec<bool>>(), [
617 false, false, true, true, false, false, true, true, false, false, true, true, false, false, true, false, ]);
620 }
621
622 #[test]
623 fn test_bytes_from_bits_le() {
624 let bits = [
625 false, false, true, true, false, false, true, true, false, false, true, true, false, false, true, false, ];
628 assert_eq!(bytes_from_bits_le(&bits), [204, 76]);
629 }
630
631 #[test]
632 fn test_from_bits_le_to_bytes_le_roundtrip() {
633 let mut rng = TestRng::default();
634
635 for _ in 0..ITERATIONS {
636 let given_bytes: [u8; 32] = rng.r#gen();
637
638 let bits = bits_from_bytes_le(&given_bytes).collect::<Vec<_>>();
639 let recovered_bytes = bytes_from_bits_le(&bits);
640
641 assert_eq!(given_bytes.to_vec(), recovered_bytes);
642 }
643 }
644
645 #[test]
646 fn test_socketaddr_bytes() {
647 fn random_ipv4_address(rng: &mut TestRng) -> Ipv4Addr {
648 Ipv4Addr::new(rng.r#gen(), rng.r#gen(), rng.r#gen(), rng.r#gen())
649 }
650
651 fn random_ipv6_address(rng: &mut TestRng) -> Ipv6Addr {
652 Ipv6Addr::new(
653 rng.r#gen(),
654 rng.r#gen(),
655 rng.r#gen(),
656 rng.r#gen(),
657 rng.r#gen(),
658 rng.r#gen(),
659 rng.r#gen(),
660 rng.r#gen(),
661 )
662 }
663
664 fn random_port(rng: &mut TestRng) -> u16 {
665 rng.gen_range(1025..=65535) }
667
668 let rng = &mut TestRng::default();
669
670 for _ in 0..1_000_000 {
671 let ipv4 = SocketAddr::new(IpAddr::V4(random_ipv4_address(rng)), random_port(rng));
672 let bytes = ipv4.to_bytes_le().unwrap();
673 let ipv4_2 = SocketAddr::read_le(&bytes[..]).unwrap();
674 assert_eq!(ipv4, ipv4_2);
675
676 let ipv6 = SocketAddr::new(IpAddr::V6(random_ipv6_address(rng)), random_port(rng));
677 let bytes = ipv6.to_bytes_le().unwrap();
678 let ipv6_2 = SocketAddr::read_le(&bytes[..]).unwrap();
679 assert_eq!(ipv6, ipv6_2);
680 }
681 }
682}