1use std::{
17 fmt,
18 io::{Read, Result as IoResult, Write},
19 marker::PhantomData,
20 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
21};
22
23use serde::{
24 Deserializer,
25 Serializer,
26 de::{self, Error, SeqAccess, Visitor},
27 ser::{self, SerializeTuple},
28};
29use smol_str::SmolStr;
30
31use crate::error;
32
33#[macro_export]
36macro_rules! to_bytes_le {
37 ($($x:expr),*) => ({
38 let mut buffer = vec![];
39 buffer.reserve(64);
40 {$crate::push_bytes_to_vec!(buffer, $($x),*)}.map(|_| buffer)
41 });
42}
43
44#[macro_export]
45macro_rules! push_bytes_to_vec {
46 ($buffer:expr, $y:expr, $($x:expr),*) => ({
47 {ToBytes::write_le(&$y, &mut $buffer)}.and({$crate::push_bytes_to_vec!($buffer, $($x),*)})
48 });
49
50 ($buffer:expr, $x:expr) => ({
51 ToBytes::write_le(&$x, &mut $buffer)
52 })
53}
54
55pub trait ToBytes {
56 fn write_le<W: Write>(&self, writer: W) -> IoResult<()>
58 where
59 Self: Sized;
60
61 fn to_bytes_le(&self) -> anyhow::Result<Vec<u8>>
63 where
64 Self: Sized,
65 {
66 Ok(to_bytes_le![self]?)
67 }
68}
69
70pub trait FromBytes {
71 fn read_le<R: Read>(reader: R) -> IoResult<Self>
73 where
74 Self: Sized;
75
76 fn from_bytes_le(bytes: &[u8]) -> anyhow::Result<Self>
78 where
79 Self: Sized,
80 {
81 Ok(Self::read_le(bytes)?)
82 }
83}
84
85pub struct ToBytesSerializer<T: ToBytes>(PhantomData<T>);
86
87impl<T: ToBytes> ToBytesSerializer<T> {
88 pub fn serialize<S: Serializer>(object: &T, serializer: S) -> Result<S::Ok, S::Error> {
90 let bytes = object.to_bytes_le().map_err(ser::Error::custom)?;
91 let mut tuple = serializer.serialize_tuple(bytes.len())?;
92 for byte in &bytes {
93 tuple.serialize_element(byte)?;
94 }
95 tuple.end()
96 }
97
98 pub fn serialize_with_size_encoding<S: Serializer>(object: &T, serializer: S) -> Result<S::Ok, S::Error> {
100 let bytes = object.to_bytes_le().map_err(ser::Error::custom)?;
101 serializer.serialize_bytes(&bytes)
102 }
103}
104
105pub struct FromBytesDeserializer<T: FromBytes>(PhantomData<T>);
106
107impl<'de, T: FromBytes> FromBytesDeserializer<T> {
108 pub fn deserialize<D: Deserializer<'de>>(deserializer: D, name: &str, size: usize) -> Result<T, D::Error> {
112 let mut buffer = Vec::with_capacity(size);
113 deserializer.deserialize_tuple(size, FromBytesVisitor::new(&mut buffer, name))?;
114 FromBytes::read_le(&*buffer).map_err(de::Error::custom)
115 }
116
117 pub fn deserialize_with_u8<D: Deserializer<'de>>(deserializer: D, name: &str) -> Result<T, D::Error> {
119 deserializer.deserialize_tuple(1usize << 8usize, FromBytesWithU8Visitor::<T>::new(name))
120 }
121
122 pub fn deserialize_with_u16<D: Deserializer<'de>>(deserializer: D, name: &str) -> Result<T, D::Error> {
124 deserializer.deserialize_tuple(1usize << 16usize, FromBytesWithU16Visitor::<T>::new(name))
125 }
126
127 pub fn deserialize_with_size_encoding<D: Deserializer<'de>>(deserializer: D, name: &str) -> Result<T, D::Error> {
129 let mut buffer = Vec::with_capacity(32);
130 deserializer.deserialize_bytes(FromBytesVisitor::new(&mut buffer, name))?;
131 FromBytes::read_le(&*buffer).map_err(de::Error::custom)
132 }
133
134 pub fn deserialize_extended<D: Deserializer<'de>>(
139 deserializer: D,
140 name: &str,
141 size_a: usize,
142 size_b: usize,
143 ) -> Result<T, D::Error> {
144 let (size_a, size_b) = match size_a < size_b {
146 true => (size_a, size_b),
147 false => (size_b, size_a),
148 };
149
150 if size_b > i32::MAX as usize {
152 return Err(D::Error::custom(format!("size_b ({size_b}) exceeds maximum")));
153 }
154
155 let mut buffer = Vec::with_capacity(size_b);
157
158 match deserializer.deserialize_tuple(size_b, FromBytesVisitor::new(&mut buffer, name)) {
160 Ok(()) => FromBytes::read_le(&buffer[..size_b]).map_err(de::Error::custom),
162 Err(error) => match buffer.len() == size_a {
164 true => FromBytes::read_le(&buffer[..size_a]).map_err(de::Error::custom),
165 false => Err(error),
166 },
167 }
168 }
169}
170
171pub struct FromBytesVisitor<'a>(&'a mut Vec<u8>, SmolStr);
172
173impl<'a> FromBytesVisitor<'a> {
174 pub fn new(buffer: &'a mut Vec<u8>, name: &str) -> Self {
176 Self(buffer, SmolStr::new(name))
177 }
178}
179
180impl<'de> Visitor<'de> for FromBytesVisitor<'_> {
181 type Value = ();
182
183 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
184 formatter.write_str(&format!("a valid {} ", self.1))
185 }
186
187 fn visit_borrowed_bytes<E: serde::de::Error>(self, bytes: &'de [u8]) -> Result<Self::Value, E> {
188 self.0.extend_from_slice(bytes);
189 Ok(())
190 }
191
192 fn visit_seq<S: SeqAccess<'de>>(self, mut seq: S) -> Result<Self::Value, S::Error> {
193 while let Some(byte) = seq.next_element()? {
194 self.0.push(byte);
195 }
196 Ok(())
197 }
198}
199
200struct FromBytesWithU8Visitor<T: FromBytes>(String, PhantomData<T>);
201
202impl<T: FromBytes> FromBytesWithU8Visitor<T> {
203 pub fn new(name: &str) -> Self {
205 Self(name.to_string(), PhantomData)
206 }
207}
208
209impl<'de, T: FromBytes> Visitor<'de> for FromBytesWithU8Visitor<T> {
210 type Value = T;
211
212 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
213 formatter.write_str(&format!("a valid {} ", self.0))
214 }
215
216 fn visit_seq<V: SeqAccess<'de>>(self, mut seq: V) -> Result<Self::Value, V::Error> {
217 let length: u8 = seq.next_element()?.ok_or_else(|| Error::invalid_length(0, &self))?;
219
220 let mut bytes: Vec<u8> = Vec::with_capacity((length as usize) + 1);
222 bytes.push(length);
224 for i in 0..length {
226 bytes.push(seq.next_element()?.ok_or_else(|| Error::invalid_length((i as usize) + 1, &self))?);
228 }
229 FromBytes::read_le(&*bytes).map_err(de::Error::custom)
231 }
232}
233
234struct FromBytesWithU16Visitor<T: FromBytes>(String, PhantomData<T>);
235
236impl<T: FromBytes> FromBytesWithU16Visitor<T> {
237 pub fn new(name: &str) -> Self {
239 Self(name.to_string(), PhantomData)
240 }
241}
242
243impl<'de, T: FromBytes> Visitor<'de> for FromBytesWithU16Visitor<T> {
244 type Value = T;
245
246 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
247 formatter.write_str(&format!("a valid {} ", self.0))
248 }
249
250 fn visit_seq<V: SeqAccess<'de>>(self, mut seq: V) -> Result<Self::Value, V::Error> {
251 let length: u16 = seq.next_element()?.ok_or_else(|| Error::invalid_length(0, &self))?;
253
254 let mut bytes: Vec<u8> = Vec::with_capacity((length as usize) + 2);
256 bytes.extend(length.to_le_bytes());
258 for i in 0..length {
260 bytes.push(seq.next_element()?.ok_or_else(|| Error::invalid_length((i as usize) + 2, &self))?);
262 }
263 FromBytes::read_le(&*bytes).map_err(de::Error::custom)
265 }
266}
267
268impl ToBytes for () {
269 #[inline]
270 fn write_le<W: Write>(&self, _writer: W) -> IoResult<()> {
271 Ok(())
272 }
273}
274
275impl FromBytes for () {
276 #[inline]
277 fn read_le<R: Read>(_bytes: R) -> IoResult<Self> {
278 Ok(())
279 }
280}
281
282impl ToBytes for bool {
283 #[inline]
284 fn write_le<W: Write>(&self, writer: W) -> IoResult<()> {
285 u8::write_le(&(*self as u8), writer)
286 }
287}
288
289impl FromBytes for bool {
290 #[inline]
291 fn read_le<R: Read>(reader: R) -> IoResult<Self> {
292 match u8::read_le(reader) {
293 Ok(0) => Ok(false),
294 Ok(1) => Ok(true),
295 Ok(_) => Err(error("FromBytes::read failed")),
296 Err(err) => Err(err),
297 }
298 }
299}
300
301impl ToBytes for SocketAddr {
302 #[inline]
303 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
304 match self.ip() {
306 IpAddr::V4(ipv4) => {
307 0u8.write_le(&mut writer)?;
308 u32::from(ipv4).write_le(&mut writer)?;
309 }
310 IpAddr::V6(ipv6) => {
311 1u8.write_le(&mut writer)?;
312 u128::from(ipv6).write_le(&mut writer)?;
313 }
314 }
315 self.port().write_le(&mut writer)?;
317 Ok(())
318 }
319}
320
321impl FromBytes for SocketAddr {
322 #[inline]
323 fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
324 let ip = match u8::read_le(&mut reader)? {
326 0 => IpAddr::V4(Ipv4Addr::from(u32::read_le(&mut reader)?)),
327 1 => IpAddr::V6(Ipv6Addr::from(u128::read_le(&mut reader)?)),
328 _ => return Err(error("Invalid IP address")),
329 };
330 let port = u16::read_le(&mut reader)?;
332 Ok(SocketAddr::new(ip, port))
333 }
334}
335
336macro_rules! impl_bytes_for_integer {
337 ($int:ty) => {
338 impl ToBytes for $int {
339 #[inline]
340 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
341 writer.write_all(&self.to_le_bytes())
342 }
343 }
344
345 impl FromBytes for $int {
346 #[inline]
347 fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
348 let mut bytes = [0u8; core::mem::size_of::<$int>()];
349 reader.read_exact(&mut bytes)?;
350 Ok(<$int>::from_le_bytes(bytes))
351 }
352 }
353 };
354}
355
356impl_bytes_for_integer!(u8);
357impl_bytes_for_integer!(u16);
358impl_bytes_for_integer!(u32);
359impl_bytes_for_integer!(u64);
360impl_bytes_for_integer!(u128);
361
362impl_bytes_for_integer!(i8);
363impl_bytes_for_integer!(i16);
364impl_bytes_for_integer!(i32);
365impl_bytes_for_integer!(i64);
366impl_bytes_for_integer!(i128);
367
368impl<const N: usize> ToBytes for [u8; N] {
369 #[inline]
370 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
371 writer.write_all(self)
372 }
373}
374
375impl<const N: usize> FromBytes for [u8; N] {
376 #[inline]
377 fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
378 let mut arr = [0u8; N];
379 reader.read_exact(&mut arr)?;
380 Ok(arr)
381 }
382}
383
384macro_rules! impl_bytes_for_integer_array {
385 ($int:ty) => {
386 impl<const N: usize> ToBytes for [$int; N] {
387 #[inline]
388 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
389 for num in self {
390 writer.write_all(&num.to_le_bytes())?;
391 }
392 Ok(())
393 }
394 }
395
396 impl<const N: usize> FromBytes for [$int; N] {
397 #[inline]
398 fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
399 let mut res: [$int; N] = [0; N];
400 for num in res.iter_mut() {
401 let mut bytes = [0u8; core::mem::size_of::<$int>()];
402 reader.read_exact(&mut bytes)?;
403 *num = <$int>::from_le_bytes(bytes);
404 }
405 Ok(res)
406 }
407 }
408 };
409}
410
411impl_bytes_for_integer_array!(u16);
413impl_bytes_for_integer_array!(u32);
414impl_bytes_for_integer_array!(u64);
415
416impl<L: ToBytes, R: ToBytes> ToBytes for (L, R) {
417 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
418 self.0.write_le(&mut writer)?;
419 self.1.write_le(&mut writer)?;
420 Ok(())
421 }
422}
423
424impl<L: FromBytes, R: FromBytes> FromBytes for (L, R) {
425 #[inline]
426 fn read_le<Reader: Read>(mut reader: Reader) -> IoResult<Self> {
427 let left: L = FromBytes::read_le(&mut reader)?;
428 let right: R = FromBytes::read_le(&mut reader)?;
429 Ok((left, right))
430 }
431}
432
433impl<T: ToBytes> ToBytes for Vec<T> {
434 #[inline]
435 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
436 for item in self {
437 item.write_le(&mut writer)?;
438 }
439 Ok(())
440 }
441}
442
443impl<'a, T: 'a + ToBytes> ToBytes for &'a [T] {
444 #[inline]
445 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
446 for item in *self {
447 item.write_le(&mut writer)?;
448 }
449 Ok(())
450 }
451}
452
453impl<'a, T: 'a + ToBytes> ToBytes for &'a T {
454 #[inline]
455 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
456 (*self).write_le(&mut writer)
457 }
458}
459
460#[inline]
461pub fn bits_from_bytes_le(bytes: &[u8]) -> impl DoubleEndedIterator<Item = bool> + '_ {
462 bytes.iter().flat_map(|byte| (0..8).map(move |i| (*byte >> i) & 1 == 1))
463}
464
465#[inline]
466pub fn bytes_from_bits_le(bits: &[bool]) -> Vec<u8> {
467 let desired_size = if bits.len() % 8 == 0 { bits.len() / 8 } else { bits.len() / 8 + 1 };
468
469 let mut bytes = Vec::with_capacity(desired_size);
470 for bits in bits.chunks(8) {
471 let mut result = 0u8;
472 for (i, bit) in bits.iter().enumerate() {
473 let bit_value = *bit as u8;
474 result += bit_value << i as u8;
475 }
476
477 bytes.push(result);
478 }
479
480 bytes
481}
482
483pub struct LimitedWriter<W: Write> {
485 writer: W,
486 limit: usize,
487 remaining: usize,
488}
489
490impl<W: Write> LimitedWriter<W> {
491 pub fn new(writer: W, limit: usize) -> Self {
492 Self { writer, limit, remaining: limit }
493 }
494}
495
496impl<W: Write> Write for LimitedWriter<W> {
497 fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
498 if self.remaining == 0 && !buf.is_empty() {
499 return Err(std::io::Error::other(format!("Byte limit exceeded: {}", self.limit)));
500 }
501
502 let max_write = std::cmp::min(buf.len(), self.remaining);
503 match self.writer.write(&buf[..max_write]) {
504 Ok(n) => {
505 self.remaining -= n;
506 Ok(n)
507 }
508 Err(e) => Err(e),
509 }
510 }
511
512 fn flush(&mut self) -> IoResult<()> {
513 self.writer.flush()
514 }
515}
516
517#[cfg(test)]
518mod test {
519 use super::*;
520 use crate::TestRng;
521
522 use rand::Rng;
523
524 const ITERATIONS: usize = 1000;
525
526 #[test]
527 fn test_macro_empty() {
528 let array: Vec<u8> = vec![];
529 let bytes_a: Vec<u8> = to_bytes_le![array].unwrap();
530 assert_eq!(&array, &bytes_a);
531 assert_eq!(0, bytes_a.len());
532
533 let bytes_b: Vec<u8> = array.to_bytes_le().unwrap();
534 assert_eq!(&array, &bytes_b);
535 assert_eq!(0, bytes_b.len());
536 }
537
538 #[test]
539 fn test_macro() {
540 let array1 = [1u8; 32];
541 let array2 = [2u8; 16];
542 let array3 = [3u8; 8];
543 let bytes = to_bytes_le![array1, array2, array3].unwrap();
544 assert_eq!(bytes.len(), 56);
545
546 let mut actual_bytes = Vec::new();
547 actual_bytes.extend_from_slice(&array1);
548 actual_bytes.extend_from_slice(&array2);
549 actual_bytes.extend_from_slice(&array3);
550 assert_eq!(bytes, actual_bytes);
551 }
552
553 #[test]
554 fn test_bits_from_bytes_le() {
555 assert_eq!(bits_from_bytes_le(&[204, 76]).collect::<Vec<bool>>(), [
556 false, false, true, true, false, false, true, true, false, false, true, true, false, false, true, false, ]);
559 }
560
561 #[test]
562 fn test_bytes_from_bits_le() {
563 let bits = [
564 false, false, true, true, false, false, true, true, false, false, true, true, false, false, true, false, ];
567 assert_eq!(bytes_from_bits_le(&bits), [204, 76]);
568 }
569
570 #[test]
571 fn test_from_bits_le_to_bytes_le_roundtrip() {
572 let mut rng = TestRng::default();
573
574 for _ in 0..ITERATIONS {
575 let given_bytes: [u8; 32] = rng.r#gen();
576
577 let bits = bits_from_bytes_le(&given_bytes).collect::<Vec<_>>();
578 let recovered_bytes = bytes_from_bits_le(&bits);
579
580 assert_eq!(given_bytes.to_vec(), recovered_bytes);
581 }
582 }
583
584 #[test]
585 fn test_socketaddr_bytes() {
586 fn random_ipv4_address(rng: &mut TestRng) -> Ipv4Addr {
587 Ipv4Addr::new(rng.r#gen(), rng.r#gen(), rng.r#gen(), rng.r#gen())
588 }
589
590 fn random_ipv6_address(rng: &mut TestRng) -> Ipv6Addr {
591 Ipv6Addr::new(
592 rng.r#gen(),
593 rng.r#gen(),
594 rng.r#gen(),
595 rng.r#gen(),
596 rng.r#gen(),
597 rng.r#gen(),
598 rng.r#gen(),
599 rng.r#gen(),
600 )
601 }
602
603 fn random_port(rng: &mut TestRng) -> u16 {
604 rng.gen_range(1025..=65535) }
606
607 let rng = &mut TestRng::default();
608
609 for _ in 0..1_000_000 {
610 let ipv4 = SocketAddr::new(IpAddr::V4(random_ipv4_address(rng)), random_port(rng));
611 let bytes = ipv4.to_bytes_le().unwrap();
612 let ipv4_2 = SocketAddr::read_le(&bytes[..]).unwrap();
613 assert_eq!(ipv4, ipv4_2);
614
615 let ipv6 = SocketAddr::new(IpAddr::V6(random_ipv6_address(rng)), random_port(rng));
616 let bytes = ipv6.to_bytes_le().unwrap();
617 let ipv6_2 = SocketAddr::read_le(&bytes[..]).unwrap();
618 assert_eq!(ipv6, ipv6_2);
619 }
620 }
621}